1use crate::{system_time_as_millis, Run};
2use anyhow::Result;
3use border_core::{
4 record::{RecordStorage, RecordValue, Recorder},
5 Agent, Env, ReplayBufferBase,
6};
7use chrono::{DateTime, Duration, Local, SecondsFormat};
8use reqwest::blocking::Client;
9use serde::Serialize;
10use serde_json::Value;
11use std::marker::PhantomData;
12use std::path::{Path, PathBuf};
13use tempdir::TempDir;
14
15#[derive(Debug, Serialize)]
16struct LogParamParams<'a> {
17 run_id: &'a String,
18 key: &'a String,
19 value: String,
20}
21
22#[derive(Debug, Serialize)]
23struct LogMetricParams<'a> {
24 run_id: &'a String,
25 key: &'a String,
26 value: f64,
27 timestamp: i64,
28 step: i64,
29}
30
31#[derive(Debug, Serialize)]
32struct UpdateRunParams<'a> {
33 run_id: &'a String,
34 status: String,
35 end_time: i64,
36 run_name: &'a String,
37}
38
39#[derive(Debug, Serialize)]
40struct SetTagParams<'a> {
41 run_id: &'a String,
42 key: &'a String,
43 value: &'a String,
44}
45
46#[allow(dead_code)]
47pub struct MlflowTrackingRecorder<E, R>
65where
66 E: Env,
67 R: ReplayBufferBase,
68{
69 client: Client,
70 base_url: String,
71 experiment_id: String,
72 run: Run,
73 user_name: String,
74 storage: RecordStorage,
75 password: String,
76 start_time: DateTime<Local>,
77 artifact_base: PathBuf,
78 phantom: PhantomData<(E, R)>,
79}
80
81impl<E, R> MlflowTrackingRecorder<E, R>
82where
83 E: Env,
84 R: ReplayBufferBase,
85{
86 pub fn new(
96 base_url: &String,
97 experiment_id: &String,
98 run: Run,
99 artifact_base: PathBuf,
100 ) -> Result<Self> {
101 let client = Client::new();
102 let start_time = Local::now();
103 let recorder = Self {
104 client,
105 base_url: base_url.clone(),
106 experiment_id: experiment_id.to_string(),
107 run,
108 user_name: "".to_string(),
109 password: "".to_string(),
110 storage: RecordStorage::new(),
111 start_time: start_time.clone(),
112 artifact_base,
113 phantom: PhantomData,
114 };
115
116 recorder.set_tag(
118 "host_start_time",
119 start_time.to_rfc3339_opts(SecondsFormat::Secs, true),
120 )?;
121
122 Ok(recorder)
123 }
124
125 pub fn log_params(&self, params: impl Serialize) -> Result<()> {
126 let url = format!("{}/api/2.0/mlflow/runs/log-parameter", self.base_url);
127 let flatten_map = {
128 let map = match serde_json::to_value(params).unwrap() {
129 Value::Object(map) => map,
130 _ => panic!("Failed to parse object"),
131 };
132 flatten_serde_json::flatten(&map)
133 };
134 for (key, value) in flatten_map.iter() {
135 let params = LogParamParams {
136 run_id: &self.run.info.run_id,
137 key,
138 value: value.to_string(),
139 };
140 let _resp = self
141 .client
142 .post(&url)
143 .basic_auth(&self.user_name, Some(&self.password))
144 .json(¶ms) .send()
146 .unwrap();
147 }
149
150 Ok(())
151 }
152
153 pub fn set_tag(&self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
157 if self.run.exist_tag(key.as_ref()) {
158 log::warn!("Tag {} exists, so set_tag() was ignored.", key.as_ref());
159 return Ok(());
160 }
161
162 let url = format!("{}/api/2.0/mlflow/runs/set-tag", self.base_url);
163 let params = SetTagParams {
164 run_id: &self.run.info.run_id,
165 key: &key.as_ref().to_string(),
166 value: &value.as_ref().to_string(),
167 };
168 let _resp = self
169 .client
170 .post(&url)
171 .basic_auth(&self.user_name, Some(&self.password))
172 .json(¶ms)
173 .send()
174 .unwrap();
175
176 Ok(())
177 }
178
179 pub fn set_tags<'a, T: AsRef<str> + std::fmt::Debug + 'a>(
180 &self,
181 tags: impl Into<&'a [(T, T)]>,
182 ) -> Result<()> {
183 for tag in tags.into().iter() {
184 self.set_tag(&tag.0, &tag.1)?;
185 }
186 Ok(())
187 }
188}
189
190impl<E, R> Recorder<E, R> for MlflowTrackingRecorder<E, R>
191where
192 E: Env,
193 R: ReplayBufferBase,
194{
195 fn write(&mut self, record: border_core::record::Record) {
196 let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url);
197 let timestamp = system_time_as_millis() as i64;
198 let step = record.get_scalar("opt_steps").unwrap() as i64;
199
200 for (key, value) in record.iter() {
201 if *key != "opt_steps" {
202 match value {
203 RecordValue::Scalar(v) => {
204 let value = *v as f64;
205 let params = LogMetricParams {
206 run_id: &self.run.info.run_id,
207 key,
208 value,
209 timestamp,
210 step,
211 };
212 let _resp = self
213 .client
214 .post(&url)
215 .basic_auth(&self.user_name, Some(&self.password))
216 .json(¶ms) .send()
218 .unwrap();
219 }
221 _ => {} }
223 }
224 }
225 }
226
227 fn flush(&mut self, step: i64) {
228 let mut record = self.storage.aggregate();
229 record.insert("opt_steps", RecordValue::Scalar(step as _));
230 self.write(record);
231 }
232
233 fn store(&mut self, record: border_core::record::Record) {
234 self.storage.store(record);
235 }
236
237 fn save_model(&self, base: &Path, agent: &Box<dyn border_core::Agent<E, R>>) -> Result<()> {
244 let tmp = TempDir::new("mlflow")?;
246 let srcs = agent.save_params(&tmp.path())?;
247
248 for src in srcs.iter() {
250 let dest = {
251 let path = self.artifact_base.join(base);
253 if !path.exists() {
254 let _ = std::fs::create_dir(&path)?;
255 }
256
257 let file = src.strip_prefix(tmp.path())?;
259 path.join(file)
260 };
261 let bytes = std::fs::copy(src, &dest)?;
262 log::info!("Save {:?}", &src);
263 log::info!("Copy {:?}, {:.2}MB", &dest, bytes as f32 / (1024. * 1024.));
264 }
265 Ok(())
266 }
267
268 fn load_model(&self, base: &Path, agent: &mut Box<dyn Agent<E, R>>) -> Result<()> {
275 let artifact_base = crate::get_artifact_base(self.run.clone())?;
277
278 let path = &artifact_base.join(base);
280 agent.load_params(path)
281 }
282}
283
284impl<E, R> Drop for MlflowTrackingRecorder<E, R>
285where
286 E: Env,
287 R: ReplayBufferBase,
288{
289 fn drop(&mut self) {
293 let end_time = Local::now();
294 let duration = end_time.signed_duration_since(self.start_time);
295 self.set_tag(
296 "host_end_time",
297 end_time.to_rfc3339_opts(SecondsFormat::Secs, true),
298 )
299 .unwrap();
300 self.set_tag("host_duration", format_duration(&duration))
301 .unwrap();
302
303 let url = format!("{}/api/2.0/mlflow/runs/update", self.base_url);
304 let params = UpdateRunParams {
305 run_id: &self.run.info.run_id,
306 status: "FINISHED".to_string(),
307 end_time: end_time.timestamp_millis(),
308 run_name: &self.run.info.run_name,
309 };
310 let _resp = self
311 .client
312 .post(&url)
313 .basic_auth(&self.user_name, Some(&self.password))
314 .json(¶ms) .send()
316 .unwrap();
317 }
319}
320
321fn format_duration(dt: &Duration) -> String {
322 let mut seconds = dt.num_seconds();
323 let mut minutes = seconds / 60;
324 seconds %= 60;
325 let hours = minutes / 60;
326 minutes %= 60;
327 format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
328}