use crate::{system_time_as_millis, Run};
use anyhow::Result;
use border_core::{
record::{RecordStorage, RecordValue, Recorder},
Agent, Env, ReplayBufferBase,
};
use chrono::{DateTime, Duration, Local, SecondsFormat};
use reqwest::blocking::Client;
use serde::Serialize;
use serde_json::Value;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use tempdir::TempDir;
#[derive(Debug, Serialize)]
struct LogParamParams<'a> {
run_id: &'a String,
key: &'a String,
value: String,
}
#[derive(Debug, Serialize)]
struct LogMetricParams<'a> {
run_id: &'a String,
key: &'a String,
value: f64,
timestamp: i64,
step: i64,
}
#[derive(Debug, Serialize)]
struct UpdateRunParams<'a> {
run_id: &'a String,
status: String,
end_time: i64,
run_name: &'a String,
}
#[derive(Debug, Serialize)]
struct SetTagParams<'a> {
run_id: &'a String,
key: &'a String,
value: &'a String,
}
#[allow(dead_code)]
pub struct MlflowTrackingRecorder<E, R>
where
E: Env,
R: ReplayBufferBase,
{
client: Client,
base_url: String,
experiment_id: String,
run: Run,
user_name: String,
storage: RecordStorage,
password: String,
start_time: DateTime<Local>,
artifact_base: PathBuf,
phantom: PhantomData<(E, R)>,
}
impl<E, R> MlflowTrackingRecorder<E, R>
where
E: Env,
R: ReplayBufferBase,
{
pub fn new(
base_url: &String,
experiment_id: &String,
run: Run,
artifact_base: PathBuf,
) -> Result<Self> {
let client = Client::new();
let start_time = Local::now();
let recorder = Self {
client,
base_url: base_url.clone(),
experiment_id: experiment_id.to_string(),
run,
user_name: "".to_string(),
password: "".to_string(),
storage: RecordStorage::new(),
start_time: start_time.clone(),
artifact_base,
phantom: PhantomData,
};
recorder.set_tag(
"host_start_time",
start_time.to_rfc3339_opts(SecondsFormat::Secs, true),
)?;
Ok(recorder)
}
pub fn log_params(&self, params: impl Serialize) -> Result<()> {
let url = format!("{}/api/2.0/mlflow/runs/log-parameter", self.base_url);
let flatten_map = {
let map = match serde_json::to_value(params).unwrap() {
Value::Object(map) => map,
_ => panic!("Failed to parse object"),
};
flatten_serde_json::flatten(&map)
};
for (key, value) in flatten_map.iter() {
let params = LogParamParams {
run_id: &self.run.info.run_id,
key,
value: value.to_string(),
};
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(¶ms) .send()
.unwrap();
}
Ok(())
}
pub fn set_tag(&self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
if self.run.exist_tag(key.as_ref()) {
log::warn!("Tag {} exists, so set_tag() was ignored.", key.as_ref());
return Ok(());
}
let url = format!("{}/api/2.0/mlflow/runs/set-tag", self.base_url);
let params = SetTagParams {
run_id: &self.run.info.run_id,
key: &key.as_ref().to_string(),
value: &value.as_ref().to_string(),
};
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(¶ms)
.send()
.unwrap();
Ok(())
}
pub fn set_tags<'a, T: AsRef<str> + std::fmt::Debug + 'a>(
&self,
tags: impl Into<&'a [(T, T)]>,
) -> Result<()> {
for tag in tags.into().iter() {
self.set_tag(&tag.0, &tag.1)?;
}
Ok(())
}
}
impl<E, R> Recorder<E, R> for MlflowTrackingRecorder<E, R>
where
E: Env,
R: ReplayBufferBase,
{
fn write(&mut self, record: border_core::record::Record) {
let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url);
let timestamp = system_time_as_millis() as i64;
let step = record.get_scalar("opt_steps").unwrap() as i64;
for (key, value) in record.iter() {
if *key != "opt_steps" {
match value {
RecordValue::Scalar(v) => {
let value = *v as f64;
let params = LogMetricParams {
run_id: &self.run.info.run_id,
key,
value,
timestamp,
step,
};
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(¶ms) .send()
.unwrap();
}
_ => {} }
}
}
}
fn flush(&mut self, step: i64) {
let mut record = self.storage.aggregate();
record.insert("opt_steps", RecordValue::Scalar(step as _));
self.write(record);
}
fn store(&mut self, record: border_core::record::Record) {
self.storage.store(record);
}
fn save_model(&self, base: &Path, agent: &Box<dyn border_core::Agent<E, R>>) -> Result<()> {
let tmp = TempDir::new("mlflow")?;
let srcs = agent.save_params(&tmp.path())?;
for src in srcs.iter() {
let dest = {
let path = self.artifact_base.join(base);
if !path.exists() {
let _ = std::fs::create_dir(&path)?;
}
let file = src.strip_prefix(tmp.path())?;
path.join(file)
};
let bytes = std::fs::copy(src, &dest)?;
log::info!("Save {:?}", &src);
log::info!("Copy {:?}, {:.2}MB", &dest, bytes as f32 / (1024. * 1024.));
}
Ok(())
}
fn load_model(&self, base: &Path, agent: &mut Box<dyn Agent<E, R>>) -> Result<()> {
let artifact_base = crate::get_artifact_base(self.run.clone())?;
let path = &artifact_base.join(base);
agent.load_params(path)
}
}
impl<E, R> Drop for MlflowTrackingRecorder<E, R>
where
E: Env,
R: ReplayBufferBase,
{
fn drop(&mut self) {
let end_time = Local::now();
let duration = end_time.signed_duration_since(self.start_time);
self.set_tag(
"host_end_time",
end_time.to_rfc3339_opts(SecondsFormat::Secs, true),
)
.unwrap();
self.set_tag("host_duration", format_duration(&duration))
.unwrap();
let url = format!("{}/api/2.0/mlflow/runs/update", self.base_url);
let params = UpdateRunParams {
run_id: &self.run.info.run_id,
status: "FINISHED".to_string(),
end_time: end_time.timestamp_millis(),
run_name: &self.run.info.run_name,
};
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(¶ms) .send()
.unwrap();
}
}
fn format_duration(dt: &Duration) -> String {
let mut seconds = dt.num_seconds();
let mut minutes = seconds / 60;
seconds %= 60;
let hours = minutes / 60;
minutes %= 60;
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
}