burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use std::sync::{Mutex, OnceLock};

#[derive(Clone, Copy, Debug, Default)]
pub struct TrainProfileSnapshot {
    pub dataloader_cpu_ns: u128,
    pub dataloader_tensor_copy_ns: u128,
    pub dataloader_host_to_device_copy_bytes: u128,
    pub host_sync_points: u64,
    pub forward_ns: u128,
    pub loss_backward_ns: u128,
    pub embed_probe_ns: u128,
    pub first_layer_forward_probe_ns: u128,
    pub first_layer_probe_ns: u128,
    pub logits_loss_probe_ns: u128,
    pub hidden_logits_loss_probe_ns: u128,
    pub hidden_model_forward_probe_ns: u128,
    pub hidden_model_probe_ns: u128,
    pub detail_probe_steps: u64,
    pub train_steps: u64,
    pub max_step_reserved_before_bytes: u64,
    pub max_step_in_use_before_bytes: u64,
    pub max_step_reserved_after_forward_bytes: u64,
    pub max_step_in_use_after_forward_bytes: u64,
    pub max_step_reserved_after_backward_bytes: u64,
    pub max_step_in_use_after_backward_bytes: u64,
}

#[derive(Clone, Copy, Debug, Default)]
struct TrainProfileState {
    dataloader_cpu_ns: u128,
    dataloader_tensor_copy_ns: u128,
    dataloader_host_to_device_copy_bytes: u128,
    host_sync_points: u64,
    forward_ns: u128,
    loss_backward_ns: u128,
    embed_probe_ns: u128,
    first_layer_forward_probe_ns: u128,
    first_layer_probe_ns: u128,
    logits_loss_probe_ns: u128,
    hidden_logits_loss_probe_ns: u128,
    hidden_model_forward_probe_ns: u128,
    hidden_model_probe_ns: u128,
    detail_probe_steps: u64,
    train_steps: u64,
    max_step_reserved_before_bytes: u64,
    max_step_in_use_before_bytes: u64,
    max_step_reserved_after_forward_bytes: u64,
    max_step_in_use_after_forward_bytes: u64,
    max_step_reserved_after_backward_bytes: u64,
    max_step_in_use_after_backward_bytes: u64,
}

static TRAIN_PROFILE: OnceLock<Mutex<TrainProfileState>> = OnceLock::new();

pub fn enabled() -> bool {
    std::env::var_os("BDH_STAGE_PROFILE").is_some()
}

pub fn detail_enabled() -> bool {
    std::env::var_os("BDH_STAGE_PROFILE_DETAIL").is_some()
}

pub fn memory_enabled() -> bool {
    std::env::var_os("BDH_STAGE_PROFILE_MEMORY").is_some()
}

fn state() -> &'static Mutex<TrainProfileState> {
    TRAIN_PROFILE.get_or_init(|| Mutex::new(TrainProfileState::default()))
}

fn record(mutator: impl FnOnce(&mut TrainProfileState)) {
    if let Ok(mut profile) = state().lock() {
        mutator(&mut profile);
    }
}

pub fn reset() {
    if let Ok(mut profile) = state().lock() {
        *profile = TrainProfileState::default();
    }
}

pub fn snapshot() -> TrainProfileSnapshot {
    if let Ok(profile) = state().lock() {
        return TrainProfileSnapshot {
            dataloader_cpu_ns: profile.dataloader_cpu_ns,
            dataloader_tensor_copy_ns: profile.dataloader_tensor_copy_ns,
            dataloader_host_to_device_copy_bytes: profile.dataloader_host_to_device_copy_bytes,
            host_sync_points: profile.host_sync_points,
            forward_ns: profile.forward_ns,
            loss_backward_ns: profile.loss_backward_ns,
            embed_probe_ns: profile.embed_probe_ns,
            first_layer_forward_probe_ns: profile.first_layer_forward_probe_ns,
            first_layer_probe_ns: profile.first_layer_probe_ns,
            logits_loss_probe_ns: profile.logits_loss_probe_ns,
            hidden_logits_loss_probe_ns: profile.hidden_logits_loss_probe_ns,
            hidden_model_forward_probe_ns: profile.hidden_model_forward_probe_ns,
            hidden_model_probe_ns: profile.hidden_model_probe_ns,
            detail_probe_steps: profile.detail_probe_steps,
            train_steps: profile.train_steps,
            max_step_reserved_before_bytes: profile.max_step_reserved_before_bytes,
            max_step_in_use_before_bytes: profile.max_step_in_use_before_bytes,
            max_step_reserved_after_forward_bytes: profile.max_step_reserved_after_forward_bytes,
            max_step_in_use_after_forward_bytes: profile.max_step_in_use_after_forward_bytes,
            max_step_reserved_after_backward_bytes: profile.max_step_reserved_after_backward_bytes,
            max_step_in_use_after_backward_bytes: profile.max_step_in_use_after_backward_bytes,
        };
    }
    TrainProfileSnapshot::default()
}

