Skip to main content

openai_core/resources/
fine_tuning.rs

1//! Fine-tuning namespace implementations.
2
3use std::collections::BTreeMap;
4
5use http::Method;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::generated::endpoints;
10use crate::json_payload::JsonPayload;
11
12use super::{
13    DeleteResponse, FineTuningAlphaGradersResource, FineTuningAlphaResource, FineTuningCheckpoint,
14    FineTuningCheckpointPermission, FineTuningCheckpointPermissionsResource, FineTuningJob,
15    FineTuningJobCheckpointsResource, FineTuningJobCreateRequestBuilder, FineTuningJobEvent,
16    FineTuningJobsResource, FineTuningResource, GradersResource, JsonRequestBuilder,
17    ListRequestBuilder, encode_path_segment,
18};
19
20/// 表示 grader 执行结果。
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22pub struct GraderRunResponse {
23    /// grader 元数据。
24    pub metadata: Option<GraderRunMetadata>,
25    /// 按模型拆分的 token 使用情况。
26    #[serde(default)]
27    pub model_grader_token_usage_per_model: BTreeMap<String, Value>,
28    /// 总 reward。
29    pub reward: Option<f64>,
30    /// 子 reward。
31    #[serde(default)]
32    pub sub_rewards: BTreeMap<String, Value>,
33    /// 额外字段。
34    #[serde(flatten)]
35    pub extra: BTreeMap<String, Value>,
36}
37
38/// 表示 grader 运行元数据。
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct GraderRunMetadata {
41    /// 错误位图。
42    pub errors: Option<GraderRunErrors>,
43    /// 执行时间。
44    pub execution_time: Option<f64>,
45    /// grader 名称。
46    pub name: Option<String>,
47    /// 被评估模型名称。
48    pub sampled_model_name: Option<String>,
49    /// 分数字段。
50    #[serde(default)]
51    pub scores: BTreeMap<String, Value>,
52    /// token 使用量。
53    pub token_usage: Option<u64>,
54    /// grader 类型。
55    #[serde(rename = "type")]
56    pub grader_type: Option<String>,
57    /// 额外字段。
58    #[serde(flatten)]
59    pub extra: BTreeMap<String, Value>,
60}
61
62/// 表示 grader 执行错误标记。
63#[derive(Debug, Clone, Serialize, Deserialize, Default)]
64pub struct GraderRunErrors {
65    /// 公式解析错误。
66    #[serde(default)]
67    pub formula_parse_error: bool,
68    /// 非法变量错误。
69    #[serde(default)]
70    pub invalid_variable_error: bool,
71    /// 模型 grader 解析错误。
72    #[serde(default)]
73    pub model_grader_parse_error: bool,
74    /// 模型 grader 拒绝错误。
75    #[serde(default)]
76    pub model_grader_refusal_error: bool,
77    /// 模型 grader 服务端错误。
78    #[serde(default)]
79    pub model_grader_server_error: bool,
80    /// 模型 grader 服务端错误细节。
81    pub model_grader_server_error_details: Option<String>,
82    /// 其他错误。
83    #[serde(default)]
84    pub other_error: bool,
85    /// Python grader 运行时错误。
86    #[serde(default)]
87    pub python_grader_runtime_error: bool,
88    /// Python grader 运行时错误细节。
89    pub python_grader_runtime_error_details: Option<String>,
90    /// Python grader 服务端错误。
91    #[serde(default)]
92    pub python_grader_server_error: bool,
93    /// Python grader 服务端错误类型。
94    pub python_grader_server_error_type: Option<String>,
95    /// 样本解析错误。
96    #[serde(default)]
97    pub sample_parse_error: bool,
98    /// 截断观测错误。
99    #[serde(default)]
100    pub truncated_observation_error: bool,
101    /// 无响应 reward 错误。
102    #[serde(default)]
103    pub unresponsive_reward_error: bool,
104    /// 额外字段。
105    #[serde(flatten)]
106    pub extra: BTreeMap<String, Value>,
107}
108
109/// 表示 grader 校验结果。
110#[derive(Debug, Clone, Serialize, Deserialize, Default)]
111pub struct GraderValidateResponse {
112    /// 返回的 grader 定义。
113    pub grader: Option<JsonPayload>,
114    /// 额外字段。
115    #[serde(flatten)]
116    pub extra: BTreeMap<String, Value>,
117}
118
119/// 表示 grader model 列表。
120#[derive(Debug, Clone, Serialize, Deserialize, Default)]
121pub struct GraderModelCatalog {
122    /// 列表对象类型。
123    pub object: Option<String>,
124    /// grader models。
125    #[serde(default)]
126    pub data: Vec<GraderModel>,
127    /// 额外字段。
128    #[serde(flatten)]
129    pub extra: BTreeMap<String, Value>,
130}
131
132/// 表示单个 grader model。
133#[derive(Debug, Clone, Serialize, Deserialize, Default)]
134pub struct GraderModel {
135    /// grader model ID。
136    pub id: Option<String>,
137    /// grader 名称。
138    pub name: Option<String>,
139    /// grader 类型。
140    #[serde(rename = "type")]
141    pub grader_type: Option<String>,
142    /// 额外字段。
143    #[serde(flatten)]
144    pub extra: BTreeMap<String, Value>,
145}
146
147impl FineTuningResource {
148    /// 返回 jobs 子资源。
149    pub fn jobs(&self) -> FineTuningJobsResource {
150        FineTuningJobsResource::new(self.client.clone())
151    }
152
153    /// 返回 checkpoints permissions 子资源。
154    pub fn checkpoints(&self) -> FineTuningCheckpointPermissionsResource {
155        FineTuningCheckpointPermissionsResource::new(self.client.clone())
156    }
157
158    /// 返回 alpha 子资源。
159    pub fn alpha(&self) -> FineTuningAlphaResource {
160        FineTuningAlphaResource::new(self.client.clone())
161    }
162}
163
164impl FineTuningJobsResource {
165    /// 创建 fine-tuning job。
166    pub fn create(&self) -> FineTuningJobCreateRequestBuilder {
167        FineTuningJobCreateRequestBuilder::new(self.client.clone())
168    }
169
170    /// 获取 fine-tuning job。
171    pub fn retrieve(&self, job_id: impl Into<String>) -> JsonRequestBuilder<FineTuningJob> {
172        JsonRequestBuilder::new(
173            self.client.clone(),
174            "fine_tuning.jobs.retrieve",
175            Method::GET,
176            format!("/fine_tuning/jobs/{}", encode_path_segment(job_id.into())),
177        )
178    }
179
180    /// 列出 fine-tuning jobs。
181    pub fn list(&self) -> ListRequestBuilder<FineTuningJob> {
182        ListRequestBuilder::new(
183            self.client.clone(),
184            "fine_tuning.jobs.list",
185            "/fine_tuning/jobs",
186        )
187    }
188
189    /// 取消 fine-tuning job。
190    pub fn cancel(&self, job_id: impl Into<String>) -> JsonRequestBuilder<FineTuningJob> {
191        JsonRequestBuilder::new(
192            self.client.clone(),
193            "fine_tuning.jobs.cancel",
194            Method::POST,
195            format!(
196                "/fine_tuning/jobs/{}/cancel",
197                encode_path_segment(job_id.into())
198            ),
199        )
200    }
201
202    /// 暂停 fine-tuning job。
203    pub fn pause(&self, job_id: impl Into<String>) -> JsonRequestBuilder<FineTuningJob> {
204        JsonRequestBuilder::new(
205            self.client.clone(),
206            "fine_tuning.jobs.pause",
207            Method::POST,
208            format!(
209                "/fine_tuning/jobs/{}/pause",
210                encode_path_segment(job_id.into())
211            ),
212        )
213    }
214
215    /// 恢复 fine-tuning job。
216    pub fn resume(&self, job_id: impl Into<String>) -> JsonRequestBuilder<FineTuningJob> {
217        JsonRequestBuilder::new(
218            self.client.clone(),
219            "fine_tuning.jobs.resume",
220            Method::POST,
221            format!(
222                "/fine_tuning/jobs/{}/resume",
223                encode_path_segment(job_id.into())
224            ),
225        )
226    }
227
228    /// 列出事件。
229    pub fn list_events(&self, job_id: impl Into<String>) -> ListRequestBuilder<FineTuningJobEvent> {
230        ListRequestBuilder::new(
231            self.client.clone(),
232            "fine_tuning.jobs.list_events",
233            format!(
234                "/fine_tuning/jobs/{}/events",
235                encode_path_segment(job_id.into())
236            ),
237        )
238    }
239
240    /// 返回 checkpoints 子资源。
241    pub fn checkpoints(&self) -> FineTuningJobCheckpointsResource {
242        FineTuningJobCheckpointsResource::new(self.client.clone())
243    }
244}
245
246impl FineTuningJobCheckpointsResource {
247    /// 列出某个 job 的 checkpoints。
248    pub fn list(&self, job_id: impl Into<String>) -> ListRequestBuilder<FineTuningCheckpoint> {
249        ListRequestBuilder::new(
250            self.client.clone(),
251            "fine_tuning.jobs.checkpoints.list",
252            format!(
253                "/fine_tuning/jobs/{}/checkpoints",
254                encode_path_segment(job_id.into())
255            ),
256        )
257    }
258}
259
260impl FineTuningCheckpointPermissionsResource {
261    /// 创建 checkpoint permission。
262    pub fn create(
263        &self,
264        checkpoint_id: impl Into<String>,
265    ) -> JsonRequestBuilder<FineTuningCheckpointPermission> {
266        JsonRequestBuilder::new(
267            self.client.clone(),
268            "fine_tuning.checkpoints.permissions.create",
269            Method::POST,
270            format!(
271                "/fine_tuning/checkpoints/{}/permissions",
272                encode_path_segment(checkpoint_id.into())
273            ),
274        )
275    }
276
277    /// 获取 checkpoint permission。
278    pub fn retrieve(
279        &self,
280        checkpoint_id: impl Into<String>,
281        permission_id: impl Into<String>,
282    ) -> JsonRequestBuilder<FineTuningCheckpointPermission> {
283        JsonRequestBuilder::new(
284            self.client.clone(),
285            "fine_tuning.checkpoints.permissions.retrieve",
286            Method::GET,
287            format!(
288                "/fine_tuning/checkpoints/{}/permissions/{}",
289                encode_path_segment(checkpoint_id.into()),
290                encode_path_segment(permission_id.into())
291            ),
292        )
293    }
294
295    /// 列出 checkpoint permission。
296    pub fn list(
297        &self,
298        checkpoint_id: impl Into<String>,
299    ) -> ListRequestBuilder<FineTuningCheckpointPermission> {
300        ListRequestBuilder::new(
301            self.client.clone(),
302            "fine_tuning.checkpoints.permissions.list",
303            format!(
304                "/fine_tuning/checkpoints/{}/permissions",
305                encode_path_segment(checkpoint_id.into())
306            ),
307        )
308    }
309
310    /// 删除 checkpoint permission。
311    pub fn delete(
312        &self,
313        checkpoint_id: impl Into<String>,
314        permission_id: impl Into<String>,
315    ) -> JsonRequestBuilder<DeleteResponse> {
316        JsonRequestBuilder::new(
317            self.client.clone(),
318            "fine_tuning.checkpoints.permissions.delete",
319            Method::DELETE,
320            format!(
321                "/fine_tuning/checkpoints/{}/permissions/{}",
322                encode_path_segment(checkpoint_id.into()),
323                encode_path_segment(permission_id.into())
324            ),
325        )
326    }
327}
328
329impl FineTuningAlphaResource {
330    /// 返回 graders 子资源。
331    pub fn graders(&self) -> FineTuningAlphaGradersResource {
332        FineTuningAlphaGradersResource::new(self.client.clone())
333    }
334}
335
336impl FineTuningAlphaGradersResource {
337    /// 执行 grader。
338    pub fn run(&self) -> JsonRequestBuilder<GraderRunResponse> {
339        let endpoint = endpoints::fine_tuning::FINE_TUNING_ALPHA_GRADERS_RUN;
340        JsonRequestBuilder::new(
341            self.client.clone(),
342            endpoint.id,
343            Method::POST,
344            endpoint.template,
345        )
346    }
347
348    /// 校验 grader。
349    pub fn validate(&self) -> JsonRequestBuilder<GraderValidateResponse> {
350        let endpoint = endpoints::fine_tuning::FINE_TUNING_ALPHA_GRADERS_VALIDATE;
351        JsonRequestBuilder::new(
352            self.client.clone(),
353            endpoint.id,
354            Method::POST,
355            endpoint.template,
356        )
357    }
358}
359
360impl GradersResource {
361    /// 当前资源主要导出类型,暂不提供额外 HTTP 方法。
362    pub fn grader_models(&self) -> JsonRequestBuilder<GraderModelCatalog> {
363        let endpoint = endpoints::fine_tuning::GRADERS_GRADER_MODELS;
364        JsonRequestBuilder::new(
365            self.client.clone(),
366            endpoint.id,
367            Method::GET,
368            endpoint.template,
369        )
370    }
371}