border_mlflow_tracking/
lib.rs

1//! A logger for border-core crate.
2//!
3//! This crate is based on [MLflow](https://mlflow.org) tracking.
4//!
5//! # Setup
6//!
7//! To use this crate, you need to start an MLflow tracking server first. You can do this by running:
8//!
9//! ```bash
10//! mlflow server --host 127.0.0.1 --port 8080
11//! ```
12//!
13//! Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT`
14//! environment variable to specify where model parameters and artifacts will be saved during training.
15//! Typically, you should set this to the `mlruns` directory of your MLflow installation.
16//!
17//! # Example
18//!
19//! The following code is an example. Nested configuration parameters will be flattened,
20//! logged like `hyper_params.param1`, `hyper_params.param2`.
21//!
22//! ```no_run
23//! use anyhow::Result;
24//! use border_core::record::{Record, RecordValue, Recorder};
25//! use border_mlflow_tracking::MlflowTrackingClient;
26//! use serde::Serialize;
27//!
28//! // Nested Configuration struct
29//! #[derive(Debug, Serialize)]
30//! struct Config {
31//!     env_params: String,
32//!     hyper_params: HyperParameters,
33//! }
34//!
35//! #[derive(Debug, Serialize)]
36//! struct HyperParameters {
37//!     param1: i64,
38//!     param2: Param2,
39//!     param3: Param3,
40//! }
41//!
42//! #[derive(Debug, Serialize)]
43//! enum Param2 {
44//!     Variant1,
45//!     Variant2(f32),
46//! }
47//!
48//! #[derive(Debug, Serialize)]
49//! struct Param3 {
50//!     dataset_name: String,
51//! }
52//!
53//! fn main() -> Result<()> {
54//!     env_logger::init();
55//!
56//!     let config1 = Config {
57//!         env_params: "env1".to_string(),
58//!         hyper_params: HyperParameters {
59//!             param1: 0,
60//!             param2: Param2::Variant1,
61//!             param3: Param3 {
62//!                 dataset_name: "a".to_string(),
63//!             },
64//!         },
65//!     };
66//!     let config2 = Config {
67//!         env_params: "env2".to_string(),
68//!         hyper_params: HyperParameters {
69//!             param1: 0,
70//!             param2: Param2::Variant2(3.0),
71//!             param3: Param3 {
72//!                 dataset_name: "a".to_string(),
73//!             },
74//!         },
75//!     };
76//!
77//!     // Set experiment for runs
78//!     let client = MlflowTrackingClient::new("http://localhost:8080")
79//!         .set_experiment("Default")?;
80//!
81//!     // Create recorders for logging
82//!     let mut recorder_run1 = client.create_recorder("")?;
83//!     let mut recorder_run2 = client.create_recorder("")?;
84//!     recorder_run1.log_params(&config1)?;
85//!     recorder_run2.log_params(&config2)?;
86//!
87//!     // Logging while training
88//!     for opt_steps in 0..100 {
89//!         let opt_steps = opt_steps as f32;
90//!
91//!         // Create a record
92//!         let mut record = Record::empty();
93//!         record.insert("opt_steps", RecordValue::Scalar(opt_steps));
94//!         record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp()));
95//!
96//!         // Log metrices in the record
97//!         recorder_run1.write(record);
98//!     }
99//!
100//!     // Logging while training
101//!     for opt_steps in 0..100 {
102//!         let opt_steps = opt_steps as f32;
103//!
104//!         // Create a record
105//!         let mut record = Record::empty();
106//!         record.insert("opt_steps", RecordValue::Scalar(opt_steps));
107//!         record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp()));
108//!
109//!         // Log metrices in the record
110//!         recorder_run2.write(record);
111//!     }
112//!
113//!     Ok(())
114//! }
115//! ```
116//!
117//! ## Save model parameters during training
118//!
119//! [`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable
120//! to locate where model parameters are saved during training. Note that this environment variable
121//! should be set for the program using this crate, not for the tracking server program.
122//! Currently, only saving to the local file system is supported.
123//!
124mod client;
125mod experiment;
126mod recorder;
127mod run;
128use anyhow::Result;
129pub use client::{GetExperimentIdError, MlflowTrackingClient};
130use experiment::Experiment;
131pub use recorder::MlflowTrackingRecorder;
132pub use run::Run;
133use std::path::PathBuf;
134use std::time::{SystemTime, UNIX_EPOCH};
135
136/// Code adapted from <https://stackoverflow.com/questions/26593387>.
137fn system_time_as_millis() -> u128 {
138    let time = SystemTime::now();
139    time.duration_since(UNIX_EPOCH)
140        .expect("Time went backwards")
141        .as_millis()
142}
143
144/// Get the directory to which artifacts will be saved.
145pub(crate) fn get_artifact_base(run: Run) -> Result<PathBuf> {
146    let artifact_uri: PathBuf = run
147        .clone()
148        .info
149        .artifact_uri
150        .expect("Failed to get artifact_uri")
151        .into();
152    let artifact_uri = artifact_uri.strip_prefix("mlflow-artifacts:/")?;
153    let path: PathBuf = std::env::var("MLFLOW_DEFAULT_ARTIFACT_ROOT")
154        .expect("MLFLOW_DEFAULT_ARTIFACT_ROOT must be set")
155        .into();
156    Ok(path.join(artifact_uri))
157}
158
159// /// https://stackoverflow.com/questions/26958489/how-to-copy-a-folder-recursively-in-rust
160// fn copy_dir_all(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> Result<()> {
161//     fs::create_dir_all(&dst)?;
162//     for entry in fs::read_dir(src)? {
163//         let entry = entry?;
164//         let ty = entry.file_type()?;
165//         if ty.is_dir() {
166//             copy_dir_all(entry.path(), dst.as_ref().join(entry.file_name()))?;
167//         } else {
168//             fs::copy(entry.path(), dst.as_ref().join(entry.file_name()))?;
169//         }
170//     }
171//     Ok(())
172// }