border_mlflow_tracking/
recorder.rs

1use crate::{system_time_as_millis, Run};
2use anyhow::Result;
3use border_core::{
4    record::{RecordStorage, RecordValue, Recorder},
5    Agent, Env, ReplayBufferBase,
6};
7use chrono::{DateTime, Duration, Local, SecondsFormat};
8use reqwest::blocking::Client;
9use serde::Serialize;
10use serde_json::Value;
11use std::marker::PhantomData;
12use std::path::{Path, PathBuf};
13use tempdir::TempDir;
14
15#[derive(Debug, Serialize)]
16struct LogParamParams<'a> {
17    run_id: &'a String,
18    key: &'a String,
19    value: String,
20}
21
22#[derive(Debug, Serialize)]
23struct LogMetricParams<'a> {
24    run_id: &'a String,
25    key: &'a String,
26    value: f64,
27    timestamp: i64,
28    step: i64,
29}
30
31#[derive(Debug, Serialize)]
32struct UpdateRunParams<'a> {
33    run_id: &'a String,
34    status: String,
35    end_time: i64,
36    run_name: &'a String,
37}
38
39#[derive(Debug, Serialize)]
40struct SetTagParams<'a> {
41    run_id: &'a String,
42    key: &'a String,
43    value: &'a String,
44}
45
46#[allow(dead_code)]
47/// Record metrics to the MLflow tracking server during training.
48///
49/// Before training, you can use [`MlflowTrackingRecorder::log_params()`] to log parameters
50/// of the run like hyperparameters of the algorithm, the name of environment on which the
51/// agent is trained, etc.
52///
53/// [`MlflowTrackingRecorder::write()`] method logs [`RecordValue::Scalar`] values in the record
54/// as metrics. As an exception, `opt_steps` is treated as the `step` field of Mlflow's metric data
55/// (<https://mlflow.org/docs/latest/rest-api.html#metric>).
56///
57/// Other types of values like [`RecordValue::Array1`] will be ignored.
58///
59/// When dropped, this struct updates run's status to "FINISHED"
60/// (<https://mlflow.org/docs/latest/rest-api.html#mlflowrunstatus>).
61///
62/// [`RecordValue::Scalar`]: border_core::record::RecordValue::Scalar
63/// [`RecordValue::Array1`]: border_core::record::RecordValue::Array1
64pub struct MlflowTrackingRecorder<E, R>
65where
66    E: Env,
67    R: ReplayBufferBase,
68{
69    client: Client,
70    base_url: String,
71    experiment_id: String,
72    run: Run,
73    user_name: String,
74    storage: RecordStorage,
75    password: String,
76    start_time: DateTime<Local>,
77    artifact_base: PathBuf,
78    phantom: PhantomData<(E, R)>,
79}
80
81impl<E, R> MlflowTrackingRecorder<E, R>
82where
83    E: Env,
84    R: ReplayBufferBase,
85{
86    /// Create a new instance of `MlflowTrackingRecorder`.
87    ///
88    /// This method is used in [`MlflowTrackingClient::create_recorder()`].
89    ///
90    /// This method adds a tag "host_start_time" with the current time.
91    /// This tag is useful when using mlflow-export-import: it losts the original time.
92    /// See <https://github.com/mlflow/mlflow-export-import/issues/72>
93    ///
94    /// [`MlflowTrackingClient::create_recorder()`]: crate::MlflowTrackingClient::create_recorder
95    pub fn new(
96        base_url: &String,
97        experiment_id: &String,
98        run: Run,
99        artifact_base: PathBuf,
100    ) -> Result<Self> {
101        let client = Client::new();
102        let start_time = Local::now();
103        let recorder = Self {
104            client,
105            base_url: base_url.clone(),
106            experiment_id: experiment_id.to_string(),
107            run,
108            user_name: "".to_string(),
109            password: "".to_string(),
110            storage: RecordStorage::new(),
111            start_time: start_time.clone(),
112            artifact_base,
113            phantom: PhantomData,
114        };
115
116        // Record current time as tag "host_start_time"
117        recorder.set_tag(
118            "host_start_time",
119            start_time.to_rfc3339_opts(SecondsFormat::Secs, true),
120        )?;
121
122        Ok(recorder)
123    }
124
125    pub fn log_params(&self, params: impl Serialize) -> Result<()> {
126        let url = format!("{}/api/2.0/mlflow/runs/log-parameter", self.base_url);
127        let flatten_map = {
128            let map = match serde_json::to_value(params).unwrap() {
129                Value::Object(map) => map,
130                _ => panic!("Failed to parse object"),
131            };
132            flatten_serde_json::flatten(&map)
133        };
134        for (key, value) in flatten_map.iter() {
135            let params = LogParamParams {
136                run_id: &self.run.info.run_id,
137                key,
138                value: value.to_string(),
139            };
140            let _resp = self
141                .client
142                .post(&url)
143                .basic_auth(&self.user_name, Some(&self.password))
144                .json(&params) // auto serialize
145                .send()
146                .unwrap();
147            // TODO: error handling caused by API call
148        }
149
150        Ok(())
151    }
152
153    /// Set tag.
154    ///
155    /// This method does not overwrite tags.
156    pub fn set_tag(&self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
157        if self.run.exist_tag(key.as_ref()) {
158            log::warn!("Tag {} exists, so set_tag() was ignored.", key.as_ref());
159            return Ok(());
160        }
161
162        let url = format!("{}/api/2.0/mlflow/runs/set-tag", self.base_url);
163        let params = SetTagParams {
164            run_id: &self.run.info.run_id,
165            key: &key.as_ref().to_string(),
166            value: &value.as_ref().to_string(),
167        };
168        let _resp = self
169            .client
170            .post(&url)
171            .basic_auth(&self.user_name, Some(&self.password))
172            .json(&params)
173            .send()
174            .unwrap();
175
176        Ok(())
177    }
178
179    pub fn set_tags<'a, T: AsRef<str> + std::fmt::Debug + 'a>(
180        &self,
181        tags: impl Into<&'a [(T, T)]>,
182    ) -> Result<()> {
183        for tag in tags.into().iter() {
184            self.set_tag(&tag.0, &tag.1)?;
185        }
186        Ok(())
187    }
188}
189
190impl<E, R> Recorder<E, R> for MlflowTrackingRecorder<E, R>
191where
192    E: Env,
193    R: ReplayBufferBase,
194{
195    fn write(&mut self, record: border_core::record::Record) {
196        let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url);
197        let timestamp = system_time_as_millis() as i64;
198        let step = record.get_scalar("opt_steps").unwrap() as i64;
199
200        for (key, value) in record.iter() {
201            if *key != "opt_steps" {
202                match value {
203                    RecordValue::Scalar(v) => {
204                        let value = *v as f64;
205                        let params = LogMetricParams {
206                            run_id: &self.run.info.run_id,
207                            key,
208                            value,
209                            timestamp,
210                            step,
211                        };
212                        let _resp = self
213                            .client
214                            .post(&url)
215                            .basic_auth(&self.user_name, Some(&self.password))
216                            .json(&params) // auto serialize
217                            .send()
218                            .unwrap();
219                        // TODO: error handling caused by API call
220                    }
221                    _ => {} // ignore record value
222                }
223            }
224        }
225    }
226
227    fn flush(&mut self, step: i64) {
228        let mut record = self.storage.aggregate();
229        record.insert("opt_steps", RecordValue::Scalar(step as _));
230        self.write(record);
231    }
232
233    fn store(&mut self, record: border_core::record::Record) {
234        self.storage.store(record);
235    }
236
237    /// Save model parameters as MLflow artifacts.
238    ///
239    /// MLflow server is assumed to be running on the same host as the program using this struct.
240    /// Under this condition, this method saves model parameters under the `mlruns` directory managed by
241    /// the MLflow server. This method recognizes the environment variable `MLFLOW_DEFAULT_ARTIFACT_ROOT`
242    /// as the location of the `mlruns` directory.
243    fn save_model(&self, base: &Path, agent: &Box<dyn border_core::Agent<E, R>>) -> Result<()> {
244        // Saves the artifacts in the temporary directory
245        let tmp = TempDir::new("mlflow")?;
246        let srcs = agent.save_params(&tmp.path())?;
247
248        // Copies the artifacts
249        for src in srcs.iter() {
250            let dest = {
251                // Create subdirectory
252                let path = self.artifact_base.join(base);
253                if !path.exists() {
254                    let _ = std::fs::create_dir(&path)?;
255                }
256
257                // Create destination file path
258                let file = src.strip_prefix(tmp.path())?;
259                path.join(file)
260            };
261            let bytes = std::fs::copy(src, &dest)?;
262            log::info!("Save {:?}", &src);
263            log::info!("Copy {:?}, {:.2}MB", &dest, bytes as f32 / (1024. * 1024.));
264        }
265        Ok(())
266    }
267
268    /// Loads model parameters previously saved as MLflow artifacts.
269    ///
270    /// This method uses `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable as the directory
271    /// where artifacts, like model parameters, will be saved. It is recommended to set this
272    /// environment variable `mlruns` directory to which the tracking server persists experiment
273    /// and run data.
274    fn load_model(&self, base: &Path, agent: &mut Box<dyn Agent<E, R>>) -> Result<()> {
275        // Get the directory to which artifacts will be saved
276        let artifact_base = crate::get_artifact_base(self.run.clone())?;
277
278        // Load model parameters from the artifact directory
279        let path = &artifact_base.join(base);
280        agent.load_params(path)
281    }
282}
283
284impl<E, R> Drop for MlflowTrackingRecorder<E, R>
285where
286    E: Env,
287    R: ReplayBufferBase,
288{
289    /// Update run's status to "FINISHED" when dropped.
290    ///
291    /// It also adds tags "host_end_time" and "host_duration" with the current time and duration.
292    fn drop(&mut self) {
293        let end_time = Local::now();
294        let duration = end_time.signed_duration_since(self.start_time);
295        self.set_tag(
296            "host_end_time",
297            end_time.to_rfc3339_opts(SecondsFormat::Secs, true),
298        )
299        .unwrap();
300        self.set_tag("host_duration", format_duration(&duration))
301            .unwrap();
302
303        let url = format!("{}/api/2.0/mlflow/runs/update", self.base_url);
304        let params = UpdateRunParams {
305            run_id: &self.run.info.run_id,
306            status: "FINISHED".to_string(),
307            end_time: end_time.timestamp_millis(),
308            run_name: &self.run.info.run_name,
309        };
310        let _resp = self
311            .client
312            .post(&url)
313            .basic_auth(&self.user_name, Some(&self.password))
314            .json(&params) // auto serialize
315            .send()
316            .unwrap();
317        // TODO: error handling caused by API call
318    }
319}
320
321fn format_duration(dt: &Duration) -> String {
322    let mut seconds = dt.num_seconds();
323    let mut minutes = seconds / 60;
324    seconds %= 60;
325    let hours = minutes / 60;
326    minutes %= 60;
327    format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
328}