openai_api_wrapper/
fine_tuning.rs

1use crate::client::OpenAIRequest;
2use reqwest::Method;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize)]
7pub struct FineTuningRequest {
8    /// The ID of an uploaded file that contains training data.
9    /// Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys "prompt" and "completion".
10    /// Additionally, you must upload your file with the purpose fine-tune.
11    pub training_file: String,
12
13    /// The ID of an uploaded file that contains validation data.
14    /// If you provide this file, the data is used to generate validation metrics periodically during fine-tuning.
15    /// These metrics can be viewed in the fine-tuning results file.
16    /// Your train and validation data should be mutually exclusive.
17    /// Your dataset must be formatted as a JSONL file, where each validation example is a JSON object with the keys "prompt" and "completion".
18    /// Additionally, you must upload your file with the purpose fine-tune.
19    pub validation_file: Option<String>,
20
21    /// The name of the base model to fine-tune.
22    /// You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21.
23    pub model: Option<String>,
24
25    /// The number of epochs to train the model for.
26    pub n_epochs: Option<u32>,
27
28    /// The batch size to use for training.
29    pub batch_size: Option<u32>,
30
31    /// The learning rate multiplier to use for training.
32    pub learning_rate_multiplier: Option<f64>,
33
34    /// The weight to use for loss on the prompt tokens.
35    pub prompt_loss_weight: Option<f64>,
36
37    /// If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set at the end of every epoch.
38    /// These metrics can be viewed in the results file.
39    /// In order to compute classification metrics, you must provide a validation_file.
40    /// Additionally, you must specify classification_n_classes for multiclass classification or classification_positive_class for binary classification.
41    pub compute_classification_metrics: Option<bool>,
42
43    /// The number of classes in a classification task.
44    /// This parameter is required for multiclass classification.
45    pub classification_n_classes: Option<u32>,
46
47    /// The positive class in binary classification.
48    /// This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification.
49    pub classification_positive_class: Option<String>,
50
51    /// If this is provided, we calculate F-beta scores at the specified beta values.
52    /// The F-beta score is a generalization of F-1 score.
53    /// This is only used for binary classification.
54    pub classification_betas: Option<Vec<f64>>,
55
56    /// A string of up to 40 characters that will be added to your fine-tuned model name.
57    pub suffix: Option<String>,
58}
59
60impl OpenAIRequest for FineTuningRequest {
61    type Response = FineTuneResponse;
62
63    fn method() -> Method {
64        Method::POST
65    }
66
67    fn url() -> &'static str {
68        "https://api.openai.com/v1/fine-tunes"
69    }
70}
71
72#[derive(Debug, Deserialize, Serialize)]
73pub struct FineTuneResponse {
74    id: String,
75    object: String,
76    model: String,
77    created_at: i64,
78    events: Vec<FineTuneEvent>,
79    fine_tuned_model: Option<String>,
80    hyperparams: Hyperparams,
81    organization_id: String,
82    result_files: Vec<String>,
83    status: String,
84    validation_files: Vec<String>,
85    training_files: Vec<TrainingFile>,
86    updated_at: i64,
87}
88
89#[derive(Debug, Deserialize, Serialize)]
90struct FineTuneEvent {
91    object: String,
92    created_at: i64,
93    level: String,
94    message: String,
95}
96
97#[derive(Debug, Deserialize, Serialize)]
98struct Hyperparams {
99    batch_size: i32,
100    learning_rate_multiplier: f64,
101    n_epochs: i32,
102    prompt_loss_weight: f64,
103}
104
105#[derive(Debug, Deserialize, Serialize)]
106struct TrainingFile {
107    id: String,
108    object: String,
109    bytes: i64,
110    created_at: i64,
111    filename: String,
112    purpose: String,
113}