1use std::sync::Arc;
2
3use reqwest::{
4 blocking::{Client, Response},
5 Url,
6};
7use serde::{de::DeserializeOwned, Serialize};
8use serde_json::{json, Value};
9
10use crate::{
11 data::{
12 CreateExperimentOptions, CreateRunOptions, DatasetInput, Metric, Param, RunTag,
13 SearchExperimentsOptions, SearchRunsOptions, Timestamp, UpdateRunOptions,
14 },
15 Error, Result,
16};
17
18pub mod response;
19
20use response::*;
21
22#[derive(Debug, Clone)]
23pub struct MlflowClient {
24 uri: Arc<Url>,
25}
26
27impl MlflowClient {
28 pub fn new(uri: &str) -> Result<MlflowClient> {
29 Ok(MlflowClient {
30 uri: Arc::new(Url::parse(uri)?),
31 })
32 }
33 pub fn create_experiment(
35 &self,
36 name: &str,
37 options: CreateExperimentOptions,
38 ) -> Result<CreateExperimentResponse> {
39 let body = build_body(json!({ "name": name }), options)?;
40 self.post("experiments/create", body)
41 }
42
43 pub const SEARCH_EXPERIMENTS_MAX_RESULTS_SUPPORTED: i64 = 1000;
44
45 pub fn search_experiments(
47 &self,
48 options: SearchExperimentsOptions,
49 max_results: i64,
50 page_token: Option<&str>,
51 ) -> Result<SearchExperimentsResponse> {
52 let body = build_body(
53 json!({ "max_results": max_results, "page_token": page_token}),
54 options,
55 )?;
56 self.post("experiments/search", body)
57 }
58
59 pub fn get_experiment(&self, experiment_id: &str) -> Result<GetExperimentResponse> {
61 self.get("experiments/get", &[("experiment_id", experiment_id)])
62 }
63
64 pub fn get_experiment_by_name(&self, experiment_name: &str) -> Result<GetExperimentResponse> {
66 self.get(
67 "experiments/get-by-name",
68 &[("experiment_name", experiment_name)],
69 )
70 }
71
72 pub fn delete_experiment(&self, experiment_id: &str) -> Result<UnitResponse> {
74 self.post(
75 "experiments/delete",
76 json!({ "experiment_id": experiment_id }),
77 )
78 }
79
80 pub fn restore_experiment(&self, experiment_id: &str) -> Result<UnitResponse> {
82 self.post(
83 "experiments/restore",
84 json!({ "experiment_id": experiment_id }),
85 )
86 }
87
88 pub fn update_experiment(&self, experiment_id: &str, new_name: &str) -> Result<UnitResponse> {
90 self.post(
91 "experiments/update",
92 json!({
93 "experiment_id": experiment_id,
94 "new_name": new_name,
95 }),
96 )
97 }
98
99 pub fn create_run(
101 &self,
102 experiment_id: &str,
103 run_name: &str,
104 options: CreateRunOptions,
105 ) -> Result<GetRunResponse> {
106 let body = build_body(
107 json!({ "experiment_id": experiment_id, "run_name": run_name }),
108 options,
109 )?;
110 self.post("runs/create", body)
111 }
112
113 pub fn delete_run(&self, run_id: &str) -> Result<UnitResponse> {
115 self.post("runs/delete", json!({ "run_id": run_id }))
116 }
117
118 pub fn restore_run(&self, run_id: &str) -> Result<UnitResponse> {
120 self.post("runs/restore", json!({ "run_id": run_id }))
121 }
122
123 pub fn get_run(&self, run_id: &str) -> Result<GetRunResponse> {
125 self.get("runs/get", &[("run_id", run_id)])
126 }
127
128 pub fn log_metric(
130 &self,
131 run_id: &str,
132 key: &str,
133 value: f64,
134 timestamp: Timestamp,
135 step: Option<i64>,
136 ) -> Result<UnitResponse> {
137 self.post(
138 "runs/log-metric",
139 json!({
140 "run_id": run_id,
141 "key": key,
142 "value": value,
143 "timestamp": timestamp,
144 "step": step,
145 }),
146 )
147 }
148
149 pub const LOG_BATCH_MAX_TOTAL: usize = 1000;
150 pub const LOG_BATCH_MAX_METRICS: usize = 1000;
151 pub const LOG_BATCH_MAX_PARAMS: usize = 100;
152 pub const LOG_BATCH_MAX_TAGS: usize = 100;
153
154 pub fn log_batch(
156 &self,
157 run_id: &str,
158 metrics: &[Metric],
159 params: &[Param],
160 tags: &[RunTag],
161 ) -> Result<UnitResponse> {
162 self.post(
163 "runs/log-batch",
164 json!({
165 "run_id": run_id,
166 "metrics": metrics,
167 "params": params,
168 "tags": tags,
169 }),
170 )
171 }
172
173 pub fn log_inputs(&self, run_id: &str, datasets: &[DatasetInput]) -> Result<UnitResponse> {
186 self.post(
187 "runs/log-inputs",
188 json!({
189 "run_id": run_id,
190 "datasets": datasets,
191 }),
192 )
193 }
194
195 pub fn set_experiment_tag(
197 &self,
198 experiment_id: &str,
199 key: &str,
200 value: &str,
201 ) -> Result<UnitResponse> {
202 self.post(
203 "experiments/set-experiment-tag",
204 json!({
205 "experiment_id": experiment_id,
206 "key": key,
207 "value": value,
208 }),
209 )
210 }
211
212 pub fn set_tag(&self, run_id: &str, key: &str, value: &str) -> Result<UnitResponse> {
214 self.post(
215 "runs/set-tag",
216 json!({
217 "run_id": run_id,
218 "key": key,
219 "value": value,
220 }),
221 )
222 }
223
224 pub fn delete_tag(&self, run_id: &str, key: &str) -> Result<UnitResponse> {
226 self.post(
227 "runs/delete-tag",
228 json!({
229 "run_id": run_id,
230 "key": key,
231 }),
232 )
233 }
234
235 pub fn log_param(&self, run_id: &str, key: &str, value: &str) -> Result<UnitResponse> {
237 self.post(
238 "runs/log-parameter",
239 json!({
240 "run_id": run_id,
241 "key": key,
242 "value": value,
243 }),
244 )
245 }
246
247 pub fn get_metric_history(
249 &self,
250 run_id: &str,
251 metric_key: &str,
252 max_results: i32,
253 page_token: Option<&str>,
254 ) -> Result<GetMetricHistoryResponse> {
255 self.get(
256 "metrics/get-history",
257 &[
258 ("run_id", run_id),
259 ("metric_key", metric_key),
260 ("max_results", &max_results.to_string()),
261 ("page_token", page_token.unwrap_or("")),
262 ],
263 )
264 }
265
266 pub const SEARCH_RUNS_MAX_RESULTS_SUPPORTED: i32 = 50000;
267
268 pub fn search_runs(
270 &self,
271 experiment_ids: &[&str],
272 options: SearchRunsOptions,
273 max_results: i32,
274 page_token: Option<&str>,
275 ) -> Result<SearchRunsResponse> {
276 let body = build_body(
277 json!({ "experiment_ids": experiment_ids, "max_results" : max_results, "page_token": page_token }),
278 options,
279 )?;
280 self.post("runs/search", body)
281 }
282
283 pub fn update_run(&self, run_id: &str, options: UpdateRunOptions) -> Result<UpdateRunResponse> {
302 let body = build_body(json!({ "run_id":run_id }), options)?;
303 self.post("runs/update", body)
304 }
305
306 fn post<T: DeserializeOwned>(&self, path: &str, body: impl Serialize) -> Result<T> {
307 to_result(Client::new().post(self.url(path)?).json(&body).send()?)
308 }
309 fn get<T: DeserializeOwned>(&self, path: &str, query: &[(&str, &str)]) -> Result<T> {
310 to_result(Client::new().get(self.url(path)?).query(query).send()?)
311 }
312
313 fn url(&self, path: &str) -> Result<Url> {
314 Ok(self.uri.join("/api/2.0/mlflow/")?.join(path)?)
315 }
316}
317impl Default for MlflowClient {
318 fn default() -> Self {
319 MlflowClient::new("http://localhost:5000").unwrap()
320 }
321}
322
323fn to_result<T: DeserializeOwned>(r: Response) -> Result<T> {
324 if r.status().is_success() {
325 Ok(r.json()?)
326 } else {
327 let e: ErrorResponse = r.json()?;
328 Err(Error::ApiError {
329 error_code: e.error_code,
330 message: e.message,
331 })
332 }
333}
334fn build_body(json: Value, options: impl Serialize) -> Result<Value> {
335 let Value::Object(mut l) = json else {
336 panic!("l: expected object");
337 };
338 let r = serde_json::to_value(options)?;
339 let Value::Object(r) = r else {
340 panic!("r: expected object");
341 };
342 for (k, v) in r {
343 l.insert(k, v);
344 }
345 Ok(Value::Object(l))
346}