use crate::common::auth::{AuthProvider, OpenAIAuth};
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::files::response::{DeleteResponse, File, FileListResponse};
use request::multipart::{Form, Part};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
const FILES_PATH: &str = "files";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FilePurpose {
Assistants,
AssistantsOutput,
Batch,
BatchOutput,
FineTune,
FineTuneResults,
Vision,
UserData,
}
impl FilePurpose {
pub fn as_str(&self) -> &'static str {
match self {
FilePurpose::Assistants => "assistants",
FilePurpose::AssistantsOutput => "assistants_output",
FilePurpose::Batch => "batch",
FilePurpose::BatchOutput => "batch_output",
FilePurpose::FineTune => "fine-tune",
FilePurpose::FineTuneResults => "fine-tune-results",
FilePurpose::Vision => "vision",
FilePurpose::UserData => "user_data",
}
}
}
impl std::fmt::Display for FilePurpose {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
pub struct Files {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Files {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn base_url<T: AsRef<str>>(&mut self, url: T) -> &mut Self {
if let AuthProvider::OpenAI(ref openai_auth) = self.auth {
let new_auth = OpenAIAuth::new(openai_auth.api_key()).with_base_url(url.as_ref());
self.auth = AuthProvider::OpenAI(new_auth);
} else {
tracing::warn!("base_url() is only supported for OpenAI provider. Use azure() or with_auth() for Azure.");
}
self
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn upload_path(&self, file_path: &str, purpose: FilePurpose) -> Result<File> {
let path = Path::new(file_path);
let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("file").to_string();
let content = tokio::fs::read(file_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read file: {}", e)))?;
self.upload_bytes(&content, &filename, purpose).await
}
pub async fn upload_bytes(&self, content: &[u8], filename: &str, purpose: FilePurpose) -> Result<File> {
let (client, headers) = self.create_client()?;
let file_part = Part::bytes(content.to_vec())
.file_name(filename.to_string())
.mime_str("application/octet-stream")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
let form = Form::new().part("file", file_part).text("purpose", purpose.as_str().to_string());
let endpoint = self.auth.endpoint(FILES_PATH);
let response = client.post(&endpoint).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<File>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list(&self, purpose: Option<FilePurpose>) -> Result<FileListResponse> {
let (client, headers) = self.create_client()?;
let endpoint = self.auth.endpoint(FILES_PATH);
let url = match purpose {
Some(p) => format!("{}?purpose={}", endpoint, p.as_str()),
None => endpoint,
};
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<FileListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn retrieve(&self, file_id: &str) -> Result<File> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(FILES_PATH), file_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<File>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn delete(&self, file_id: &str) -> Result<DeleteResponse> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(FILES_PATH), file_id);
let response = client.delete(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<DeleteResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn content(&self, file_id: &str) -> Result<Vec<u8>> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}/content", self.auth.endpoint(FILES_PATH), file_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let bytes = response.bytes().await.map_err(OpenAIToolError::RequestError)?;
Ok(bytes.to_vec())
}
}