1use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run};
2use anyhow::Result;
3use border_core::{Env, ReplayBufferBase};
4use log::info;
5use reqwest::blocking::Client;
6use serde::{Deserialize, Serialize};
7use std::error::Error;
8use std::fmt::Display;
9
10#[derive(Debug, Deserialize)]
11struct Experiment_ {
13 pub(crate) experiment: Experiment,
14}
15
16#[derive(Clone, Debug, Deserialize)]
17struct Run_ {
19 run: Run,
20}
21
22#[derive(Debug, Clone)]
23pub struct GetExperimentIdError;
24
25impl Display for GetExperimentIdError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "Failed to get experiment ID")
28 }
29}
30
31impl Error for GetExperimentIdError {}
32
33#[derive(Debug, Serialize)]
34struct CreateRunParams {
36 experiment_id: String,
37 start_time: i64,
38 run_name: String,
39}
40
41#[derive(Debug, Serialize)]
42struct CreateExperimentParams {
44 name: String,
45}
46
47#[derive(Debug, Serialize)]
48struct SearchRunsParams {
50 experiment_ids: Vec<String>,
51 filter: String,
52}
53
54#[derive(Clone, Debug, Deserialize)]
55pub struct SearchRunsResponse {
56 runs: Option<Vec<Run>>,
57 #[allow(dead_code)]
58 next_page_token: Option<String>,
59}
60
61#[derive(Debug)]
62pub struct MlflowTrackingClient {
66 client: Client,
67
68 base_url: String,
70
71 experiment_id: Option<String>,
73
74 user_name: String,
76
77 password: String,
79}
80
81impl MlflowTrackingClient {
82 pub fn new(base_url: impl AsRef<str>) -> Self {
83 Self {
84 client: Client::new(),
85 base_url: base_url.as_ref().to_string(),
86 experiment_id: None,
87 user_name: "".to_string(),
88 password: "".to_string(),
89 }
90 }
91
92 pub fn basic_auth(self, user_name: impl AsRef<str>, password: impl AsRef<str>) -> Self {
94 Self {
95 client: self.client,
96 base_url: self.base_url,
97 experiment_id: self.experiment_id,
98 user_name: user_name.as_ref().to_string(),
99 password: password.as_ref().to_string(),
100 }
101 }
102
103 pub fn set_experiment(self, name: impl AsRef<str>) -> Result<Self, GetExperimentIdError> {
105 let experiment_id = {
106 self.get_experiment(name.as_ref())
107 .expect(format!("Failed to get experiment: {:?}", name.as_ref()).as_str())
108 .experiment_id
109 };
110
111 info!(
112 "For experiment '{}', id={} is set in MlflowTrackingClient",
113 name.as_ref(),
114 experiment_id
115 );
116
117 Ok(Self {
118 client: self.client,
119 base_url: self.base_url,
120 experiment_id: Some(experiment_id),
121 user_name: self.user_name,
122 password: self.password,
123 })
124 }
125
126 #[allow(rustdoc::private_intra_doc_links)]
127 pub fn get_experiment(&self, name: impl AsRef<str>) -> Option<Experiment> {
134 let resp = match self.get(
135 self.url("experiments/get-by-name"),
136 &[("experiment_name", name.as_ref())],
137 ) {
138 Ok(resp) => {
139 if resp.status().is_success() {
140 resp
141 } else {
142 self.post(
144 self.url("experiments/create"),
145 &CreateExperimentParams {
146 name: name.as_ref().into(),
147 },
148 )
149 .unwrap();
150 self.get(
151 self.url("experiments/get-by-name"),
152 &[("experiment_name", name.as_ref())],
153 )
154 .unwrap()
155 }
156 }
157 Err(_) => {
158 panic!();
159 }
160 };
161 let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap();
162
163 Some(experiment.experiment)
164 }
165
166 fn url(&self, api: impl AsRef<str>) -> String {
167 format!("{}/api/2.0/mlflow/{}", self.base_url, api.as_ref())
168 }
169
170 fn get(
171 &self,
172 url: String,
173 query: &impl Serialize,
174 ) -> reqwest::Result<reqwest::blocking::Response> {
175 self.client
176 .get(url)
177 .basic_auth(&self.user_name, Some(&self.password))
178 .query(query)
179 .send()
180 }
181
182 fn post(
183 &self,
184 url: String,
185 params: &impl Serialize,
186 ) -> reqwest::Result<reqwest::blocking::Response> {
187 self.client
188 .post(url)
189 .basic_auth(&self.user_name, Some(&self.password))
190 .json(¶ms) .send()
192 }
193
194 pub fn create_recorder<E, R>(
210 &self,
211 run_name: impl AsRef<str>,
212 ) -> Result<MlflowTrackingRecorder<E, R>>
213 where
214 E: Env,
215 R: ReplayBufferBase,
216 {
217 let run_name = run_name.as_ref();
218 let run = {
219 let runs = self.get_runs_by_name(run_name)?;
220 if runs.len() >= 2 {
221 panic!("There are 2 or more runs with name '{}'", run_name);
222 } else if runs.len() == 1 {
223 runs[0].clone()
224 } else {
225 self.get_run_info(run_name)?
226 }
227 };
228 if run_name.len() == 0 {
229 info!(
230 "Run name '{}' has been automatically generated",
231 run.info.run_name
232 );
233 }
234
235 let artifact_base = crate::get_artifact_base(run.clone())?;
237
238 let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id");
240 MlflowTrackingRecorder::new(&self.base_url, &experiment_id, run, artifact_base)
241 }
242
243 fn get_run_info(&self, run_name: impl AsRef<str>) -> Result<Run> {
245 let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id");
246 let resp = self
247 .post(
248 self.url("runs/create"),
249 &CreateRunParams {
250 experiment_id: experiment_id.to_string(),
251 start_time: system_time_as_millis() as i64,
252 run_name: run_name.as_ref().to_string(),
253 },
254 )
255 .unwrap();
256 let run = {
259 let run: Run_ =
260 serde_json::from_str(&resp.text().unwrap()).expect("Failed to deserialize Run");
261 run.run
262 };
263 Ok(run)
264 }
265
266 pub fn get_runs_by_name(&self, name: impl AsRef<str>) -> Result<Vec<Run>> {
270 let experiment_id = self
271 .experiment_id
272 .clone()
273 .expect("Experiment id must be set before search runs");
274 let resp = self
275 .post(
276 self.url("runs/search"),
277 &SearchRunsParams {
278 experiment_ids: vec![experiment_id],
279 filter: format!("tags.mlflow.runName = '{}'", name.as_ref()),
280 },
281 )
282 .unwrap();
283
284 let resp: SearchRunsResponse =
285 serde_json::from_str(&resp.text().unwrap().as_str()).expect("Failed to deserialize");
286
287 Ok(resp.runs.unwrap_or(vec![]))
288 }
289}
290
291