burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use crate::train::prelude::*;
use burn_train::LearningComponentsTypes;
use burn_train::logger::{FileMetricLogger, MetricLogger};
use burn_train::metric::{
    MetricDefinition, MetricEntry, MetricId, NumericEntry,
    store::{EpochSummary, MetricsUpdate, Split},
};
use rerun::{MemoryLimit, RecordingStream, RecordingStreamBuilder, ServerOptions};
use std::collections::HashMap;
use std::process::Command;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;

const RERUN_APP_ID: &str = "burn_dragon_language_train";
const RERUN_SERVER_MEMORY_LIMIT_BYTES: u64 = 64 * 1024 * 1024;

static ACTIVE_SESSION: std::sync::OnceLock<Mutex<Option<TrainingRerunSession>>> =
    std::sync::OnceLock::new();

#[derive(Debug, Clone)]
pub struct TrainingRerunConfig {
    pub run_name: String,
    pub bind_ip: String,
    pub port: u16,
    pub telemetry_interval: Duration,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TrainingRerunServerInfo {
    pub server_url: String,
    pub viewer_url: String,
}

struct TrainingRerunSession {
    recording: RecordingStream,
    shutdown: Arc<AtomicBool>,
    telemetry_thread: Option<JoinHandle<()>>,
    info: TrainingRerunServerInfo,
}

#[derive(Clone)]
struct RerunMetricLogger {
    recording: RecordingStream,
    metric_names: HashMap<MetricId, String>,
    split_samples: HashMap<String, i64>,
}

pub fn initialize_training_rerun(config: &TrainingRerunConfig) -> Result<TrainingRerunServerInfo> {
    shutdown_training_rerun();

    let server_url = rerun_server_url(&config.bind_ip, config.port);
    let viewer_url = rerun_viewer_url(&server_url);
    let recording = RecordingStreamBuilder::new(RERUN_APP_ID)
        .recording_name(config.run_name.clone())
        .serve_grpc_opts(
            &config.bind_ip,
            config.port,
            ServerOptions {
                memory_limit: MemoryLimit::from_bytes(RERUN_SERVER_MEMORY_LIMIT_BYTES),
                ..ServerOptions::default()
            },
        )
        .map_err(|err| anyhow!("failed to start rerun gRPC server: {err}"))?;

    let shutdown = Arc::new(AtomicBool::new(false));
    let telemetry_thread = spawn_gpu_telemetry_thread(
        recording.clone(),
        Arc::clone(&shutdown),
        config.telemetry_interval,
    );
    let info = TrainingRerunServerInfo {
        server_url,
        viewer_url,
    };

    active_session()
        .lock()
        .expect("rerun session lock")
        .replace(TrainingRerunSession {
            recording,
            shutdown,
            telemetry_thread,
            info: info.clone(),
        });

    Ok(info)
}

pub fn shutdown_training_rerun() {
    let Some(session) = active_session().lock().expect("rerun session lock").take() else {
        return;
    };

    session.shutdown.store(true, Ordering::Relaxed);
    if let Some(thread) = session.telemetry_thread {
        let _ = thread.join();
    }
    let _ = session.recording.flush_blocking();
}

pub(crate) fn attach_metric_loggers<LC>(
    builder: SupervisedTraining<LC>,
    run_dir: &Path,
) -> SupervisedTraining<LC>
where
    LC: LearningComponentsTypes,
{
    let Some(session) = active_recording() else {
        return builder;
    };

    builder
        .with_metric_logger(FileMetricLogger::new(run_dir))
        .with_metric_logger(RerunMetricLogger::new(session))
}

pub fn rerun_server_url(bind_ip: &str, port: u16) -> String {
    format!("rerun+http://{bind_ip}:{port}/proxy")
}

pub fn rerun_viewer_url(server_url: &str) -> String {
    format!(
        "https://rerun.io/viewer?url={}",
        urlencoding::encode(server_url)
    )
}

impl RerunMetricLogger {
    fn new(recording: RecordingStream) -> Self {
        Self {
            recording,
            metric_names: HashMap::new(),
            split_samples: HashMap::new(),
        }
    }

    fn next_sample(&mut self, split: &Split) -> i64 {
        let label = split_label(split).to_string();
        let counter = self.split_samples.entry(label).or_insert(0);
        *counter += 1;
        *counter
    }

    fn metric_name(&self, entry: &MetricEntry) -> Option<&str> {
        self.metric_names
            .get(&entry.metric_id)
            .map(std::string::String::as_str)
    }

    fn log_numeric_value(&self, path: &str, value: &NumericEntry) {
        let _ = self
            .recording
            .log(path, &rerun::Scalars::single(value.current()));
    }
}

impl MetricLogger for RerunMetricLogger {
    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
        let sample = self.next_sample(split);
        let split_label = split_label(split);
        self.recording.set_time_sequence("epoch", epoch as i64);
        self.recording.set_time_sequence("sample", sample);

