openai_tools/moderations/request.rs
1//! OpenAI Moderations API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Moderations API.
4//! It allows you to classify text inputs to determine if they violate content policies.
5//!
6//! # Key Features
7//!
8//! - **Single Text Moderation**: Check a single text string
9//! - **Batch Moderation**: Check multiple texts at once
10//! - **Model Selection**: Choose between omni-moderation and text-moderation models
11//!
12//! # Quick Start
13//!
14//! ```rust,no_run
15//! use openai_tools::moderations::request::Moderations;
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let moderations = Moderations::new()?;
20//!
21//! // Check a text for policy violations
22//! let response = moderations.moderate_text("Hello, world!", None).await?;
23//! if response.results[0].flagged {
24//! println!("Content was flagged!");
25//! } else {
26//! println!("Content is safe");
27//! }
28//!
29//! Ok(())
30//! }
31//! ```
32
33use crate::common::auth::AuthProvider;
34use crate::common::client::create_http_client;
35use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
36use crate::moderations::response::ModerationResponse;
37use serde::{Deserialize, Serialize};
38use std::time::Duration;
39
40/// Default API path for Moderations
41const MODERATIONS_PATH: &str = "moderations";
42
43/// Moderation model options.
44///
45/// The model to use for content moderation. Newer omni-moderation models
46/// support more categorization options and multi-modal inputs.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
48pub enum ModerationModel {
49 /// Latest omni-moderation model with multi-modal support
50 #[serde(rename = "omni-moderation-latest")]
51 #[default]
52 OmniModerationLatest,
53 /// Legacy text-only moderation model
54 #[serde(rename = "text-moderation-latest")]
55 TextModerationLatest,
56}
57
58impl ModerationModel {
59 /// Returns the model identifier string.
60 pub fn as_str(&self) -> &'static str {
61 match self {
62 Self::OmniModerationLatest => "omni-moderation-latest",
63 Self::TextModerationLatest => "text-moderation-latest",
64 }
65 }
66}
67
68impl std::fmt::Display for ModerationModel {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "{}", self.as_str())
71 }
72}
73
74/// Request payload for moderation endpoint.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76struct ModerationRequest {
77 /// The input to classify
78 input: ModerationInput,
79 /// The model to use for classification
80 #[serde(skip_serializing_if = "Option::is_none")]
81 model: Option<String>,
82}
83
84/// Input types for moderation.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(untagged)]
87enum ModerationInput {
88 /// Single text string
89 Single(String),
90 /// Multiple text strings
91 Multiple(Vec<String>),
92}
93
94/// Client for interacting with the OpenAI Moderations API.
95///
96/// This struct provides methods to classify text content for potential
97/// policy violations. Use [`Moderations::new()`] to create a new instance.
98///
99/// # Example
100///
101/// ```rust,no_run
102/// use openai_tools::moderations::request::{Moderations, ModerationModel};
103///
104/// #[tokio::main]
105/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
106/// let moderations = Moderations::new()?;
107///
108/// // Check content with a specific model
109/// let response = moderations
110/// .moderate_text("Some text to check", Some(ModerationModel::OmniModerationLatest))
111/// .await?;
112///
113/// for result in &response.results {
114/// println!("Flagged: {}", result.flagged);
115/// }
116///
117/// Ok(())
118/// }
119/// ```
120pub struct Moderations {
121 /// Authentication provider (OpenAI or Azure)
122 auth: AuthProvider,
123 /// Optional request timeout duration
124 timeout: Option<Duration>,
125}
126
127impl Moderations {
128 /// Creates a new Moderations client for OpenAI API.
129 ///
130 /// Initializes the client by loading the OpenAI API key from
131 /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
132 /// via dotenvy.
133 ///
134 /// # Returns
135 ///
136 /// * `Ok(Moderations)` - A new Moderations client ready for use
137 /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
138 ///
139 /// # Example
140 ///
141 /// ```rust,no_run
142 /// use openai_tools::moderations::request::Moderations;
143 ///
144 /// let moderations = Moderations::new().expect("API key should be set");
145 /// ```
146 pub fn new() -> Result<Self> {
147 let auth = AuthProvider::openai_from_env()?;
148 Ok(Self { auth, timeout: None })
149 }
150
151 /// Creates a new Moderations client with a custom authentication provider
152 pub fn with_auth(auth: AuthProvider) -> Self {
153 Self { auth, timeout: None }
154 }
155
156 /// Creates a new Moderations client for Azure OpenAI API
157 pub fn azure() -> Result<Self> {
158 let auth = AuthProvider::azure_from_env()?;
159 Ok(Self { auth, timeout: None })
160 }
161
162 /// Creates a new Moderations client by auto-detecting the provider
163 pub fn detect_provider() -> Result<Self> {
164 let auth = AuthProvider::from_env()?;
165 Ok(Self { auth, timeout: None })
166 }
167
168 /// Creates a new Moderations client with URL-based provider detection
169 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
170 let auth = AuthProvider::from_url_with_key(base_url, api_key);
171 Self { auth, timeout: None }
172 }
173
174 /// Creates a new Moderations client from URL using environment variables
175 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
176 let auth = AuthProvider::from_url(url)?;
177 Ok(Self { auth, timeout: None })
178 }
179
180 /// Returns the authentication provider
181 pub fn auth(&self) -> &AuthProvider {
182 &self.auth
183 }
184
185 /// Sets the request timeout duration.
186 ///
187 /// # Arguments
188 ///
189 /// * `timeout` - The maximum time to wait for a response
190 ///
191 /// # Returns
192 ///
193 /// A mutable reference to self for method chaining
194 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
195 self.timeout = Some(timeout);
196 self
197 }
198
199 /// Creates the HTTP client with default headers.
200 fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
201 let client = create_http_client(self.timeout)?;
202 let mut headers = request::header::HeaderMap::new();
203 self.auth.apply_headers(&mut headers)?;
204 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
205 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
206 Ok((client, headers))
207 }
208
209 /// Moderates a single text string.
210 ///
211 /// Classifies the input text to determine if it violates OpenAI's content policy.
212 ///
213 /// # Arguments
214 ///
215 /// * `text` - The text content to moderate
216 /// * `model` - Optional model to use (defaults to `omni-moderation-latest`)
217 ///
218 /// # Returns
219 ///
220 /// * `Ok(ModerationResponse)` - The moderation results
221 /// * `Err(OpenAIToolError)` - If the request fails or response parsing fails
222 ///
223 /// # Example
224 ///
225 /// ```rust,no_run
226 /// use openai_tools::moderations::request::Moderations;
227 ///
228 /// #[tokio::main]
229 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
230 /// let moderations = Moderations::new()?;
231 /// let response = moderations.moderate_text("Hello, world!", None).await?;
232 ///
233 /// let result = &response.results[0];
234 /// if result.flagged {
235 /// println!("Content was flagged!");
236 /// println!("Hate score: {}", result.category_scores.hate);
237 /// }
238 /// Ok(())
239 /// }
240 /// ```
241 pub async fn moderate_text(&self, text: &str, model: Option<ModerationModel>) -> Result<ModerationResponse> {
242 let request_body = ModerationRequest { input: ModerationInput::Single(text.to_string()), model: model.map(|m| m.as_str().to_string()) };
243
244 self.send_request(&request_body).await
245 }
246
247 /// Moderates multiple text strings.
248 ///
249 /// Classifies multiple input texts in a single request.
250 ///
251 /// # Arguments
252 ///
253 /// * `texts` - Vector of text strings to moderate
254 /// * `model` - Optional model to use (defaults to `omni-moderation-latest`)
255 ///
256 /// # Returns
257 ///
258 /// * `Ok(ModerationResponse)` - The moderation results (one result per input)
259 /// * `Err(OpenAIToolError)` - If the request fails or response parsing fails
260 ///
261 /// # Example
262 ///
263 /// ```rust,no_run
264 /// use openai_tools::moderations::request::Moderations;
265 ///
266 /// #[tokio::main]
267 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
268 /// let moderations = Moderations::new()?;
269 /// let texts = vec![
270 /// "First text to check".to_string(),
271 /// "Second text to check".to_string(),
272 /// ];
273 /// let response = moderations.moderate_texts(texts, None).await?;
274 ///
275 /// for (i, result) in response.results.iter().enumerate() {
276 /// println!("Text {}: flagged = {}", i + 1, result.flagged);
277 /// }
278 /// Ok(())
279 /// }
280 /// ```
281 pub async fn moderate_texts(&self, texts: Vec<String>, model: Option<ModerationModel>) -> Result<ModerationResponse> {
282 let request_body = ModerationRequest { input: ModerationInput::Multiple(texts), model: model.map(|m| m.as_str().to_string()) };
283
284 self.send_request(&request_body).await
285 }
286
287 /// Sends the moderation request to the API.
288 async fn send_request(&self, request_body: &ModerationRequest) -> Result<ModerationResponse> {
289 let (client, headers) = self.create_client()?;
290
291 let body = serde_json::to_string(request_body).map_err(OpenAIToolError::SerdeJsonError)?;
292
293 let url = self.auth.endpoint(MODERATIONS_PATH);
294 let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
295
296 let status = response.status();
297 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
298
299 if cfg!(test) {
300 tracing::info!("Response content: {}", content);
301 }
302
303 if !status.is_success() {
304 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
305 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
306 }
307 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
308 }
309
310 serde_json::from_str::<ModerationResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
311 }
312}