border_mlflow_tracking/lib.rs
1//! A logger for border-core crate.
2//!
3//! This crate is based on [MLflow](https://mlflow.org) tracking.
4//!
5//! # Setup
6//!
7//! To use this crate, you need to start an MLflow tracking server first. You can do this by running:
8//!
9//! ```bash
10//! mlflow server --host 127.0.0.1 --port 8080
11//! ```
12//!
13//! Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT`
14//! environment variable to specify where model parameters and artifacts will be saved during training.
15//! Typically, you should set this to the `mlruns` directory of your MLflow installation.
16//!
17//! # Example
18//!
19//! The following code is an example. Nested configuration parameters will be flattened,
20//! logged like `hyper_params.param1`, `hyper_params.param2`.
21//!
22//! ```no_run
23//! use anyhow::Result;
24//! use border_core::record::{Record, RecordValue, Recorder};
25//! use border_mlflow_tracking::MlflowTrackingClient;
26//! use serde::Serialize;
27//!
28//! // Nested Configuration struct
29//! #[derive(Debug, Serialize)]
30//! struct Config {
31//! env_params: String,
32//! hyper_params: HyperParameters,
33//! }
34//!
35//! #[derive(Debug, Serialize)]
36//! struct HyperParameters {
37//! param1: i64,
38//! param2: Param2,
39//! param3: Param3,
40//! }
41//!
42//! #[derive(Debug, Serialize)]
43//! enum Param2 {
44//! Variant1,
45//! Variant2(f32),
46//! }
47//!
48//! #[derive(Debug, Serialize)]
49//! struct Param3 {
50//! dataset_name: String,
51//! }
52//!
53//! fn main() -> Result<()> {
54//! env_logger::init();
55//!
56//! let config1 = Config {
57//! env_params: "env1".to_string(),
58//! hyper_params: HyperParameters {
59//! param1: 0,
60//! param2: Param2::Variant1,
61//! param3: Param3 {
62//! dataset_name: "a".to_string(),
63//! },
64//! },
65//! };
66//! let config2 = Config {
67//! env_params: "env2".to_string(),
68//! hyper_params: HyperParameters {
69//! param1: 0,
70//! param2: Param2::Variant2(3.0),
71//! param3: Param3 {
72//! dataset_name: "a".to_string(),
73//! },
74//! },
75//! };
76//!
77//! // Set experiment for runs
78//! let client = MlflowTrackingClient::new("http://localhost:8080")
79//! .set_experiment("Default")?;
80//!
81//! // Create recorders for logging
82//! let mut recorder_run1 = client.create_recorder("")?;
83//! let mut recorder_run2 = client.create_recorder("")?;
84//! recorder_run1.log_params(&config1)?;
85//! recorder_run2.log_params(&config2)?;
86//!
87//! // Logging while training
88//! for opt_steps in 0..100 {
89//! let opt_steps = opt_steps as f32;
90//!
91//! // Create a record
92//! let mut record = Record::empty();
93//! record.insert("opt_steps", RecordValue::Scalar(opt_steps));
94//! record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp()));
95//!
96//! // Log metrices in the record
97//! recorder_run1.write(record);
98//! }
99//!
100//! // Logging while training
101//! for opt_steps in 0..100 {
102//! let opt_steps = opt_steps as f32;
103//!
104//! // Create a record
105//! let mut record = Record::empty();
106//! record.insert("opt_steps", RecordValue::Scalar(opt_steps));
107//! record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp()));
108//!
109//! // Log metrices in the record
110//! recorder_run2.write(record);
111//! }
112//!
113//! Ok(())
114//! }
115//! ```
116//!
117//! ## Save model parameters during training
118//!
119//! [`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable
120//! to locate where model parameters are saved during training. Note that this environment variable
121//! should be set for the program using this crate, not for the tracking server program.
122//! Currently, only saving to the local file system is supported.
123//!
124mod client;
125mod experiment;
126mod recorder;
127mod run;
128use anyhow::Result;
129pub use client::{GetExperimentIdError, MlflowTrackingClient};
130use experiment::Experiment;
131pub use recorder::MlflowTrackingRecorder;
132pub use run::Run;
133use std::path::PathBuf;
134use std::time::{SystemTime, UNIX_EPOCH};
135
136/// Code adapted from <https://stackoverflow.com/questions/26593387>.
137fn system_time_as_millis() -> u128 {
138 let time = SystemTime::now();
139 time.duration_since(UNIX_EPOCH)
140 .expect("Time went backwards")
141 .as_millis()
142}
143
144/// Get the directory to which artifacts will be saved.
145pub(crate) fn get_artifact_base(run: Run) -> Result<PathBuf> {
146 let artifact_uri: PathBuf = run
147 .clone()
148 .info
149 .artifact_uri
150 .expect("Failed to get artifact_uri")
151 .into();
152 let artifact_uri = artifact_uri.strip_prefix("mlflow-artifacts:/")?;
153 let path: PathBuf = std::env::var("MLFLOW_DEFAULT_ARTIFACT_ROOT")
154 .expect("MLFLOW_DEFAULT_ARTIFACT_ROOT must be set")
155 .into();
156 Ok(path.join(artifact_uri))
157}
158
159// /// https://stackoverflow.com/questions/26958489/how-to-copy-a-folder-recursively-in-rust
160// fn copy_dir_all(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> Result<()> {
161// fs::create_dir_all(&dst)?;
162// for entry in fs::read_dir(src)? {
163// let entry = entry?;
164// let ty = entry.file_type()?;
165// if ty.is_dir() {
166// copy_dir_all(entry.path(), dst.as_ref().join(entry.file_name()))?;
167// } else {
168// fs::copy(entry.path(), dst.as_ref().join(entry.file_name()))?;
169// }
170// }
171// Ok(())
172// }