openai_tools/fine_tuning/request.rs
1//! OpenAI Fine-tuning API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Fine-tuning API.
4//! It allows you to create, list, retrieve, and cancel fine-tuning jobs, as well as
5//! access training events and checkpoints.
6//!
7//! # Key Features
8//!
9//! - **Create Jobs**: Start a fine-tuning job with custom hyperparameters
10//! - **Retrieve Jobs**: Get the status and details of a fine-tuning job
11//! - **List Jobs**: List all fine-tuning jobs
12//! - **Cancel Jobs**: Cancel an in-progress job
13//! - **List Events**: View training progress and events
14//! - **List Checkpoints**: Access model checkpoints from training
15//!
16//! # Quick Start
17//!
18//! ```rust,no_run
19//! use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//! let fine_tuning = FineTuning::new()?;
24//!
25//! // List all fine-tuning jobs
26//! let response = fine_tuning.list(None, None).await?;
27//! for job in &response.data {
28//! println!("{}: {:?}", job.id, job.status);
29//! }
30//!
31//! Ok(())
32//! }
33//! ```
34
35use crate::common::auth::AuthProvider;
36use crate::common::client::create_http_client;
37use crate::common::errors::{OpenAIToolError, Result};
38use crate::common::models::FineTuningModel;
39use crate::fine_tuning::response::{
40 DpoConfig, FineTuningCheckpointListResponse, FineTuningEventListResponse, FineTuningJob, FineTuningJobListResponse, Hyperparameters, Integration,
41 MethodConfig, SupervisedConfig,
42};
43use serde::Serialize;
44use std::time::Duration;
45
46/// Default API path for Fine-tuning
47const FINE_TUNING_PATH: &str = "fine_tuning/jobs";
48
49/// Request to create a new fine-tuning job.
50#[derive(Debug, Clone, Serialize)]
51pub struct CreateFineTuningJobRequest {
52 /// The base model to fine-tune.
53 pub model: FineTuningModel,
54
55 /// The ID of the uploaded training file.
56 pub training_file: String,
57
58 /// The ID of the uploaded validation file (optional).
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub validation_file: Option<String>,
61
62 /// A string suffix for the fine-tuned model name (max 64 chars).
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub suffix: Option<String>,
65
66 /// A seed for reproducibility.
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub seed: Option<u64>,
69
70 /// The fine-tuning method and hyperparameters.
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub method: Option<MethodConfig>,
73
74 /// Integrations to enable (e.g., Weights & Biases).
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub integrations: Option<Vec<Integration>>,
77}
78
79impl CreateFineTuningJobRequest {
80 /// Creates a new fine-tuning job request with the given model and training file.
81 ///
82 /// # Arguments
83 ///
84 /// * `model` - The base model to fine-tune
85 /// * `training_file` - The ID of the uploaded training file
86 ///
87 /// # Example
88 ///
89 /// ```rust
90 /// use openai_tools::fine_tuning::request::CreateFineTuningJobRequest;
91 /// use openai_tools::common::models::FineTuningModel;
92 ///
93 /// let request = CreateFineTuningJobRequest::new(
94 /// FineTuningModel::Gpt4oMini_2024_07_18,
95 /// "file-abc123"
96 /// );
97 /// ```
98 pub fn new(model: FineTuningModel, training_file: impl Into<String>) -> Self {
99 Self { model, training_file: training_file.into(), validation_file: None, suffix: None, seed: None, method: None, integrations: None }
100 }
101
102 /// Sets the validation file for the job.
103 pub fn with_validation_file(mut self, file_id: impl Into<String>) -> Self {
104 self.validation_file = Some(file_id.into());
105 self
106 }
107
108 /// Sets the suffix for the fine-tuned model name.
109 pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
110 self.suffix = Some(suffix.into());
111 self
112 }
113
114 /// Sets the seed for reproducibility.
115 pub fn with_seed(mut self, seed: u64) -> Self {
116 self.seed = Some(seed);
117 self
118 }
119
120 /// Configures supervised fine-tuning with custom hyperparameters.
121 pub fn with_supervised_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
122 self.method = Some(MethodConfig { method_type: "supervised".to_string(), supervised: Some(SupervisedConfig { hyperparameters }), dpo: None });
123 self
124 }
125
126 /// Configures DPO (Direct Preference Optimization) fine-tuning.
127 pub fn with_dpo_method(mut self, hyperparameters: Option<Hyperparameters>) -> Self {
128 self.method = Some(MethodConfig { method_type: "dpo".to_string(), supervised: None, dpo: Some(DpoConfig { hyperparameters }) });
129 self
130 }
131
132 /// Adds integrations to the job.
133 pub fn with_integrations(mut self, integrations: Vec<Integration>) -> Self {
134 self.integrations = Some(integrations);
135 self
136 }
137}
138
139/// Client for interacting with the OpenAI Fine-tuning API.
140///
141/// This struct provides methods to create, list, retrieve, and cancel fine-tuning jobs,
142/// as well as access training events and checkpoints.
143///
144/// # Example
145///
146/// ```rust,no_run
147/// use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
148/// use openai_tools::fine_tuning::response::Hyperparameters;
149/// use openai_tools::common::models::FineTuningModel;
150///
151/// #[tokio::main]
152/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
153/// let fine_tuning = FineTuning::new()?;
154///
155/// // Create a fine-tuning job
156/// let hyperparams = Hyperparameters {
157/// n_epochs: Some(3),
158/// batch_size: None,
159/// learning_rate_multiplier: None,
160/// };
161///
162/// let request = CreateFineTuningJobRequest::new(
163/// FineTuningModel::Gpt4oMini_2024_07_18,
164/// "file-abc123"
165/// )
166/// .with_suffix("my-custom-model")
167/// .with_supervised_method(Some(hyperparams));
168///
169/// let job = fine_tuning.create(request).await?;
170/// println!("Created job: {} ({:?})", job.id, job.status);
171///
172/// Ok(())
173/// }
174/// ```
175pub struct FineTuning {
176 /// Authentication provider (OpenAI or Azure)
177 auth: AuthProvider,
178 /// Optional request timeout duration
179 timeout: Option<Duration>,
180}
181
182impl FineTuning {
183 /// Creates a new FineTuning client for OpenAI API.
184 ///
185 /// Initializes the client by loading the OpenAI API key from
186 /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
187 /// via dotenvy.
188 ///
189 /// # Returns
190 ///
191 /// * `Ok(FineTuning)` - A new FineTuning client ready for use
192 /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
193 ///
194 /// # Example
195 ///
196 /// ```rust,no_run
197 /// use openai_tools::fine_tuning::request::FineTuning;
198 ///
199 /// let fine_tuning = FineTuning::new().expect("API key should be set");
200 /// ```
201 pub fn new() -> Result<Self> {
202 let auth = AuthProvider::openai_from_env()?;
203 Ok(Self { auth, timeout: None })
204 }
205
206 /// Creates a new FineTuning client with a custom authentication provider
207 pub fn with_auth(auth: AuthProvider) -> Self {
208 Self { auth, timeout: None }
209 }
210
211 /// Creates a new FineTuning client for Azure OpenAI API
212 pub fn azure() -> Result<Self> {
213 let auth = AuthProvider::azure_from_env()?;
214 Ok(Self { auth, timeout: None })
215 }
216
217 /// Creates a new FineTuning client by auto-detecting the provider
218 pub fn detect_provider() -> Result<Self> {
219 let auth = AuthProvider::from_env()?;
220 Ok(Self { auth, timeout: None })
221 }
222
223 /// Creates a new FineTuning client with URL-based provider detection
224 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
225 let auth = AuthProvider::from_url_with_key(base_url, api_key);
226 Self { auth, timeout: None }
227 }
228
229 /// Creates a new FineTuning client from URL using environment variables
230 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
231 let auth = AuthProvider::from_url(url)?;
232 Ok(Self { auth, timeout: None })
233 }
234
235 /// Returns the authentication provider
236 pub fn auth(&self) -> &AuthProvider {
237 &self.auth
238 }
239
240 /// Sets the request timeout duration.
241 ///
242 /// # Arguments
243 ///
244 /// * `timeout` - The maximum time to wait for a response
245 ///
246 /// # Returns
247 ///
248 /// A mutable reference to self for method chaining
249 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
250 self.timeout = Some(timeout);
251 self
252 }
253
254 /// Creates the HTTP client with default headers.
255 fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
256 let client = create_http_client(self.timeout)?;
257 let mut headers = request::header::HeaderMap::new();
258 self.auth.apply_headers(&mut headers)?;
259 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
260 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
261 Ok((client, headers))
262 }
263
264 /// Creates a new fine-tuning job.
265 ///
266 /// # Arguments
267 ///
268 /// * `request` - The fine-tuning job creation request
269 ///
270 /// # Returns
271 ///
272 /// * `Ok(FineTuningJob)` - The created job object
273 /// * `Err(OpenAIToolError)` - If the request fails
274 ///
275 /// # Example
276 ///
277 /// ```rust,no_run
278 /// use openai_tools::fine_tuning::request::{FineTuning, CreateFineTuningJobRequest};
279 /// use openai_tools::common::models::FineTuningModel;
280 ///
281 /// #[tokio::main]
282 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
283 /// let fine_tuning = FineTuning::new()?;
284 ///
285 /// let request = CreateFineTuningJobRequest::new(
286 /// FineTuningModel::Gpt4oMini_2024_07_18,
287 /// "file-abc123"
288 /// )
289 /// .with_suffix("my-model");
290 ///
291 /// let job = fine_tuning.create(request).await?;
292 /// println!("Created job: {}", job.id);
293 /// Ok(())
294 /// }
295 /// ```
296 pub async fn create(&self, request: CreateFineTuningJobRequest) -> Result<FineTuningJob> {
297 let (client, headers) = self.create_client()?;
298
299 let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
300
301 let url = self.auth.endpoint(FINE_TUNING_PATH);
302 let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
303
304 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
305
306 if cfg!(test) {
307 tracing::info!("Response content: {}", content);
308 }
309
310 serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
311 }
312
313 /// Retrieves details of a specific fine-tuning job.
314 ///
315 /// # Arguments
316 ///
317 /// * `job_id` - The ID of the job to retrieve
318 ///
319 /// # Returns
320 ///
321 /// * `Ok(FineTuningJob)` - The job details
322 /// * `Err(OpenAIToolError)` - If the job is not found or the request fails
323 ///
324 /// # Example
325 ///
326 /// ```rust,no_run
327 /// use openai_tools::fine_tuning::request::FineTuning;
328 ///
329 /// #[tokio::main]
330 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
331 /// let fine_tuning = FineTuning::new()?;
332 /// let job = fine_tuning.retrieve("ftjob-abc123").await?;
333 ///
334 /// println!("Status: {:?}", job.status);
335 /// if let Some(model) = &job.fine_tuned_model {
336 /// println!("Fine-tuned model: {}", model);
337 /// }
338 /// Ok(())
339 /// }
340 /// ```
341 pub async fn retrieve(&self, job_id: &str) -> Result<FineTuningJob> {
342 let (client, headers) = self.create_client()?;
343 let url = format!("{}/{}", self.auth.endpoint(FINE_TUNING_PATH), job_id);
344
345 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
346
347 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
348
349 if cfg!(test) {
350 tracing::info!("Response content: {}", content);
351 }
352
353 serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
354 }
355
356 /// Cancels an in-progress fine-tuning job.
357 ///
358 /// # Arguments
359 ///
360 /// * `job_id` - The ID of the job to cancel
361 ///
362 /// # Returns
363 ///
364 /// * `Ok(FineTuningJob)` - The updated job object
365 /// * `Err(OpenAIToolError)` - If the job cannot be cancelled or the request fails
366 ///
367 /// # Example
368 ///
369 /// ```rust,no_run
370 /// use openai_tools::fine_tuning::request::FineTuning;
371 ///
372 /// #[tokio::main]
373 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
374 /// let fine_tuning = FineTuning::new()?;
375 /// let job = fine_tuning.cancel("ftjob-abc123").await?;
376 ///
377 /// println!("Job status: {:?}", job.status);
378 /// Ok(())
379 /// }
380 /// ```
381 pub async fn cancel(&self, job_id: &str) -> Result<FineTuningJob> {
382 let (client, headers) = self.create_client()?;
383 let url = format!("{}/{}/cancel", self.auth.endpoint(FINE_TUNING_PATH), job_id);
384
385 let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
386
387 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
388
389 if cfg!(test) {
390 tracing::info!("Response content: {}", content);
391 }
392
393 serde_json::from_str::<FineTuningJob>(&content).map_err(OpenAIToolError::SerdeJsonError)
394 }
395
396 /// Lists all fine-tuning jobs.
397 ///
398 /// Supports pagination through `limit` and `after` parameters.
399 ///
400 /// # Arguments
401 ///
402 /// * `limit` - Maximum number of jobs to return (default: 20)
403 /// * `after` - Cursor for pagination (job ID to start after)
404 ///
405 /// # Returns
406 ///
407 /// * `Ok(FineTuningJobListResponse)` - The list of jobs
408 /// * `Err(OpenAIToolError)` - If the request fails
409 ///
410 /// # Example
411 ///
412 /// ```rust,no_run
413 /// use openai_tools::fine_tuning::request::FineTuning;
414 ///
415 /// #[tokio::main]
416 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
417 /// let fine_tuning = FineTuning::new()?;
418 ///
419 /// let response = fine_tuning.list(Some(10), None).await?;
420 /// for job in &response.data {
421 /// println!("{}: {:?}", job.id, job.status);
422 /// }
423 ///
424 /// Ok(())
425 /// }
426 /// ```
427 pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningJobListResponse> {
428 let (client, headers) = self.create_client()?;
429
430 let mut url = self.auth.endpoint(FINE_TUNING_PATH);
431 let mut params = Vec::new();
432
433 if let Some(l) = limit {
434 params.push(format!("limit={}", l));
435 }
436 if let Some(a) = after {
437 params.push(format!("after={}", a));
438 }
439
440 if !params.is_empty() {
441 url.push('?');
442 url.push_str(¶ms.join("&"));
443 }
444
445 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
446
447 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
448
449 if cfg!(test) {
450 tracing::info!("Response content: {}", content);
451 }
452
453 serde_json::from_str::<FineTuningJobListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
454 }
455
456 /// Lists events for a fine-tuning job.
457 ///
458 /// Events provide insight into the training process.
459 ///
460 /// # Arguments
461 ///
462 /// * `job_id` - The ID of the fine-tuning job
463 /// * `limit` - Maximum number of events to return (default: 20)
464 /// * `after` - Cursor for pagination (event ID to start after)
465 ///
466 /// # Returns
467 ///
468 /// * `Ok(FineTuningEventListResponse)` - The list of events
469 /// * `Err(OpenAIToolError)` - If the request fails
470 ///
471 /// # Example
472 ///
473 /// ```rust,no_run
474 /// use openai_tools::fine_tuning::request::FineTuning;
475 ///
476 /// #[tokio::main]
477 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
478 /// let fine_tuning = FineTuning::new()?;
479 ///
480 /// let response = fine_tuning.list_events("ftjob-abc123", Some(10), None).await?;
481 /// for event in &response.data {
482 /// println!("[{}] {}: {}", event.level, event.event_type, event.message);
483 /// }
484 ///
485 /// Ok(())
486 /// }
487 /// ```
488 pub async fn list_events(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningEventListResponse> {
489 let (client, headers) = self.create_client()?;
490
491 let mut url = format!("{}/{}/events", self.auth.endpoint(FINE_TUNING_PATH), job_id);
492 let mut params = Vec::new();
493
494 if let Some(l) = limit {
495 params.push(format!("limit={}", l));
496 }
497 if let Some(a) = after {
498 params.push(format!("after={}", a));
499 }
500
501 if !params.is_empty() {
502 url.push('?');
503 url.push_str(¶ms.join("&"));
504 }
505
506 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
507
508 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
509
510 if cfg!(test) {
511 tracing::info!("Response content: {}", content);
512 }
513
514 serde_json::from_str::<FineTuningEventListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
515 }
516
517 /// Lists checkpoints for a fine-tuning job.
518 ///
519 /// Checkpoints are saved at the end of each training epoch.
520 /// Only the last 3 checkpoints are available.
521 ///
522 /// # Arguments
523 ///
524 /// * `job_id` - The ID of the fine-tuning job
525 /// * `limit` - Maximum number of checkpoints to return (default: 10)
526 /// * `after` - Cursor for pagination (checkpoint ID to start after)
527 ///
528 /// # Returns
529 ///
530 /// * `Ok(FineTuningCheckpointListResponse)` - The list of checkpoints
531 /// * `Err(OpenAIToolError)` - If the request fails
532 ///
533 /// # Example
534 ///
535 /// ```rust,no_run
536 /// use openai_tools::fine_tuning::request::FineTuning;
537 ///
538 /// #[tokio::main]
539 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
540 /// let fine_tuning = FineTuning::new()?;
541 ///
542 /// let response = fine_tuning.list_checkpoints("ftjob-abc123", None, None).await?;
543 /// for checkpoint in &response.data {
544 /// println!("Step {}: loss={}", checkpoint.step_number, checkpoint.metrics.train_loss);
545 /// }
546 ///
547 /// Ok(())
548 /// }
549 /// ```
550 pub async fn list_checkpoints(&self, job_id: &str, limit: Option<u32>, after: Option<&str>) -> Result<FineTuningCheckpointListResponse> {
551 let (client, headers) = self.create_client()?;
552
553 let mut url = format!("{}/{}/checkpoints", self.auth.endpoint(FINE_TUNING_PATH), job_id);
554 let mut params = Vec::new();
555
556 if let Some(l) = limit {
557 params.push(format!("limit={}", l));
558 }
559 if let Some(a) = after {
560 params.push(format!("after={}", a));
561 }
562
563 if !params.is_empty() {
564 url.push('?');
565 url.push_str(¶ms.join("&"));
566 }
567
568 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
569
570 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
571
572 if cfg!(test) {
573 tracing::info!("Response content: {}", content);
574 }
575
576 serde_json::from_str::<FineTuningCheckpointListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
577 }
578}