use core::fmt;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Url};
use serde::{Deserialize, Serialize};
use crate::error::Error;
pub struct ModerationClient {
base_url: Url,
http_client: ReqwestClient,
}
impl ModerationClient {
pub fn new(base_url: Url, http_client: ReqwestClient) -> Self {
Self {
base_url,
http_client,
}
}
pub async fn create_moderation(&self, payload: CreateModeration) -> Result<Moderation, Error> {
let model = self
.request(Method::POST, "moderations")?
.json(&payload)
.send()
.await?
.json::<Moderation>()
.await?;
Ok(model)
}
fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
let url = self
.base_url
.join(path)
.map_err(|err| Error::UrlParse(err.to_string()))?;
Ok(self.http_client.request(method, url))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Moderation {
pub id: String,
pub model: String,
pub results: Vec<ModerationResult>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateModeration {
pub input: String,
pub model: ModerationModel,
}
impl CreateModeration {
pub fn new(input: impl Into<String>) -> Self {
Self {
input: input.into(),
model: ModerationModel::Stable,
}
}
pub fn with_input(mut self, input: impl Into<String>) -> Self {
self.input = input.into();
self
}
pub fn with_model(mut self, model: ModerationModel) -> Self {
self.model = model;
self
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ModerationModel {
#[serde(rename = "text-moderation-stable")]
Stable,
#[serde(rename = "text-moderation-latest")]
Latest,
}
impl fmt::Display for ModerationModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Stable => write!(f, "text-moderation-stable"),
Self::Latest => write!(f, "text-moderation-latest"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModerationResult {
pub flagged: bool,
pub categories: ModerationCategories,
pub category_scores: ModerationCategoryScores,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModerationCategories {
pub sexual: bool,
pub hate: bool,
pub harassment: bool,
#[serde(rename = "self-harm")]
pub self_harm: bool,
#[serde(rename = "sexual/minors")]
pub sexual_minors: bool,
#[serde(rename = "hate/threatening")]
pub hate_threatening: bool,
#[serde(rename = "violence/graphic")]
pub violence_graphic: bool,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: bool,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: bool,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: bool,
pub violence: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModerationCategoryScores {
pub sexual: f64,
pub hate: f64,
pub harassment: f64,
#[serde(rename = "self-harm")]
pub self_harm: f64,
#[serde(rename = "sexual/minors")]
pub sexual_minors: f64,
#[serde(rename = "hate/threatening")]
pub hate_threatening: f64,
#[serde(rename = "violence/graphic")]
pub violence_graphic: f64,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: f64,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: f64,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: f64,
pub violence: f64,
}