Skip to main content

openai_oxide/types/
moderation.rs

1// Moderation types — mirrors openai-python types/moderation.py
2
3use serde::{Deserialize, Serialize};
4
5// ── Request types ──
6
7/// Input for moderations: a single string or array of strings.
8#[derive(Debug, Clone, Serialize)]
9#[serde(untagged)]
10#[non_exhaustive]
11pub enum ModerationInput {
12    String(String),
13    StringArray(Vec<String>),
14}
15
16impl From<&str> for ModerationInput {
17    fn from(s: &str) -> Self {
18        ModerationInput::String(s.to_string())
19    }
20}
21
22impl From<String> for ModerationInput {
23    fn from(s: String) -> Self {
24        ModerationInput::String(s)
25    }
26}
27
28impl From<Vec<String>> for ModerationInput {
29    fn from(v: Vec<String>) -> Self {
30        ModerationInput::StringArray(v)
31    }
32}
33
34/// Request body for `POST /moderations`.
35#[derive(Debug, Clone, Serialize)]
36pub struct ModerationRequest {
37    /// Input text to classify.
38    pub input: ModerationInput,
39
40    /// Model to use (e.g. "omni-moderation-latest").
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub model: Option<String>,
43}
44
45impl ModerationRequest {
46    pub fn new(input: impl Into<ModerationInput>) -> Self {
47        Self {
48            input: input.into(),
49            model: None,
50        }
51    }
52}
53
54// ── Response types ──
55
56/// Category flags for moderation results.
57#[derive(Debug, Clone, Deserialize)]
58pub struct Categories {
59    pub harassment: bool,
60    #[serde(rename = "harassment/threatening")]
61    pub harassment_threatening: bool,
62    pub hate: bool,
63    #[serde(rename = "hate/threatening")]
64    pub hate_threatening: bool,
65    #[serde(default, rename = "illicit")]
66    pub illicit: Option<bool>,
67    #[serde(default, rename = "illicit/violent")]
68    pub illicit_violent: Option<bool>,
69    #[serde(rename = "self-harm")]
70    pub self_harm: bool,
71    #[serde(rename = "self-harm/instructions")]
72    pub self_harm_instructions: bool,
73    #[serde(rename = "self-harm/intent")]
74    pub self_harm_intent: bool,
75    pub sexual: bool,
76    #[serde(rename = "sexual/minors")]
77    pub sexual_minors: bool,
78    pub violence: bool,
79    #[serde(rename = "violence/graphic")]
80    pub violence_graphic: bool,
81}
82
83/// Category scores for moderation results.
84#[derive(Debug, Clone, Deserialize)]
85pub struct CategoryScores {
86    pub harassment: f64,
87    #[serde(rename = "harassment/threatening")]
88    pub harassment_threatening: f64,
89    pub hate: f64,
90    #[serde(rename = "hate/threatening")]
91    pub hate_threatening: f64,
92    #[serde(default, rename = "illicit")]
93    pub illicit: Option<f64>,
94    #[serde(default, rename = "illicit/violent")]
95    pub illicit_violent: Option<f64>,
96    #[serde(rename = "self-harm")]
97    pub self_harm: f64,
98    #[serde(rename = "self-harm/instructions")]
99    pub self_harm_instructions: f64,
100    #[serde(rename = "self-harm/intent")]
101    pub self_harm_intent: f64,
102    pub sexual: f64,
103    #[serde(rename = "sexual/minors")]
104    pub sexual_minors: f64,
105    pub violence: f64,
106    #[serde(rename = "violence/graphic")]
107    pub violence_graphic: f64,
108}
109
110/// A single moderation result.
111#[derive(Debug, Clone, Deserialize)]
112pub struct Moderation {
113    pub flagged: bool,
114    pub categories: Categories,
115    pub category_scores: CategoryScores,
116}
117
118/// Response from `POST /moderations`.
119#[derive(Debug, Clone, Deserialize)]
120pub struct ModerationCreateResponse {
121    pub id: String,
122    pub model: String,
123    pub results: Vec<Moderation>,
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_serialize_moderation_request() {
132        let req = ModerationRequest::new("I want to harm someone");
133        let json = serde_json::to_value(&req).unwrap();
134        assert_eq!(json["input"], "I want to harm someone");
135        assert!(json.get("model").is_none());
136    }
137
138    #[test]
139    fn test_deserialize_moderation_response() {
140        let json = r#"{
141            "id": "modr-abc123",
142            "model": "text-moderation-007",
143            "results": [{
144                "flagged": true,
145                "categories": {
146                    "harassment": true,
147                    "harassment/threatening": false,
148                    "hate": false,
149                    "hate/threatening": false,
150                    "self-harm": false,
151                    "self-harm/instructions": false,
152                    "self-harm/intent": false,
153                    "sexual": false,
154                    "sexual/minors": false,
155                    "violence": true,
156                    "violence/graphic": false
157                },
158                "category_scores": {
159                    "harassment": 0.85,
160                    "harassment/threatening": 0.02,
161                    "hate": 0.001,
162                    "hate/threatening": 0.0001,
163                    "self-harm": 0.0001,
164                    "self-harm/instructions": 0.0001,
165                    "self-harm/intent": 0.0001,
166                    "sexual": 0.0001,
167                    "sexual/minors": 0.0001,
168                    "violence": 0.75,
169                    "violence/graphic": 0.001
170                }
171            }]
172        }"#;
173
174        let resp: ModerationCreateResponse = serde_json::from_str(json).unwrap();
175        assert_eq!(resp.id, "modr-abc123");
176        assert_eq!(resp.results.len(), 1);
177        assert!(resp.results[0].flagged);
178        assert!(resp.results[0].categories.harassment);
179        assert!(resp.results[0].categories.violence);
180        assert!(resp.results[0].category_scores.harassment > 0.5);
181    }
182}