use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use crate::{api_resources::TokenUsage, Client, Result};
#[skip_serializing_none]
#[derive(Builder, Debug, Default, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct ModerationParam {
model: Option<String>,
input: String,
}
impl ModerationParamBuilder {
pub fn new(input: impl Into<String>) -> Self {
Self {
input: Some(input.into()),
..Self::default()
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct Moderation {
pub id: String,
pub model: String,
pub flagged: bool,
pub results: Vec<ModerationResult>,
pub token_usage: Option<TokenUsage>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct ModerationResult {
pub categories: Categories,
pub category_scores: CategoryScores,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct Categories {
pub hate: bool,
#[serde(rename = "hate/threatening")]
pub hate_threatening: bool,
#[serde(rename = "self-harm")]
pub self_harm: bool,
pub sexual: bool,
#[serde(rename = "sexual/minors")]
pub sexual_minors: bool,
pub violence: bool,
#[serde(rename = "violence/graphic")]
pub violence_graphic: bool,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct CategoryScores {
pub hate: f64,
#[serde(rename = "hate/threatening")]
pub hate_threatening: f64,
#[serde(rename = "self-harm")]
pub self_harm: f64,
pub sexual: f64,
#[serde(rename = "sexual/minors")]
pub sexual_minors: f64,
pub violence: f64,
#[serde(rename = "violence/graphic")]
pub violence_graphic: f64,
}
pub async fn create(client: &Client, param: &ModerationParam) -> Result<Moderation> {
client.create_moderation(param).await
}
impl Client {
async fn create_moderation(&self, param: &ModerationParam) -> Result<Moderation> {
self.post::<ModerationParam, Moderation>("moderations", Some(param))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_moderation() {
let param: ModerationParam = serde_json::from_str(
r#"
{
"input": "I want to kill them."
}
"#,
)
.unwrap();
let resp: Moderation = serde_json::from_str(
r#"
{
"id": "modr-5MWoLO",
"model": "text-moderation-001",
"results": [
{
"categories": {
"hate": false,
"hate/threatening": true,
"self-harm": false,
"sexual": false,
"sexual/minors": false,
"violence": true,
"violence/graphic": false
},
"category_scores": {
"hate": 0.22714105248451233,
"hate/threatening": 0.4132447838783264,
"self-harm": 0.005232391878962517,
"sexual": 0.01407341007143259,
"sexual/minors": 0.0038522258400917053,
"violence": 0.9223177433013916,
"violence/graphic": 0.036865197122097015
},
"flagged": true
}
]
}
"#,
)
.unwrap();
assert_eq!(param.input, "I want to kill them.");
assert_eq!(resp.id, "modr-5MWoLO");
assert_eq!(resp.model, "text-moderation-001");
assert_eq!(resp.results.len(), 1);
}
}