use reqwest::{multipart, Client};
use serde::{Deserialize, Serialize};
use std::path::Path;
use thiserror::Error;
use tokio::fs::File;
use tokio_util::codec::{BytesCodec, FramedRead};
use crate::{
utils::{api_key, remove_trailing_slash, OpenAiApiKeyError},
OpenAiError,
};
pub struct FilesClient {
pub api_key: String,
pub base_url: url::Url,
pub files_path: String,
pub http_client: Client,
}
impl From<&crate::chat_completions::ChatClient> for FilesClient {
fn from(client: &crate::chat_completions::ChatClient) -> Self {
Self {
api_key: client.api_key.clone(),
base_url: client.base_url.clone(),
files_path: "files/".to_string(),
http_client: client.http_client.clone(),
}
}
}
#[derive(Serialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum FilePurpose {
#[serde(rename = "fine-tune")]
FineTune,
#[serde(rename = "assistants")]
Assistants,
#[serde(rename = "batch")]
Batch,
#[serde(rename = "user_data")]
UserData,
#[serde(rename = "vision")]
Vision,
#[serde(rename = "evals")]
Evals,
}
impl std::fmt::Debug for FilePurpose {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FilePurpose::FineTune => write!(f, "fine-tune"),
FilePurpose::Assistants => write!(f, "assistants"),
FilePurpose::Batch => write!(f, "batch"),
FilePurpose::UserData => write!(f, "user_data"),
FilePurpose::Vision => write!(f, "vision"),
FilePurpose::Evals => write!(f, "evals"),
}
}
}
impl std::fmt::Display for FilePurpose {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Debug, Deserialize)]
pub struct FileObject {
pub id: String,
pub object: String,
pub bytes: u64,
pub created_at: u64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
enum UploadFileResponse {
#[serde(rename = "error")]
Error(OpenAiError),
#[serde(untagged)]
File(FileObject),
}
#[derive(Debug, Deserialize)]
pub struct FileList {
pub data: Vec<FileObject>,
pub object: String,
}
#[derive(Error, Debug)]
pub enum FilesError {
#[error("Request error: {0}")]
RequestError(#[from] reqwest::Error),
#[error("API {url} returned an unknown response: {response}")]
ApiParseError {
url: String,
response: String,
#[source]
error: serde_json::Error,
},
#[error("API returned an error response")]
ApiError(#[from] OpenAiError),
#[error("File error: {0}")]
IoError(#[from] std::io::Error),
#[error("Invalid file path")]
InvalidFilePath,
}
impl FilesClient {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: url::Url::parse("https://api.openai.com/v1/").unwrap(),
files_path: "files/".to_string(),
http_client: crate::utils::pooled_client(),
}
}
fn files_url(&self) -> url::Url {
self.base_url.join(&self.files_path).unwrap()
}
pub fn from_env() -> Result<Self, OpenAiApiKeyError> {
Ok(Self::new(api_key()?))
}
pub async fn upload_file(
&self,
file_path: impl AsRef<Path>,
purpose: FilePurpose,
) -> Result<FileObject, FilesError> {
let file_path = file_path.as_ref();
let file_name = file_path
.file_name()
.and_then(|name| name.to_str())
.ok_or(FilesError::InvalidFilePath)?;
let file = File::open(file_path).await?;
let stream = FramedRead::new(file, BytesCodec::new());
let file_part = multipart::Part::stream(reqwest::Body::wrap_stream(stream))
.file_name(file_name.to_string());
let form = multipart::Form::new()
.text("purpose", format!("{:?}", purpose).to_lowercase())
.part("file", file_part);
let url = remove_trailing_slash(self.files_url());
let response = self
.http_client
.post(url.clone())
.header("Authorization", format!("Bearer {}", self.api_key))
.multipart(form)
.send()
.await?;
let response_text = response.text().await?;
let file_object: UploadFileResponse =
serde_json::from_str(&response_text).map_err(|e| FilesError::ApiParseError {
url: url.to_string(),
response: response_text.clone(),
error: e,
})?;
match file_object {
UploadFileResponse::File(file) => Ok(file),
UploadFileResponse::Error(error) => Err(FilesError::ApiError(error)),
}
}
pub async fn upload_bytes(
&self,
filename: &str,
bytes: Vec<u8>,
purpose: FilePurpose,
) -> Result<FileObject, FilesError> {
let file_part = multipart::Part::bytes(bytes).file_name(filename.to_string());
let form = multipart::Form::new()
.text("purpose", format!("{:?}", purpose).to_lowercase())
.part("file", file_part);
let url = remove_trailing_slash(self.files_url());
let response = self
.http_client
.post(url.clone())
.header("Authorization", format!("Bearer {}", self.api_key))
.multipart(form)
.send()
.await?;
let response_text = response.text().await?;
let file_object: UploadFileResponse =
serde_json::from_str(&response_text).map_err(|e| FilesError::ApiParseError {
url: url.to_string(),
response: response_text.clone(),
error: e,
})?;
match file_object {
UploadFileResponse::File(file) => Ok(file),
UploadFileResponse::Error(error) => Err(FilesError::ApiError(error)),
}
}
pub async fn list_files(&self) -> Result<FileList, FilesError> {
let response = self
.http_client
.get(self.files_url())
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let file_list = response.json::<FileList>().await?;
Ok(file_list)
}
pub async fn retrieve_file(&self, file_id: &str) -> Result<FileObject, FilesError> {
let response = self
.http_client
.get(self.files_url().join(file_id).unwrap())
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let file_object = response.json::<FileObject>().await?;
Ok(file_object)
}
pub async fn delete_file(&self, file_id: &str) -> Result<DeletedFile, FilesError> {
let response = self
.http_client
.delete(self.files_url().join(file_id).unwrap())
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let deleted_file = response.json::<DeletedFile>().await?;
Ok(deleted_file)
}
pub async fn download_file(&self, file_id: &str) -> Result<String, FilesError> {
let url = self
.files_url()
.join(&format!("{file_id}/content"))
.unwrap();
let response = self
.http_client
.get(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let content = response.text().await?;
Ok(content)
}
}
#[derive(Debug, Deserialize)]
pub struct DeletedFile {
pub id: String,
pub object: String,
pub deleted: bool,
}