mlflow_client/
client.rs

1use std::sync::Arc;
2
3use reqwest::{
4    blocking::{Client, Response},
5    Url,
6};
7use serde::{de::DeserializeOwned, Serialize};
8use serde_json::{json, Value};
9
10use crate::{
11    data::{
12        CreateExperimentOptions, CreateRunOptions, DatasetInput, Metric, Param, RunTag,
13        SearchExperimentsOptions, SearchRunsOptions, Timestamp, UpdateRunOptions,
14    },
15    Error, Result,
16};
17
18pub mod response;
19
20use response::*;
21
22#[derive(Debug, Clone)]
23pub struct MlflowClient {
24    uri: Arc<Url>,
25}
26
27impl MlflowClient {
28    pub fn new(uri: &str) -> Result<MlflowClient> {
29        Ok(MlflowClient {
30            uri: Arc::new(Url::parse(uri)?),
31        })
32    }
33    /// <https://mlflow.org/docs/latest/rest-api.html#create-experiment>
34    pub fn create_experiment(
35        &self,
36        name: &str,
37        options: CreateExperimentOptions,
38    ) -> Result<CreateExperimentResponse> {
39        let body = build_body(json!({ "name": name }), options)?;
40        self.post("experiments/create", body)
41    }
42
43    pub const SEARCH_EXPERIMENTS_MAX_RESULTS_SUPPORTED: i64 = 1000;
44
45    /// <https://mlflow.org/docs/latest/rest-api.html#search-experiments>
46    pub fn search_experiments(
47        &self,
48        options: SearchExperimentsOptions,
49        max_results: i64,
50        page_token: Option<&str>,
51    ) -> Result<SearchExperimentsResponse> {
52        let body = build_body(
53            json!({ "max_results": max_results, "page_token": page_token}),
54            options,
55        )?;
56        self.post("experiments/search", body)
57    }
58
59    /// <https://mlflow.org/docs/latest/rest-api.html#get-experiment>
60    pub fn get_experiment(&self, experiment_id: &str) -> Result<GetExperimentResponse> {
61        self.get("experiments/get", &[("experiment_id", experiment_id)])
62    }
63
64    /// <https://mlflow.org/docs/latest/rest-api.html#get-experiment-by-name>
65    pub fn get_experiment_by_name(&self, experiment_name: &str) -> Result<GetExperimentResponse> {
66        self.get(
67            "experiments/get-by-name",
68            &[("experiment_name", experiment_name)],
69        )
70    }
71
72    /// <https://mlflow.org/docs/latest/rest-api.html#delete-experiment>
73    pub fn delete_experiment(&self, experiment_id: &str) -> Result<UnitResponse> {
74        self.post(
75            "experiments/delete",
76            json!({ "experiment_id": experiment_id }),
77        )
78    }
79
80    /// <https://mlflow.org/docs/latest/rest-api.html#restore-experiment>
81    pub fn restore_experiment(&self, experiment_id: &str) -> Result<UnitResponse> {
82        self.post(
83            "experiments/restore",
84            json!({ "experiment_id": experiment_id }),
85        )
86    }
87
88    /// <https://mlflow.org/docs/latest/rest-api.html#update-experiment>
89    pub fn update_experiment(&self, experiment_id: &str, new_name: &str) -> Result<UnitResponse> {
90        self.post(
91            "experiments/update",
92            json!({
93                "experiment_id": experiment_id,
94                "new_name": new_name,
95            }),
96        )
97    }
98
99    /// <https://mlflow.org/docs/latest/rest-api.html#create-run>
100    pub fn create_run(
101        &self,
102        experiment_id: &str,
103        run_name: &str,
104        options: CreateRunOptions,
105    ) -> Result<GetRunResponse> {
106        let body = build_body(
107            json!({ "experiment_id": experiment_id, "run_name": run_name }),
108            options,
109        )?;
110        self.post("runs/create", body)
111    }
112
113    /// <https://mlflow.org/docs/latest/rest-api.html#delete-run>
114    pub fn delete_run(&self, run_id: &str) -> Result<UnitResponse> {
115        self.post("runs/delete", json!({ "run_id": run_id }))
116    }
117
118    /// <https://mlflow.org/docs/latest/rest-api.html#restore-run>
119    pub fn restore_run(&self, run_id: &str) -> Result<UnitResponse> {
120        self.post("runs/restore", json!({ "run_id": run_id }))
121    }
122
123    /// <https://mlflow.org/docs/latest/rest-api.html#get-run>
124    pub fn get_run(&self, run_id: &str) -> Result<GetRunResponse> {
125        self.get("runs/get", &[("run_id", run_id)])
126    }
127
128    /// <https://mlflow.org/docs/latest/rest-api.html#log-metric>
129    pub fn log_metric(
130        &self,
131        run_id: &str,
132        key: &str,
133        value: f64,
134        timestamp: Timestamp,
135        step: Option<i64>,
136    ) -> Result<UnitResponse> {
137        self.post(
138            "runs/log-metric",
139            json!({
140                "run_id": run_id,
141                "key": key,
142                "value": value,
143                "timestamp": timestamp,
144                "step": step,
145            }),
146        )
147    }
148
149    pub const LOG_BATCH_MAX_TOTAL: usize = 1000;
150    pub const LOG_BATCH_MAX_METRICS: usize = 1000;
151    pub const LOG_BATCH_MAX_PARAMS: usize = 100;
152    pub const LOG_BATCH_MAX_TAGS: usize = 100;
153
154    /// <https://mlflow.org/docs/latest/rest-api.html#log-batch>
155    pub fn log_batch(
156        &self,
157        run_id: &str,
158        metrics: &[Metric],
159        params: &[Param],
160        tags: &[RunTag],
161    ) -> Result<UnitResponse> {
162        self.post(
163            "runs/log-batch",
164            json!({
165                "run_id": run_id,
166                "metrics": metrics,
167                "params": params,
168                "tags": tags,
169            }),
170        )
171    }
172
173    // /// <https://mlflow.org/docs/latest/rest-api.html#log-model>
174    // pub fn log_model(&self, run_id: &str, model_json: &str) -> Result<UnitResponse> {
175    //     self.post(
176    //         "runs/log-model",
177    //         json!({
178    //             "run_id": run_id,
179    //             "model_json": model_json,
180    //         }),
181    //     )
182    // }
183
184    /// <https://mlflow.org/docs/latest/rest-api.html#log-inputs>
185    pub fn log_inputs(&self, run_id: &str, datasets: &[DatasetInput]) -> Result<UnitResponse> {
186        self.post(
187            "runs/log-inputs",
188            json!({
189                "run_id": run_id,
190                "datasets": datasets,
191            }),
192        )
193    }
194
195    /// <https://mlflow.org/docs/latest/rest-api.html#set-experiment-tag>
196    pub fn set_experiment_tag(
197        &self,
198        experiment_id: &str,
199        key: &str,
200        value: &str,
201    ) -> Result<UnitResponse> {
202        self.post(
203            "experiments/set-experiment-tag",
204            json!({
205                "experiment_id": experiment_id,
206                "key": key,
207                "value": value,
208            }),
209        )
210    }
211
212    /// <https://mlflow.org/docs/latest/rest-api.html#set-tag>
213    pub fn set_tag(&self, run_id: &str, key: &str, value: &str) -> Result<UnitResponse> {
214        self.post(
215            "runs/set-tag",
216            json!({
217                "run_id": run_id,
218                "key": key,
219                "value": value,
220            }),
221        )
222    }
223
224    /// <https://mlflow.org/docs/latest/rest-api.html#delete-tag>
225    pub fn delete_tag(&self, run_id: &str, key: &str) -> Result<UnitResponse> {
226        self.post(
227            "runs/delete-tag",
228            json!({
229                "run_id": run_id,
230                "key": key,
231            }),
232        )
233    }
234
235    /// <https://mlflow.org/docs/latest/rest-api.html#log-param>
236    pub fn log_param(&self, run_id: &str, key: &str, value: &str) -> Result<UnitResponse> {
237        self.post(
238            "runs/log-parameter",
239            json!({
240                "run_id": run_id,
241                "key": key,
242                "value": value,
243            }),
244        )
245    }
246
247    /// <https://mlflow.org/docs/latest/rest-api.html#get-metric-history>
248    pub fn get_metric_history(
249        &self,
250        run_id: &str,
251        metric_key: &str,
252        max_results: i32,
253        page_token: Option<&str>,
254    ) -> Result<GetMetricHistoryResponse> {
255        self.get(
256            "metrics/get-history",
257            &[
258                ("run_id", run_id),
259                ("metric_key", metric_key),
260                ("max_results", &max_results.to_string()),
261                ("page_token", page_token.unwrap_or("")),
262            ],
263        )
264    }
265
266    pub const SEARCH_RUNS_MAX_RESULTS_SUPPORTED: i32 = 50000;
267
268    /// <https://mlflow.org/docs/latest/rest-api.html#search-runs>
269    pub fn search_runs(
270        &self,
271        experiment_ids: &[&str],
272        options: SearchRunsOptions,
273        max_results: i32,
274        page_token: Option<&str>,
275    ) -> Result<SearchRunsResponse> {
276        let body = build_body(
277            json!({ "experiment_ids": experiment_ids, "max_results" : max_results, "page_token": page_token }),
278            options,
279        )?;
280        self.post("runs/search", body)
281    }
282
283    // /// <https://mlflow.org/docs/latest/rest-api.html#list-artifacts>
284    // pub fn list_artifacts(
285    //     &self,
286    //     run_id: &str,
287    //     path: &str,
288    //     page_token: Option<&str>,
289    // ) -> Result<ListArtifactsResponse> {
290    //     self.get(
291    //         "artifacts/list",
292    //         &[
293    //             ("run_id", run_id),
294    //             ("path", path),
295    //             ("page_token", page_token.unwrap_or("")),
296    //         ],
297    //     )
298    // }
299
300    /// <https://mlflow.org/docs/latest/rest-api.html#update-run>
301    pub fn update_run(&self, run_id: &str, options: UpdateRunOptions) -> Result<UpdateRunResponse> {
302        let body = build_body(json!({ "run_id":run_id }), options)?;
303        self.post("runs/update", body)
304    }
305
306    fn post<T: DeserializeOwned>(&self, path: &str, body: impl Serialize) -> Result<T> {
307        to_result(Client::new().post(self.url(path)?).json(&body).send()?)
308    }
309    fn get<T: DeserializeOwned>(&self, path: &str, query: &[(&str, &str)]) -> Result<T> {
310        to_result(Client::new().get(self.url(path)?).query(query).send()?)
311    }
312
313    fn url(&self, path: &str) -> Result<Url> {
314        Ok(self.uri.join("/api/2.0/mlflow/")?.join(path)?)
315    }
316}
317impl Default for MlflowClient {
318    fn default() -> Self {
319        MlflowClient::new("http://localhost:5000").unwrap()
320    }
321}
322
323fn to_result<T: DeserializeOwned>(r: Response) -> Result<T> {
324    if r.status().is_success() {
325        Ok(r.json()?)
326    } else {
327        let e: ErrorResponse = r.json()?;
328        Err(Error::ApiError {
329            error_code: e.error_code,
330            message: e.message,
331        })
332    }
333}
334fn build_body(json: Value, options: impl Serialize) -> Result<Value> {
335    let Value::Object(mut l) = json else {
336        panic!("l: expected object");
337    };
338    let r = serde_json::to_value(options)?;
339    let Value::Object(r) = r else {
340        panic!("r: expected object");
341    };
342    for (k, v) in r {
343        l.insert(k, v);
344    }
345    Ok(Value::Object(l))
346}