use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::LlmError;
use crate::traits::FileManagementCapability;
use crate::types::{
FileDeleteResponse, FileListQuery, FileListResponse, FileObject, FileUploadRequest,
};
use super::config::OpenAiConfig;
#[derive(Debug, Clone, Serialize)]
#[allow(dead_code)]
struct OpenAiFileUploadForm {
purpose: String,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiFileResponse {
id: String,
object: String,
bytes: u64,
created_at: u64,
filename: String,
purpose: String,
status: String,
status_details: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiFileListResponse {
#[allow(dead_code)]
object: String,
data: Vec<OpenAiFileResponse>,
has_more: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiFileDeleteResponse {
id: String,
#[allow(dead_code)]
object: String,
deleted: bool,
}
#[derive(Debug, Clone)]
pub struct OpenAiFiles {
config: OpenAiConfig,
http_client: reqwest::Client,
}
impl OpenAiFiles {
pub const fn new(config: OpenAiConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
pub fn get_supported_purposes(&self) -> Vec<String> {
vec![
"assistants".to_string(),
"batch".to_string(),
"fine-tune".to_string(),
"vision".to_string(),
]
}
pub const fn get_max_file_size(&self) -> u64 {
512 * 1024 * 1024 }
pub fn get_supported_formats(&self) -> Vec<String> {
vec![
"txt".to_string(),
"json".to_string(),
"jsonl".to_string(),
"csv".to_string(),
"tsv".to_string(),
"pdf".to_string(),
"docx".to_string(),
"png".to_string(),
"jpg".to_string(),
"jpeg".to_string(),
"gif".to_string(),
"webp".to_string(),
"mp3".to_string(),
"mp4".to_string(),
"mpeg".to_string(),
"mpga".to_string(),
"m4a".to_string(),
"wav".to_string(),
"webm".to_string(),
]
}
fn validate_upload_request(&self, request: &FileUploadRequest) -> Result<(), LlmError> {
if request.content.len() as u64 > self.get_max_file_size() {
return Err(LlmError::InvalidInput(format!(
"File size {} bytes exceeds maximum allowed size of {} bytes",
request.content.len(),
self.get_max_file_size()
)));
}
if !self.get_supported_purposes().contains(&request.purpose) {
return Err(LlmError::InvalidInput(format!(
"Unsupported file purpose: {}. Supported purposes: {:?}",
request.purpose,
self.get_supported_purposes()
)));
}
if request.filename.is_empty() {
return Err(LlmError::InvalidInput(
"Filename cannot be empty".to_string(),
));
}
if let Some(extension) = request.filename.split('.').next_back() {
let supported_formats = self.get_supported_formats();
if !supported_formats.contains(&extension.to_lowercase()) {
return Err(LlmError::InvalidInput(format!(
"Unsupported file format: {extension}. Supported formats: {supported_formats:?}"
)));
}
}
Ok(())
}
fn convert_file_response(&self, openai_file: OpenAiFileResponse) -> FileObject {
let mut metadata = HashMap::new();
metadata.insert(
"object".to_string(),
serde_json::Value::String(openai_file.object),
);
metadata.insert(
"status".to_string(),
serde_json::Value::String(openai_file.status),
);
if let Some(status_details) = openai_file.status_details {
metadata.insert(
"status_details".to_string(),
serde_json::Value::String(status_details),
);
}
FileObject {
id: openai_file.id,
filename: openai_file.filename,
bytes: openai_file.bytes,
created_at: openai_file.created_at,
purpose: openai_file.purpose,
status: "uploaded".to_string(), mime_type: None, metadata,
}
}
async fn make_request(
&self,
method: reqwest::Method,
endpoint: &str,
) -> Result<reqwest::RequestBuilder, LlmError> {
let url = format!("{}/{}", self.config.base_url, endpoint);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
Ok(self.http_client.request(method, &url).headers(headers))
}
async fn handle_response_error(&self, response: reqwest::Response) -> LlmError {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
match status.as_u16() {
404 => LlmError::NotFound(format!("File not found: {error_text}")),
413 => LlmError::InvalidInput("File too large".to_string()),
415 => LlmError::InvalidInput("Unsupported file type".to_string()),
_ => LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI Files API error {status}: {error_text}"),
details: None,
},
}
}
}
#[async_trait]
impl FileManagementCapability for OpenAiFiles {
async fn upload_file(&self, request: FileUploadRequest) -> Result<FileObject, LlmError> {
self.validate_upload_request(&request)?;
let form = reqwest::multipart::Form::new()
.text("purpose", request.purpose.clone())
.part(
"file",
reqwest::multipart::Part::bytes(request.content)
.file_name(request.filename.clone())
.mime_str(
request
.mime_type
.as_deref()
.unwrap_or("application/octet-stream"),
)
.map_err(|e| LlmError::HttpError(format!("Invalid MIME type: {e}")))?,
);
let request_builder = self.make_request(reqwest::Method::POST, "files").await?;
let response = request_builder
.multipart(form)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
return Err(self.handle_response_error(response).await);
}
let openai_response: OpenAiFileResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(self.convert_file_response(openai_response))
}
async fn list_files(&self, query: Option<FileListQuery>) -> Result<FileListResponse, LlmError> {
let mut endpoint = "files".to_string();
if let Some(q) = query {
let mut params = Vec::new();
if let Some(purpose) = q.purpose {
params.push(format!("purpose={}", urlencoding::encode(&purpose)));
}
if let Some(limit) = q.limit {
params.push(format!("limit={limit}"));
}
if let Some(after) = q.after {
params.push(format!("after={}", urlencoding::encode(&after)));
}
if let Some(order) = q.order {
params.push(format!("order={}", urlencoding::encode(&order)));
}
if !params.is_empty() {
endpoint.push('?');
endpoint.push_str(¶ms.join("&"));
}
}
let request_builder = self.make_request(reqwest::Method::GET, &endpoint).await?;
let response = request_builder
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
return Err(self.handle_response_error(response).await);
}
let openai_response: OpenAiFileListResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
let files: Vec<FileObject> = openai_response
.data
.into_iter()
.map(|f| self.convert_file_response(f))
.collect();
Ok(FileListResponse {
files,
has_more: openai_response.has_more.unwrap_or(false),
next_cursor: None, })
}
async fn retrieve_file(&self, file_id: String) -> Result<FileObject, LlmError> {
let endpoint = format!("files/{file_id}");
let request_builder = self.make_request(reqwest::Method::GET, &endpoint).await?;
let response = request_builder
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
return Err(self.handle_response_error(response).await);
}
let openai_response: OpenAiFileResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(self.convert_file_response(openai_response))
}
async fn delete_file(&self, file_id: String) -> Result<FileDeleteResponse, LlmError> {
let endpoint = format!("files/{file_id}");
let request_builder = self
.make_request(reqwest::Method::DELETE, &endpoint)
.await?;
let response = request_builder
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
return Err(self.handle_response_error(response).await);
}
let openai_response: OpenAiFileDeleteResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(FileDeleteResponse {
id: openai_response.id,
deleted: openai_response.deleted,
})
}
async fn get_file_content(&self, file_id: String) -> Result<Vec<u8>, LlmError> {
let endpoint = format!("files/{file_id}/content");
let request_builder = self.make_request(reqwest::Method::GET, &endpoint).await?;
let response = request_builder
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
return Err(self.handle_response_error(response).await);
}
let content = response
.bytes()
.await
.map_err(|e| LlmError::HttpError(format!("Failed to read response body: {e}")))?;
Ok(content.to_vec())
}
}