pub fn record_dataloader(
    cpu_ns: u128,
    tensor_copy_ns: u128,
    host_to_device_copy_bytes: u128,
    host_sync_points: u64,
) {
    record(|profile| {
        profile.dataloader_cpu_ns = profile.dataloader_cpu_ns.saturating_add(cpu_ns);
        profile.dataloader_tensor_copy_ns = profile
            .dataloader_tensor_copy_ns
            .saturating_add(tensor_copy_ns);
        profile.dataloader_host_to_device_copy_bytes = profile
            .dataloader_host_to_device_copy_bytes
            .saturating_add(host_to_device_copy_bytes);
        profile.host_sync_points = profile.host_sync_points.saturating_add(host_sync_points);
    });
}

pub fn record_train_step(forward_ns: u128, loss_backward_ns: u128) {
    record(|profile| {
        profile.forward_ns = profile.forward_ns.saturating_add(forward_ns);
        profile.loss_backward_ns = profile.loss_backward_ns.saturating_add(loss_backward_ns);
        profile.train_steps = profile.train_steps.saturating_add(1);
    });
}

pub fn record_train_step_memory(
    before_reserved_bytes: u64,
    before_in_use_bytes: u64,
    after_forward_reserved_bytes: u64,
    after_forward_in_use_bytes: u64,
    after_backward_reserved_bytes: u64,
    after_backward_in_use_bytes: u64,
) {
    record(|profile| {
        profile.max_step_reserved_before_bytes = profile
            .max_step_reserved_before_bytes
            .max(before_reserved_bytes);
        profile.max_step_in_use_before_bytes = profile
            .max_step_in_use_before_bytes
            .max(before_in_use_bytes);
        profile.max_step_reserved_after_forward_bytes = profile
            .max_step_reserved_after_forward_bytes
            .max(after_forward_reserved_bytes);
        profile.max_step_in_use_after_forward_bytes = profile
            .max_step_in_use_after_forward_bytes
            .max(after_forward_in_use_bytes);
        profile.max_step_reserved_after_backward_bytes = profile
            .max_step_reserved_after_backward_bytes
            .max(after_backward_reserved_bytes);
        profile.max_step_in_use_after_backward_bytes = profile
            .max_step_in_use_after_backward_bytes
            .max(after_backward_in_use_bytes);
    });
}

pub fn record_detail_probe(
    embed_probe_ns: u128,
    first_layer_forward_probe_ns: u128,
    first_layer_probe_ns: u128,
    logits_loss_probe_ns: u128,
    hidden_logits_loss_probe_ns: u128,
    hidden_model_forward_probe_ns: u128,
    hidden_model_probe_ns: u128,
) {
    record(|profile| {
        profile.embed_probe_ns = profile.embed_probe_ns.saturating_add(embed_probe_ns);
        profile.first_layer_forward_probe_ns = profile
            .first_layer_forward_probe_ns
            .saturating_add(first_layer_forward_probe_ns);
        profile.first_layer_probe_ns = profile
            .first_layer_probe_ns
            .saturating_add(first_layer_probe_ns);
        profile.logits_loss_probe_ns = profile
            .logits_loss_probe_ns
            .saturating_add(logits_loss_probe_ns);
        profile.hidden_logits_loss_probe_ns = profile
            .hidden_logits_loss_probe_ns
            .saturating_add(hidden_logits_loss_probe_ns);
        profile.hidden_model_forward_probe_ns = profile
            .hidden_model_forward_probe_ns
            .saturating_add(hidden_model_forward_probe_ns);
        profile.hidden_model_probe_ns = profile
            .hidden_model_probe_ns
            .saturating_add(hidden_model_probe_ns);
        profile.detail_probe_steps = profile.detail_probe_steps.saturating_add(1);
    });
}