1use serde::{Deserialize, Serialize};
45
46use crate::api::{get_json, parse_error_response, post_json};
47use crate::config::OpenAIClient;
48use crate::error::OpenAIError;
49
50#[derive(Debug, Serialize, Default, Clone)]
57pub struct CreateFineTuneRequest {
58 pub training_file: String,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
67 pub validation_file: Option<String>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
72 pub model: Option<String>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub n_epochs: Option<u32>,
77
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub batch_size: Option<u32>,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
85 pub learning_rate_multiplier: Option<f64>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub prompt_loss_weight: Option<f64>,
90
91 #[serde(skip_serializing_if = "Option::is_none")]
94 pub compute_classification_metrics: Option<bool>,
95
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub classification_n_classes: Option<u32>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub classification_positive_class: Option<String>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub classification_betas: Option<Vec<f64>>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub suffix: Option<String>,
111}
112
113#[derive(Debug, Deserialize)]
115pub struct FineTune {
116 pub id: String,
118 pub object: String,
120 pub created_at: u64,
122 pub updated_at: u64,
124 pub model: String,
126 pub fine_tuned_model: Option<String>,
128 pub status: String,
130 #[serde(default)]
132 pub events: Vec<FineTuneEvent>,
133}
134
135#[derive(Debug, Deserialize)]
137pub struct FineTuneEvent {
138 pub object: String,
140 pub created_at: u64,
142 pub level: String,
144 pub message: String,
146}
147
148#[derive(Debug, Deserialize)]
150pub struct FineTuneList {
151 pub object: String,
153 pub data: Vec<FineTune>,
155}
156
157pub async fn create_fine_tune(
174 client: &OpenAIClient,
175 request: &CreateFineTuneRequest,
176) -> Result<FineTune, OpenAIError> {
177 let endpoint = "fine-tunes";
178 post_json(client, endpoint, request).await
179}
180
181pub async fn list_fine_tunes(client: &OpenAIClient) -> Result<FineTuneList, OpenAIError> {
193 let endpoint = "fine-tunes";
194 get_json(client, endpoint).await
195}
196
197pub async fn retrieve_fine_tune(
213 client: &OpenAIClient,
214 fine_tune_id: &str,
215) -> Result<FineTune, OpenAIError> {
216 let endpoint = format!("fine-tunes/{}", fine_tune_id);
217 get_json(client, &endpoint).await
218}
219
220pub async fn cancel_fine_tune(
236 client: &OpenAIClient,
237 fine_tune_id: &str,
238) -> Result<FineTune, OpenAIError> {
239 let endpoint = format!("fine-tunes/{}/cancel", fine_tune_id);
240 post_json::<(), FineTune>(client, &endpoint, &()).await
241}
242
243pub async fn list_fine_tune_events(
259 client: &OpenAIClient,
260 fine_tune_id: &str,
261) -> Result<FineTuneEventsList, OpenAIError> {
262 let endpoint = format!("fine-tunes/{}/events", fine_tune_id);
263 get_json(client, &endpoint).await
264}
265
266#[derive(Debug, Deserialize)]
268pub struct FineTuneEventsList {
269 pub object: String,
271 pub data: Vec<FineTuneEvent>,
273}
274
275pub async fn delete_fine_tune_model(
285 client: &OpenAIClient,
286 model: &str,
287) -> Result<DeleteFineTuneModelResponse, OpenAIError> {
288 let endpoint = format!("models/{}", model);
290 let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
291
292 let response = client
293 .http_client
294 .delete(&url)
295 .bearer_auth(client.api_key())
296 .send()
297 .await?; if !response.status().is_success() {
301 return Err(parse_error_response(response).await?);
303 }
304
305 let response_body = response.json::<DeleteFineTuneModelResponse>().await?;
307 Ok(response_body)
308}
309#[derive(Debug, Deserialize)]
311pub struct DeleteFineTuneModelResponse {
312 pub object: String,
314 pub id: String,
316 pub deleted: bool,
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
334 use crate::config::OpenAIClient;
335 use crate::error::OpenAIError;
336 use serde_json::json;
337 use wiremock::matchers::{method, path, path_regex};
338 use wiremock::{Mock, MockServer, ResponseTemplate};
339
340 #[tokio::test]
341 async fn test_create_fine_tune_success() {
342 let mock_server = MockServer::start().await;
343
344 let success_body = json!({
346 "id": "ft-abcdefgh",
347 "object": "fine-tune",
348 "created_at": 1673645000,
349 "updated_at": 1673645200,
350 "model": "curie",
351 "fine_tuned_model": null,
352 "status": "pending",
353 "events": []
354 });
355
356 Mock::given(method("POST"))
357 .and(path("/fine-tunes"))
358 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
359 .mount(&mock_server)
360 .await;
361
362 let client = OpenAIClient::builder()
363 .with_api_key("test-key")
364 .with_base_url(&mock_server.uri())
365 .build()
366 .unwrap();
367
368 let req = CreateFineTuneRequest {
369 training_file: "file-abc123".into(),
370 model: Some("curie".into()),
371 ..Default::default()
372 };
373
374 let result = create_fine_tune(&client, &req).await;
375 assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
376
377 let fine_tune = result.unwrap();
378 assert_eq!(fine_tune.id, "ft-abcdefgh");
379 assert_eq!(fine_tune.status, "pending");
380 assert_eq!(fine_tune.model, "curie");
381 assert!(fine_tune.fine_tuned_model.is_none());
382 assert_eq!(fine_tune.events.len(), 0);
383 }
384
385 #[tokio::test]
386 async fn test_create_fine_tune_api_error() {
387 let mock_server = MockServer::start().await;
388
389 let error_body = json!({
391 "error": {
392 "message": "Invalid training file",
393 "type": "invalid_request_error",
394 "code": null
395 }
396 });
397
398 Mock::given(method("POST"))
399 .and(path("/fine-tunes"))
400 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
401 .mount(&mock_server)
402 .await;
403
404 let client = OpenAIClient::builder()
405 .with_api_key("test-key")
406 .with_base_url(&mock_server.uri())
407 .build()
408 .unwrap();
409
410 let req = CreateFineTuneRequest {
411 training_file: "file-nonexistent".into(),
412 ..Default::default()
413 };
414
415 let result = create_fine_tune(&client, &req).await;
416 match result {
417 Err(OpenAIError::APIError { message, .. }) => {
418 assert!(message.contains("Invalid training file"));
419 }
420 other => panic!("Expected APIError, got: {:?}", other),
421 }
422 }
423
424 #[tokio::test]
425 async fn test_list_fine_tunes_success() {
426 let mock_server = MockServer::start().await;
427
428 let success_body = json!({
429 "object": "list",
430 "data": [
431 {
432 "id": "ft-abc123",
433 "object": "fine-tune",
434 "created_at": 1673645000,
435 "updated_at": 1673645200,
436 "model": "curie",
437 "fine_tuned_model": "curie:ft-yourorg-2023-01-01-xxxx",
438 "status": "succeeded",
439 "events": []
440 }
441 ]
442 });
443
444 Mock::given(method("GET"))
445 .and(path("/fine-tunes"))
446 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
447 .mount(&mock_server)
448 .await;
449
450 let client = OpenAIClient::builder()
451 .with_api_key("test-key")
452 .with_base_url(&mock_server.uri())
453 .build()
454 .unwrap();
455
456 let result = list_fine_tunes(&client).await;
457 assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
458
459 let list = result.unwrap();
460 assert_eq!(list.object, "list");
461 assert_eq!(list.data.len(), 1);
462 let first = &list.data[0];
463 assert_eq!(first.id, "ft-abc123");
464 assert_eq!(first.status, "succeeded");
465 }
466
467 #[tokio::test]
468 async fn test_list_fine_tunes_api_error() {
469 let mock_server = MockServer::start().await;
470
471 let error_body = json!({
472 "error": {
473 "message": "Could not list fine-tunes",
474 "type": "internal_server_error",
475 "code": null
476 }
477 });
478
479 Mock::given(method("GET"))
480 .and(path("/fine-tunes"))
481 .respond_with(ResponseTemplate::new(500).set_body_json(error_body))
482 .mount(&mock_server)
483 .await;
484
485 let client = OpenAIClient::builder()
486 .with_api_key("test-key")
487 .with_base_url(&mock_server.uri())
488 .build()
489 .unwrap();
490
491 let result = list_fine_tunes(&client).await;
492 match result {
493 Err(OpenAIError::APIError { message, .. }) => {
494 assert!(message.contains("Could not list fine-tunes"));
495 }
496 other => panic!("Expected APIError, got {:?}", other),
497 }
498 }
499
500 #[tokio::test]
501 async fn test_retrieve_fine_tune_success() {
502 let mock_server = MockServer::start().await;
503
504 let success_body = json!({
505 "id": "ft-xyz789",
506 "object": "fine-tune",
507 "created_at": 1673646000,
508 "updated_at": 1673646200,
509 "model": "curie",
510 "fine_tuned_model": null,
511 "status": "running",
512 "events": []
513 });
514
515 Mock::given(method("GET"))
516 .and(path_regex(r"^/fine-tunes/ft-xyz789$"))
517 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
518 .mount(&mock_server)
519 .await;
520
521 let client = OpenAIClient::builder()
522 .with_api_key("test-key")
523 .with_base_url(&mock_server.uri())
524 .build()
525 .unwrap();
526
527 let result = retrieve_fine_tune(&client, "ft-xyz789").await;
528 assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
529
530 let ft = result.unwrap();
531 assert_eq!(ft.id, "ft-xyz789");
532 assert_eq!(ft.status, "running");
533 }
534
535 #[tokio::test]
536 async fn test_retrieve_fine_tune_api_error() {
537 let mock_server = MockServer::start().await;
538 let error_body = json!({
539 "error": {
540 "message": "Fine-tune not found",
541 "type": "invalid_request_error",
542 "code": null
543 }
544 });
545
546 Mock::given(method("GET"))
547 .and(path_regex(r"^/fine-tunes/ft-000$"))
548 .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
549 .mount(&mock_server)
550 .await;
551
552 let client = OpenAIClient::builder()
553 .with_api_key("test-key")
554 .with_base_url(&mock_server.uri())
555 .build()
556 .unwrap();
557
558 let result = retrieve_fine_tune(&client, "ft-000").await;
559 match result {
560 Err(OpenAIError::APIError { message, .. }) => {
561 assert!(message.contains("Fine-tune not found"));
562 }
563 other => panic!("Expected APIError, got {:?}", other),
564 }
565 }
566
567 #[tokio::test]
568 async fn test_cancel_fine_tune_success() {
569 let mock_server = MockServer::start().await;
570
571 let success_body = json!({
572 "id": "ft-abc123",
573 "object": "fine-tune",
574 "created_at": 1673647000,
575 "updated_at": 1673647200,
576 "model": "curie",
577 "fine_tuned_model": null,
578 "status": "cancelled",
579 "events": []
580 });
581
582 Mock::given(method("POST"))
583 .and(path_regex(r"^/fine-tunes/ft-abc123/cancel$"))
584 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
585 .mount(&mock_server)
586 .await;
587
588 let client = OpenAIClient::builder()
589 .with_api_key("test-key")
590 .with_base_url(&mock_server.uri())
591 .build()
592 .unwrap();
593
594 let result = cancel_fine_tune(&client, "ft-abc123").await;
595 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
596
597 let ft = result.unwrap();
598 assert_eq!(ft.id, "ft-abc123");
599 assert_eq!(ft.status, "cancelled");
600 }
601
602 #[tokio::test]
603 async fn test_cancel_fine_tune_api_error() {
604 let mock_server = MockServer::start().await;
605
606 let error_body = json!({
607 "error": {
608 "message": "Cannot cancel a completed fine-tune",
609 "type": "invalid_request_error",
610 "code": null
611 }
612 });
613
614 Mock::given(method("POST"))
615 .and(path_regex(r"^/fine-tunes/ft-zzz/cancel$"))
616 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
617 .mount(&mock_server)
618 .await;
619
620 let client = OpenAIClient::builder()
621 .with_api_key("test-key")
622 .with_base_url(&mock_server.uri())
623 .build()
624 .unwrap();
625
626 let result = cancel_fine_tune(&client, "ft-zzz").await;
627 match result {
628 Err(OpenAIError::APIError { message, .. }) => {
629 assert!(message.contains("Cannot cancel a completed fine-tune"));
630 }
631 other => panic!("Expected APIError, got {:?}", other),
632 }
633 }
634
635 #[tokio::test]
636 async fn test_list_fine_tune_events_success() {
637 let mock_server = MockServer::start().await;
638
639 let success_body = json!({
640 "object": "list",
641 "data": [
642 {
643 "object": "fine-tune-event",
644 "created_at": 1673648000,
645 "level": "info",
646 "message": "Job enqueued"
647 },
648 {
649 "object": "fine-tune-event",
650 "created_at": 1673648100,
651 "level": "info",
652 "message": "Job started"
653 }
654 ]
655 });
656
657 Mock::given(method("GET"))
658 .and(path_regex(r"^/fine-tunes/ft-abc/events$"))
659 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
660 .mount(&mock_server)
661 .await;
662
663 let client = OpenAIClient::builder()
664 .with_api_key("test-key")
665 .with_base_url(&mock_server.uri())
666 .build()
667 .unwrap();
668
669 let result = list_fine_tune_events(&client, "ft-abc").await;
670 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
671
672 let events_list = result.unwrap();
673 assert_eq!(events_list.object, "list");
674 assert_eq!(events_list.data.len(), 2);
675 assert_eq!(events_list.data[0].message, "Job enqueued");
676 }
677
678 #[tokio::test]
679 async fn test_list_fine_tune_events_api_error() {
680 let mock_server = MockServer::start().await;
681
682 let error_body = json!({
683 "error": {
684 "message": "No events found",
685 "type": "invalid_request_error",
686 "code": null
687 }
688 });
689
690 Mock::given(method("GET"))
691 .and(path_regex(r"^/fine-tunes/ft-xyz/events$"))
692 .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
693 .mount(&mock_server)
694 .await;
695
696 let client = OpenAIClient::builder()
697 .with_api_key("test-key")
698 .with_base_url(&mock_server.uri())
699 .build()
700 .unwrap();
701
702 let result = list_fine_tune_events(&client, "ft-xyz").await;
703 match result {
704 Err(OpenAIError::APIError { message, .. }) => {
705 assert!(message.contains("No events found"));
706 }
707 other => panic!("Expected APIError, got {:?}", other),
708 }
709 }
710
711 #[tokio::test]
712 async fn test_delete_fine_tune_model_success() {
713 let mock_server = MockServer::start().await;
714
715 let success_body = json!({
716 "object": "model",
717 "id": "curie:ft-yourorg-2023-01-01-xxxx",
718 "deleted": true
719 });
720
721 Mock::given(method("DELETE"))
722 .and(path_regex(r"^/models/curie:ft-yourorg-2023-01-01-xxxx$"))
723 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
724 .mount(&mock_server)
725 .await;
726
727 let client = OpenAIClient::builder()
728 .with_api_key("test-key")
729 .with_base_url(&mock_server.uri())
730 .build()
731 .unwrap();
732
733 let result = delete_fine_tune_model(&client, "curie:ft-yourorg-2023-01-01-xxxx").await;
734 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
735
736 let del_resp = result.unwrap();
737 assert_eq!(del_resp.object, "model");
738 assert_eq!(del_resp.id, "curie:ft-yourorg-2023-01-01-xxxx");
739 assert!(del_resp.deleted);
740 }
741
742 #[tokio::test]
743 async fn test_delete_fine_tune_model_api_error() {
744 let mock_server = MockServer::start().await;
745
746 let error_body = json!({
747 "error": {
748 "message": "Model not found",
749 "type": "invalid_request_error",
750 "code": null
751 }
752 });
753
754 Mock::given(method("DELETE"))
755 .and(path_regex(r"^/models/doesnotexist$"))
756 .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
757 .mount(&mock_server)
758 .await;
759
760 let client = OpenAIClient::builder()
761 .with_api_key("test-key")
762 .with_base_url(&mock_server.uri())
763 .build()
764 .unwrap();
765
766 let result = delete_fine_tune_model(&client, "doesnotexist").await;
767 match result {
768 Err(OpenAIError::APIError { message, .. }) => {
769 assert!(message.contains("Model not found"));
770 }
771 other => panic!("Expected APIError, got {:?}", other),
772 }
773 }
774}