        for numeric in update.entries_numeric {
            let Some(metric_name) = self.metric_name(&numeric.entry) else {
                continue;
            };
            let metric_name = sanitize_metric_name(metric_name);
            let value_path = format!("metrics/{split_label}/{metric_name}/value");
            let running_path = format!("metrics/{split_label}/{metric_name}/running");
            self.log_numeric_value(&value_path, &numeric.numeric_entry);
            self.log_numeric_value(&running_path, &numeric.running_entry);
        }

        for entry in update.entries {
            let Some(metric_name) = self.metric_name(&entry) else {
                continue;
            };
            let text_path = format!(
                "metrics/{split_label}/{}/text",
                sanitize_metric_name(metric_name)
            );
            let _ = self.recording.log(
                text_path.as_str(),
                &rerun::TextLog::new(entry.serialized_entry.formatted.clone()),
            );
        }
    }

    fn read_numeric(
        &mut self,
        _name: &str,
        _epoch: usize,
        _split: &Split,
    ) -> std::result::Result<Vec<NumericEntry>, String> {
        Ok(Vec::new())
    }

    fn log_metric_definition(&mut self, definition: MetricDefinition) {
        self.metric_names
            .insert(definition.metric_id, definition.name.clone());
    }

    fn log_epoch_summary(&mut self, summary: EpochSummary) {
        let split_label = split_label(&summary.split);
        self.recording
            .set_time_sequence("epoch", summary.epoch_number as i64);
        let _ = self.recording.log(
            format!("metrics/{split_label}/epoch_summary"),
            &rerun::TextLog::new(format!(
                "completed {} epoch {}",
                split_label, summary.epoch_number
            )),
        );
    }
}

fn active_session() -> &'static Mutex<Option<TrainingRerunSession>> {
    ACTIVE_SESSION.get_or_init(|| Mutex::new(None))
}

fn active_recording() -> Option<RecordingStream> {
    active_session()
        .lock()
        .expect("rerun session lock")
        .as_ref()
        .map(|session| session.recording.clone())
}

fn split_label(split: &Split) -> &str {
    match split {
        Split::Train => "train",
        Split::Valid => "valid",
        Split::Test(_) => "test",
    }
}

fn sanitize_metric_name(metric_name: &str) -> String {
    metric_name
        .trim()
        .to_lowercase()
        .replace([' ', '/', ':'], "_")
}

fn spawn_gpu_telemetry_thread(
    recording: RecordingStream,
    shutdown: Arc<AtomicBool>,
    interval: Duration,
) -> Option<JoinHandle<()>> {
    if interval.is_zero() {
        return None;
    }

    Some(thread::spawn(move || {
        let mut sample = 0_i64;
        while !shutdown.load(Ordering::Relaxed) {
            if let Some(stats) = sample_nvidia_smi() {
                sample += 1;
                recording.set_time_sequence("gpu_sample", sample);
                let _ = recording.log(
                    "system/gpu/utilization_pct",
                    &rerun::Scalars::single(stats.utilization_pct),
                );
                let _ = recording.log(
                    "system/gpu/power_watts",
                    &rerun::Scalars::single(stats.power_watts),
                );
                let _ = recording.log(
                    "system/gpu/memory_used_mib",
                    &rerun::Scalars::single(stats.memory_used_mib),
                );
            }
            thread::sleep(interval);
        }
    }))
}

#[derive(Debug, Clone, Copy, PartialEq)]
struct GpuStats {
    utilization_pct: f64,
    power_watts: f64,
    memory_used_mib: f64,
}

fn sample_nvidia_smi() -> Option<GpuStats> {
    let output = Command::new("nvidia-smi")
        .args([
            "--query-gpu=utilization.gpu,power.draw,memory.used",
            "--format=csv,noheader,nounits",
        ])
        .output()
        .ok()?;
    if !output.status.success() {
        return None;
    }
    let line = String::from_utf8(output.stdout).ok()?;
    let mut fields = line.lines().next()?.split(',').map(str::trim);
    let utilization_pct = fields.next()?.parse().ok()?;
    let power_watts = fields.next()?.parse().ok()?;
    let memory_used_mib = fields.next()?.parse().ok()?;
    Some(GpuStats {
        utilization_pct,
        power_watts,
        memory_used_mib,
    })
}

#[cfg(test)]
mod tests {
    use super::{rerun_server_url, rerun_viewer_url, sanitize_metric_name};

    #[test]
    fn rerun_viewer_url_wraps_server_proxy_url() {
        let server_url = rerun_server_url("127.0.0.1", 9876);
        assert_eq!(server_url, "rerun+http://127.0.0.1:9876/proxy");
        assert_eq!(
            rerun_viewer_url(&server_url),
            "https://rerun.io/viewer?url=rerun%2Bhttp%3A%2F%2F127.0.0.1%3A9876%2Fproxy"
        );
    }

    #[test]
    fn sanitize_metric_name_normalizes_common_separators() {
        assert_eq!(sanitize_metric_name("Learning Rate"), "learning_rate");
        assert_eq!(sanitize_metric_name("Loss/Valid"), "loss_valid");
        assert_eq!(sanitize_metric_name("device:cuda"), "device_cuda");
    }
}