dynamo_async_openai/
fine_tuning.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use serde::Serialize;
12
13use crate::{
14    Client,
15    config::Config,
16    error::OpenAIError,
17    types::{
18        CreateFineTuningJobRequest, FineTuningJob, ListFineTuningJobCheckpointsResponse,
19        ListFineTuningJobEventsResponse, ListPaginatedFineTuningJobsResponse,
20    },
21};
22
23/// Manage fine-tuning jobs to tailor a model to your specific training data.
24///
25/// Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning)
26pub struct FineTuning<'c, C: Config> {
27    client: &'c Client<C>,
28}
29
30impl<'c, C: Config> FineTuning<'c, C> {
31    pub fn new(client: &'c Client<C>) -> Self {
32        Self { client }
33    }
34
35    /// Creates a job that fine-tunes a specified model from a given dataset.
36    ///
37    /// Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
38    ///
39    /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
40    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
41    pub async fn create(
42        &self,
43        request: CreateFineTuningJobRequest,
44    ) -> Result<FineTuningJob, OpenAIError> {
45        self.client.post("/fine_tuning/jobs", request).await
46    }
47
48    /// List your organization's fine-tuning jobs
49    #[crate::byot(T0 = serde::Serialize, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
50    pub async fn list_paginated<Q>(
51        &self,
52        query: &Q,
53    ) -> Result<ListPaginatedFineTuningJobsResponse, OpenAIError>
54    where
55        Q: Serialize + ?Sized,
56    {
57        self.client
58            .get_with_query("/fine_tuning/jobs", &query)
59            .await
60    }
61
62    /// Gets info about the fine-tune job.
63    ///
64    /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
65    #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
66    pub async fn retrieve(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
67        self.client
68            .get(format!("/fine_tuning/jobs/{fine_tuning_job_id}").as_str())
69            .await
70    }
71
72    /// Immediately cancel a fine-tune job.
73    #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
74    pub async fn cancel(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
75        self.client
76            .post(
77                format!("/fine_tuning/jobs/{fine_tuning_job_id}/cancel").as_str(),
78                (),
79            )
80            .await
81    }
82
83    /// Get fine-grained status updates for a fine-tune job.
84    #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
85    pub async fn list_events<Q>(
86        &self,
87        fine_tuning_job_id: &str,
88        query: &Q,
89    ) -> Result<ListFineTuningJobEventsResponse, OpenAIError>
90    where
91        Q: Serialize + ?Sized,
92    {
93        self.client
94            .get_with_query(
95                format!("/fine_tuning/jobs/{fine_tuning_job_id}/events").as_str(),
96                &query,
97            )
98            .await
99    }
100
101    #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
102    pub async fn list_checkpoints<Q>(
103        &self,
104        fine_tuning_job_id: &str,
105        query: &Q,
106    ) -> Result<ListFineTuningJobCheckpointsResponse, OpenAIError>
107    where
108        Q: Serialize + ?Sized,
109    {
110        self.client
111            .get_with_query(
112                format!("/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints").as_str(),
113                &query,
114            )
115            .await
116    }
117}