openai_ergonomic/builders/
moderations.rs

1//! Moderations API builders.
2//!
3//! This module provides ergonomic builders for `OpenAI` Moderations API operations,
4//! which help detect potentially harmful content across various categories.
5//!
6//! The Moderations API can identify content that may be:
7//! - Hate speech
8//! - Harassment
9//! - Self-harm related
10//! - Sexual content
11//! - Violence
12//! - And other harmful categories
13
14/// Builder for content moderation requests.
15///
16/// This builder provides a fluent interface for creating moderation requests
17/// to check if content violates `OpenAI`'s usage policies.
18#[derive(Debug, Clone)]
19pub struct ModerationBuilder {
20    input: ModerationInput,
21    model: Option<String>,
22}
23
24/// Input for moderation requests.
25#[derive(Debug, Clone)]
26pub enum ModerationInput {
27    /// Single text input
28    Text(String),
29    /// Multiple text inputs
30    TextArray(Vec<String>),
31}
32
33/// Result of a moderation check.
34#[derive(Debug, Clone)]
35pub struct ModerationResult {
36    /// Whether the content was flagged
37    pub flagged: bool,
38    /// The flagged categories
39    pub categories: ModerationCategories,
40    /// The confidence scores for each category
41    pub category_scores: ModerationCategoryScores,
42}
43
44/// Categories that can be flagged by moderation.
45#[derive(Debug, Clone)]
46#[allow(clippy::struct_excessive_bools)]
47pub struct ModerationCategories {
48    /// Hate speech
49    pub hate: bool,
50    /// Threatening hate speech
51    pub hate_threatening: bool,
52    /// Harassment
53    pub harassment: bool,
54    /// Threatening harassment
55    pub harassment_threatening: bool,
56    /// Self-harm content
57    pub self_harm: bool,
58    /// Intent to self-harm
59    pub self_harm_intent: bool,
60    /// Instructions for self-harm
61    pub self_harm_instructions: bool,
62    /// Sexual content
63    pub sexual: bool,
64    /// Sexual content involving minors
65    pub sexual_minors: bool,
66    /// Violence
67    pub violence: bool,
68    /// Graphic violence
69    pub violence_graphic: bool,
70}
71
72/// Confidence scores for each moderation category.
73#[derive(Debug, Clone)]
74pub struct ModerationCategoryScores {
75    /// Hate speech score
76    pub hate: f64,
77    /// Threatening hate speech score
78    pub hate_threatening: f64,
79    /// Harassment score
80    pub harassment: f64,
81    /// Threatening harassment score
82    pub harassment_threatening: f64,
83    /// Self-harm content score
84    pub self_harm: f64,
85    /// Intent to self-harm score
86    pub self_harm_intent: f64,
87    /// Instructions for self-harm score
88    pub self_harm_instructions: f64,
89    /// Sexual content score
90    pub sexual: f64,
91    /// Sexual content involving minors score
92    pub sexual_minors: f64,
93    /// Violence score
94    pub violence: f64,
95    /// Graphic violence score
96    pub violence_graphic: f64,
97}
98
99impl ModerationBuilder {
100    /// Create a new moderation builder with text input.
101    ///
102    /// # Examples
103    ///
104    /// ```rust
105    /// use openai_ergonomic::builders::moderations::ModerationBuilder;
106    ///
107    /// let builder = ModerationBuilder::new("Check this text for harmful content");
108    /// ```
109    #[must_use]
110    pub fn new(input: impl Into<String>) -> Self {
111        Self {
112            input: ModerationInput::Text(input.into()),
113            model: None,
114        }
115    }
116
117    /// Create a moderation builder with multiple text inputs.
118    #[must_use]
119    pub fn new_array(inputs: Vec<String>) -> Self {
120        Self {
121            input: ModerationInput::TextArray(inputs),
122            model: None,
123        }
124    }
125
126    /// Set the moderation model to use.
127    ///
128    /// Common models include:
129    /// - `text-moderation-latest` (default)
130    /// - `text-moderation-stable`
131    #[must_use]
132    pub fn model(mut self, model: impl Into<String>) -> Self {
133        self.model = Some(model.into());
134        self
135    }
136
137    /// Get the input for this moderation request.
138    #[must_use]
139    pub fn input(&self) -> &ModerationInput {
140        &self.input
141    }
142
143    /// Get the model for this moderation request.
144    #[must_use]
145    pub fn model_ref(&self) -> Option<&str> {
146        self.model.as_deref()
147    }
148
149    /// Check if this request has multiple inputs.
150    #[must_use]
151    pub fn has_multiple_inputs(&self) -> bool {
152        matches!(self.input, ModerationInput::TextArray(_))
153    }
154
155    /// Get the number of inputs in this request.
156    #[must_use]
157    pub fn input_count(&self) -> usize {
158        match &self.input {
159            ModerationInput::Text(_) => 1,
160            ModerationInput::TextArray(texts) => texts.len(),
161        }
162    }
163
164    /// Get the first input text (useful for single input requests).
165    #[must_use]
166    pub fn first_input(&self) -> Option<&str> {
167        match &self.input {
168            ModerationInput::Text(text) => Some(text),
169            ModerationInput::TextArray(texts) => texts.first().map(std::string::String::as_str),
170        }
171    }
172
173    /// Get all input texts as a vector.
174    #[must_use]
175    pub fn all_inputs(&self) -> Vec<&str> {
176        match &self.input {
177            ModerationInput::Text(text) => vec![text],
178            ModerationInput::TextArray(texts) => {
179                texts.iter().map(std::string::String::as_str).collect()
180            }
181        }
182    }
183
184    /// Build the moderation request.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if the request cannot be built.
189    pub fn build(self) -> crate::Result<openai_client_base::models::CreateModerationRequest> {
190        let input_string = match self.input {
191            ModerationInput::Text(text) => text,
192            ModerationInput::TextArray(texts) => {
193                // For array inputs, join with newlines for now
194                // The base API expects a single string
195                texts.join("\n")
196            }
197        };
198
199        Ok(openai_client_base::models::CreateModerationRequest {
200            input: input_string,
201            model: self.model,
202        })
203    }
204}
205
206impl crate::builders::Builder<openai_client_base::models::CreateModerationRequest>
207    for ModerationBuilder
208{
209    fn build(self) -> crate::Result<openai_client_base::models::CreateModerationRequest> {
210        self.build()
211    }
212}
213
214impl ModerationCategories {
215    /// Create a new `ModerationCategories` with all categories set to false.
216    #[must_use]
217    pub fn new_clean() -> Self {
218        Self {
219            hate: false,
220            hate_threatening: false,
221            harassment: false,
222            harassment_threatening: false,
223            self_harm: false,
224            self_harm_intent: false,
225            self_harm_instructions: false,
226            sexual: false,
227            sexual_minors: false,
228            violence: false,
229            violence_graphic: false,
230        }
231    }
232
233    /// Check if any category is flagged.
234    #[must_use]
235    pub fn any_flagged(&self) -> bool {
236        self.hate
237            || self.hate_threatening
238            || self.harassment
239            || self.harassment_threatening
240            || self.self_harm
241            || self.self_harm_intent
242            || self.self_harm_instructions
243            || self.sexual
244            || self.sexual_minors
245            || self.violence
246            || self.violence_graphic
247    }
248
249    /// Get all flagged categories as a vector of strings.
250    #[must_use]
251    pub fn flagged_categories(&self) -> Vec<&'static str> {
252        let mut flagged = Vec::new();
253        if self.hate {
254            flagged.push("hate");
255        }
256        if self.hate_threatening {
257            flagged.push("hate/threatening");
258        }
259        if self.harassment {
260            flagged.push("harassment");
261        }
262        if self.harassment_threatening {
263            flagged.push("harassment/threatening");
264        }
265        if self.self_harm {
266            flagged.push("self-harm");
267        }
268        if self.self_harm_intent {
269            flagged.push("self-harm/intent");
270        }
271        if self.self_harm_instructions {
272            flagged.push("self-harm/instructions");
273        }
274        if self.sexual {
275            flagged.push("sexual");
276        }
277        if self.sexual_minors {
278            flagged.push("sexual/minors");
279        }
280        if self.violence {
281            flagged.push("violence");
282        }
283        if self.violence_graphic {
284            flagged.push("violence/graphic");
285        }
286        flagged
287    }
288}
289
290impl ModerationCategoryScores {
291    /// Create a new `ModerationCategoryScores` with all scores set to 0.0.
292    #[must_use]
293    pub fn new_zero() -> Self {
294        Self {
295            hate: 0.0,
296            hate_threatening: 0.0,
297            harassment: 0.0,
298            harassment_threatening: 0.0,
299            self_harm: 0.0,
300            self_harm_intent: 0.0,
301            self_harm_instructions: 0.0,
302            sexual: 0.0,
303            sexual_minors: 0.0,
304            violence: 0.0,
305            violence_graphic: 0.0,
306        }
307    }
308
309    /// Get the highest score across all categories.
310    #[must_use]
311    pub fn max_score(&self) -> f64 {
312        [
313            self.hate,
314            self.hate_threatening,
315            self.harassment,
316            self.harassment_threatening,
317            self.self_harm,
318            self.self_harm_intent,
319            self.self_harm_instructions,
320            self.sexual,
321            self.sexual_minors,
322            self.violence,
323            self.violence_graphic,
324        ]
325        .iter()
326        .fold(0.0, |max, &score| if score > max { score } else { max })
327    }
328
329    /// Get scores above a certain threshold.
330    #[must_use]
331    pub fn scores_above_threshold(&self, threshold: f64) -> Vec<(&'static str, f64)> {
332        let mut high_scores = Vec::new();
333        if self.hate > threshold {
334            high_scores.push(("hate", self.hate));
335        }
336        if self.hate_threatening > threshold {
337            high_scores.push(("hate/threatening", self.hate_threatening));
338        }
339        if self.harassment > threshold {
340            high_scores.push(("harassment", self.harassment));
341        }
342        if self.harassment_threatening > threshold {
343            high_scores.push(("harassment/threatening", self.harassment_threatening));
344        }
345        if self.self_harm > threshold {
346            high_scores.push(("self-harm", self.self_harm));
347        }
348        if self.self_harm_intent > threshold {
349            high_scores.push(("self-harm/intent", self.self_harm_intent));
350        }
351        if self.self_harm_instructions > threshold {
352            high_scores.push(("self-harm/instructions", self.self_harm_instructions));
353        }
354        if self.sexual > threshold {
355            high_scores.push(("sexual", self.sexual));
356        }
357        if self.sexual_minors > threshold {
358            high_scores.push(("sexual/minors", self.sexual_minors));
359        }
360        if self.violence > threshold {
361            high_scores.push(("violence", self.violence));
362        }
363        if self.violence_graphic > threshold {
364            high_scores.push(("violence/graphic", self.violence_graphic));
365        }
366        high_scores
367    }
368}
369
370impl ModerationResult {
371    /// Create a new clean moderation result (not flagged).
372    #[must_use]
373    pub fn new_clean() -> Self {
374        Self {
375            flagged: false,
376            categories: ModerationCategories::new_clean(),
377            category_scores: ModerationCategoryScores::new_zero(),
378        }
379    }
380
381    /// Check if the content is safe (not flagged).
382    #[must_use]
383    pub fn is_safe(&self) -> bool {
384        !self.flagged
385    }
386
387    /// Get a summary of why content was flagged (if it was).
388    #[must_use]
389    pub fn flagged_summary(&self) -> Option<Vec<&'static str>> {
390        if self.flagged {
391            Some(self.categories.flagged_categories())
392        } else {
393            None
394        }
395    }
396}
397
398/// Helper function to create a simple moderation request.
399#[must_use]
400pub fn moderate_text(input: impl Into<String>) -> ModerationBuilder {
401    ModerationBuilder::new(input)
402}
403
404/// Helper function to create a moderation request with multiple inputs.
405#[must_use]
406pub fn moderate_texts(inputs: Vec<String>) -> ModerationBuilder {
407    ModerationBuilder::new_array(inputs)
408}
409
410/// Helper function to create a moderation request with a specific model.
411#[must_use]
412pub fn moderate_text_with_model(
413    input: impl Into<String>,
414    model: impl Into<String>,
415) -> ModerationBuilder {
416    ModerationBuilder::new(input).model(model)
417}
418
419/// Helper function to moderate a batch of messages.
420#[must_use]
421pub fn moderate_messages(messages: &[impl AsRef<str>]) -> ModerationBuilder {
422    let inputs = messages
423        .iter()
424        .map(|msg| msg.as_ref().to_string())
425        .collect();
426    ModerationBuilder::new_array(inputs)
427}
428
429/// Check if a given text is likely to be flagged based on simple heuristics.
430/// This is not a replacement for the actual API, just a helper for testing.
431#[must_use]
432pub fn likely_flagged(text: &str) -> bool {
433    let lower = text.to_lowercase();
434    // This is a very basic heuristic - the real API is much more sophisticated
435    lower.contains("hate") || lower.contains("violence") || lower.contains("harmful")
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_moderation_builder_new() {
444        let builder = ModerationBuilder::new("Test content");
445
446        assert_eq!(builder.input_count(), 1);
447        assert_eq!(builder.first_input(), Some("Test content"));
448        assert!(!builder.has_multiple_inputs());
449        assert!(builder.model_ref().is_none());
450    }
451
452    #[test]
453    fn test_moderation_builder_new_array() {
454        let inputs = vec!["First text".to_string(), "Second text".to_string()];
455        let builder = ModerationBuilder::new_array(inputs);
456
457        assert_eq!(builder.input_count(), 2);
458        assert_eq!(builder.first_input(), Some("First text"));
459        assert!(builder.has_multiple_inputs());
460        assert_eq!(builder.all_inputs(), vec!["First text", "Second text"]);
461    }
462
463    #[test]
464    fn test_moderation_builder_with_model() {
465        let builder = ModerationBuilder::new("Test").model("text-moderation-stable");
466
467        assert_eq!(builder.model_ref(), Some("text-moderation-stable"));
468    }
469
470    #[test]
471    fn test_moderation_categories_new_clean() {
472        let categories = ModerationCategories::new_clean();
473        assert!(!categories.any_flagged());
474        assert!(categories.flagged_categories().is_empty());
475    }
476
477    #[test]
478    fn test_moderation_categories_flagged() {
479        let mut categories = ModerationCategories::new_clean();
480        categories.hate = true;
481        categories.violence = true;
482
483        assert!(categories.any_flagged());
484        let flagged = categories.flagged_categories();
485        assert_eq!(flagged.len(), 2);
486        assert!(flagged.contains(&"hate"));
487        assert!(flagged.contains(&"violence"));
488    }
489
490    #[test]
491    fn test_moderation_category_scores_new_zero() {
492        let scores = ModerationCategoryScores::new_zero();
493        assert!((scores.max_score() - 0.0).abs() < f64::EPSILON);
494        assert!(scores.scores_above_threshold(0.1).is_empty());
495    }
496
497    #[test]
498    fn test_moderation_category_scores_max_and_threshold() {
499        let mut scores = ModerationCategoryScores::new_zero();
500        scores.hate = 0.8;
501        scores.violence = 0.6;
502        scores.sexual = 0.3;
503
504        assert!((scores.max_score() - 0.8).abs() < f64::EPSILON);
505
506        let high_scores = scores.scores_above_threshold(0.5);
507        assert_eq!(high_scores.len(), 2);
508        assert!(high_scores.contains(&("hate", 0.8)));
509        assert!(high_scores.contains(&("violence", 0.6)));
510    }
511
512    #[test]
513    fn test_moderation_result_new_clean() {
514        let result = ModerationResult::new_clean();
515        assert!(result.is_safe());
516        assert!(result.flagged_summary().is_none());
517    }
518
519    #[test]
520    fn test_moderation_result_flagged() {
521        let mut result = ModerationResult::new_clean();
522        result.flagged = true;
523        result.categories.hate = true;
524
525        assert!(!result.is_safe());
526        let summary = result.flagged_summary().unwrap();
527        assert_eq!(summary, vec!["hate"]);
528    }
529
530    #[test]
531    fn test_moderate_text_helper() {
532        let builder = moderate_text("Test content");
533        assert_eq!(builder.first_input(), Some("Test content"));
534        assert!(!builder.has_multiple_inputs());
535    }
536
537    #[test]
538    fn test_moderate_texts_helper() {
539        let inputs = vec!["Text 1".to_string(), "Text 2".to_string()];
540        let builder = moderate_texts(inputs);
541        assert_eq!(builder.input_count(), 2);
542        assert!(builder.has_multiple_inputs());
543    }
544
545    #[test]
546    fn test_moderate_text_with_model_helper() {
547        let builder = moderate_text_with_model("Test", "text-moderation-latest");
548        assert_eq!(builder.first_input(), Some("Test"));
549        assert_eq!(builder.model_ref(), Some("text-moderation-latest"));
550    }
551
552    #[test]
553    fn test_moderate_messages_helper() {
554        let messages = ["Hello", "World"];
555        let builder = moderate_messages(&messages);
556        assert_eq!(builder.input_count(), 2);
557        assert_eq!(builder.all_inputs(), vec!["Hello", "World"]);
558    }
559
560    #[test]
561    fn test_likely_flagged_helper() {
562        assert!(likely_flagged("This contains hate speech"));
563        assert!(likely_flagged("Violence is not good"));
564        assert!(likely_flagged("This is harmful content"));
565        assert!(!likely_flagged("This is normal content"));
566        assert!(!likely_flagged("Hello, how are you?"));
567    }
568
569    #[test]
570    fn test_moderation_input_variants() {
571        let single = ModerationInput::Text("single".to_string());
572        let multiple = ModerationInput::TextArray(vec!["one".to_string(), "two".to_string()]);
573
574        match single {
575            ModerationInput::Text(text) => assert_eq!(text, "single"),
576            ModerationInput::TextArray(_) => panic!("Expected Text variant"),
577        }
578
579        match multiple {
580            ModerationInput::TextArray(texts) => assert_eq!(texts.len(), 2),
581            ModerationInput::Text(_) => panic!("Expected TextArray variant"),
582        }
583    }
584}