Skip to main content

openai_tools/moderations/
request.rs

1//! OpenAI Moderations API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Moderations API.
4//! It allows you to classify text inputs to determine if they violate content policies.
5//!
6//! # Key Features
7//!
8//! - **Single Text Moderation**: Check a single text string
9//! - **Batch Moderation**: Check multiple texts at once
10//! - **Model Selection**: Choose between omni-moderation and text-moderation models
11//!
12//! # Quick Start
13//!
14//! ```rust,no_run
15//! use openai_tools::moderations::request::Moderations;
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//!     let moderations = Moderations::new()?;
20//!
21//!     // Check a text for policy violations
22//!     let response = moderations.moderate_text("Hello, world!", None).await?;
23//!     if response.results[0].flagged {
24//!         println!("Content was flagged!");
25//!     } else {
26//!         println!("Content is safe");
27//!     }
28//!
29//!     Ok(())
30//! }
31//! ```
32
33use crate::common::auth::AuthProvider;
34use crate::common::client::create_http_client;
35use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
36use crate::moderations::response::ModerationResponse;
37use serde::{Deserialize, Serialize};
38use std::time::Duration;
39
40/// Default API path for Moderations
41const MODERATIONS_PATH: &str = "moderations";
42
43/// Moderation model options.
44///
45/// The model to use for content moderation. Newer omni-moderation models
46/// support more categorization options and multi-modal inputs.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
48pub enum ModerationModel {
49    /// Latest omni-moderation model with multi-modal support
50    #[serde(rename = "omni-moderation-latest")]
51    #[default]
52    OmniModerationLatest,
53    /// Legacy text-only moderation model
54    #[serde(rename = "text-moderation-latest")]
55    TextModerationLatest,
56}
57
58impl ModerationModel {
59    /// Returns the model identifier string.
60    pub fn as_str(&self) -> &'static str {
61        match self {
62            Self::OmniModerationLatest => "omni-moderation-latest",
63            Self::TextModerationLatest => "text-moderation-latest",
64        }
65    }
66}
67
68impl std::fmt::Display for ModerationModel {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(f, "{}", self.as_str())
71    }
72}
73
74/// Request payload for moderation endpoint.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76struct ModerationRequest {
77    /// The input to classify
78    input: ModerationInput,
79    /// The model to use for classification
80    #[serde(skip_serializing_if = "Option::is_none")]
81    model: Option<String>,
82}
83
84/// Input types for moderation.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(untagged)]
87enum ModerationInput {
88    /// Single text string
89    Single(String),
90    /// Multiple text strings
91    Multiple(Vec<String>),
92}
93
94/// Client for interacting with the OpenAI Moderations API.
95///
96/// This struct provides methods to classify text content for potential
97/// policy violations. Use [`Moderations::new()`] to create a new instance.
98///
99/// # Example
100///
101/// ```rust,no_run
102/// use openai_tools::moderations::request::{Moderations, ModerationModel};
103///
104/// #[tokio::main]
105/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
106///     let moderations = Moderations::new()?;
107///
108///     // Check content with a specific model
109///     let response = moderations
110///         .moderate_text("Some text to check", Some(ModerationModel::OmniModerationLatest))
111///         .await?;
112///
113///     for result in &response.results {
114///         println!("Flagged: {}", result.flagged);
115///     }
116///
117///     Ok(())
118/// }
119/// ```
120pub struct Moderations {
121    /// Authentication provider (OpenAI or Azure)
122    auth: AuthProvider,
123    /// Optional request timeout duration
124    timeout: Option<Duration>,
125}
126
127impl Moderations {
128    /// Creates a new Moderations client for OpenAI API.
129    ///
130    /// Initializes the client by loading the OpenAI API key from
131    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
132    /// via dotenvy.
133    ///
134    /// # Returns
135    ///
136    /// * `Ok(Moderations)` - A new Moderations client ready for use
137    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
138    ///
139    /// # Example
140    ///
141    /// ```rust,no_run
142    /// use openai_tools::moderations::request::Moderations;
143    ///
144    /// let moderations = Moderations::new().expect("API key should be set");
145    /// ```
146    pub fn new() -> Result<Self> {
147        let auth = AuthProvider::openai_from_env()?;
148        Ok(Self { auth, timeout: None })
149    }
150
151    /// Creates a new Moderations client with a custom authentication provider
152    pub fn with_auth(auth: AuthProvider) -> Self {
153        Self { auth, timeout: None }
154    }
155
156    /// Creates a new Moderations client for Azure OpenAI API
157    pub fn azure() -> Result<Self> {
158        let auth = AuthProvider::azure_from_env()?;
159        Ok(Self { auth, timeout: None })
160    }
161
162    /// Creates a new Moderations client by auto-detecting the provider
163    pub fn detect_provider() -> Result<Self> {
164        let auth = AuthProvider::from_env()?;
165        Ok(Self { auth, timeout: None })
166    }
167
168    /// Creates a new Moderations client with URL-based provider detection
169    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
170        let auth = AuthProvider::from_url_with_key(base_url, api_key);
171        Self { auth, timeout: None }
172    }
173
174    /// Creates a new Moderations client from URL using environment variables
175    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
176        let auth = AuthProvider::from_url(url)?;
177        Ok(Self { auth, timeout: None })
178    }
179
180    /// Returns the authentication provider
181    pub fn auth(&self) -> &AuthProvider {
182        &self.auth
183    }
184
185    /// Sets the request timeout duration.
186    ///
187    /// # Arguments
188    ///
189    /// * `timeout` - The maximum time to wait for a response
190    ///
191    /// # Returns
192    ///
193    /// A mutable reference to self for method chaining
194    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
195        self.timeout = Some(timeout);
196        self
197    }
198
199    /// Creates the HTTP client with default headers.
200    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
201        let client = create_http_client(self.timeout)?;
202        let mut headers = request::header::HeaderMap::new();
203        self.auth.apply_headers(&mut headers)?;
204        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
205        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
206        Ok((client, headers))
207    }
208
209    /// Moderates a single text string.
210    ///
211    /// Classifies the input text to determine if it violates OpenAI's content policy.
212    ///
213    /// # Arguments
214    ///
215    /// * `text` - The text content to moderate
216    /// * `model` - Optional model to use (defaults to `omni-moderation-latest`)
217    ///
218    /// # Returns
219    ///
220    /// * `Ok(ModerationResponse)` - The moderation results
221    /// * `Err(OpenAIToolError)` - If the request fails or response parsing fails
222    ///
223    /// # Example
224    ///
225    /// ```rust,no_run
226    /// use openai_tools::moderations::request::Moderations;
227    ///
228    /// #[tokio::main]
229    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
230    ///     let moderations = Moderations::new()?;
231    ///     let response = moderations.moderate_text("Hello, world!", None).await?;
232    ///
233    ///     let result = &response.results[0];
234    ///     if result.flagged {
235    ///         println!("Content was flagged!");
236    ///         println!("Hate score: {}", result.category_scores.hate);
237    ///     }
238    ///     Ok(())
239    /// }
240    /// ```
241    pub async fn moderate_text(&self, text: &str, model: Option<ModerationModel>) -> Result<ModerationResponse> {
242        let request_body = ModerationRequest { input: ModerationInput::Single(text.to_string()), model: model.map(|m| m.as_str().to_string()) };
243
244        self.send_request(&request_body).await
245    }
246
247    /// Moderates multiple text strings.
248    ///
249    /// Classifies multiple input texts in a single request.
250    ///
251    /// # Arguments
252    ///
253    /// * `texts` - Vector of text strings to moderate
254    /// * `model` - Optional model to use (defaults to `omni-moderation-latest`)
255    ///
256    /// # Returns
257    ///
258    /// * `Ok(ModerationResponse)` - The moderation results (one result per input)
259    /// * `Err(OpenAIToolError)` - If the request fails or response parsing fails
260    ///
261    /// # Example
262    ///
263    /// ```rust,no_run
264    /// use openai_tools::moderations::request::Moderations;
265    ///
266    /// #[tokio::main]
267    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
268    ///     let moderations = Moderations::new()?;
269    ///     let texts = vec![
270    ///         "First text to check".to_string(),
271    ///         "Second text to check".to_string(),
272    ///     ];
273    ///     let response = moderations.moderate_texts(texts, None).await?;
274    ///
275    ///     for (i, result) in response.results.iter().enumerate() {
276    ///         println!("Text {}: flagged = {}", i + 1, result.flagged);
277    ///     }
278    ///     Ok(())
279    /// }
280    /// ```
281    pub async fn moderate_texts(&self, texts: Vec<String>, model: Option<ModerationModel>) -> Result<ModerationResponse> {
282        let request_body = ModerationRequest { input: ModerationInput::Multiple(texts), model: model.map(|m| m.as_str().to_string()) };
283
284        self.send_request(&request_body).await
285    }
286
287    /// Sends the moderation request to the API.
288    async fn send_request(&self, request_body: &ModerationRequest) -> Result<ModerationResponse> {
289        let (client, headers) = self.create_client()?;
290
291        let body = serde_json::to_string(request_body).map_err(OpenAIToolError::SerdeJsonError)?;
292
293        let url = self.auth.endpoint(MODERATIONS_PATH);
294        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
295
296        let status = response.status();
297        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
298
299        if cfg!(test) {
300            tracing::info!("Response content: {}", content);
301        }
302
303        if !status.is_success() {
304            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
305                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
306            }
307            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
308        }
309
310        serde_json::from_str::<ModerationResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
311    }
312}