openai_ergonomic/builders/
fine_tuning.rs

1//! Fine-tuning API builders.
2//!
3//! This module provides ergonomic builders for `OpenAI` Fine-tuning API operations,
4//! including creating fine-tuning jobs, monitoring progress, and managing models.
5//!
6//! Fine-tuning allows you to customize models on your specific training data
7//! to improve performance for your particular use case.
8
9use std::collections::HashMap;
10
11/// Builder for creating fine-tuning jobs.
12///
13/// Fine-tuning jobs train models on your specific data to improve performance
14/// for your particular use case.
15#[derive(Debug, Clone)]
16pub struct FineTuningJobBuilder {
17    model: String,
18    training_file: String,
19    validation_file: Option<String>,
20    hyperparameters: FineTuningHyperparameters,
21    suffix: Option<String>,
22    integrations: Vec<FineTuningIntegration>,
23}
24
25/// Hyperparameters for fine-tuning jobs.
26#[derive(Debug, Clone, Default)]
27pub struct FineTuningHyperparameters {
28    /// Number of epochs to train for
29    pub n_epochs: Option<i32>,
30    /// Batch size for training
31    pub batch_size: Option<i32>,
32    /// Learning rate multiplier
33    pub learning_rate_multiplier: Option<f64>,
34}
35
36/// Integration for fine-tuning jobs (e.g., Weights & Biases).
37#[derive(Debug, Clone)]
38pub struct FineTuningIntegration {
39    /// Type of integration
40    pub integration_type: String,
41    /// Integration settings
42    pub settings: HashMap<String, String>,
43}
44
45impl FineTuningJobBuilder {
46    /// Create a new fine-tuning job builder.
47    ///
48    /// # Examples
49    ///
50    /// ```rust
51    /// use openai_ergonomic::builders::fine_tuning::FineTuningJobBuilder;
52    ///
53    /// let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training-data");
54    /// ```
55    #[must_use]
56    pub fn new(model: impl Into<String>, training_file: impl Into<String>) -> Self {
57        Self {
58            model: model.into(),
59            training_file: training_file.into(),
60            validation_file: None,
61            hyperparameters: FineTuningHyperparameters::default(),
62            suffix: None,
63            integrations: Vec::new(),
64        }
65    }
66
67    /// Set the validation file for the fine-tuning job.
68    #[must_use]
69    pub fn validation_file(mut self, file_id: impl Into<String>) -> Self {
70        self.validation_file = Some(file_id.into());
71        self
72    }
73
74    /// Set the number of epochs to train for.
75    #[must_use]
76    pub fn epochs(mut self, epochs: i32) -> Self {
77        self.hyperparameters.n_epochs = Some(epochs);
78        self
79    }
80
81    /// Set the batch size for training.
82    #[must_use]
83    pub fn batch_size(mut self, batch_size: i32) -> Self {
84        self.hyperparameters.batch_size = Some(batch_size);
85        self
86    }
87
88    /// Set the learning rate multiplier.
89    #[must_use]
90    pub fn learning_rate_multiplier(mut self, multiplier: f64) -> Self {
91        self.hyperparameters.learning_rate_multiplier = Some(multiplier);
92        self
93    }
94
95    /// Set a suffix for the fine-tuned model name.
96    #[must_use]
97    pub fn suffix(mut self, suffix: impl Into<String>) -> Self {
98        self.suffix = Some(suffix.into());
99        self
100    }
101
102    /// Add a Weights & Biases integration.
103    #[must_use]
104    pub fn with_wandb(mut self, project: impl Into<String>) -> Self {
105        let mut settings = HashMap::new();
106        settings.insert("project".to_string(), project.into());
107
108        self.integrations.push(FineTuningIntegration {
109            integration_type: "wandb".to_string(),
110            settings,
111        });
112        self
113    }
114
115    /// Get the base model for this fine-tuning job.
116    #[must_use]
117    pub fn model(&self) -> &str {
118        &self.model
119    }
120
121    /// Get the training file ID.
122    #[must_use]
123    pub fn training_file(&self) -> &str {
124        &self.training_file
125    }
126
127    /// Get the validation file ID.
128    #[must_use]
129    pub fn validation_file_ref(&self) -> Option<&str> {
130        self.validation_file.as_deref()
131    }
132
133    /// Get the hyperparameters.
134    #[must_use]
135    pub fn hyperparameters(&self) -> &FineTuningHyperparameters {
136        &self.hyperparameters
137    }
138
139    /// Get the model suffix.
140    #[must_use]
141    pub fn suffix_ref(&self) -> Option<&str> {
142        self.suffix.as_deref()
143    }
144
145    /// Get the integrations.
146    #[must_use]
147    pub fn integrations(&self) -> &[FineTuningIntegration] {
148        &self.integrations
149    }
150}
151
152/// Builder for listing fine-tuning jobs.
153#[derive(Debug, Clone, Default)]
154pub struct FineTuningJobListBuilder {
155    after: Option<String>,
156    limit: Option<i32>,
157}
158
159impl FineTuningJobListBuilder {
160    /// Create a new fine-tuning job list builder.
161    #[must_use]
162    pub fn new() -> Self {
163        Self::default()
164    }
165
166    /// Set the cursor for pagination.
167    #[must_use]
168    pub fn after(mut self, cursor: impl Into<String>) -> Self {
169        self.after = Some(cursor.into());
170        self
171    }
172
173    /// Set the maximum number of jobs to return.
174    #[must_use]
175    pub fn limit(mut self, limit: i32) -> Self {
176        self.limit = Some(limit);
177        self
178    }
179
180    /// Get the pagination cursor.
181    #[must_use]
182    pub fn after_ref(&self) -> Option<&str> {
183        self.after.as_deref()
184    }
185
186    /// Get the limit.
187    #[must_use]
188    pub fn limit_ref(&self) -> Option<i32> {
189        self.limit
190    }
191}
192
193/// Builder for retrieving fine-tuning job details.
194#[derive(Debug, Clone)]
195pub struct FineTuningJobRetrievalBuilder {
196    job_id: String,
197}
198
199impl FineTuningJobRetrievalBuilder {
200    /// Create a new fine-tuning job retrieval builder.
201    #[must_use]
202    pub fn new(job_id: impl Into<String>) -> Self {
203        Self {
204            job_id: job_id.into(),
205        }
206    }
207
208    /// Get the job ID.
209    #[must_use]
210    pub fn job_id(&self) -> &str {
211        &self.job_id
212    }
213}
214
215/// Builder for cancelling fine-tuning jobs.
216#[derive(Debug, Clone)]
217pub struct FineTuningJobCancelBuilder {
218    job_id: String,
219}
220
221impl FineTuningJobCancelBuilder {
222    /// Create a new fine-tuning job cancel builder.
223    #[must_use]
224    pub fn new(job_id: impl Into<String>) -> Self {
225        Self {
226            job_id: job_id.into(),
227        }
228    }
229
230    /// Get the job ID.
231    #[must_use]
232    pub fn job_id(&self) -> &str {
233        &self.job_id
234    }
235}
236
237/// Helper function to create a basic fine-tuning job.
238#[must_use]
239pub fn fine_tune_model(
240    base_model: impl Into<String>,
241    training_file: impl Into<String>,
242) -> FineTuningJobBuilder {
243    FineTuningJobBuilder::new(base_model, training_file)
244}
245
246/// Helper function to fine-tune with validation data.
247#[must_use]
248pub fn fine_tune_with_validation(
249    base_model: impl Into<String>,
250    training_file: impl Into<String>,
251    validation_file: impl Into<String>,
252) -> FineTuningJobBuilder {
253    FineTuningJobBuilder::new(base_model, training_file).validation_file(validation_file)
254}
255
256/// Helper function to create a fine-tuning job with custom hyperparameters.
257#[must_use]
258pub fn fine_tune_with_params(
259    base_model: impl Into<String>,
260    training_file: impl Into<String>,
261    epochs: i32,
262    learning_rate: f64,
263) -> FineTuningJobBuilder {
264    FineTuningJobBuilder::new(base_model, training_file)
265        .epochs(epochs)
266        .learning_rate_multiplier(learning_rate)
267}
268
269/// Helper function to list fine-tuning jobs.
270#[must_use]
271pub fn list_fine_tuning_jobs() -> FineTuningJobListBuilder {
272    FineTuningJobListBuilder::new()
273}
274
275/// Helper function to retrieve a specific fine-tuning job.
276#[must_use]
277pub fn get_fine_tuning_job(job_id: impl Into<String>) -> FineTuningJobRetrievalBuilder {
278    FineTuningJobRetrievalBuilder::new(job_id)
279}
280
281/// Helper function to cancel a fine-tuning job.
282#[must_use]
283pub fn cancel_fine_tuning_job(job_id: impl Into<String>) -> FineTuningJobCancelBuilder {
284    FineTuningJobCancelBuilder::new(job_id)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_fine_tuning_job_builder_new() {
293        let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training");
294
295        assert_eq!(builder.model(), "gpt-3.5-turbo");
296        assert_eq!(builder.training_file(), "file-training");
297        assert!(builder.validation_file_ref().is_none());
298        assert!(builder.suffix_ref().is_none());
299        assert!(builder.integrations().is_empty());
300    }
301
302    #[test]
303    fn test_fine_tuning_job_builder_with_validation() {
304        let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training")
305            .validation_file("file-validation");
306
307        assert_eq!(builder.validation_file_ref(), Some("file-validation"));
308    }
309
310    #[test]
311    fn test_fine_tuning_job_builder_with_hyperparameters() {
312        let builder = FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training")
313            .epochs(3)
314            .batch_size(16)
315            .learning_rate_multiplier(0.1);
316
317        assert_eq!(builder.hyperparameters().n_epochs, Some(3));
318        assert_eq!(builder.hyperparameters().batch_size, Some(16));
319        assert_eq!(
320            builder.hyperparameters().learning_rate_multiplier,
321            Some(0.1)
322        );
323    }
324
325    #[test]
326    fn test_fine_tuning_job_builder_with_suffix() {
327        let builder =
328            FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training").suffix("my-model-v1");
329
330        assert_eq!(builder.suffix_ref(), Some("my-model-v1"));
331    }
332
333    #[test]
334    fn test_fine_tuning_job_builder_with_wandb() {
335        let builder =
336            FineTuningJobBuilder::new("gpt-3.5-turbo", "file-training").with_wandb("my-project");
337
338        assert_eq!(builder.integrations().len(), 1);
339        assert_eq!(builder.integrations()[0].integration_type, "wandb");
340        assert_eq!(
341            builder.integrations()[0].settings.get("project"),
342            Some(&"my-project".to_string())
343        );
344    }
345
346    #[test]
347    fn test_fine_tuning_job_list_builder() {
348        let builder = FineTuningJobListBuilder::new().after("job-123").limit(10);
349
350        assert_eq!(builder.after_ref(), Some("job-123"));
351        assert_eq!(builder.limit_ref(), Some(10));
352    }
353
354    #[test]
355    fn test_fine_tuning_job_retrieval_builder() {
356        let builder = FineTuningJobRetrievalBuilder::new("job-456");
357        assert_eq!(builder.job_id(), "job-456");
358    }
359
360    #[test]
361    fn test_fine_tuning_job_cancel_builder() {
362        let builder = FineTuningJobCancelBuilder::new("job-789");
363        assert_eq!(builder.job_id(), "job-789");
364    }
365
366    #[test]
367    fn test_fine_tune_model_helper() {
368        let builder = fine_tune_model("gpt-3.5-turbo", "file-training");
369        assert_eq!(builder.model(), "gpt-3.5-turbo");
370        assert_eq!(builder.training_file(), "file-training");
371    }
372
373    #[test]
374    fn test_fine_tune_with_validation_helper() {
375        let builder =
376            fine_tune_with_validation("gpt-3.5-turbo", "file-training", "file-validation");
377        assert_eq!(builder.validation_file_ref(), Some("file-validation"));
378    }
379
380    #[test]
381    fn test_fine_tune_with_params_helper() {
382        let builder = fine_tune_with_params("gpt-3.5-turbo", "file-training", 5, 0.2);
383        assert_eq!(builder.hyperparameters().n_epochs, Some(5));
384        assert_eq!(
385            builder.hyperparameters().learning_rate_multiplier,
386            Some(0.2)
387        );
388    }
389
390    #[test]
391    fn test_list_fine_tuning_jobs_helper() {
392        let builder = list_fine_tuning_jobs();
393        assert!(builder.after_ref().is_none());
394        assert!(builder.limit_ref().is_none());
395    }
396
397    #[test]
398    fn test_get_fine_tuning_job_helper() {
399        let builder = get_fine_tuning_job("job-123");
400        assert_eq!(builder.job_id(), "job-123");
401    }
402
403    #[test]
404    fn test_cancel_fine_tuning_job_helper() {
405        let builder = cancel_fine_tuning_job("job-456");
406        assert_eq!(builder.job_id(), "job-456");
407    }
408
409    #[test]
410    fn test_fine_tuning_hyperparameters_default() {
411        let params = FineTuningHyperparameters::default();
412        assert!(params.n_epochs.is_none());
413        assert!(params.batch_size.is_none());
414        assert!(params.learning_rate_multiplier.is_none());
415    }
416}