Expand description
A logger for border-core crate.
This crate is based on MLflow tracking.
§Setup
To use this crate, you need to start an MLflow tracking server first. You can do this by running:
mlflow server --host 127.0.0.1 --port 8080Before running the program using this crate, you need to set the MLFLOW_DEFAULT_ARTIFACT_ROOT
environment variable to specify where model parameters and artifacts will be saved during training.
Typically, you should set this to the mlruns directory of your MLflow installation.
§Example
The following code is an example. Nested configuration parameters will be flattened,
logged like hyper_params.param1, hyper_params.param2.
use anyhow::Result;
use border_core::record::{Record, RecordValue, Recorder};
use border_mlflow_tracking::MlflowTrackingClient;
use serde::Serialize;
// Nested Configuration struct
#[derive(Debug, Serialize)]
struct Config {
env_params: String,
hyper_params: HyperParameters,
}
#[derive(Debug, Serialize)]
struct HyperParameters {
param1: i64,
param2: Param2,
param3: Param3,
}
#[derive(Debug, Serialize)]
enum Param2 {
Variant1,
Variant2(f32),
}
#[derive(Debug, Serialize)]
struct Param3 {
dataset_name: String,
}
fn main() -> Result<()> {
env_logger::init();
let config1 = Config {
env_params: "env1".to_string(),
hyper_params: HyperParameters {
param1: 0,
param2: Param2::Variant1,
param3: Param3 {
dataset_name: "a".to_string(),
},
},
};
let config2 = Config {
env_params: "env2".to_string(),
hyper_params: HyperParameters {
param1: 0,
param2: Param2::Variant2(3.0),
param3: Param3 {
dataset_name: "a".to_string(),
},
},
};
// Set experiment for runs
let client = MlflowTrackingClient::new("http://localhost:8080")
.set_experiment("Default")?;
// Create recorders for logging
let mut recorder_run1 = client.create_recorder("")?;
let mut recorder_run2 = client.create_recorder("")?;
recorder_run1.log_params(&config1)?;
recorder_run2.log_params(&config2)?;
// Logging while training
for opt_steps in 0..100 {
let opt_steps = opt_steps as f32;
// Create a record
let mut record = Record::empty();
record.insert("opt_steps", RecordValue::Scalar(opt_steps));
record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp()));
// Log metrices in the record
recorder_run1.write(record);
}
// Logging while training
for opt_steps in 0..100 {
let opt_steps = opt_steps as f32;
// Create a record
let mut record = Record::empty();
record.insert("opt_steps", RecordValue::Scalar(opt_steps));
record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp()));
// Log metrices in the record
recorder_run2.write(record);
}
Ok(())
}§Save model parameters during training
MlflowTrackingClient relies on the MLFLOW_DEFAULT_ARTIFACT_ROOT environment variable
to locate where model parameters are saved during training. Note that this environment variable
should be set for the program using this crate, not for the tracking server program.
Currently, only saving to the local file system is supported.
Structs§
- GetExperiment
IdError - Mlflow
Tracking Client - Provides access to a MLflow tracking server via REST API.
- Mlflow
Tracking Recorder - Record metrics to the MLflow tracking server during training.
- Run