use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::models::response::{DeleteResponse, Model, ModelsListResponse};
use std::time::Duration;
const MODELS_PATH: &str = "models";
pub struct Models {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Models {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn list(&self) -> Result<ModelsListResponse> {
let (client, headers) = self.create_client()?;
let url = self.auth.endpoint(MODELS_PATH);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<ModelsListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn retrieve(&self, model_id: &str) -> Result<Model> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(MODELS_PATH), model_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<Model>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn delete(&self, model_id: &str) -> Result<DeleteResponse> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(MODELS_PATH), model_id);
let response = client.delete(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<DeleteResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}