outfox_openai/spec/
fine_tuning.rs

1use derive_builder::Builder;
2use serde::{Deserialize, Serialize};
3
4use crate::error::OpenAIError;
5
6#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
7#[serde(untagged)]
8pub enum NEpochs {
9    NEpochs(u8),
10    #[default]
11    #[serde(rename = "auto")]
12    Auto,
13}
14
15#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
16#[serde(untagged)]
17pub enum BatchSize {
18    BatchSize(u16),
19    #[default]
20    #[serde(rename = "auto")]
21    Auto,
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
25#[serde(untagged)]
26pub enum LearningRateMultiplier {
27    LearningRateMultiplier(f32),
28    #[default]
29    #[serde(rename = "auto")]
30    Auto,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
34pub struct Hyperparameters {
35    /// Number of examples in each batch. A larger batch size means that model parameters
36    /// are updated less frequently, but with lower variance.
37    pub batch_size: BatchSize,
38    /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid
39    /// overfitting.
40    pub learning_rate_multiplier: LearningRateMultiplier,
41    /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
42    pub n_epochs: NEpochs,
43}
44
45#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
46#[serde(untagged)]
47pub enum Beta {
48    Beta(f32),
49    #[default]
50    #[serde(rename = "auto")]
51    Auto,
52}
53
54#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
55pub struct DPOHyperparameters {
56    /// The beta value for the DPO method. A higher beta value will increase the weight of the penalty between the policy and reference model.
57    pub beta: Beta,
58    /// Number of examples in each batch. A larger batch size means that model parameters
59    /// are updated less frequently, but with lower variance.
60    pub batch_size: BatchSize,
61    /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid
62    /// overfitting.
63    pub learning_rate_multiplier: LearningRateMultiplier,
64    /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
65    pub n_epochs: NEpochs,
66}
67
68#[derive(Debug, Serialize, Deserialize, Clone, Default, Builder, PartialEq)]
69#[builder(name = "CreateFineTuningJobRequestBuilder")]
70#[builder(pattern = "mutable")]
71#[builder(setter(into, strip_option), default)]
72#[builder(derive(Debug))]
73#[builder(build_fn(error = "OpenAIError"))]
74pub struct CreateFineTuningJobRequest {
75    /// The name of the model to fine-tune. You can select one of the
76    /// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned).
77    pub model: String,
78
79    /// The ID of an uploaded file that contains training data.
80    ///
81    /// See [upload file](https://platform.openai.com/docs/api-reference/files/create) for how to upload a file.
82    ///
83    /// Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.
84    ///
85    /// The contents of the file should differ depending on if the model uses the [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input), [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) format, or if the fine-tuning method uses the [preference](https://platform.openai.com/docs/api-reference/fine-tuning/preference-input) format.
86    ///
87    /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details.
88    pub training_file: String,
89
90    /// A string of up to 64 characters that will be added to your fine-tuned model name.
91    ///
92    /// For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub suffix: Option<String>, // default: null, minLength:1, maxLength:40
95
96    /// The ID of an uploaded file that contains validation data.
97    ///
98    /// If you provide this file, the data is used to generate validation
99    /// metrics periodically during fine-tuning. These metrics can be viewed in
100    /// the fine-tuning results file.
101    /// The same data should not be present in both train and validation files.
102    ///
103    /// Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.
104    ///
105    /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub validation_file: Option<String>,
108
109    /// A list of integrations to enable for your fine-tuning job.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub integrations: Option<Vec<FineTuningIntegration>>,
112
113    /// The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.
114    /// If a seed is not specified, one will be generated for you.
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub seed: Option<u32>, // min:0, max: 2147483647
117
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub method: Option<FineTuneMethod>,
120}
121
122/// The method used for fine-tuning.
123#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum FineTuneMethod {
126    Supervised {
127        supervised: FineTuneSupervisedMethod,
128    },
129    DPO {
130        dpo: FineTuneDPOMethod,
131    },
132}
133
134#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
135pub struct FineTuneSupervisedMethod {
136    pub hyperparameters: Hyperparameters,
137}
138
139#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
140pub struct FineTuneDPOMethod {
141    pub hyperparameters: DPOHyperparameters,
142}
143
144#[derive(Debug, Deserialize, Clone, PartialEq, Serialize, Default)]
145#[serde(rename_all = "lowercase")]
146pub enum FineTuningJobIntegrationType {
147    #[default]
148    Wandb,
149}
150
151#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
152pub struct FineTuningIntegration {
153    /// The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.
154    #[serde(rename = "type")]
155    pub kind: FineTuningJobIntegrationType,
156
157    /// The settings for your integration with Weights and Biases. This payload specifies the project that
158    /// metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags
159    /// to your run, and set a default entity (team, username, etc) to be associated with your run.
160    pub wandb: WandB,
161}
162
163#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
164pub struct WandB {
165    /// The name of the project that the new run will be created under.
166    pub project: String,
167    /// A display name to set for the run. If not set, we will use the Job ID as the name.
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub name: Option<String>,
170    /// The entity to use for the run. This allows you to set the team or username of the WandB user that you would
171    /// like associated with the run. If not set, the default entity for the registered WandB API key is used.
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub entity: Option<String>,
174    /// A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some
175    /// default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub tags: Option<Vec<String>>,
178}
179
180/// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure.
181#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
182pub struct FineTuneJobError {
183    ///  A machine-readable error code.
184    pub code: String,
185    ///  A human-readable error message.
186    pub message: String,
187    /// The parameter that was invalid, usually `training_file` or `validation_file`.
188    /// This field will be null if the failure was not parameter-specific.
189    pub param: Option<String>, // nullable true
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "snake_case")]
194pub enum FineTuningJobStatus {
195    ValidatingFiles,
196    Queued,
197    Running,
198    Succeeded,
199    Failed,
200    Cancelled,
201}
202
203/// The `fine_tuning.job` object represents a fine-tuning job that has been created through the API.
204#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
205pub struct FineTuningJob {
206    /// The object identifier, which can be referenced in the API endpoints.
207    pub id: String,
208    /// The Unix timestamp (in seconds) for when the fine-tuning job was created.
209    pub created_at: u32,
210    /// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure.
211    pub error: Option<FineTuneJobError>,
212    /// The name of the fine-tuned model that is being created.
213    /// The value will be null if the fine-tuning job is still running.
214    pub fine_tuned_model: Option<String>, // nullable: true
215    /// The Unix timestamp (in seconds) for when the fine-tuning job was finished.
216    /// The value will be null if the fine-tuning job is still running.
217    pub finished_at: Option<u32>, // nullable true
218
219    /// The hyperparameters used for the fine-tuning job.
220    /// See the [fine-tuning guide](/docs/guides/fine-tuning) for more details.
221    pub hyperparameters: Hyperparameters,
222
223    ///  The base model that is being fine-tuned.
224    pub model: String,
225
226    /// The object type, which is always "fine_tuning.job".
227    pub object: String,
228    /// The organization that owns the fine-tuning job.
229    pub organization_id: String,
230
231    /// The compiled results file ID(s) for the fine-tuning job.
232    /// You can retrieve the results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
233    pub result_files: Vec<String>,
234
235    /// The current status of the fine-tuning job, which can be either
236    /// `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
237    pub status: FineTuningJobStatus,
238
239    /// The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running.
240    pub trained_tokens: Option<u32>,
241
242    /// The file ID used for training. You can retrieve the training data with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
243    pub training_file: String,
244
245    ///  The file ID used for validation. You can retrieve the validation results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
246    pub validation_file: Option<String>,
247
248    /// A list of integrations to enable for this fine-tuning job.
249    pub integrations: Option<Vec<FineTuningIntegration>>, // maxItems: 5
250
251    /// The seed used for the fine-tuning job.
252    pub seed: u32,
253
254    /// The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
255    pub estimated_finish: Option<u32>,
256
257    pub method: Option<FineTuneMethod>,
258}
259
260#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
261pub struct ListPaginatedFineTuningJobsResponse {
262    pub data: Vec<FineTuningJob>,
263    pub has_more: bool,
264    pub object: String,
265}
266
267#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
268pub struct ListFineTuningJobEventsResponse {
269    pub data: Vec<FineTuningJobEvent>,
270    pub object: String,
271}
272
273#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
274pub struct ListFineTuningJobCheckpointsResponse {
275    pub data: Vec<FineTuningJobCheckpoint>,
276    pub object: String,
277    pub first_id: Option<String>,
278    pub last_id: Option<String>,
279    pub has_more: bool,
280}
281
282#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
283#[serde(rename_all = "lowercase")]
284pub enum Level {
285    Info,
286    Warn,
287    Error,
288}
289
290///Fine-tuning job event object
291#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
292pub struct FineTuningJobEvent {
293    /// The object identifier.
294    pub id: String,
295    /// The Unix timestamp (in seconds) for when the fine-tuning job event was created.
296    pub created_at: u32,
297    /// The log level of the event.
298    pub level: Level,
299    /// The message of the event.
300    pub message: String,
301    /// The object type, which is always "fine_tuning.job.event".
302    pub object: String,
303    /// The type of event.
304    #[serde(rename = "type")]
305    pub kind: Option<FineTuningJobEventType>,
306    /// The data associated with the event.
307    pub data: Option<serde_json::Value>,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311#[serde(rename_all = "lowercase")]
312pub enum FineTuningJobEventType {
313    Message,
314    Metrics,
315}
316
317/// The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use.
318#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
319pub struct FineTuningJobCheckpoint {
320    /// The checkpoint identifier, which can be referenced in the API endpoints.
321    pub id: String,
322    /// The Unix timestamp (in seconds) for when the checkpoint was created.
323    pub created_at: u32,
324    /// The name of the fine-tuned checkpoint model that is created.
325    pub fine_tuned_model_checkpoint: String,
326    /// The step number that the checkpoint was created at.
327    pub step_number: u32,
328    /// Metrics at the step number during the fine-tuning job.
329    pub metrics: FineTuningJobCheckpointMetrics,
330    /// The name of the fine-tuning job that this checkpoint was created from.
331    pub fine_tuning_job_id: String,
332    /// The object type, which is always "fine_tuning.job.checkpoint".
333    pub object: String,
334}
335
336#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
337pub struct FineTuningJobCheckpointMetrics {
338    pub step: u32,
339    pub train_loss: f32,
340    pub train_mean_token_accuracy: f32,
341    pub valid_loss: f32,
342    pub valid_mean_token_accuracy: f32,
343    pub full_valid_loss: f32,
344    pub full_valid_mean_token_accuracy: f32,
345}