openai_api_wrapper/fine_tuning.rs
1use crate::client::OpenAIRequest;
2use reqwest::Method;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize)]
7pub struct FineTuningRequest {
8 /// The ID of an uploaded file that contains training data.
9 /// Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys "prompt" and "completion".
10 /// Additionally, you must upload your file with the purpose fine-tune.
11 pub training_file: String,
12
13 /// The ID of an uploaded file that contains validation data.
14 /// If you provide this file, the data is used to generate validation metrics periodically during fine-tuning.
15 /// These metrics can be viewed in the fine-tuning results file.
16 /// Your train and validation data should be mutually exclusive.
17 /// Your dataset must be formatted as a JSONL file, where each validation example is a JSON object with the keys "prompt" and "completion".
18 /// Additionally, you must upload your file with the purpose fine-tune.
19 pub validation_file: Option<String>,
20
21 /// The name of the base model to fine-tune.
22 /// You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21.
23 pub model: Option<String>,
24
25 /// The number of epochs to train the model for.
26 pub n_epochs: Option<u32>,
27
28 /// The batch size to use for training.
29 pub batch_size: Option<u32>,
30
31 /// The learning rate multiplier to use for training.
32 pub learning_rate_multiplier: Option<f64>,
33
34 /// The weight to use for loss on the prompt tokens.
35 pub prompt_loss_weight: Option<f64>,
36
37 /// If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set at the end of every epoch.
38 /// These metrics can be viewed in the results file.
39 /// In order to compute classification metrics, you must provide a validation_file.
40 /// Additionally, you must specify classification_n_classes for multiclass classification or classification_positive_class for binary classification.
41 pub compute_classification_metrics: Option<bool>,
42
43 /// The number of classes in a classification task.
44 /// This parameter is required for multiclass classification.
45 pub classification_n_classes: Option<u32>,
46
47 /// The positive class in binary classification.
48 /// This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification.
49 pub classification_positive_class: Option<String>,
50
51 /// If this is provided, we calculate F-beta scores at the specified beta values.
52 /// The F-beta score is a generalization of F-1 score.
53 /// This is only used for binary classification.
54 pub classification_betas: Option<Vec<f64>>,
55
56 /// A string of up to 40 characters that will be added to your fine-tuned model name.
57 pub suffix: Option<String>,
58}
59
60impl OpenAIRequest for FineTuningRequest {
61 type Response = FineTuneResponse;
62
63 fn method() -> Method {
64 Method::POST
65 }
66
67 fn url() -> &'static str {
68 "https://api.openai.com/v1/fine-tunes"
69 }
70}
71
72#[derive(Debug, Deserialize, Serialize)]
73pub struct FineTuneResponse {
74 id: String,
75 object: String,
76 model: String,
77 created_at: i64,
78 events: Vec<FineTuneEvent>,
79 fine_tuned_model: Option<String>,
80 hyperparams: Hyperparams,
81 organization_id: String,
82 result_files: Vec<String>,
83 status: String,
84 validation_files: Vec<String>,
85 training_files: Vec<TrainingFile>,
86 updated_at: i64,
87}
88
89#[derive(Debug, Deserialize, Serialize)]
90struct FineTuneEvent {
91 object: String,
92 created_at: i64,
93 level: String,
94 message: String,
95}
96
97#[derive(Debug, Deserialize, Serialize)]
98struct Hyperparams {
99 batch_size: i32,
100 learning_rate_multiplier: f64,
101 n_epochs: i32,
102 prompt_loss_weight: f64,
103}
104
105#[derive(Debug, Deserialize, Serialize)]
106struct TrainingFile {
107 id: String,
108 object: String,
109 bytes: i64,
110 created_at: i64,
111 filename: String,
112 purpose: String,
113}