use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{OpenAIToolError, Result};
use crate::common::models::FineTuningModel;
use crate::fine_tuning::response::{
DpoConfig, FineTuningCheckpointListResponse, FineTuningEventListResponse, FineTuningJob, FineTuningJobListResponse, Hyperparameters, Integration,
MethodConfig, SupervisedConfig,
};
use serde::Serialize;
use std::time::Duration;
const FINE_TUNING_PATH: &str = "fine_tuning/jobs";
#[derive(Debug, Clone, Serialize)]
pub struct CreateFineTuningJobRequest {
pub model: FineTuningModel,
pub training_file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<MethodConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
}
impl CreateFineTuningJobRequest {
pub fn new(model: FineTuningModel, training_file: impl Into<String>) -> Self {
Self { model, training_file: training_file.into(), validation_file: None, suffix: None, seed: None, method: None, integrations: None }
}
pub fn with_validation_file(mut self, file_id: impl Into<String>) -> Self {
self.validation_file = Some(file_id.into());
self
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = Some(suffix.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_supervised_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
self.method = Some(MethodConfig { method_type: "supervised".to_string(), supervised: Some(SupervisedConfig { hyperparameters }), dpo: None });
self
}
pub fn with_dpo_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
self.method = Some(MethodConfig { method_type: "dpo".to_string(), supervised: None, dpo: Some(DpoConfig { hyperparameters }) });
self
}
pub fn with_integrations(mut self, integrations: Vec<Integration>) -> Self {
self.integrations = Some(integrations);
self
}
}
pub struct FineTuning {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl FineTuning {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn create(&self, request: CreateFineTuningJobRequest) -> Result<FineTuningJob> {
let (client, headers) = self.create_client()?;
let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
let url = self.auth.endpoint(FINE_TUNING_PATH);
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn retrieve(&self, job_id: &str) -> Result<FineTuningJob> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(FINE_TUNING_PATH), job_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn cancel(&self, job_id: &str) -> Result<FineTuningJob> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}/cancel", self.auth.endpoint(FINE_TUNING_PATH), job_id);
let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningJobListResponse> {
let (client, headers) = self.create_client()?;
let mut url = self.auth.endpoint(FINE_TUNING_PATH);
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
if !params.is_empty() {
url.push('?');
url.push_str(¶ms.join("&"));
}
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningJobListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list_events(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningEventListResponse> {
let (client, headers) = self.create_client()?;
let mut url = format!("{}/{}/events", self.auth.endpoint(FINE_TUNING_PATH), job_id);
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
if !params.is_empty() {
url.push('?');
url.push_str(¶ms.join("&"));
}
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningEventListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list_checkpoints(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningCheckpointListResponse> {
let (client, headers) = self.create_client()?;
let mut url = format!("{}/{}/checkpoints", self.auth.endpoint(FINE_TUNING_PATH), job_id);
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
if !params.is_empty() {
url.push('?');
url.push_str(¶ms.join("&"));
}
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<FineTuningCheckpointListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}