border_mlflow_tracking/
client.rs

1use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run};
2use anyhow::Result;
3use border_core::{Env, ReplayBufferBase};
4use log::info;
5use reqwest::blocking::Client;
6use serde::{Deserialize, Serialize};
7use std::error::Error;
8use std::fmt::Display;
9
10#[derive(Debug, Deserialize)]
11/// Internally used.
12struct Experiment_ {
13    pub(crate) experiment: Experiment,
14}
15
16#[derive(Clone, Debug, Deserialize)]
17/// Internally used.
18struct Run_ {
19    run: Run,
20}
21
22#[derive(Debug, Clone)]
23pub struct GetExperimentIdError;
24
25impl Display for GetExperimentIdError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "Failed to get experiment ID")
28    }
29}
30
31impl Error for GetExperimentIdError {}
32
33#[derive(Debug, Serialize)]
34/// Request body of [Create Run](https://mlflow.org/docs/2.11.3/rest-api.html#id74).
35struct CreateRunParams {
36    experiment_id: String,
37    start_time: i64,
38    run_name: String,
39}
40
41#[derive(Debug, Serialize)]
42/// Request body of [Create Experiment](https://mlflow.org/docs/2.11.3/rest-api.html#id67).
43struct CreateExperimentParams {
44    name: String,
45}
46
47#[derive(Debug, Serialize)]
48/// Request body of [Search Runs](https://mlflow.org/docs/2.11.3/rest-api.html#id87).
49struct SearchRunsParams {
50    experiment_ids: Vec<String>,
51    filter: String,
52}
53
54#[derive(Clone, Debug, Deserialize)]
55pub struct SearchRunsResponse {
56    runs: Option<Vec<Run>>,
57    #[allow(dead_code)]
58    next_page_token: Option<String>,
59}
60
61#[derive(Debug)]
62/// Provides access to a MLflow tracking server via REST API.
63///
64/// Support Mlflow API version 2.0.
65pub struct MlflowTrackingClient {
66    client: Client,
67
68    /// Base URL.
69    base_url: String,
70
71    /// Current experiment ID.
72    experiment_id: Option<String>,
73
74    /// User name of the tracking server.
75    user_name: String,
76
77    /// Password.
78    password: String,
79}
80
81impl MlflowTrackingClient {
82    pub fn new(base_url: impl AsRef<str>) -> Self {
83        Self {
84            client: Client::new(),
85            base_url: base_url.as_ref().to_string(),
86            experiment_id: None,
87            user_name: "".to_string(),
88            password: "".to_string(),
89        }
90    }
91
92    /// Sets user name and password for basic authentication of the tracking server.
93    pub fn basic_auth(self, user_name: impl AsRef<str>, password: impl AsRef<str>) -> Self {
94        Self {
95            client: self.client,
96            base_url: self.base_url,
97            experiment_id: self.experiment_id,
98            user_name: user_name.as_ref().to_string(),
99            password: password.as_ref().to_string(),
100        }
101    }
102
103    /// Sets an experiment with the given name to this struct.
104    pub fn set_experiment(self, name: impl AsRef<str>) -> Result<Self, GetExperimentIdError> {
105        let experiment_id = {
106            self.get_experiment(name.as_ref())
107                .expect(format!("Failed to get experiment: {:?}", name.as_ref()).as_str())
108                .experiment_id
109        };
110
111        info!(
112            "For experiment '{}', id={} is set in MlflowTrackingClient",
113            name.as_ref(),
114            experiment_id
115        );
116
117        Ok(Self {
118            client: self.client,
119            base_url: self.base_url,
120            experiment_id: Some(experiment_id),
121            user_name: self.user_name,
122            password: self.password,
123        })
124    }
125
126    #[allow(rustdoc::private_intra_doc_links)]
127    /// Gets [`Experiment`] by name from the tracking server.
128    ///
129    /// If the experiment with given name does not exist in the trackingserver,
130    /// it will be created.
131    ///
132    /// TODO: Better error handling
133    pub fn get_experiment(&self, name: impl AsRef<str>) -> Option<Experiment> {
134        let resp = match self.get(
135            self.url("experiments/get-by-name"),
136            &[("experiment_name", name.as_ref())],
137        ) {
138            Ok(resp) => {
139                if resp.status().is_success() {
140                    resp
141                } else {
142                    // if the experiment does not exist, create it
143                    self.post(
144                        self.url("experiments/create"),
145                        &CreateExperimentParams {
146                            name: name.as_ref().into(),
147                        },
148                    )
149                    .unwrap();
150                    self.get(
151                        self.url("experiments/get-by-name"),
152                        &[("experiment_name", name.as_ref())],
153                    )
154                    .unwrap()
155                }
156            }
157            Err(_) => {
158                panic!();
159            }
160        };
161        let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap();
162
163        Some(experiment.experiment)
164    }
165
166    fn url(&self, api: impl AsRef<str>) -> String {
167        format!("{}/api/2.0/mlflow/{}", self.base_url, api.as_ref())
168    }
169
170    fn get(
171        &self,
172        url: String,
173        query: &impl Serialize,
174    ) -> reqwest::Result<reqwest::blocking::Response> {
175        self.client
176            .get(url)
177            .basic_auth(&self.user_name, Some(&self.password))
178            .query(query)
179            .send()
180    }
181
182    fn post(
183        &self,
184        url: String,
185        params: &impl Serialize,
186    ) -> reqwest::Result<reqwest::blocking::Response> {
187        self.client
188            .post(url)
189            .basic_auth(&self.user_name, Some(&self.password))
190            .json(&params) // auto serialize
191            .send()
192    }
193
194    /// Create [`MlflowTrackingRecorder`] corresponding to a run.
195    ///
196    /// If `name` is empty (`""`), a run name is generated by the tracking server.
197    ///
198    /// If a Run with `name` exists in the tracking server, the run is used
199    /// to create the recorder. If two or more runs with `name` exists,
200    /// this method panics.
201    ///
202    /// You need to set an experiment using [`MlflowTrackingClient::set_experiment()`]
203    /// before calling this method.
204    ///
205    /// This method uses `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable as the directory
206    /// where artifacts, like model parameters, will be saved. It is recommended to set this
207    /// environment variable `mlruns` directory to which the tracking server persists experiment
208    /// and run data.
209    pub fn create_recorder<E, R>(
210        &self,
211        run_name: impl AsRef<str>,
212    ) -> Result<MlflowTrackingRecorder<E, R>>
213    where
214        E: Env,
215        R: ReplayBufferBase,
216    {
217        let run_name = run_name.as_ref();
218        let run = {
219            let runs = self.get_runs_by_name(run_name)?;
220            if runs.len() >= 2 {
221                panic!("There are 2 or more runs with name '{}'", run_name);
222            } else if runs.len() == 1 {
223                runs[0].clone()
224            } else {
225                self.get_run_info(run_name)?
226            }
227        };
228        if run_name.len() == 0 {
229            info!(
230                "Run name '{}' has been automatically generated",
231                run.info.run_name
232            );
233        }
234
235        // Get the directory to which artifacts will be saved
236        let artifact_base = crate::get_artifact_base(run.clone())?;
237
238        // Return a recorder
239        let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id");
240        MlflowTrackingRecorder::new(&self.base_url, &experiment_id, run, artifact_base)
241    }
242
243    /// Get Run info.
244    fn get_run_info(&self, run_name: impl AsRef<str>) -> Result<Run> {
245        let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id");
246        let resp = self
247            .post(
248                self.url("runs/create"),
249                &CreateRunParams {
250                    experiment_id: experiment_id.to_string(),
251                    start_time: system_time_as_millis() as i64,
252                    run_name: run_name.as_ref().to_string(),
253                },
254            )
255            .unwrap();
256        // TODO: Check the response from the tracking server here
257
258        let run = {
259            let run: Run_ =
260                serde_json::from_str(&resp.text().unwrap()).expect("Failed to deserialize Run");
261            run.run
262        };
263        Ok(run)
264    }
265
266    /// Get runs by names.
267    ///
268    /// This method queries the tracking server and returns [`Run`]s.
269    pub fn get_runs_by_name(&self, name: impl AsRef<str>) -> Result<Vec<Run>> {
270        let experiment_id = self
271            .experiment_id
272            .clone()
273            .expect("Experiment id must be set before search runs");
274        let resp = self
275            .post(
276                self.url("runs/search"),
277                &SearchRunsParams {
278                    experiment_ids: vec![experiment_id],
279                    filter: format!("tags.mlflow.runName = '{}'", name.as_ref()),
280                },
281            )
282            .unwrap();
283
284        let resp: SearchRunsResponse =
285            serde_json::from_str(&resp.text().unwrap().as_str()).expect("Failed to deserialize");
286
287        Ok(resp.runs.unwrap_or(vec![]))
288    }
289}
290
291// // Used to test on vscode
292// #[test]
293// fn test() -> Result<()> {
294//     let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment("Gym")?;
295//     // let runs = client.get_runs_by_name("dqn_cartpole_candle")?;
296//     let runs = client.get_runs_by_name("")?;
297//     println!("{:?}", runs.len());
298
299//     Ok(())
300// }