Skip to main content

openai_tools/fine_tuning/
request.rs

1//! OpenAI Fine-tuning API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Fine-tuning API.
4//! It allows you to create, list, retrieve, and cancel fine-tuning jobs, as well as
5//! access training events and checkpoints.
6//!
7//! # Key Features
8//!
9//! - **Create Jobs**: Start a fine-tuning job with custom hyperparameters
10//! - **Retrieve Jobs**: Get the status and details of a fine-tuning job
11//! - **List Jobs**: List all fine-tuning jobs
12//! - **Cancel Jobs**: Cancel an in-progress job
13//! - **List Events**: View training progress and events
14//! - **List Checkpoints**: Access model checkpoints from training
15//!
16//! # Quick Start
17//!
18//! ```rust,no_run
19//! use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//!     let fine_tuning = FineTuning::new()?;
24//!
25//!     // List all fine-tuning jobs
26//!     let response = fine_tuning.list(None, None).await?;
27//!     for job in &response.data {
28//!         println!("{}: {:?}", job.id, job.status);
29//!     }
30//!
31//!     Ok(())
32//! }
33//! ```
34
35use crate::common::auth::AuthProvider;
36use crate::common::client::create_http_client;
37use crate::common::errors::{OpenAIToolError, Result};
38use crate::common::models::FineTuningModel;
39use crate::fine_tuning::response::{
40    DpoConfig, FineTuningCheckpointListResponse, FineTuningEventListResponse, FineTuningJob, FineTuningJobListResponse, Hyperparameters, Integration,
41    MethodConfig, SupervisedConfig,
42};
43use serde::Serialize;
44use std::time::Duration;
45
46/// Default API path for Fine-tuning
47const FINE_TUNING_PATH: &str = "fine_tuning/jobs";
48
49/// Request to create a new fine-tuning job.
50#[derive(Debug, Clone, Serialize)]
51pub struct CreateFineTuningJobRequest {
52    /// The base model to fine-tune.
53    pub model: FineTuningModel,
54
55    /// The ID of the uploaded training file.
56    pub training_file: String,
57
58    /// The ID of the uploaded validation file (optional).
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub validation_file: Option<String>,
61
62    /// A string suffix for the fine-tuned model name (max 64 chars).
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub suffix: Option<String>,
65
66    /// A seed for reproducibility.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub seed: Option<u64>,
69
70    /// The fine-tuning method and hyperparameters.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub method: Option<MethodConfig>,
73
74    /// Integrations to enable (e.g., Weights & Biases).
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub integrations: Option<Vec<Integration>>,
77}
78
79impl CreateFineTuningJobRequest {
80    /// Creates a new fine-tuning job request with the given model and training file.
81    ///
82    /// # Arguments
83    ///
84    /// * `model` - The base model to fine-tune
85    /// * `training_file` - The ID of the uploaded training file
86    ///
87    /// # Example
88    ///
89    /// ```rust
90    /// use openai_tools::fine_tuning::request::CreateFineTuningJobRequest;
91    /// use openai_tools::common::models::FineTuningModel;
92    ///
93    /// let request = CreateFineTuningJobRequest::new(
94    ///     FineTuningModel::Gpt4oMini_2024_07_18,
95    ///     "file-abc123"
96    /// );
97    /// ```
98    pub fn new(model: FineTuningModel, training_file: impl Into<String>) -> Self {
99        Self { model, training_file: training_file.into(), validation_file: None, suffix: None, seed: None, method: None, integrations: None }
100    }
101
102    /// Sets the validation file for the job.
103    pub fn with_validation_file(mut self, file_id: impl Into<String>) -> Self {
104        self.validation_file = Some(file_id.into());
105        self
106    }
107
108    /// Sets the suffix for the fine-tuned model name.
109    pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
110        self.suffix = Some(suffix.into());
111        self
112    }
113
114    /// Sets the seed for reproducibility.
115    pub fn with_seed(mut self, seed: u64) -> Self {
116        self.seed = Some(seed);
117        self
118    }
119
120    /// Configures supervised fine-tuning with custom hyperparameters.
121    pub fn with_supervised_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
122        self.method = Some(MethodConfig { method_type: "supervised".to_string(), supervised: Some(SupervisedConfig { hyperparameters }), dpo: None });
123        self
124    }
125
126    /// Configures DPO (Direct Preference Optimization) fine-tuning.
127    pub fn with_dpo_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
128        self.method = Some(MethodConfig { method_type: "dpo".to_string(), supervised: None, dpo: Some(DpoConfig { hyperparameters }) });
129        self
130    }
131
132    /// Adds integrations to the job.
133    pub fn with_integrations(mut self, integrations: Vec<Integration>) -> Self {
134        self.integrations = Some(integrations);
135        self
136    }
137}
138
139/// Client for interacting with the OpenAI Fine-tuning API.
140///
141/// This struct provides methods to create, list, retrieve, and cancel fine-tuning jobs,
142/// as well as access training events and checkpoints.
143///
144/// # Example
145///
146/// ```rust,no_run
147/// use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
148/// use openai_tools::fine_tuning::response::Hyperparameters;
149/// use openai_tools::common::models::FineTuningModel;
150///
151/// #[tokio::main]
152/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
153///     let fine_tuning = FineTuning::new()?;
154///
155///     // Create a fine-tuning job
156///     let hyperparams = Hyperparameters {
157///         n_epochs: Some(3),
158///         batch_size: None,
159///         learning_rate_multiplier: None,
160///     };
161///
162///     let request = CreateFineTuningJobRequest::new(
163///             FineTuningModel::Gpt4oMini_2024_07_18,
164///             "file-abc123"
165///         )
166///         .with_suffix("my-custom-model")
167///         .with_supervised_method(Some(hyperparams));
168///
169///     let job = fine_tuning.create(request).await?;
170///     println!("Created job: {} ({:?})", job.id, job.status);
171///
172///     Ok(())
173/// }
174/// ```
175pub struct FineTuning {
176    /// Authentication provider (OpenAI or Azure)
177    auth: AuthProvider,
178    /// Optional request timeout duration
179    timeout: Option<Duration>,
180}
181
182impl FineTuning {
183    /// Creates a new FineTuning client for OpenAI API.
184    ///
185    /// Initializes the client by loading the OpenAI API key from
186    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
187    /// via dotenvy.
188    ///
189    /// # Returns
190    ///
191    /// * `Ok(FineTuning)` - A new FineTuning client ready for use
192    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
193    ///
194    /// # Example
195    ///
196    /// ```rust,no_run
197    /// use openai_tools::fine_tuning::request::FineTuning;
198    ///
199    /// let fine_tuning = FineTuning::new().expect("API key should be set");
200    /// ```
201    pub fn new() -> Result<Self> {
202        let auth = AuthProvider::openai_from_env()?;
203        Ok(Self { auth, timeout: None })
204    }
205
206    /// Creates a new FineTuning client with a custom authentication provider
207    pub fn with_auth(auth: AuthProvider) -> Self {
208        Self { auth, timeout: None }
209    }
210
211    /// Creates a new FineTuning client for Azure OpenAI API
212    pub fn azure() -> Result<Self> {
213        let auth = AuthProvider::azure_from_env()?;
214        Ok(Self { auth, timeout: None })
215    }
216
217    /// Creates a new FineTuning client by auto-detecting the provider
218    pub fn detect_provider() -> Result<Self> {
219        let auth = AuthProvider::from_env()?;
220        Ok(Self { auth, timeout: None })
221    }
222
223    /// Creates a new FineTuning client with URL-based provider detection
224    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
225        let auth = AuthProvider::from_url_with_key(base_url, api_key);
226        Self { auth, timeout: None }
227    }
228
229    /// Creates a new FineTuning client from URL using environment variables
230    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
231        let auth = AuthProvider::from_url(url)?;
232        Ok(Self { auth, timeout: None })
233    }
234
235    /// Returns the authentication provider
236    pub fn auth(&self) -> &AuthProvider {
237        &self.auth
238    }
239
240    /// Sets the request timeout duration.
241    ///
242    /// # Arguments
243    ///
244    /// * `timeout` - The maximum time to wait for a response
245    ///
246    /// # Returns
247    ///
248    /// A mutable reference to self for method chaining
249    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
250        self.timeout = Some(timeout);
251        self
252    }
253
254    /// Creates the HTTP client with default headers.
255    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
256        let client = create_http_client(self.timeout)?;
257        let mut headers = request::header::HeaderMap::new();
258        self.auth.apply_headers(&mut headers)?;
259        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
260        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
261        Ok((client, headers))
262    }
263
264    /// Creates a new fine-tuning job.
265    ///
266    /// # Arguments
267    ///
268    /// * `request` - The fine-tuning job creation request
269    ///
270    /// # Returns
271    ///
272    /// * `Ok(FineTuningJob)` - The created job object
273    /// * `Err(OpenAIToolError)` - If the request fails
274    ///
275    /// # Example
276    ///
277    /// ```rust,no_run
278    /// use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
279    /// use openai_tools::common::models::FineTuningModel;
280    ///
281    /// #[tokio::main]
282    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
283    ///     let fine_tuning = FineTuning::new()?;
284    ///
285    ///     let request = CreateFineTuningJobRequest::new(
286    ///             FineTuningModel::Gpt4oMini_2024_07_18,
287    ///             "file-abc123"
288    ///         )
289    ///         .with_suffix("my-model");
290    ///
291    ///     let job = fine_tuning.create(request).await?;
292    ///     println!("Created job: {}", job.id);
293    ///     Ok(())
294    /// }
295    /// ```
296    pub async fn create(&self, request: CreateFineTuningJobRequest) -> Result<FineTuningJob> {
297        let (client, headers) = self.create_client()?;
298
299        let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
300
301        let url = self.auth.endpoint(FINE_TUNING_PATH);
302        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
303
304        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
305
306        if cfg!(test) {
307            tracing::info!("Response content: {}", content);
308        }
309
310        serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
311    }
312
313    /// Retrieves details of a specific fine-tuning job.
314    ///
315    /// # Arguments
316    ///
317    /// * `job_id` - The ID of the job to retrieve
318    ///
319    /// # Returns
320    ///
321    /// * `Ok(FineTuningJob)` - The job details
322    /// * `Err(OpenAIToolError)` - If the job is not found or the request fails
323    ///
324    /// # Example
325    ///
326    /// ```rust,no_run
327    /// use openai_tools::fine_tuning::request::FineTuning;
328    ///
329    /// #[tokio::main]
330    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
331    ///     let fine_tuning = FineTuning::new()?;
332    ///     let job = fine_tuning.retrieve("ftjob-abc123").await?;
333    ///
334    ///     println!("Status: {:?}", job.status);
335    ///     if let Some(model) = &job.fine_tuned_model {
336    ///         println!("Fine-tuned model: {}", model);
337    ///     }
338    ///     Ok(())
339    /// }
340    /// ```
341    pub async fn retrieve(&self, job_id: &str) -> Result<FineTuningJob> {
342        let (client, headers) = self.create_client()?;
343        let url = format!("{}/{}", self.auth.endpoint(FINE_TUNING_PATH), job_id);
344
345        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
346
347        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
348
349        if cfg!(test) {
350            tracing::info!("Response content: {}", content);
351        }
352
353        serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
354    }
355
356    /// Cancels an in-progress fine-tuning job.
357    ///
358    /// # Arguments
359    ///
360    /// * `job_id` - The ID of the job to cancel
361    ///
362    /// # Returns
363    ///
364    /// * `Ok(FineTuningJob)` - The updated job object
365    /// * `Err(OpenAIToolError)` - If the job cannot be cancelled or the request fails
366    ///
367    /// # Example
368    ///
369    /// ```rust,no_run
370    /// use openai_tools::fine_tuning::request::FineTuning;
371    ///
372    /// #[tokio::main]
373    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
374    ///     let fine_tuning = FineTuning::new()?;
375    ///     let job = fine_tuning.cancel("ftjob-abc123").await?;
376    ///
377    ///     println!("Job status: {:?}", job.status);
378    ///     Ok(())
379    /// }
380    /// ```
381    pub async fn cancel(&self, job_id: &str) -> Result<FineTuningJob> {
382        let (client, headers) = self.create_client()?;
383        let url = format!("{}/{}/cancel", self.auth.endpoint(FINE_TUNING_PATH), job_id);
384
385        let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
386
387        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
388
389        if cfg!(test) {
390            tracing::info!("Response content: {}", content);
391        }
392
393        serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
394    }
395
396    /// Lists all fine-tuning jobs.
397    ///
398    /// Supports pagination through `limit` and `after` parameters.
399    ///
400    /// # Arguments
401    ///
402    /// * `limit` - Maximum number of jobs to return (default: 20)
403    /// * `after` - Cursor for pagination (job ID to start after)
404    ///
405    /// # Returns
406    ///
407    /// * `Ok(FineTuningJobListResponse)` - The list of jobs
408    /// * `Err(OpenAIToolError)` - If the request fails
409    ///
410    /// # Example
411    ///
412    /// ```rust,no_run
413    /// use openai_tools::fine_tuning::request::FineTuning;
414    ///
415    /// #[tokio::main]
416    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
417    ///     let fine_tuning = FineTuning::new()?;
418    ///
419    ///     let response = fine_tuning.list(Some(10), None).await?;
420    ///     for job in &response.data {
421    ///         println!("{}: {:?}", job.id, job.status);
422    ///     }
423    ///
424    ///     Ok(())
425    /// }
426    /// ```
427    pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningJobListResponse> {
428        let (client, headers) = self.create_client()?;
429
430        let mut url = self.auth.endpoint(FINE_TUNING_PATH);
431        let mut params = Vec::new();
432
433        if let Some(l) = limit {
434            params.push(format!("limit={}", l));
435        }
436        if let Some(a) = after {
437            params.push(format!("after={}", a));
438        }
439
440        if !params.is_empty() {
441            url.push('?');
442            url.push_str(&params.join("&"));
443        }
444
445        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
446
447        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
448
449        if cfg!(test) {
450            tracing::info!("Response content: {}", content);
451        }
452
453        serde_json::from_str::<FineTuningJobListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
454    }
455
456    /// Lists events for a fine-tuning job.
457    ///
458    /// Events provide insight into the training process.
459    ///
460    /// # Arguments
461    ///
462    /// * `job_id` - The ID of the fine-tuning job
463    /// * `limit` - Maximum number of events to return (default: 20)
464    /// * `after` - Cursor for pagination (event ID to start after)
465    ///
466    /// # Returns
467    ///
468    /// * `Ok(FineTuningEventListResponse)` - The list of events
469    /// * `Err(OpenAIToolError)` - If the request fails
470    ///
471    /// # Example
472    ///
473    /// ```rust,no_run
474    /// use openai_tools::fine_tuning::request::FineTuning;
475    ///
476    /// #[tokio::main]
477    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
478    ///     let fine_tuning = FineTuning::new()?;
479    ///
480    ///     let response = fine_tuning.list_events("ftjob-abc123", Some(10), None).await?;
481    ///     for event in &response.data {
482    ///         println!("[{}] {}: {}", event.level, event.event_type, event.message);
483    ///     }
484    ///
485    ///     Ok(())
486    /// }
487    /// ```
488    pub async fn list_events(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningEventListResponse> {
489        let (client, headers) = self.create_client()?;
490
491        let mut url = format!("{}/{}/events", self.auth.endpoint(FINE_TUNING_PATH), job_id);
492        let mut params = Vec::new();
493
494        if let Some(l) = limit {
495            params.push(format!("limit={}", l));
496        }
497        if let Some(a) = after {
498            params.push(format!("after={}", a));
499        }
500
501        if !params.is_empty() {
502            url.push('?');
503            url.push_str(&params.join("&"));
504        }
505
506        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
507
508        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
509
510        if cfg!(test) {
511            tracing::info!("Response content: {}", content);
512        }
513
514        serde_json::from_str::<FineTuningEventListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
515    }
516
517    /// Lists checkpoints for a fine-tuning job.
518    ///
519    /// Checkpoints are saved at the end of each training epoch.
520    /// Only the last 3 checkpoints are available.
521    ///
522    /// # Arguments
523    ///
524    /// * `job_id` - The ID of the fine-tuning job
525    /// * `limit` - Maximum number of checkpoints to return (default: 10)
526    /// * `after` - Cursor for pagination (checkpoint ID to start after)
527    ///
528    /// # Returns
529    ///
530    /// * `Ok(FineTuningCheckpointListResponse)` - The list of checkpoints
531    /// * `Err(OpenAIToolError)` - If the request fails
532    ///
533    /// # Example
534    ///
535    /// ```rust,no_run
536    /// use openai_tools::fine_tuning::request::FineTuning;
537    ///
538    /// #[tokio::main]
539    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
540    ///     let fine_tuning = FineTuning::new()?;
541    ///
542    ///     let response = fine_tuning.list_checkpoints("ftjob-abc123", None, None).await?;
543    ///     for checkpoint in &response.data {
544    ///         println!("Step {}: loss={}", checkpoint.step_number, checkpoint.metrics.train_loss);
545    ///     }
546    ///
547    ///     Ok(())
548    /// }
549    /// ```
550    pub async fn list_checkpoints(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningCheckpointListResponse> {
551        let (client, headers) = self.create_client()?;
552
553        let mut url = format!("{}/{}/checkpoints", self.auth.endpoint(FINE_TUNING_PATH), job_id);
554        let mut params = Vec::new();
555
556        if let Some(l) = limit {
557            params.push(format!("limit={}", l));
558        }
559        if let Some(a) = after {
560            params.push(format!("after={}", a));
561        }
562
563        if !params.is_empty() {
564            url.push('?');
565            url.push_str(&params.join("&"));
566        }
567
568        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
569
570        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
571
572        if cfg!(test) {
573            tracing::info!("Response content: {}", content);
574        }
575
576        serde_json::from_str::<FineTuningCheckpointListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
577    }
578}