chat_gpt_lib_rs/api_resources/moderations.rs
1//! This module provides functionality for classifying text against OpenAI's content moderation
2//! policies using the [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
3//!
4//! The Moderations API takes text (or multiple pieces of text) and returns a set of boolean flags
5//! indicating whether the content violates certain categories (e.g., hate, self-harm, sexual), along
6//! with confidence scores for each category.
7//!
8//! # Overview
9//!
10//! You can call [`create_moderation`] with a [`CreateModerationRequest`], specifying your input text(s)
11//! (and optionally, a specific model). The response (`CreateModerationResponse`) includes a list of
12//! [`ModerationResult`] objects—one per input. Each result contains a set of [`ModerationCategories`],
13//! matching confidence scores ([`ModerationCategoryScores`]), and a `flagged` boolean indicating if the
14//! text violates policy overall.
15//!
16//! # Example
17//!
18//! ```rust
19//! use chat_gpt_lib_rs::api_resources::moderations::{create_moderation, CreateModerationRequest, ModerationsInput};
20//! use chat_gpt_lib_rs::error::OpenAIError;
21//! use chat_gpt_lib_rs::OpenAIClient;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), OpenAIError> {
25//! // load environment variables from a .env file, if present (optional).
26//! dotenvy::dotenv().ok();
27//!
28//! let client = OpenAIClient::new(None)?;
29//! let request = CreateModerationRequest {
30//! input: ModerationsInput::String("I hate you and want to harm you.".to_string()),
31//! model: None, // or Some("text-moderation-latest".into())
32//! };
33//!
34//! let response = create_moderation(&client, &request).await?;
35//! for (i, result) in response.results.iter().enumerate() {
36//! println!("== Result {} ==", i);
37//! println!("Flagged: {}", result.flagged);
38//! println!("Hate category: {}", result.categories.hate);
39//! println!("Hate score: {}", result.category_scores.hate);
40//! // ...and so on for other categories
41//! }
42//!
43//! Ok(())
44//! }
45//! ```
46
47use serde::{Deserialize, Serialize};
48
49use crate::api::post_json;
50use crate::config::OpenAIClient;
51use crate::error::OpenAIError;
52
53use super::models::Model;
54
55/// Represents the multiple ways the input can be supplied for moderations:
56///
57/// - A single string
58/// - An array of strings
59///
60/// Other forms (such as token arrays) are not commonly used for this endpoint.
61/// If you need a more advanced setup, you can adapt this or add variants as needed.
62#[derive(Debug, Serialize, Deserialize, Clone)]
63#[serde(untagged)]
64pub enum ModerationsInput {
65 /// A single string input
66 String(String),
67 /// Multiple string inputs
68 Strings(Vec<String>),
69}
70
71/// A request struct for creating a moderation check using the OpenAI Moderations API.
72///
73/// For more details, see the [API documentation](https://platform.openai.com/docs/api-reference/moderations).
74#[derive(Debug, Serialize, Clone)]
75pub struct CreateModerationRequest {
76 /// The input text(s) to classify.
77 /// **Required** by the API.
78 pub input: ModerationsInput,
79
80 /// *Optional.* Two possible values are often used:
81 /// - `"text-moderation-stable"`
82 /// - `"text-moderation-latest"`
83 ///
84 /// If omitted, the default model is used.
85 /// See [OpenAI's docs](https://platform.openai.com/docs/api-reference/moderations) for details.
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub model: Option<Model>,
88}
89
90/// The response returned by the OpenAI Moderations API.
91///
92/// Contains an identifier and a list of [`ModerationResult`] items corresponding to each input.
93#[derive(Debug, Deserialize)]
94pub struct CreateModerationResponse {
95 /// An identifier for this moderation request (e.g., "modr-xxxxxx").
96 pub id: String,
97 /// The moderation model used.
98 pub model: Model,
99 /// A list of moderation results—one per input in `CreateModerationRequest.input`.
100 pub results: Vec<ModerationResult>,
101}
102
103/// A single moderation result, indicating how the input text matches various policy categories.
104#[derive(Debug, Deserialize)]
105pub struct ModerationResult {
106 /// Boolean flags indicating which categories (hate, self-harm, sexual, etc.) are triggered.
107 pub categories: ModerationCategories,
108 /// Floating-point confidence scores for each category.
109 pub category_scores: ModerationCategoryScores,
110 /// Overall flag indicating if the content violates policy (i.e., if the text should be disallowed).
111 pub flagged: bool,
112}
113
114/// A breakdown of the moderation categories.
115///
116/// Each field corresponds to a distinct policy category recognized by OpenAI's model.
117/// If `true`, the text has been flagged under that category.
118#[derive(Debug, Deserialize)]
119pub struct ModerationCategories {
120 /// Hateful content directed towards a protected group or individual.
121 pub hate: bool,
122 #[serde(rename = "hate/threatening")]
123 /// Hateful content with threats.
124 pub hate_threatening: bool,
125 #[serde(rename = "self-harm")]
126 /// Content about self-harm or suicide.
127 pub self_harm: bool,
128 /// If `true`, the text includes sexual content or references.
129 pub sexual: bool,
130 #[serde(rename = "sexual/minors")]
131 /// If `true`, the text includes sexual content involving minors.
132 pub sexual_minors: bool,
133 /// If `true`, the text includes violent content or context.
134 pub violence: bool,
135 #[serde(rename = "violence/graphic")]
136 /// If `true`, the text includes particularly graphic or gory violence.
137 pub violence_graphic: bool,
138}
139
140/// Floating-point confidence scores for each moderated category.
141///
142/// Higher values indicate higher model confidence that the content falls under that category.
143#[derive(Debug, Deserialize)]
144pub struct ModerationCategoryScores {
145 /// The confidence score for hateful content.
146 pub hate: f64,
147 #[serde(rename = "hate/threatening")]
148 /// The confidence score for hateful content that includes threats.
149 pub hate_threatening: f64,
150 #[serde(rename = "self-harm")]
151 /// The confidence score for self-harm or suicidal content.
152 pub self_harm: f64,
153 /// The confidence score for sexual content or references.
154 pub sexual: f64,
155 #[serde(rename = "sexual/minors")]
156 /// The confidence score for sexual content involving minors.
157 pub sexual_minors: f64,
158 /// The confidence score for violent content or context.
159 pub violence: f64,
160 #[serde(rename = "violence/graphic")]
161 /// The confidence score for particularly graphic or gory violence.
162 pub violence_graphic: f64,
163}
164
165/// Creates a moderation request using the [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
166///
167/// # Parameters
168///
169/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
170/// * `request` - A [`CreateModerationRequest`] specifying the input text(s) and an optional model.
171///
172/// # Returns
173///
174/// A [`CreateModerationResponse`] containing moderation results for each input.
175///
176/// # Errors
177///
178/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
179/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
180/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g. invalid request).
181pub async fn create_moderation(
182 client: &OpenAIClient,
183 request: &CreateModerationRequest,
184) -> Result<CreateModerationResponse, OpenAIError> {
185 // POST /v1/moderations
186 let endpoint = "moderations";
187 post_json(client, endpoint, request).await
188}
189
190#[cfg(test)]
191mod tests {
192 /// # Tests for the `moderations` module
193 ///
194 /// These tests use [`wiremock`](https://crates.io/crates/wiremock) to simulate the
195 /// [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
196 /// We specifically test [`create_moderation`] for:
197 /// 1. A **success** scenario where it returns a valid [`CreateModerationResponse`].
198 /// 2. An **API error** scenario with a non-2xx response.
199 /// 3. A **deserialization** scenario where JSON is malformed.
200 ///
201 use super::*;
202 use crate::config::OpenAIClient;
203 use crate::error::OpenAIError;
204 use serde_json::json;
205 use wiremock::matchers::{method, path};
206 use wiremock::{Mock, MockServer, ResponseTemplate};
207
208 #[tokio::test]
209 async fn test_create_moderation_success() {
210 // Start a local mock server
211 let mock_server = MockServer::start().await;
212
213 // Example success response
214 let success_body = json!({
215 "id": "modr-abc123",
216 "model": "text-moderation-latest",
217 "results": [
218 {
219 "flagged": true,
220 "categories": {
221 "hate": true,
222 "hate/threatening": false,
223 "self-harm": false,
224 "sexual": false,
225 "sexual/minors": false,
226 "violence": true,
227 "violence/graphic": false
228 },
229 "category_scores": {
230 "hate": 0.98,
231 "hate/threatening": 0.25,
232 "self-harm": 0.05,
233 "sexual": 0.0,
234 "sexual/minors": 0.0,
235 "violence": 0.85,
236 "violence/graphic": 0.1
237 }
238 }
239 ]
240 });
241
242 // Mock a 200 response on /v1/moderations
243 Mock::given(method("POST"))
244 .and(path("/moderations"))
245 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
246 .mount(&mock_server)
247 .await;
248
249 // Build a test client
250 let client = OpenAIClient::builder()
251 .with_api_key("test-key")
252 .with_base_url(&mock_server.uri())
253 .build()
254 .unwrap();
255
256 // Minimal request
257 let req = CreateModerationRequest {
258 input: ModerationsInput::String("some potentially hateful text".to_string()),
259 model: Some("text-moderation-latest".into()),
260 };
261
262 let result = create_moderation(&client, &req).await;
263 assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
264
265 let resp = result.unwrap();
266 assert_eq!(resp.id, "modr-abc123");
267 assert_eq!(resp.model, "text-moderation-latest".into());
268 assert_eq!(resp.results.len(), 1);
269
270 let first = &resp.results[0];
271 assert!(first.flagged);
272 assert!(first.categories.hate);
273 assert!(first.categories.violence);
274 assert!(!first.categories.hate_threatening);
275 assert_eq!(first.category_scores.hate, 0.98);
276 assert_eq!(first.category_scores.violence, 0.85);
277 }
278
279 #[tokio::test]
280 async fn test_create_moderation_api_error() {
281 let mock_server = MockServer::start().await;
282
283 let error_body = json!({
284 "error": {
285 "message": "Invalid model",
286 "type": "invalid_request_error",
287 "code": null
288 }
289 });
290
291 // Mock a 400 error on POST /moderations
292 Mock::given(method("POST"))
293 .and(path("/moderations"))
294 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
295 .mount(&mock_server)
296 .await;
297
298 let client = OpenAIClient::builder()
299 .with_api_key("test-key")
300 .with_base_url(&mock_server.uri())
301 .build()
302 .unwrap();
303
304 let req = CreateModerationRequest {
305 input: ModerationsInput::Strings(vec!["test text".into()]),
306 model: Some("text-moderation-unknown".into()),
307 };
308
309 let result = create_moderation(&client, &req).await;
310 match result {
311 Err(OpenAIError::APIError { message, .. }) => {
312 assert!(message.contains("Invalid model"));
313 }
314 other => panic!("Expected APIError, got {:?}", other),
315 }
316 }
317
318 #[tokio::test]
319 async fn test_create_moderation_deserialize_error() {
320 let mock_server = MockServer::start().await;
321
322 // Return 200 with malformed JSON
323 let malformed_body = r#"{
324 "id": "modr-12345",
325 "model": "text-moderation-latest",
326 "results": "should_be_array_not_string"
327 }"#;
328
329 Mock::given(method("POST"))
330 .and(path("/moderations"))
331 .respond_with(
332 ResponseTemplate::new(200).set_body_raw(malformed_body, "application/json"),
333 )
334 .mount(&mock_server)
335 .await;
336
337 let client = OpenAIClient::builder()
338 .with_api_key("test-key")
339 .with_base_url(&mock_server.uri())
340 .build()
341 .unwrap();
342
343 let req = CreateModerationRequest {
344 input: ModerationsInput::String("Another text".to_string()),
345 model: None,
346 };
347
348 let result = create_moderation(&client, &req).await;
349 match result {
350 Err(OpenAIError::DeserializeError(_)) => {
351 // success
352 }
353 other => panic!("Expected DeserializeError, got {:?}", other),
354 }
355 }
356}