openai_ergonomic/builders/
fine_tuning.rs1use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
16pub struct FineTuningJobBuilder {
17 model: String,
18 training_file: String,
19 validation_file: Option<String>,
20 hyperparameters: FineTuningHyperparameters,
21 suffix: Option<String>,
22 integrations: Vec<FineTuningIntegration>,
23}
24
25#[derive(Debug, Clone, Default)]
27pub struct FineTuningHyperparameters {
28 pub n_epochs: Option<i32>,
30 pub batch_size: Option<i32>,
32 pub learning_rate_multiplier: Option<f64>,
34}
35
36#[derive(Debug, Clone)]
38pub struct FineTuningIntegration {
39 pub integration_type: String,
41 pub settings: HashMap<String, String>,
43}
44
45impl FineTuningJobBuilder {
46 #[must_use]
56 pub fn new(model: impl Into<String>, training_file: impl Into<String>) -> Self {
57 Self {
58 model: model.into(),
59 training_file: training_file.into(),
60 validation_file: None,
61 hyperparameters: FineTuningHyperparameters::default(),
62 suffix: None,
63 integrations: Vec::new(),
64 }
65 }
66
67 #[must_use]
69 pub fn validation_file(mut self, file_id: impl Into<String>) -> Self {
70 self.validation_file = Some(file_id.into());
71 self
72 }
73
74 #[must_use]
76 pub fn epochs(mut self, epochs: i32) -> Self {
77 self.hyperparameters.n_epochs = Some(epochs);
78 self
79 }
80
81 #[must_use]
83 pub fn batch_size(mut self, batch_size: i32) -> Self {
84 self.hyperparameters.batch_size = Some(batch_size);
85 self
86 }
87
88 #[must_use]
90 pub fn learning_rate_multiplier(mut self, multiplier: f64) -> Self {
91 self.hyperparameters.learning_rate_multiplier = Some(multiplier);
92 self
93 }
94
95 #[must_use]
97 pub fn suffix(mut self, suffix: impl Into<String>) -> Self {
98 self.suffix = Some(suffix.into());
99 self
100 }
101
102 #[must_use]
104 pub fn with_wandb(mut self, project: impl Into<String>) -> Self {
105 let mut settings = HashMap::new();
106 settings.insert("project".to_string(), project.into());
107
108 self.integrations.push(FineTuningIntegration {
109 integration_type: "wandb".to_string(),
110 settings,
111 });
112 self
113 }
114
115 #[must_use]
117 pub fn model(&self) -> &str {
118 &self.model
119 }
120
121 #[must_use]
123 pub fn training_file(&self) -> &str {
124 &self.training_file
125 }
126
127 #[must_use]
129 pub fn validation_file_ref(&self) -> Option<&str> {
130 self.validation_file.as_deref()
131 }
132
133 #[must_use]
135 pub fn hyperparameters(&self) -> &FineTuningHyperparameters {
136 &self.hyperparameters
137 }
138
139 #[must_use]
141 pub fn suffix_ref(&self) -> Option<&str> {
142 self.suffix.as_deref()
143 }
144
145 #[must_use]
147 pub fn integrations(&self) -> &[FineTuningIntegration] {
148 &self.integrations
149 }
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct FineTuningJobListBuilder {
155 after: Option<String>,
156 limit: Option<i32>,
157}
158
159impl FineTuningJobListBuilder {
160 #[must_use]
162 pub fn new() -> Self {
163 Self::default()
164 }
165
166 #[must_use]
168 pub fn after(mut self, cursor: impl Into<String>) -> Self {
169 self.after = Some(cursor.into());
170 self
171 }
172
173 #[must_use]
175 pub fn limit(mut self, limit: i32) -> Self {
176 self.limit = Some(limit);
177 self
178 }
179
180 #[must_use]
182 pub fn after_ref(&self) -> Option<&str> {
183 self.after.as_deref()
184 }
185
186 #[must_use]
188 pub fn limit_ref(&self) -> Option<i32> {
189 self.limit
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct FineTuningJobRetrievalBuilder {
196 job_id: String,
197}
198
199impl FineTuningJobRetrievalBuilder {
200 #[must_use]
202 pub fn new(job_id: impl Into<String>) -> Self {
203 Self {
204 job_id: job_id.into(),
205 }
206 }
207
208 #[must_use]
210 pub fn job_id(&self) -> &str {
211 &self.job_id
212 }
213}
214
215#[derive(Debug, Clone)]
217pub struct FineTuningJobCancelBuilder {
218 job_id: String,
219}
220
221impl FineTuningJobCancelBuilder {
222 #[must_use]
224 pub fn new(job_id: impl Into<String>) -> Self {
225 Self {
226 job_id: job_id.into(),
227 }
228 }
229
230 #[must_use]
232 pub fn job_id(&self) -> &str {
233 &self.job_id
234 }
235}
236
237#[must_use]
239pub fn fine_tune_model(
240 base_model: impl Into<String>,
241 training_file: impl Into<String>,
242) -> FineTuningJobBuilder {
243 FineTuningJobBuilder::new(base_model, training_file)
244}
245
246#[must_use]
248pub fn fine_tune_with_validation(
249 base_model: impl Into<String>,
250 training_file: impl Into<String>,
251 validation_file: impl Into<String>,
252) -> FineTuningJobBuilder {
253 FineTuningJobBuilder::new(base_model, training_file).validation_file(validation_file)
254}
255
256#[must_use]
258pub fn fine_tune_with_params(
259 base_model: impl Into<String>,
260 training_file: impl Into<String>,
261 epochs: i32,
262 learning_rate: f64,
263) -> FineTuningJobBuilder {
264 FineTuningJobBuilder::new(base_model, training_file)
265 .epochs(epochs)
266 .learning_rate_multiplier(learning_rate)
267}
268
269#[must_use]
271pub fn list_fine_tuning_jobs() -> FineTuningJobListBuilder {
272 FineTuningJobListBuilder::new()
273}
274
275#[must_use]
277pub fn get_fine_tuning_job(job_id: impl Into<String>) -> FineTuningJobRetrievalBuilder {
278 FineTuningJobRetrievalBuilder::new(job_id)
279}
280
281#[must_use]
283pub fn cancel_fine_tuning_job(job_id: impl Into<String>) -> FineTuningJobCancelBuilder {
284 FineTuningJobCancelBuilder::new(job_id)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_fine_tuning_job_builder_new() {
293 let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training");
294
295 assert_eq!(builder.model(), "gpt-3.5-turbo");
296 assert_eq!(builder.training_file(), "file-training");
297 assert!(builder.validation_file_ref().is_none());
298 assert!(builder.suffix_ref().is_none());
299 assert!(builder.integrations().is_empty());
300 }
301
302 #[test]
303 fn test_fine_tuning_job_builder_with_validation() {
304 let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training")
305 .validation_file("file-validation");
306
307 assert_eq!(builder.validation_file_ref(), Some("file-validation"));
308 }
309
310 #[test]
311 fn test_fine_tuning_job_builder_with_hyperparameters() {
312 let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training")
313 .epochs(3)
314 .batch_size(16)
315 .learning_rate_multiplier(0.1);
316
317 assert_eq!(builder.hyperparameters().n_epochs, Some(3));
318 assert_eq!(builder.hyperparameters().batch_size, Some(16));
319 assert_eq!(
320 builder.hyperparameters().learning_rate_multiplier,
321 Some(0.1)
322 );
323 }
324
325 #[test]
326 fn test_fine_tuning_job_builder_with_suffix() {
327 let builder =
328 FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training").suffix("my-model-v1");
329
330 assert_eq!(builder.suffix_ref(), Some("my-model-v1"));
331 }
332
333 #[test]
334 fn test_fine_tuning_job_builder_with_wandb() {
335 let builder =
336 FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training").with_wandb("my-project");
337
338 assert_eq!(builder.integrations().len(), 1);
339 assert_eq!(builder.integrations()[0].integration_type, "wandb");
340 assert_eq!(
341 builder.integrations()[0].settings.get("project"),
342 Some(&"my-project".to_string())
343 );
344 }
345
346 #[test]
347 fn test_fine_tuning_job_list_builder() {
348 let builder = FineTuningJobListBuilder::new().after("job-123").limit(10);
349
350 assert_eq!(builder.after_ref(), Some("job-123"));
351 assert_eq!(builder.limit_ref(), Some(10));
352 }
353
354 #[test]
355 fn test_fine_tuning_job_retrieval_builder() {
356 let builder = FineTuningJobRetrievalBuilder::new("job-456");
357 assert_eq!(builder.job_id(), "job-456");
358 }
359
360 #[test]
361 fn test_fine_tuning_job_cancel_builder() {
362 let builder = FineTuningJobCancelBuilder::new("job-789");
363 assert_eq!(builder.job_id(), "job-789");
364 }
365
366 #[test]
367 fn test_fine_tune_model_helper() {
368 let builder = fine_tune_model("gpt-3.5-turbo", "file-training");
369 assert_eq!(builder.model(), "gpt-3.5-turbo");
370 assert_eq!(builder.training_file(), "file-training");
371 }
372
373 #[test]
374 fn test_fine_tune_with_validation_helper() {
375 let builder =
376 fine_tune_with_validation("gpt-3.5-turbo", "file-training", "file-validation");
377 assert_eq!(builder.validation_file_ref(), Some("file-validation"));
378 }
379
380 #[test]
381 fn test_fine_tune_with_params_helper() {
382 let builder = fine_tune_with_params("gpt-3.5-turbo", "file-training", 5, 0.2);
383 assert_eq!(builder.hyperparameters().n_epochs, Some(5));
384 assert_eq!(
385 builder.hyperparameters().learning_rate_multiplier,
386 Some(0.2)
387 );
388 }
389
390 #[test]
391 fn test_list_fine_tuning_jobs_helper() {
392 let builder = list_fine_tuning_jobs();
393 assert!(builder.after_ref().is_none());
394 assert!(builder.limit_ref().is_none());
395 }
396
397 #[test]
398 fn test_get_fine_tuning_job_helper() {
399 let builder = get_fine_tuning_job("job-123");
400 assert_eq!(builder.job_id(), "job-123");
401 }
402
403 #[test]
404 fn test_cancel_fine_tuning_job_helper() {
405 let builder = cancel_fine_tuning_job("job-456");
406 assert_eq!(builder.job_id(), "job-456");
407 }
408
409 #[test]
410 fn test_fine_tuning_hyperparameters_default() {
411 let params = FineTuningHyperparameters::default();
412 assert!(params.n_epochs.is_none());
413 assert!(params.batch_size.is_none());
414 assert!(params.learning_rate_multiplier.is_none());
415 }
416}