use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_with::skip_serializing_none;
use crate::{
api_resources::{Delete, Files, TokenUsage},
Client, Result,
};
#[skip_serializing_none]
#[derive(Builder, Debug, Default, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct CreateFineTuneParam {
training_file: String,
validation_file: Option<String>,
model: Option<String>,
n_epochs: Option<i32>,
batch_size: Option<i32>,
learning_rate_multiplier: Option<f32>,
prompt_loss_weight: Option<f32>,
compute_classification_metrics: Option<bool>,
classification_n_classes: Option<i32>,
classification_positive_class: Option<String>,
classification_betas: Option<Vec<f32>>,
suffix: Option<String>,
}
impl CreateFineTuneParamBuilder {
pub fn new(training_file: impl Into<String>) -> Self {
Self {
training_file: Some(training_file.into()),
..Self::default()
}
}
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct FineTune {
pub id: String,
pub object: String,
pub model: String,
pub created_at: u64,
pub events: Events,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Files,
pub validation_files: Files,
pub training_files: Files,
pub status: String,
pub updated_at: u64,
pub token_usage: Option<TokenUsage>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct HyperParams {
pub n_epochs: u32,
pub batch_size: u32,
pub learning_rate_multiplier: f32,
pub prompt_loss_weight: f32,
pub compute_classification_metrics: bool,
pub classification_n_classes: u32,
pub classification_positive_class: String,
pub classification_betas: Vec<f32>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct Event {
pub object: String,
pub created_at: u64,
pub level: String,
pub message: String,
}
type Events = Vec<Event>;
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct ListEvents {
pub object: String,
pub data: Vec<Event>,
pub token_usage: Option<TokenUsage>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct ListFineTune {
pub object: String,
pub data: Vec<FineTune>,
pub token_usage: Option<TokenUsage>,
}
pub async fn create(client: &Client, param: &CreateFineTuneParam) -> Result<FineTune> {
client.create_fine_tune(param).await
}
pub async fn list(client: &Client) -> Result<ListFineTune> {
client.list_fine_tune().await
}
pub async fn retrieve(client: &Client, fine_tune_id: impl Into<String>) -> Result<FineTune> {
client.retrieve_fine_tune(fine_tune_id.into()).await
}
pub async fn cancel(client: &Client, fine_tune_id: impl Into<String>) -> Result<FineTune> {
client.cancel_fine_tune(fine_tune_id.into()).await
}
pub async fn list_events(client: &Client, fine_tune_id: impl Into<String>) -> Result<ListEvents> {
client.list_fine_tune_events(fine_tune_id.into()).await
}
pub async fn list_events_with_stream(
client: &Client,
fine_tune_id: impl Into<String>,
) -> Result<reqwest::Response> {
client
.list_fine_tune_events_with_stream(fine_tune_id.into())
.await
}
pub async fn delete<T: Into<String>>(client: &Client, model: T) -> Result<Delete> {
client.delete_fine_tune(model.into()).await
}
impl Client {
async fn create_fine_tune(&self, param: &CreateFineTuneParam) -> Result<FineTune> {
self.post::<CreateFineTuneParam, FineTune>("fine-tunes", Some(param))
.await
}
async fn list_fine_tune(&self) -> Result<ListFineTune> {
self.get::<(), ListFineTune>("fine-tunes", None).await
}
async fn retrieve_fine_tune(&self, fine_tune_id: String) -> Result<FineTune> {
self.get::<(), FineTune>(&format!("fine-tunes/{fine_tune_id}"), None)
.await
}
async fn cancel_fine_tune(&self, fine_tune_id: String) -> Result<FineTune> {
self.post::<(), FineTune>(&format!("fine-tunes/{fine_tune_id}/cancel"), None)
.await
}
async fn list_fine_tune_events(&self, fine_tune_id: String) -> Result<ListEvents> {
self.get::<(), ListEvents>(&format!("fine-tunes/{fine_tune_id}/events"), None)
.await
}
async fn list_fine_tune_events_with_stream(
&self,
fine_tune_id: String,
) -> Result<reqwest::Response> {
self.get_stream::<serde_json::Value>(
&format!("fine-tunes/{fine_tune_id}/events"),
Some(&json!({"stream": true})),
)
.await
}
async fn delete_fine_tune(&self, model: String) -> Result<Delete> {
self.delete::<(), Delete>(&format!("models/{model}"), None)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_fine_tune() {
let resp: FineTune = serde_json::from_str(
r#"
{
"id": "ft-AF1WoRqd3aJAHsqc9NY7iL8F",
"object": "fine-tune",
"model": "curie",
"created_at": 1614807352,
"events": [
{
"object": "fine-tune-event",
"created_at": 1614807352,
"level": "info",
"message": "Job enqueued. Waiting for jobs ahead to complete. Queue number: 0."
}
],
"fine_tuned_model": null,
"hyperparams": {
"batch_size": 4,
"learning_rate_multiplier": 0.1,
"n_epochs": 4,
"prompt_loss_weight": 0.1
},
"organization_id": "org-...",
"result_files": [],
"status": "pending",
"validation_files": [],
"training_files": [
{
"id": "file-XGinujblHPwGLSztz8cPS8XY",
"object": "file",
"bytes": 1547276,
"created_at": 1610062281,
"filename": "my-data-train.jsonl",
"purpose": "fine-tune-train"
}
],
"updated_at": 1614807352
}
"#,
)
.unwrap();
assert_eq!(resp.id, "ft-AF1WoRqd3aJAHsqc9NY7iL8F");
assert_eq!(resp.object, "fine-tune");
assert_eq!(resp.events.len(), 1);
assert_eq!(resp.training_files[0].filename, "my-data-train.jsonl");
}
#[test]
fn test_list_fine_tune_events() {
let resp: ListEvents = serde_json::from_str(
r#"
{
"object": "list",
"data": [
{
"object": "fine-tune-event",
"created_at": 1614807352,
"level": "info",
"message": "Job enqueued. Waiting for jobs ahead to complete. Queue number: 0."
},
{
"object": "fine-tune-event",
"created_at": 1614807356,
"level": "info",
"message": "Job started."
},
{
"object": "fine-tune-event",
"created_at": 1614807861,
"level": "info",
"message": "Uploaded snapshot: curie:ft-acmeco-2021-03-03-21-44-20."
},
{
"object": "fine-tune-event",
"created_at": 1614807864,
"level": "info",
"message": "Uploaded result files: file-QQm6ZpqdNwAaVC3aSz5sWwLT."
},
{
"object": "fine-tune-event",
"created_at": 1614807864,
"level": "info",
"message": "Job succeeded."
}
]
}
"#,
)
.unwrap();
assert_eq!(resp.data.len(), 5);
assert_eq!(resp.data[0].level, "info");
}
#[test]
fn test_delete_fine_tune() {
let resp: Delete = serde_json::from_str(
r#"
{
"id": "curie:ft-acmeco-2021-03-03-21-44-20",
"object": "model",
"deleted": true
}
"#,
)
.unwrap();
assert_eq!(resp.id, "curie:ft-acmeco-2021-03-03-21-44-20");
}
}