openai_oxide/types/
moderation.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize)]
9#[serde(untagged)]
10#[non_exhaustive]
11pub enum ModerationInput {
12 String(String),
13 StringArray(Vec<String>),
14}
15
16impl From<&str> for ModerationInput {
17 fn from(s: &str) -> Self {
18 ModerationInput::String(s.to_string())
19 }
20}
21
22impl From<String> for ModerationInput {
23 fn from(s: String) -> Self {
24 ModerationInput::String(s)
25 }
26}
27
28impl From<Vec<String>> for ModerationInput {
29 fn from(v: Vec<String>) -> Self {
30 ModerationInput::StringArray(v)
31 }
32}
33
34#[derive(Debug, Clone, Serialize)]
36pub struct ModerationRequest {
37 pub input: ModerationInput,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub model: Option<String>,
43}
44
45impl ModerationRequest {
46 pub fn new(input: impl Into<ModerationInput>) -> Self {
47 Self {
48 input: input.into(),
49 model: None,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Deserialize)]
58pub struct Categories {
59 pub harassment: bool,
60 #[serde(rename = "harassment/threatening")]
61 pub harassment_threatening: bool,
62 pub hate: bool,
63 #[serde(rename = "hate/threatening")]
64 pub hate_threatening: bool,
65 #[serde(default, rename = "illicit")]
66 pub illicit: Option<bool>,
67 #[serde(default, rename = "illicit/violent")]
68 pub illicit_violent: Option<bool>,
69 #[serde(rename = "self-harm")]
70 pub self_harm: bool,
71 #[serde(rename = "self-harm/instructions")]
72 pub self_harm_instructions: bool,
73 #[serde(rename = "self-harm/intent")]
74 pub self_harm_intent: bool,
75 pub sexual: bool,
76 #[serde(rename = "sexual/minors")]
77 pub sexual_minors: bool,
78 pub violence: bool,
79 #[serde(rename = "violence/graphic")]
80 pub violence_graphic: bool,
81}
82
83#[derive(Debug, Clone, Deserialize)]
85pub struct CategoryScores {
86 pub harassment: f64,
87 #[serde(rename = "harassment/threatening")]
88 pub harassment_threatening: f64,
89 pub hate: f64,
90 #[serde(rename = "hate/threatening")]
91 pub hate_threatening: f64,
92 #[serde(default, rename = "illicit")]
93 pub illicit: Option<f64>,
94 #[serde(default, rename = "illicit/violent")]
95 pub illicit_violent: Option<f64>,
96 #[serde(rename = "self-harm")]
97 pub self_harm: f64,
98 #[serde(rename = "self-harm/instructions")]
99 pub self_harm_instructions: f64,
100 #[serde(rename = "self-harm/intent")]
101 pub self_harm_intent: f64,
102 pub sexual: f64,
103 #[serde(rename = "sexual/minors")]
104 pub sexual_minors: f64,
105 pub violence: f64,
106 #[serde(rename = "violence/graphic")]
107 pub violence_graphic: f64,
108}
109
110#[derive(Debug, Clone, Deserialize)]
112pub struct Moderation {
113 pub flagged: bool,
114 pub categories: Categories,
115 pub category_scores: CategoryScores,
116}
117
118#[derive(Debug, Clone, Deserialize)]
120pub struct ModerationCreateResponse {
121 pub id: String,
122 pub model: String,
123 pub results: Vec<Moderation>,
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_serialize_moderation_request() {
132 let req = ModerationRequest::new("I want to harm someone");
133 let json = serde_json::to_value(&req).unwrap();
134 assert_eq!(json["input"], "I want to harm someone");
135 assert!(json.get("model").is_none());
136 }
137
138 #[test]
139 fn test_deserialize_moderation_response() {
140 let json = r#"{
141 "id": "modr-abc123",
142 "model": "text-moderation-007",
143 "results": [{
144 "flagged": true,
145 "categories": {
146 "harassment": true,
147 "harassment/threatening": false,
148 "hate": false,
149 "hate/threatening": false,
150 "self-harm": false,
151 "self-harm/instructions": false,
152 "self-harm/intent": false,
153 "sexual": false,
154 "sexual/minors": false,
155 "violence": true,
156 "violence/graphic": false
157 },
158 "category_scores": {
159 "harassment": 0.85,
160 "harassment/threatening": 0.02,
161 "hate": 0.001,
162 "hate/threatening": 0.0001,
163 "self-harm": 0.0001,
164 "self-harm/instructions": 0.0001,
165 "self-harm/intent": 0.0001,
166 "sexual": 0.0001,
167 "sexual/minors": 0.0001,
168 "violence": 0.75,
169 "violence/graphic": 0.001
170 }
171 }]
172 }"#;
173
174 let resp: ModerationCreateResponse = serde_json::from_str(json).unwrap();
175 assert_eq!(resp.id, "modr-abc123");
176 assert_eq!(resp.results.len(), 1);
177 assert!(resp.results[0].flagged);
178 assert!(resp.results[0].categories.harassment);
179 assert!(resp.results[0].categories.violence);
180 assert!(resp.results[0].category_scores.harassment > 0.5);
181 }
182}