Crate border_mlflow_tracking

Crate border_mlflow_tracking 

Source
Expand description

A logger for border-core crate.

This crate is based on MLflow tracking.

§Setup

To use this crate, you need to start an MLflow tracking server first. You can do this by running:

mlflow server --host 127.0.0.1 --port 8080

Before running the program using this crate, you need to set the MLFLOW_DEFAULT_ARTIFACT_ROOT environment variable to specify where model parameters and artifacts will be saved during training. Typically, you should set this to the mlruns directory of your MLflow installation.

§Example

The following code is an example. Nested configuration parameters will be flattened, logged like hyper_params.param1, hyper_params.param2.

use anyhow::Result;
use border_core::record::{Record, RecordValue, Recorder};
use border_mlflow_tracking::MlflowTrackingClient;
use serde::Serialize;

// Nested Configuration struct
#[derive(Debug, Serialize)]
struct Config {
    env_params: String,
    hyper_params: HyperParameters,
}

#[derive(Debug, Serialize)]
struct HyperParameters {
    param1: i64,
    param2: Param2,
    param3: Param3,
}

#[derive(Debug, Serialize)]
enum Param2 {
    Variant1,
    Variant2(f32),
}

#[derive(Debug, Serialize)]
struct Param3 {
    dataset_name: String,
}

fn main() -> Result<()> {
    env_logger::init();

    let config1 = Config {
        env_params: "env1".to_string(),
        hyper_params: HyperParameters {
            param1: 0,
            param2: Param2::Variant1,
            param3: Param3 {
                dataset_name: "a".to_string(),
            },
        },
    };
    let config2 = Config {
        env_params: "env2".to_string(),
        hyper_params: HyperParameters {
            param1: 0,
            param2: Param2::Variant2(3.0),
            param3: Param3 {
                dataset_name: "a".to_string(),
            },
        },
    };

    // Set experiment for runs
    let client = MlflowTrackingClient::new("http://localhost:8080")
        .set_experiment("Default")?;

    // Create recorders for logging
    let mut recorder_run1 = client.create_recorder("")?;
    let mut recorder_run2 = client.create_recorder("")?;
    recorder_run1.log_params(&config1)?;
    recorder_run2.log_params(&config2)?;

    // Logging while training
    for opt_steps in 0..100 {
        let opt_steps = opt_steps as f32;

        // Create a record
        let mut record = Record::empty();
        record.insert("opt_steps", RecordValue::Scalar(opt_steps));
        record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp()));

        // Log metrices in the record
        recorder_run1.write(record);
    }

    // Logging while training
    for opt_steps in 0..100 {
        let opt_steps = opt_steps as f32;

        // Create a record
        let mut record = Record::empty();
        record.insert("opt_steps", RecordValue::Scalar(opt_steps));
        record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp()));

        // Log metrices in the record
        recorder_run2.write(record);
    }

    Ok(())
}

§Save model parameters during training

MlflowTrackingClient relies on the MLFLOW_DEFAULT_ARTIFACT_ROOT environment variable to locate where model parameters are saved during training. Note that this environment variable should be set for the program using this crate, not for the tracking server program. Currently, only saving to the local file system is supported.

Structs§

GetExperimentIdError
MlflowTrackingClient
Provides access to a MLflow tracking server via REST API.
MlflowTrackingRecorder
Record metrics to the MLflow tracking server during training.
Run