use burn::{
data::dataloader::DataLoaderBuilder,
module::Module,
optim::AdamConfig,
record::{CompactRecorder, FullPrecisionSettings, BinFileRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::LossMetric,
renderer::{MetricState, MetricsRenderer, TrainingProgress},
LearnerBuilder,
},
};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::path::Path;
use std::time::Instant;
use crate::config::{SensorLMConfig, TrainingConfig};
use crate::data::dataset::SyntheticSensorDataset;
use crate::model::sensorlm::{SensorLMBatcher, SensorLMModel};
use crate::training::scheduler::RsqrtScheduler;
use crate::error::Result;
struct SensorLMRenderer {
_multi: MultiProgress, train_bar: ProgressBar,
valid_bar: ProgressBar,
train_loss: Option<f64>,
valid_loss: Option<f64>,
step_start: Instant,
}
impl SensorLMRenderer {
fn new(train_steps: usize, valid_steps: usize) -> Self {
let multi = MultiProgress::new();
let style = ProgressStyle::with_template(
"{prefix:.bold.cyan} [{bar:45.green/dim}] \
{pos:>3}/{len} \
{elapsed_precise} eta {eta_precise} \
{msg}",
)
.unwrap()
.progress_chars("█▉▊▋▌▍▎▏ ");
let train_bar = multi.add(ProgressBar::new(train_steps as u64));
train_bar.set_style(style.clone());
train_bar.set_prefix("Train");
let valid_bar = multi.add(ProgressBar::new(valid_steps as u64));
valid_bar.set_style(style);
valid_bar.set_prefix("Valid");
Self {
_multi: multi,
train_bar,
valid_bar,
train_loss: None,
valid_loss: None,
step_start: Instant::now(),
}
}
}
impl MetricsRenderer for SensorLMRenderer {
fn update_train(&mut self, state: MetricState) {
if let MetricState::Numeric(_entry, val) = state {
self.train_loss = Some(val);
}
}
fn update_valid(&mut self, state: MetricState) {
if let MetricState::Numeric(_entry, val) = state {
self.valid_loss = Some(val);
}
}
fn render_train(&mut self, item: TrainingProgress) {
let step = item.iteration;
let total = if item.progress.items_total > 0 {
let batch = (item.progress.items_processed as f64 / step as f64).round() as u64;
(item.progress.items_total as u64).div_ceil(batch.max(1))
} else {
self.train_bar.length().unwrap_or(0)
};
self.train_bar.set_length(total);
self.train_bar.set_position(step as u64);
let elapsed = self.step_start.elapsed().as_secs_f64();
self.step_start = Instant::now();
let msg = match self.train_loss {
Some(l) => format!(
"loss {l:.4} ({elapsed:.1}s/step) epoch {}/{}",
item.epoch, item.epoch_total
),
None => format!(
"{elapsed:.1}s/step epoch {}/{}",
item.epoch, item.epoch_total
),
};
self.train_bar.set_message(msg);
}
fn render_valid(&mut self, item: TrainingProgress) {
let step = item.iteration;
let total = self.valid_bar.length().unwrap_or(0);
self.valid_bar.set_position(step.min(total as usize) as u64);
let msg = match self.valid_loss {
Some(l) => format!("loss {l:.4}"),
None => String::new(),
};
self.valid_bar.set_message(msg);
}
}
struct AttnMemEstimate {
per_dispatch_bytes: u64,
per_layer_bwd_bytes: u64,
all_layers_bwd_bytes: u64,
}
fn estimate_attn_memory(
batch_size: usize,
depth: usize,
num_heads: usize,
num_patches: usize,
chunk_size: usize,
) -> AttnMemEstimate {
let effective_chunk = if chunk_size == 0 { num_patches } else { chunk_size };
let per_dispatch_bytes = batch_size as u64
* num_heads as u64
* effective_chunk as u64
* num_patches as u64
* 4;
let per_layer_bwd_bytes = 2
* batch_size as u64
* num_heads as u64
* num_patches as u64
* num_patches as u64
* 4;
AttnMemEstimate {
per_dispatch_bytes,
per_layer_bwd_bytes,
all_layers_bwd_bytes: depth as u64 * per_layer_bwd_bytes,
}
}
const ALL_LAYERS_LIMIT_GB: f64 = 11.0;
const ATTN_VRAM_FRACTION: f64 = 0.70;
const DISPATCH_LIMIT_BYTES: u64 = 512 * 1024 * 1024;
const PER_DISPATCH_WARN_GB: f64 = 0.5;
fn optimal_chunk_size(batch_size: usize, num_heads: usize, num_patches: usize) -> usize {
let per_chunk_row = (batch_size as u64)
.saturating_mul(num_heads as u64)
.saturating_mul(num_patches as u64)
.saturating_mul(4);
if per_chunk_row == 0 {
return 0;
}
let max_chunk = DISPATCH_LIMIT_BYTES / per_chunk_row;
if max_chunk >= num_patches as u64 {
0 } else {
let c = (max_chunk as usize / 64) * 64;
c.max(16)
}
}
fn max_safe_batch(depth: usize, num_heads: usize, num_patches: usize, limit_gb: f64) -> usize {
let limit_bytes = (limit_gb * (1u64 << 30) as f64) as u64;
let per_sample = depth as u64
* 2
* num_heads as u64
* num_patches as u64
* num_patches as u64
* 4;
if per_sample == 0 {
return usize::MAX;
}
(limit_bytes / per_sample).max(1) as usize
}
pub fn train<B: AutodiffBackend>(
mut model_cfg: SensorLMConfig,
mut train_cfg: TrainingConfig,
) -> Result<()>
where
B::Device: Clone + Default + Send + Sync + std::fmt::Debug + 'static,
B::InnerBackend: burn::tensor::backend::Backend<Device = B::Device>,
{
let num_patches = model_cfg.sensor_encoder.num_patches();
let attn_limit_gb: f64 = match train_cfg.vram_gb {
Some(vram) => {
let limit = vram * ATTN_VRAM_FRACTION;
eprintln!(
"[sensorlm] VRAM budget: {vram:.0} GB \
→ attention limit: {limit:.2} GB (= VRAM × {ATTN_VRAM_FRACTION})"
);
limit
}
None => ALL_LAYERS_LIMIT_GB,
};
if train_cfg.vram_gb.is_some() {
let safe = max_safe_batch(
model_cfg.sensor_encoder.depth,
model_cfg.sensor_encoder.num_heads,
num_patches,
attn_limit_gb,
);
if train_cfg.batch_size > safe {
eprintln!(
"[sensorlm] Auto-reducing batch_size {} → {safe} \
(largest that fits in {attn_limit_gb:.2} GB attention budget).",
train_cfg.batch_size,
);
train_cfg.batch_size = safe;
} else {
eprintln!(
"[sensorlm] batch_size={} fits (max safe for this VRAM: {safe}).",
train_cfg.batch_size,
);
}
}
{
let new_chunk = optimal_chunk_size(
train_cfg.batch_size,
model_cfg.sensor_encoder.num_heads,
num_patches,
);
let old_chunk = model_cfg.sensor_encoder.attn_chunk_size;
if new_chunk != old_chunk {
let old_subs = if old_chunk == 0 { 1 } else { num_patches.div_ceil(old_chunk) };
let new_subs = if new_chunk == 0 { 1 } else { num_patches.div_ceil(new_chunk) };
eprintln!(
"[sensorlm] Auto-tuning attn_chunk_size {old_chunk} → {new_chunk} \
({old_subs} → {new_subs} GPU submissions/layer, \
dispatch ≤ {} MB).",
DISPATCH_LIMIT_BYTES / (1024 * 1024),
);
model_cfg.sensor_encoder.attn_chunk_size = new_chunk;
}
}
let enc = &model_cfg.sensor_encoder;
let mem = estimate_attn_memory(
train_cfg.batch_size,
enc.depth,
enc.num_heads,
num_patches,
enc.attn_chunk_size,
);
let gb = |b: u64| b as f64 / (1024.0_f64.powi(3));
let dispatch_gb = gb(mem.per_dispatch_bytes);
let per_layer_gb = gb(mem.per_layer_bwd_bytes);
let all_layers_gb = gb(mem.all_layers_bwd_bytes);
eprintln!(
"[sensorlm] Sensor encoder: N={num_patches} patches, \
depth={}, heads={}, chunk_size={}, batch={}",
enc.depth, enc.num_heads, enc.attn_chunk_size, train_cfg.batch_size,
);
eprintln!("[sensorlm] Attention VRAM (score/weight tensors only; add ~1–2 GB for weights+Adam+activations):");
eprintln!("[sensorlm] per GPU dispatch : {dispatch_gb:.3} GB (TDR risk if > {PER_DISPATCH_WARN_GB} GB)");
eprintln!("[sensorlm] per layer tape : {per_layer_gb:.2} GB × {} layers", enc.depth);
eprintln!("[sensorlm] ALL layers peak : {all_layers_gb:.2} GB ← actual training peak (limit: {attn_limit_gb:.2} GB)");
if dispatch_gb > PER_DISPATCH_WARN_GB {
eprintln!(
"[sensorlm] ⚠ Per-dispatch ({dispatch_gb:.2} GB) > {PER_DISPATCH_WARN_GB} GB — \
GPU watchdog (TDR) risk. Reduce attn_chunk_size (current: {}).",
enc.attn_chunk_size,
);
}
if all_layers_gb > attn_limit_gb {
let safe_batch = max_safe_batch(
enc.depth,
enc.num_heads,
num_patches,
attn_limit_gb,
);
let safe_chunk = (enc.attn_chunk_size / 2).max(16);
let vram_hint = if train_cfg.vram_gb.is_none() {
"Specify your GPU memory with --vram-gb <GB> to auto-select the \
right batch size, or pass --no-vram-check to skip this guard."
.to_string()
} else {
format!("Pass --no-vram-check to proceed despite the estimate, or lower --batch-size to {safe_batch}.")
};
let msg = format!(
"All-layers attention peak ({all_layers_gb:.2} GB) exceeds \
the budget ({attn_limit_gb:.2} GB).\n\
\n\
WHY: Burn builds autodiff tape for all {depth} transformer layers \
during the forward pass. At the forward→backward boundary all \
{depth} layers' chunk tensors are simultaneously in GPU memory — \
the peak is depth × per-layer, not just per-layer.\n\
\n\
Largest safe batch for this model + VRAM: {safe_batch}\n\
\n\
Options:\n\
• --vram-gb <GB> tell the tool your GPU — batch auto-selected\n\
• --batch-size {safe_batch:<4} largest batch that fits\n\
• --model-size tiny ~11 M params, much lower attention memory\n\
• --model-size small ~44 M params, moderate memory\n\
• attn_chunk_size {safe_chunk} halving chunk halves per-layer tape\n\
• --no-vram-check bypass guard (crashes are your responsibility)\n\
\n\
{vram_hint}",
depth = enc.depth,
);
if train_cfg.skip_vram_check {
eprintln!("[sensorlm] ⚠ Guard exceeded but --no-vram-check set:\n{msg}");
eprintln!("[sensorlm] ⚠ Proceeding — monitor GPU memory carefully.");
} else {
return Err(crate::error::SensorLMError::Other(anyhow::anyhow!("{msg}")));
}
}
let device = B::Device::default();
let max_seq_len = train_cfg.caption_key.max_tokens();
let train_samples = train_cfg.batch_size * 20;
let valid_samples = train_cfg.batch_size * 4;
let train_dataset = SyntheticSensorDataset::new(train_samples, train_cfg.seed, max_seq_len);
let valid_dataset = SyntheticSensorDataset::new(valid_samples, train_cfg.seed + 1, max_seq_len);
let num_workers = train_cfg.num_workers.max(1);
let train_steps = train_samples / train_cfg.batch_size;
let valid_steps = valid_samples / train_cfg.batch_size;
eprintln!(
"[sensorlm] Training plan: {train_steps} train steps + \
{valid_steps} validation steps per epoch \
(dataset: {train_samples} train / {valid_samples} valid samples)"
);
let batcher_train = SensorLMBatcher::<B>::new(
device.clone(),
model_cfg.sensor_encoder.time_steps,
model_cfg.sensor_encoder.num_channels,
max_seq_len,
);
let batcher_valid = SensorLMBatcher::<B::InnerBackend>::new(
device.clone(),
model_cfg.sensor_encoder.time_steps,
model_cfg.sensor_encoder.num_channels,
max_seq_len,
);
let train_loader = DataLoaderBuilder::new(batcher_train)
.batch_size(train_cfg.batch_size)
.shuffle(train_cfg.seed)
.num_workers(num_workers)
.build(train_dataset);
let valid_loader = DataLoaderBuilder::new(batcher_valid)
.batch_size(train_cfg.batch_size)
.num_workers(num_workers)
.build(valid_dataset);
let model = SensorLMModel::<B>::new(&model_cfg, &device);
let optimizer = AdamConfig::new()
.with_beta_1(train_cfg.beta1 as f32)
.with_beta_2(train_cfg.beta2 as f32)
.with_epsilon(train_cfg.epsilon as f32)
.with_weight_decay(Some(burn::optim::decay::WeightDecayConfig::new(
train_cfg.weight_decay, )))
.init();
let lr_scheduler = RsqrtScheduler::new(
train_cfg.lr,
train_cfg.total_steps,
train_cfg.warmup_fraction,
train_cfg.cooldown_fraction,
);
std::fs::create_dir_all(&train_cfg.artifact_dir)?;
let renderer = SensorLMRenderer::new(train_steps, valid_steps);
let builder = LearnerBuilder::new(&train_cfg.artifact_dir)
.metric_train_numeric(LossMetric::<B>::new())
.metric_valid_numeric(LossMetric::<B::InnerBackend>::new())
.with_file_checkpointer(CompactRecorder::new())
.renderer(renderer)
.devices(vec![device])
.num_epochs(1);
let builder = if train_cfg.show_summary { builder.summary() } else { builder };
let learner = builder.build(model, optimizer, lr_scheduler);
let _trained_model = learner.fit(train_loader, valid_loader);
eprintln!(
"\n[sensorlm] Training complete — \
{train_steps} train + {valid_steps} valid steps."
);
Ok(())
}
pub fn save_model<B: AutodiffBackend>(
model: SensorLMModel<B>,
path: &Path,
) -> Result<()> {
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(path, &recorder)
.map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
Ok(())
}
pub fn load_model<B: AutodiffBackend>(
cfg: &SensorLMConfig,
path: &Path,
device: &B::Device,
) -> Result<SensorLMModel<B>> {
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
let model = SensorLMModel::<B>::new(cfg, device)
.load_file(path, &recorder, device)
.map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
Ok(model)
}