outfox_openai/spec/fine_tuning.rs
1use derive_builder::Builder;
2use serde::{Deserialize, Serialize};
3
4use crate::error::OpenAIError;
5
6#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
7#[serde(untagged)]
8pub enum NEpochs {
9 NEpochs(u8),
10 #[default]
11 #[serde(rename = "auto")]
12 Auto,
13}
14
15#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
16#[serde(untagged)]
17pub enum BatchSize {
18 BatchSize(u16),
19 #[default]
20 #[serde(rename = "auto")]
21 Auto,
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
25#[serde(untagged)]
26pub enum LearningRateMultiplier {
27 LearningRateMultiplier(f32),
28 #[default]
29 #[serde(rename = "auto")]
30 Auto,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
34pub struct Hyperparameters {
35 /// Number of examples in each batch. A larger batch size means that model parameters
36 /// are updated less frequently, but with lower variance.
37 pub batch_size: BatchSize,
38 /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid
39 /// overfitting.
40 pub learning_rate_multiplier: LearningRateMultiplier,
41 /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
42 pub n_epochs: NEpochs,
43}
44
45#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
46#[serde(untagged)]
47pub enum Beta {
48 Beta(f32),
49 #[default]
50 #[serde(rename = "auto")]
51 Auto,
52}
53
54#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
55pub struct DPOHyperparameters {
56 /// The beta value for the DPO method. A higher beta value will increase the weight of the penalty between the policy and reference model.
57 pub beta: Beta,
58 /// Number of examples in each batch. A larger batch size means that model parameters
59 /// are updated less frequently, but with lower variance.
60 pub batch_size: BatchSize,
61 /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid
62 /// overfitting.
63 pub learning_rate_multiplier: LearningRateMultiplier,
64 /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
65 pub n_epochs: NEpochs,
66}
67
68#[derive(Debug, Serialize, Deserialize, Clone, Default, Builder, PartialEq)]
69#[builder(name = "CreateFineTuningJobRequestBuilder")]
70#[builder(pattern = "mutable")]
71#[builder(setter(into, strip_option), default)]
72#[builder(derive(Debug))]
73#[builder(build_fn(error = "OpenAIError"))]
74pub struct CreateFineTuningJobRequest {
75 /// The name of the model to fine-tune. You can select one of the
76 /// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned).
77 pub model: String,
78
79 /// The ID of an uploaded file that contains training data.
80 ///
81 /// See [upload file](https://platform.openai.com/docs/api-reference/files/create) for how to upload a file.
82 ///
83 /// Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.
84 ///
85 /// The contents of the file should differ depending on if the model uses the [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input), [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) format, or if the fine-tuning method uses the [preference](https://platform.openai.com/docs/api-reference/fine-tuning/preference-input) format.
86 ///
87 /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details.
88 pub training_file: String,
89
90 /// A string of up to 64 characters that will be added to your fine-tuned model name.
91 ///
92 /// For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub suffix: Option<String>, // default: null, minLength:1, maxLength:40
95
96 /// The ID of an uploaded file that contains validation data.
97 ///
98 /// If you provide this file, the data is used to generate validation
99 /// metrics periodically during fine-tuning. These metrics can be viewed in
100 /// the fine-tuning results file.
101 /// The same data should not be present in both train and validation files.
102 ///
103 /// Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.
104 ///
105 /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details.
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub validation_file: Option<String>,
108
109 /// A list of integrations to enable for your fine-tuning job.
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub integrations: Option<Vec<FineTuningIntegration>>,
112
113 /// The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.
114 /// If a seed is not specified, one will be generated for you.
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub seed: Option<u32>, // min:0, max: 2147483647
117
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub method: Option<FineTuneMethod>,
120}
121
122/// The method used for fine-tuning.
123#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
124#[serde(tag = "type", rename_all = "lowercase")]
125pub enum FineTuneMethod {
126 Supervised {
127 supervised: FineTuneSupervisedMethod,
128 },
129 DPO {
130 dpo: FineTuneDPOMethod,
131 },
132}
133
134#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
135pub struct FineTuneSupervisedMethod {
136 pub hyperparameters: Hyperparameters,
137}
138
139#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
140pub struct FineTuneDPOMethod {
141 pub hyperparameters: DPOHyperparameters,
142}
143
144#[derive(Debug, Deserialize, Clone, PartialEq, Serialize, Default)]
145#[serde(rename_all = "lowercase")]
146pub enum FineTuningJobIntegrationType {
147 #[default]
148 Wandb,
149}
150
151#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
152pub struct FineTuningIntegration {
153 /// The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.
154 #[serde(rename = "type")]
155 pub kind: FineTuningJobIntegrationType,
156
157 /// The settings for your integration with Weights and Biases. This payload specifies the project that
158 /// metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags
159 /// to your run, and set a default entity (team, username, etc) to be associated with your run.
160 pub wandb: WandB,
161}
162
163#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
164pub struct WandB {
165 /// The name of the project that the new run will be created under.
166 pub project: String,
167 /// A display name to set for the run. If not set, we will use the Job ID as the name.
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub name: Option<String>,
170 /// The entity to use for the run. This allows you to set the team or username of the WandB user that you would
171 /// like associated with the run. If not set, the default entity for the registered WandB API key is used.
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub entity: Option<String>,
174 /// A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some
175 /// default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub tags: Option<Vec<String>>,
178}
179
180/// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure.
181#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
182pub struct FineTuneJobError {
183 /// A machine-readable error code.
184 pub code: String,
185 /// A human-readable error message.
186 pub message: String,
187 /// The parameter that was invalid, usually `training_file` or `validation_file`.
188 /// This field will be null if the failure was not parameter-specific.
189 pub param: Option<String>, // nullable true
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "snake_case")]
194pub enum FineTuningJobStatus {
195 ValidatingFiles,
196 Queued,
197 Running,
198 Succeeded,
199 Failed,
200 Cancelled,
201}
202
203/// The `fine_tuning.job` object represents a fine-tuning job that has been created through the API.
204#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
205pub struct FineTuningJob {
206 /// The object identifier, which can be referenced in the API endpoints.
207 pub id: String,
208 /// The Unix timestamp (in seconds) for when the fine-tuning job was created.
209 pub created_at: u32,
210 /// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure.
211 pub error: Option<FineTuneJobError>,
212 /// The name of the fine-tuned model that is being created.
213 /// The value will be null if the fine-tuning job is still running.
214 pub fine_tuned_model: Option<String>, // nullable: true
215 /// The Unix timestamp (in seconds) for when the fine-tuning job was finished.
216 /// The value will be null if the fine-tuning job is still running.
217 pub finished_at: Option<u32>, // nullable true
218
219 /// The hyperparameters used for the fine-tuning job.
220 /// See the [fine-tuning guide](/docs/guides/fine-tuning) for more details.
221 pub hyperparameters: Hyperparameters,
222
223 /// The base model that is being fine-tuned.
224 pub model: String,
225
226 /// The object type, which is always "fine_tuning.job".
227 pub object: String,
228 /// The organization that owns the fine-tuning job.
229 pub organization_id: String,
230
231 /// The compiled results file ID(s) for the fine-tuning job.
232 /// You can retrieve the results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
233 pub result_files: Vec<String>,
234
235 /// The current status of the fine-tuning job, which can be either
236 /// `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
237 pub status: FineTuningJobStatus,
238
239 /// The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running.
240 pub trained_tokens: Option<u32>,
241
242 /// The file ID used for training. You can retrieve the training data with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
243 pub training_file: String,
244
245 /// The file ID used for validation. You can retrieve the validation results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents).
246 pub validation_file: Option<String>,
247
248 /// A list of integrations to enable for this fine-tuning job.
249 pub integrations: Option<Vec<FineTuningIntegration>>, // maxItems: 5
250
251 /// The seed used for the fine-tuning job.
252 pub seed: u32,
253
254 /// The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
255 pub estimated_finish: Option<u32>,
256
257 pub method: Option<FineTuneMethod>,
258}
259
260#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
261pub struct ListPaginatedFineTuningJobsResponse {
262 pub data: Vec<FineTuningJob>,
263 pub has_more: bool,
264 pub object: String,
265}
266
267#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
268pub struct ListFineTuningJobEventsResponse {
269 pub data: Vec<FineTuningJobEvent>,
270 pub object: String,
271}
272
273#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
274pub struct ListFineTuningJobCheckpointsResponse {
275 pub data: Vec<FineTuningJobCheckpoint>,
276 pub object: String,
277 pub first_id: Option<String>,
278 pub last_id: Option<String>,
279 pub has_more: bool,
280}
281
282#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
283#[serde(rename_all = "lowercase")]
284pub enum Level {
285 Info,
286 Warn,
287 Error,
288}
289
290///Fine-tuning job event object
291#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
292pub struct FineTuningJobEvent {
293 /// The object identifier.
294 pub id: String,
295 /// The Unix timestamp (in seconds) for when the fine-tuning job event was created.
296 pub created_at: u32,
297 /// The log level of the event.
298 pub level: Level,
299 /// The message of the event.
300 pub message: String,
301 /// The object type, which is always "fine_tuning.job.event".
302 pub object: String,
303 /// The type of event.
304 #[serde(rename = "type")]
305 pub kind: Option<FineTuningJobEventType>,
306 /// The data associated with the event.
307 pub data: Option<serde_json::Value>,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311#[serde(rename_all = "lowercase")]
312pub enum FineTuningJobEventType {
313 Message,
314 Metrics,
315}
316
317/// The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use.
318#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
319pub struct FineTuningJobCheckpoint {
320 /// The checkpoint identifier, which can be referenced in the API endpoints.
321 pub id: String,
322 /// The Unix timestamp (in seconds) for when the checkpoint was created.
323 pub created_at: u32,
324 /// The name of the fine-tuned checkpoint model that is created.
325 pub fine_tuned_model_checkpoint: String,
326 /// The step number that the checkpoint was created at.
327 pub step_number: u32,
328 /// Metrics at the step number during the fine-tuning job.
329 pub metrics: FineTuningJobCheckpointMetrics,
330 /// The name of the fine-tuning job that this checkpoint was created from.
331 pub fine_tuning_job_id: String,
332 /// The object type, which is always "fine_tuning.job.checkpoint".
333 pub object: String,
334}
335
336#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
337pub struct FineTuningJobCheckpointMetrics {
338 pub step: u32,
339 pub train_loss: f32,
340 pub train_mean_token_accuracy: f32,
341 pub valid_loss: f32,
342 pub valid_mean_token_accuracy: f32,
343 pub full_valid_loss: f32,
344 pub full_valid_mean_token_accuracy: f32,
345}