use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::moderations::response::ModerationResponse;
use serde::{Deserialize, Serialize};
use std::time::Duration;
const MODERATIONS_PATH: &str = "moderations";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModerationModel {
#[serde(rename = "omni-moderation-latest")]
#[default]
OmniModerationLatest,
#[serde(rename = "text-moderation-latest")]
TextModerationLatest,
}
impl ModerationModel {
pub fn as_str(&self) -> &'static str {
match self {
Self::OmniModerationLatest => "omni-moderation-latest",
Self::TextModerationLatest => "text-moderation-latest",
}
}
}
impl std::fmt::Display for ModerationModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ModerationRequest {
input: ModerationInput,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum ModerationInput {
Single(String),
Multiple(Vec<String>),
}
pub struct Moderations {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Moderations {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn moderate_text(&self, text: &str, model: Option<ModerationModel>) -> Result<ModerationResponse> {
let request_body = ModerationRequest { input: ModerationInput::Single(text.to_string()), model: model.map(|m| m.as_str().to_string()) };
self.send_request(&request_body).await
}
pub async fn moderate_texts(&self, texts: Vec<String>, model: Option<ModerationModel>) -> Result<ModerationResponse> {
let request_body = ModerationRequest { input: ModerationInput::Multiple(texts), model: model.map(|m| m.as_str().to_string()) };
self.send_request(&request_body).await
}
async fn send_request(&self, request_body: &ModerationRequest) -> Result<ModerationResponse> {
let (client, headers) = self.create_client()?;
let body = serde_json::to_string(request_body).map_err(OpenAIToolError::SerdeJsonError)?;
let url = self.auth.endpoint(MODERATIONS_PATH);
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<ModerationResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}