openai_ergonomic/builders/
moderations.rs1#[derive(Debug, Clone)]
19pub struct ModerationBuilder {
20 input: ModerationInput,
21 model: Option<String>,
22}
23
24#[derive(Debug, Clone)]
26pub enum ModerationInput {
27 Text(String),
29 TextArray(Vec<String>),
31}
32
33#[derive(Debug, Clone)]
35pub struct ModerationResult {
36 pub flagged: bool,
38 pub categories: ModerationCategories,
40 pub category_scores: ModerationCategoryScores,
42}
43
44#[derive(Debug, Clone)]
46#[allow(clippy::struct_excessive_bools)]
47pub struct ModerationCategories {
48 pub hate: bool,
50 pub hate_threatening: bool,
52 pub harassment: bool,
54 pub harassment_threatening: bool,
56 pub self_harm: bool,
58 pub self_harm_intent: bool,
60 pub self_harm_instructions: bool,
62 pub sexual: bool,
64 pub sexual_minors: bool,
66 pub violence: bool,
68 pub violence_graphic: bool,
70}
71
72#[derive(Debug, Clone)]
74pub struct ModerationCategoryScores {
75 pub hate: f64,
77 pub hate_threatening: f64,
79 pub harassment: f64,
81 pub harassment_threatening: f64,
83 pub self_harm: f64,
85 pub self_harm_intent: f64,
87 pub self_harm_instructions: f64,
89 pub sexual: f64,
91 pub sexual_minors: f64,
93 pub violence: f64,
95 pub violence_graphic: f64,
97}
98
99impl ModerationBuilder {
100 #[must_use]
110 pub fn new(input: impl Into<String>) -> Self {
111 Self {
112 input: ModerationInput::Text(input.into()),
113 model: None,
114 }
115 }
116
117 #[must_use]
119 pub fn new_array(inputs: Vec<String>) -> Self {
120 Self {
121 input: ModerationInput::TextArray(inputs),
122 model: None,
123 }
124 }
125
126 #[must_use]
132 pub fn model(mut self, model: impl Into<String>) -> Self {
133 self.model = Some(model.into());
134 self
135 }
136
137 #[must_use]
139 pub fn input(&self) -> &ModerationInput {
140 &self.input
141 }
142
143 #[must_use]
145 pub fn model_ref(&self) -> Option<&str> {
146 self.model.as_deref()
147 }
148
149 #[must_use]
151 pub fn has_multiple_inputs(&self) -> bool {
152 matches!(self.input, ModerationInput::TextArray(_))
153 }
154
155 #[must_use]
157 pub fn input_count(&self) -> usize {
158 match &self.input {
159 ModerationInput::Text(_) => 1,
160 ModerationInput::TextArray(texts) => texts.len(),
161 }
162 }
163
164 #[must_use]
166 pub fn first_input(&self) -> Option<&str> {
167 match &self.input {
168 ModerationInput::Text(text) => Some(text),
169 ModerationInput::TextArray(texts) => texts.first().map(std::string::String::as_str),
170 }
171 }
172
173 #[must_use]
175 pub fn all_inputs(&self) -> Vec<&str> {
176 match &self.input {
177 ModerationInput::Text(text) => vec![text],
178 ModerationInput::TextArray(texts) => {
179 texts.iter().map(std::string::String::as_str).collect()
180 }
181 }
182 }
183
184 pub fn build(self) -> crate::Result<openai_client_base::models::CreateModerationRequest> {
190 let input_string = match self.input {
191 ModerationInput::Text(text) => text,
192 ModerationInput::TextArray(texts) => {
193 texts.join("\n")
196 }
197 };
198
199 Ok(openai_client_base::models::CreateModerationRequest {
200 input: input_string,
201 model: self.model,
202 })
203 }
204}
205
206impl crate::builders::Builder<openai_client_base::models::CreateModerationRequest>
207 for ModerationBuilder
208{
209 fn build(self) -> crate::Result<openai_client_base::models::CreateModerationRequest> {
210 self.build()
211 }
212}
213
214impl ModerationCategories {
215 #[must_use]
217 pub fn new_clean() -> Self {
218 Self {
219 hate: false,
220 hate_threatening: false,
221 harassment: false,
222 harassment_threatening: false,
223 self_harm: false,
224 self_harm_intent: false,
225 self_harm_instructions: false,
226 sexual: false,
227 sexual_minors: false,
228 violence: false,
229 violence_graphic: false,
230 }
231 }
232
233 #[must_use]
235 pub fn any_flagged(&self) -> bool {
236 self.hate
237 || self.hate_threatening
238 || self.harassment
239 || self.harassment_threatening
240 || self.self_harm
241 || self.self_harm_intent
242 || self.self_harm_instructions
243 || self.sexual
244 || self.sexual_minors
245 || self.violence
246 || self.violence_graphic
247 }
248
249 #[must_use]
251 pub fn flagged_categories(&self) -> Vec<&'static str> {
252 let mut flagged = Vec::new();
253 if self.hate {
254 flagged.push("hate");
255 }
256 if self.hate_threatening {
257 flagged.push("hate/threatening");
258 }
259 if self.harassment {
260 flagged.push("harassment");
261 }
262 if self.harassment_threatening {
263 flagged.push("harassment/threatening");
264 }
265 if self.self_harm {
266 flagged.push("self-harm");
267 }
268 if self.self_harm_intent {
269 flagged.push("self-harm/intent");
270 }
271 if self.self_harm_instructions {
272 flagged.push("self-harm/instructions");
273 }
274 if self.sexual {
275 flagged.push("sexual");
276 }
277 if self.sexual_minors {
278 flagged.push("sexual/minors");
279 }
280 if self.violence {
281 flagged.push("violence");
282 }
283 if self.violence_graphic {
284 flagged.push("violence/graphic");
285 }
286 flagged
287 }
288}
289
290impl ModerationCategoryScores {
291 #[must_use]
293 pub fn new_zero() -> Self {
294 Self {
295 hate: 0.0,
296 hate_threatening: 0.0,
297 harassment: 0.0,
298 harassment_threatening: 0.0,
299 self_harm: 0.0,
300 self_harm_intent: 0.0,
301 self_harm_instructions: 0.0,
302 sexual: 0.0,
303 sexual_minors: 0.0,
304 violence: 0.0,
305 violence_graphic: 0.0,
306 }
307 }
308
309 #[must_use]
311 pub fn max_score(&self) -> f64 {
312 [
313 self.hate,
314 self.hate_threatening,
315 self.harassment,
316 self.harassment_threatening,
317 self.self_harm,
318 self.self_harm_intent,
319 self.self_harm_instructions,
320 self.sexual,
321 self.sexual_minors,
322 self.violence,
323 self.violence_graphic,
324 ]
325 .iter()
326 .fold(0.0, |max, &score| if score > max { score } else { max })
327 }
328
329 #[must_use]
331 pub fn scores_above_threshold(&self, threshold: f64) -> Vec<(&'static str, f64)> {
332 let mut high_scores = Vec::new();
333 if self.hate > threshold {
334 high_scores.push(("hate", self.hate));
335 }
336 if self.hate_threatening > threshold {
337 high_scores.push(("hate/threatening", self.hate_threatening));
338 }
339 if self.harassment > threshold {
340 high_scores.push(("harassment", self.harassment));
341 }
342 if self.harassment_threatening > threshold {
343 high_scores.push(("harassment/threatening", self.harassment_threatening));
344 }
345 if self.self_harm > threshold {
346 high_scores.push(("self-harm", self.self_harm));
347 }
348 if self.self_harm_intent > threshold {
349 high_scores.push(("self-harm/intent", self.self_harm_intent));
350 }
351 if self.self_harm_instructions > threshold {
352 high_scores.push(("self-harm/instructions", self.self_harm_instructions));
353 }
354 if self.sexual > threshold {
355 high_scores.push(("sexual", self.sexual));
356 }
357 if self.sexual_minors > threshold {
358 high_scores.push(("sexual/minors", self.sexual_minors));
359 }
360 if self.violence > threshold {
361 high_scores.push(("violence", self.violence));
362 }
363 if self.violence_graphic > threshold {
364 high_scores.push(("violence/graphic", self.violence_graphic));
365 }
366 high_scores
367 }
368}
369
370impl ModerationResult {
371 #[must_use]
373 pub fn new_clean() -> Self {
374 Self {
375 flagged: false,
376 categories: ModerationCategories::new_clean(),
377 category_scores: ModerationCategoryScores::new_zero(),
378 }
379 }
380
381 #[must_use]
383 pub fn is_safe(&self) -> bool {
384 !self.flagged
385 }
386
387 #[must_use]
389 pub fn flagged_summary(&self) -> Option<Vec<&'static str>> {
390 if self.flagged {
391 Some(self.categories.flagged_categories())
392 } else {
393 None
394 }
395 }
396}
397
398#[must_use]
400pub fn moderate_text(input: impl Into<String>) -> ModerationBuilder {
401 ModerationBuilder::new(input)
402}
403
404#[must_use]
406pub fn moderate_texts(inputs: Vec<String>) -> ModerationBuilder {
407 ModerationBuilder::new_array(inputs)
408}
409
410#[must_use]
412pub fn moderate_text_with_model(
413 input: impl Into<String>,
414 model: impl Into<String>,
415) -> ModerationBuilder {
416 ModerationBuilder::new(input).model(model)
417}
418
419#[must_use]
421pub fn moderate_messages(messages: &[impl AsRef<str>]) -> ModerationBuilder {
422 let inputs = messages
423 .iter()
424 .map(|msg| msg.as_ref().to_string())
425 .collect();
426 ModerationBuilder::new_array(inputs)
427}
428
429#[must_use]
432pub fn likely_flagged(text: &str) -> bool {
433 let lower = text.to_lowercase();
434 lower.contains("hate") || lower.contains("violence") || lower.contains("harmful")
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_moderation_builder_new() {
444 let builder = ModerationBuilder::new("Test content");
445
446 assert_eq!(builder.input_count(), 1);
447 assert_eq!(builder.first_input(), Some("Test content"));
448 assert!(!builder.has_multiple_inputs());
449 assert!(builder.model_ref().is_none());
450 }
451
452 #[test]
453 fn test_moderation_builder_new_array() {
454 let inputs = vec!["First text".to_string(), "Second text".to_string()];
455 let builder = ModerationBuilder::new_array(inputs);
456
457 assert_eq!(builder.input_count(), 2);
458 assert_eq!(builder.first_input(), Some("First text"));
459 assert!(builder.has_multiple_inputs());
460 assert_eq!(builder.all_inputs(), vec!["First text", "Second text"]);
461 }
462
463 #[test]
464 fn test_moderation_builder_with_model() {
465 let builder = ModerationBuilder::new("Test").model("text-moderation-stable");
466
467 assert_eq!(builder.model_ref(), Some("text-moderation-stable"));
468 }
469
470 #[test]
471 fn test_moderation_categories_new_clean() {
472 let categories = ModerationCategories::new_clean();
473 assert!(!categories.any_flagged());
474 assert!(categories.flagged_categories().is_empty());
475 }
476
477 #[test]
478 fn test_moderation_categories_flagged() {
479 let mut categories = ModerationCategories::new_clean();
480 categories.hate = true;
481 categories.violence = true;
482
483 assert!(categories.any_flagged());
484 let flagged = categories.flagged_categories();
485 assert_eq!(flagged.len(), 2);
486 assert!(flagged.contains(&"hate"));
487 assert!(flagged.contains(&"violence"));
488 }
489
490 #[test]
491 fn test_moderation_category_scores_new_zero() {
492 let scores = ModerationCategoryScores::new_zero();
493 assert!((scores.max_score() - 0.0).abs() < f64::EPSILON);
494 assert!(scores.scores_above_threshold(0.1).is_empty());
495 }
496
497 #[test]
498 fn test_moderation_category_scores_max_and_threshold() {
499 let mut scores = ModerationCategoryScores::new_zero();
500 scores.hate = 0.8;
501 scores.violence = 0.6;
502 scores.sexual = 0.3;
503
504 assert!((scores.max_score() - 0.8).abs() < f64::EPSILON);
505
506 let high_scores = scores.scores_above_threshold(0.5);
507 assert_eq!(high_scores.len(), 2);
508 assert!(high_scores.contains(&("hate", 0.8)));
509 assert!(high_scores.contains(&("violence", 0.6)));
510 }
511
512 #[test]
513 fn test_moderation_result_new_clean() {
514 let result = ModerationResult::new_clean();
515 assert!(result.is_safe());
516 assert!(result.flagged_summary().is_none());
517 }
518
519 #[test]
520 fn test_moderation_result_flagged() {
521 let mut result = ModerationResult::new_clean();
522 result.flagged = true;
523 result.categories.hate = true;
524
525 assert!(!result.is_safe());
526 let summary = result.flagged_summary().unwrap();
527 assert_eq!(summary, vec!["hate"]);
528 }
529
530 #[test]
531 fn test_moderate_text_helper() {
532 let builder = moderate_text("Test content");
533 assert_eq!(builder.first_input(), Some("Test content"));
534 assert!(!builder.has_multiple_inputs());
535 }
536
537 #[test]
538 fn test_moderate_texts_helper() {
539 let inputs = vec!["Text 1".to_string(), "Text 2".to_string()];
540 let builder = moderate_texts(inputs);
541 assert_eq!(builder.input_count(), 2);
542 assert!(builder.has_multiple_inputs());
543 }
544
545 #[test]
546 fn test_moderate_text_with_model_helper() {
547 let builder = moderate_text_with_model("Test", "text-moderation-latest");
548 assert_eq!(builder.first_input(), Some("Test"));
549 assert_eq!(builder.model_ref(), Some("text-moderation-latest"));
550 }
551
552 #[test]
553 fn test_moderate_messages_helper() {
554 let messages = ["Hello", "World"];
555 let builder = moderate_messages(&messages);
556 assert_eq!(builder.input_count(), 2);
557 assert_eq!(builder.all_inputs(), vec!["Hello", "World"]);
558 }
559
560 #[test]
561 fn test_likely_flagged_helper() {
562 assert!(likely_flagged("This contains hate speech"));
563 assert!(likely_flagged("Violence is not good"));
564 assert!(likely_flagged("This is harmful content"));
565 assert!(!likely_flagged("This is normal content"));
566 assert!(!likely_flagged("Hello, how are you?"));
567 }
568
569 #[test]
570 fn test_moderation_input_variants() {
571 let single = ModerationInput::Text("single".to_string());
572 let multiple = ModerationInput::TextArray(vec!["one".to_string(), "two".to_string()]);
573
574 match single {
575 ModerationInput::Text(text) => assert_eq!(text, "single"),
576 ModerationInput::TextArray(_) => panic!("Expected Text variant"),
577 }
578
579 match multiple {
580 ModerationInput::TextArray(texts) => assert_eq!(texts.len(), 2),
581 ModerationInput::Text(_) => panic!("Expected TextArray variant"),
582 }
583 }
584}