chat_gpt_lib_rs/api_resources/
moderations.rs

1//! This module provides functionality for classifying text against OpenAI's content moderation
2//! policies using the [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
3//!
4//! The Moderations API takes text (or multiple pieces of text) and returns a set of boolean flags
5//! indicating whether the content violates certain categories (e.g., hate, self-harm, sexual), along
6//! with confidence scores for each category.
7//!
8//! # Overview
9//!
10//! You can call [`create_moderation`] with a [`CreateModerationRequest`], specifying your input text(s)
11//! (and optionally, a specific model). The response (`CreateModerationResponse`) includes a list of
12//! [`ModerationResult`] objects—one per input. Each result contains a set of [`ModerationCategories`],
13//! matching confidence scores ([`ModerationCategoryScores`]), and a `flagged` boolean indicating if the
14//! text violates policy overall.
15//!
16//! # Example
17//!
18//! ```rust
19//! use chat_gpt_lib_rs::api_resources::moderations::{create_moderation, CreateModerationRequest, ModerationsInput};
20//! use chat_gpt_lib_rs::error::OpenAIError;
21//! use chat_gpt_lib_rs::OpenAIClient;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), OpenAIError> {
25//!     // load environment variables from a .env file, if present (optional).
26//!     dotenvy::dotenv().ok();
27//!
28//!     let client = OpenAIClient::new(None)?;
29//!     let request = CreateModerationRequest {
30//!         input: ModerationsInput::String("I hate you and want to harm you.".to_string()),
31//!         model: None, // or Some("text-moderation-latest".into())
32//!     };
33//!
34//!     let response = create_moderation(&client, &request).await?;
35//!     for (i, result) in response.results.iter().enumerate() {
36//!         println!("== Result {} ==", i);
37//!         println!("Flagged: {}", result.flagged);
38//!         println!("Hate category: {}", result.categories.hate);
39//!         println!("Hate score: {}", result.category_scores.hate);
40//!         // ...and so on for other categories
41//!     }
42//!
43//!     Ok(())
44//! }
45//! ```
46
47use serde::{Deserialize, Serialize};
48
49use crate::api::post_json;
50use crate::config::OpenAIClient;
51use crate::error::OpenAIError;
52
53use super::models::Model;
54
55/// Represents the multiple ways the input can be supplied for moderations:
56///
57/// - A single string
58/// - An array of strings
59///
60/// Other forms (such as token arrays) are not commonly used for this endpoint.
61/// If you need a more advanced setup, you can adapt this or add variants as needed.
62#[derive(Debug, Serialize, Deserialize, Clone)]
63#[serde(untagged)]
64pub enum ModerationsInput {
65    /// A single string input
66    String(String),
67    /// Multiple string inputs
68    Strings(Vec<String>),
69}
70
71/// A request struct for creating a moderation check using the OpenAI Moderations API.
72///
73/// For more details, see the [API documentation](https://platform.openai.com/docs/api-reference/moderations).
74#[derive(Debug, Serialize, Clone)]
75pub struct CreateModerationRequest {
76    /// The input text(s) to classify.  
77    /// **Required** by the API.
78    pub input: ModerationsInput,
79
80    /// *Optional.* Two possible values are often used:  
81    /// - `"text-moderation-stable"`  
82    /// - `"text-moderation-latest"`  
83    ///
84    /// If omitted, the default model is used.  
85    /// See [OpenAI's docs](https://platform.openai.com/docs/api-reference/moderations) for details.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub model: Option<Model>,
88}
89
90/// The response returned by the OpenAI Moderations API.
91///
92/// Contains an identifier and a list of [`ModerationResult`] items corresponding to each input.
93#[derive(Debug, Deserialize)]
94pub struct CreateModerationResponse {
95    /// An identifier for this moderation request (e.g., "modr-xxxxxx").
96    pub id: String,
97    /// The moderation model used.
98    pub model: Model,
99    /// A list of moderation results—one per input in `CreateModerationRequest.input`.
100    pub results: Vec<ModerationResult>,
101}
102
103/// A single moderation result, indicating how the input text matches various policy categories.
104#[derive(Debug, Deserialize)]
105pub struct ModerationResult {
106    /// Boolean flags indicating which categories (hate, self-harm, sexual, etc.) are triggered.
107    pub categories: ModerationCategories,
108    /// Floating-point confidence scores for each category.
109    pub category_scores: ModerationCategoryScores,
110    /// Overall flag indicating if the content violates policy (i.e., if the text should be disallowed).
111    pub flagged: bool,
112}
113
114/// A breakdown of the moderation categories.
115///
116/// Each field corresponds to a distinct policy category recognized by OpenAI's model.
117/// If `true`, the text has been flagged under that category.
118#[derive(Debug, Deserialize)]
119pub struct ModerationCategories {
120    /// Hateful content directed towards a protected group or individual.
121    pub hate: bool,
122    #[serde(rename = "hate/threatening")]
123    /// Hateful content with threats.
124    pub hate_threatening: bool,
125    #[serde(rename = "self-harm")]
126    /// Content about self-harm or suicide.
127    pub self_harm: bool,
128    /// If `true`, the text includes sexual content or references.
129    pub sexual: bool,
130    #[serde(rename = "sexual/minors")]
131    /// If `true`, the text includes sexual content involving minors.
132    pub sexual_minors: bool,
133    /// If `true`, the text includes violent content or context.
134    pub violence: bool,
135    #[serde(rename = "violence/graphic")]
136    /// If `true`, the text includes particularly graphic or gory violence.
137    pub violence_graphic: bool,
138}
139
140/// Floating-point confidence scores for each moderated category.
141///
142/// Higher values indicate higher model confidence that the content falls under that category.
143#[derive(Debug, Deserialize)]
144pub struct ModerationCategoryScores {
145    /// The confidence score for hateful content.
146    pub hate: f64,
147    #[serde(rename = "hate/threatening")]
148    /// The confidence score for hateful content that includes threats.
149    pub hate_threatening: f64,
150    #[serde(rename = "self-harm")]
151    /// The confidence score for self-harm or suicidal content.
152    pub self_harm: f64,
153    /// The confidence score for sexual content or references.
154    pub sexual: f64,
155    #[serde(rename = "sexual/minors")]
156    /// The confidence score for sexual content involving minors.
157    pub sexual_minors: f64,
158    /// The confidence score for violent content or context.
159    pub violence: f64,
160    #[serde(rename = "violence/graphic")]
161    /// The confidence score for particularly graphic or gory violence.
162    pub violence_graphic: f64,
163}
164
165/// Creates a moderation request using the [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
166///
167/// # Parameters
168///
169/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
170/// * `request` - A [`CreateModerationRequest`] specifying the input text(s) and an optional model.
171///
172/// # Returns
173///
174/// A [`CreateModerationResponse`] containing moderation results for each input.
175///
176/// # Errors
177///
178/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
179/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
180/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g. invalid request).
181pub async fn create_moderation(
182    client: &OpenAIClient,
183    request: &CreateModerationRequest,
184) -> Result<CreateModerationResponse, OpenAIError> {
185    // POST /v1/moderations
186    let endpoint = "moderations";
187    post_json(client, endpoint, request).await
188}
189
190#[cfg(test)]
191mod tests {
192    /// # Tests for the `moderations` module
193    ///
194    /// These tests use [`wiremock`](https://crates.io/crates/wiremock) to simulate the
195    /// [OpenAI Moderations API](https://platform.openai.com/docs/api-reference/moderations).
196    /// We specifically test [`create_moderation`] for:
197    /// 1. A **success** scenario where it returns a valid [`CreateModerationResponse`].
198    /// 2. An **API error** scenario with a non-2xx response.
199    /// 3. A **deserialization** scenario where JSON is malformed.
200    ///
201    use super::*;
202    use crate::config::OpenAIClient;
203    use crate::error::OpenAIError;
204    use serde_json::json;
205    use wiremock::matchers::{method, path};
206    use wiremock::{Mock, MockServer, ResponseTemplate};
207
208    #[tokio::test]
209    async fn test_create_moderation_success() {
210        // Start a local mock server
211        let mock_server = MockServer::start().await;
212
213        // Example success response
214        let success_body = json!({
215            "id": "modr-abc123",
216            "model": "text-moderation-latest",
217            "results": [
218                {
219                    "flagged": true,
220                    "categories": {
221                        "hate": true,
222                        "hate/threatening": false,
223                        "self-harm": false,
224                        "sexual": false,
225                        "sexual/minors": false,
226                        "violence": true,
227                        "violence/graphic": false
228                    },
229                    "category_scores": {
230                        "hate": 0.98,
231                        "hate/threatening": 0.25,
232                        "self-harm": 0.05,
233                        "sexual": 0.0,
234                        "sexual/minors": 0.0,
235                        "violence": 0.85,
236                        "violence/graphic": 0.1
237                    }
238                }
239            ]
240        });
241
242        // Mock a 200 response on /v1/moderations
243        Mock::given(method("POST"))
244            .and(path("/moderations"))
245            .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
246            .mount(&mock_server)
247            .await;
248
249        // Build a test client
250        let client = OpenAIClient::builder()
251            .with_api_key("test-key")
252            .with_base_url(&mock_server.uri())
253            .build()
254            .unwrap();
255
256        // Minimal request
257        let req = CreateModerationRequest {
258            input: ModerationsInput::String("some potentially hateful text".to_string()),
259            model: Some("text-moderation-latest".into()),
260        };
261
262        let result = create_moderation(&client, &req).await;
263        assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
264
265        let resp = result.unwrap();
266        assert_eq!(resp.id, "modr-abc123");
267        assert_eq!(resp.model, "text-moderation-latest".into());
268        assert_eq!(resp.results.len(), 1);
269
270        let first = &resp.results[0];
271        assert!(first.flagged);
272        assert!(first.categories.hate);
273        assert!(first.categories.violence);
274        assert!(!first.categories.hate_threatening);
275        assert_eq!(first.category_scores.hate, 0.98);
276        assert_eq!(first.category_scores.violence, 0.85);
277    }
278
279    #[tokio::test]
280    async fn test_create_moderation_api_error() {
281        let mock_server = MockServer::start().await;
282
283        let error_body = json!({
284            "error": {
285                "message": "Invalid model",
286                "type": "invalid_request_error",
287                "code": null
288            }
289        });
290
291        // Mock a 400 error on POST /moderations
292        Mock::given(method("POST"))
293            .and(path("/moderations"))
294            .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
295            .mount(&mock_server)
296            .await;
297
298        let client = OpenAIClient::builder()
299            .with_api_key("test-key")
300            .with_base_url(&mock_server.uri())
301            .build()
302            .unwrap();
303
304        let req = CreateModerationRequest {
305            input: ModerationsInput::Strings(vec!["test text".into()]),
306            model: Some("text-moderation-unknown".into()),
307        };
308
309        let result = create_moderation(&client, &req).await;
310        match result {
311            Err(OpenAIError::APIError { message, .. }) => {
312                assert!(message.contains("Invalid model"));
313            }
314            other => panic!("Expected APIError, got {:?}", other),
315        }
316    }
317
318    #[tokio::test]
319    async fn test_create_moderation_deserialize_error() {
320        let mock_server = MockServer::start().await;
321
322        // Return 200 with malformed JSON
323        let malformed_body = r#"{
324          "id": "modr-12345",
325          "model": "text-moderation-latest",
326          "results": "should_be_array_not_string"
327        }"#;
328
329        Mock::given(method("POST"))
330            .and(path("/moderations"))
331            .respond_with(
332                ResponseTemplate::new(200).set_body_raw(malformed_body, "application/json"),
333            )
334            .mount(&mock_server)
335            .await;
336
337        let client = OpenAIClient::builder()
338            .with_api_key("test-key")
339            .with_base_url(&mock_server.uri())
340            .build()
341            .unwrap();
342
343        let req = CreateModerationRequest {
344            input: ModerationsInput::String("Another text".to_string()),
345            model: None,
346        };
347
348        let result = create_moderation(&client, &req).await;
349        match result {
350            Err(OpenAIError::DeserializeError(_)) => {
351                // success
352            }
353            other => panic!("Expected DeserializeError, got {:?}", other),
354        }
355    }
356}