use crate::{
autodiff, cache, compile,
data::DataLoader,
graph::Graph,
optimize::{self, OptimizeReport},
runtime::Session,
};
use std::path::Path;
#[derive(Clone, Debug)]
pub enum Optimizer {
Sgd { learning_rate: f32 },
Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
},
}
impl Optimizer {
pub fn sgd(lr: f32) -> Self {
Self::Sgd { learning_rate: lr }
}
pub fn adam(lr: f32) -> Self {
Self::Adam {
learning_rate: lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
}
}
}
impl Default for Optimizer {
fn default() -> Self {
Self::Sgd {
learning_rate: 0.01,
}
}
}
pub struct TrainConfig {
pub optimizer: Optimizer,
pub learning_rate: f32,
pub log_interval: usize,
pub data_input: String,
pub label_input: String,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
optimizer: Optimizer::default(),
learning_rate: 0.01,
log_interval: 100,
data_input: "x".into(),
label_input: "labels".into(),
}
}
}
#[derive(Clone, Debug)]
pub struct StepMetrics {
pub epoch: usize,
pub step: usize,
pub loss: f32,
}
pub trait MetricCallback {
fn on_step(&mut self, _metrics: &StepMetrics) {}
fn on_epoch(&mut self, _stats: &EpochStats) {}
}
#[derive(Default)]
pub struct LossHistory {
pub losses: Vec<f32>,
}
impl MetricCallback for LossHistory {
fn on_step(&mut self, m: &StepMetrics) {
self.losses.push(m.loss);
}
}
#[derive(Clone, Debug)]
pub struct EpochStats {
pub epoch: usize,
pub avg_loss: f32,
pub steps: usize,
}
#[derive(Clone, Debug, Default)]
pub struct TrainHistory {
pub epochs: Vec<EpochStats>,
}
impl TrainHistory {
pub fn final_loss(&self) -> Option<f32> {
self.epochs.last().map(|e| e.avg_loss)
}
}
pub struct Trainer {
session: Session,
config: TrainConfig,
}
impl Trainer {
pub fn new(session: Session, config: TrainConfig) -> Self {
Self { session, config }
}
pub fn train(&mut self, loader: &mut DataLoader, epochs: usize) -> TrainHistory {
let mut history = TrainHistory::default();
for epoch in 0..epochs {
let stats = self.train_epoch(loader, epoch);
log::info!(
"epoch {}: avg_loss = {:.4} ({} steps)",
stats.epoch,
stats.avg_loss,
stats.steps,
);
history.epochs.push(stats);
}
history
}
pub fn train_epoch(&mut self, loader: &mut DataLoader, epoch: usize) -> EpochStats {
let _span = tracing::info_span!("train_epoch", epoch).entered();
loader.shuffle(epoch as u64);
loader.reset();
let mut total_loss = 0.0_f32;
let mut steps = 0usize;
while let Some(batch) = loader.next_batch() {
{
let _span = tracing::info_span!("set_input").entered();
self.session.set_input(&self.config.data_input, batch.data);
self.session
.set_input(&self.config.label_input, batch.labels);
}
match self.config.optimizer {
Optimizer::Sgd { learning_rate } => {
self.session.set_learning_rate(learning_rate);
}
Optimizer::Adam {
learning_rate,
beta1,
beta2,
epsilon,
} => {
self.session.set_adam(learning_rate, beta1, beta2, epsilon);
}
}
self.session.step();
self.session.wait();
let loss = self.session.read_loss();
total_loss += loss;
if self.config.log_interval > 0 && steps.is_multiple_of(self.config.log_interval) {
log::info!(" epoch {} step {}: loss = {:.4}", epoch, steps, loss);
}
steps += 1;
}
let avg_loss = if steps > 0 {
total_loss / steps as f32
} else {
0.0
};
EpochStats {
epoch,
avg_loss,
steps,
}
}
pub fn session(&self) -> &Session {
&self.session
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
pub fn into_session(self) -> Session {
self.session
}
}
pub fn build_inference_session(forward_graph: &Graph) -> Session {
log::info!("building inference session...");
log::info!("forward graph:\n{}", forward_graph);
log::info!("running egglog optimization...");
let mut optimized = optimize::optimize(forward_graph);
log::info!("optimized graph: {} nodes", optimized.nodes().len());
let mut inference_fusions = Vec::new();
optimize::apply_group_norm_silu_fusions(&mut optimized, &mut inference_fusions);
if !inference_fusions.is_empty() {
log::info!(
"inference fusions: {}x GroupNorm+Silu→GroupNormSilu",
inference_fusions.len()
);
}
log::info!("compiling execution plan...");
let plan = compile::compile(&optimized);
log::info!(
"execution plan: {} buffers, {} dispatches",
plan.buffers.len(),
plan.dispatches.len()
);
log::info!("initializing GPU session...");
Session::new(plan)
}
pub fn build_session(forward_graph: &Graph) -> Session {
let (session, report) = build_session_with_report(forward_graph);
log::info!("{}", report);
session
}
pub fn build_session_with_report(forward_graph: &Graph) -> (Session, OptimizeReport) {
let _span = tracing::info_span!("build_session").entered();
log::info!("building training session...");
log::info!("forward graph:\n{}", forward_graph);
log::info!("optimizing forward graph...");
let optimized_forward = {
let _span = tracing::info_span!("optimize_forward").entered();
optimize::optimize(forward_graph)
};
log::info!(
"optimized forward: {} nodes",
optimized_forward.nodes().len()
);
let sorted_forward = optimized_forward.toposort();
log::info!(
"sorted forward: {} nodes (from {} optimized, {} original)",
sorted_forward.nodes().len(),
optimized_forward.nodes().len(),
forward_graph.nodes().len(),
);
let full_graph = {
let _span = tracing::info_span!("autodiff").entered();
autodiff::differentiate(&sorted_forward)
};
log::info!(
"full graph (forward + backward): {} nodes",
full_graph.nodes().len()
);
log::info!("optimizing full graph...");
let (optimized, report) = {
let _span = tracing::info_span!("optimize_full").entered();
optimize::optimize_with_report(&full_graph)
};
log::info!("optimized graph: {} nodes", optimized.nodes().len());
log::info!("compiling execution plan...");
let plan = {
let _span = tracing::info_span!("compile").entered();
compile::compile(&optimized)
};
log::info!(
"execution plan: {} buffers, {} dispatches",
plan.buffers.len(),
plan.dispatches.len()
);
log::info!("initializing GPU session...");
let session = {
let _span = tracing::info_span!("gpu_init").entered();
Session::new(plan)
};
(session, report)
}
pub fn build_session_unoptimized(forward_graph: &Graph) -> Session {
let _span = tracing::info_span!("build_session_unoptimized").entered();
let optimized_forward = optimize::optimize(forward_graph);
let sorted_forward = optimized_forward.toposort();
let full_graph = autodiff::differentiate(&sorted_forward);
let plan = compile::compile(&full_graph);
Session::new(plan)
}
pub fn compile_training_graph(
forward_graph: &Graph,
) -> (crate::compile::ExecutionPlan, OptimizeReport) {
let full_graph = autodiff::differentiate(forward_graph);
let (optimized, report) = optimize::optimize_with_report(&full_graph);
let plan = compile::compile(&optimized);
(plan, report)
}
pub fn build_session_cached(forward_graph: &Graph, cache_path: &Path) -> Session {
match cache::load_plan(forward_graph, cache_path) {
Ok(Some(plan)) => {
log::info!("loaded cached execution plan from {}", cache_path.display());
return Session::new(plan);
}
Ok(None) => {
log::info!("no valid cache found, running full pipeline");
}
Err(e) => {
log::warn!("failed to load cache: {}, recompiling", e);
}
}
log::info!("building training session...");
log::info!("forward graph:\n{}", forward_graph);
let full_graph = autodiff::differentiate(forward_graph);
let (optimized, report) = optimize::optimize_with_report(&full_graph);
let plan = compile::compile(&optimized);
if let Err(e) = cache::save_plan(&plan, forward_graph, cache_path) {
log::warn!("failed to save cache: {}", e);
} else {
log::info!("saved execution plan cache to {}", cache_path.display());
}
log::info!("{}", report);
Session::new(plan)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
#[test]
fn test_compile_training_graph_simple() {
let mut g = Graph::new();
let x = g.input("x", &[4, 3]);
let w = g.parameter("w", &[3, 2]);
let y = g.matmul(x, w);
let loss = g.mean_all(y);
g.set_outputs(vec![loss]);
let (plan, report) = compile_training_graph(&g);
assert!(!plan.dispatches.is_empty());
assert!(plan.loss_buffer.is_some());
assert_eq!(plan.param_grad_pairs.len(), 1);
assert_eq!(plan.param_buffers.len(), 1);
assert_eq!(plan.input_buffers.len(), 1);
assert!(report.nodes_before > 0);
}
#[test]
fn test_compile_training_graph_mlp() {
let mut g = Graph::new();
let x = g.input("x", &[4, 784]);
let w1 = g.parameter("w1", &[784, 128]);
let b1 = g.parameter("b1", &[128]);
let mm1 = g.matmul(x, w1);
let h1 = g.bias_add(mm1, b1);
let a1 = g.relu(h1);
let w2 = g.parameter("w2", &[128, 10]);
let mm2 = g.matmul(a1, w2);
let labels = g.input("labels", &[4, 10]);
let loss = g.cross_entropy_loss(mm2, labels);
g.set_outputs(vec![loss]);
let (plan, report) = compile_training_graph(&g);
assert_eq!(plan.param_grad_pairs.len(), 3);
assert_eq!(plan.param_buffers.len(), 3);
assert_eq!(plan.input_buffers.len(), 2); assert!(plan.loss_buffer.is_some());
assert!(
report.fusions_applied.is_empty(),
"unexpected fusions: {:?}",
report.fusions_applied
);
}
#[test]
fn test_train_config_default() {
let config = TrainConfig::default();
assert_eq!(config.learning_rate, 0.01);
assert_eq!(config.log_interval, 100);
assert_eq!(config.data_input, "x");
assert_eq!(config.label_input, "labels");
}
#[test]
fn test_train_history_final_loss() {
let mut h = TrainHistory::default();
assert_eq!(h.final_loss(), None);
h.epochs.push(EpochStats {
epoch: 0,
avg_loss: 2.5,
steps: 10,
});
h.epochs.push(EpochStats {
epoch: 1,
avg_loss: 1.2,
steps: 10,
});
assert_eq!(h.final_loss(), Some(1.2));
}
#[test]
fn test_epoch_stats_fields() {
let stats = EpochStats {
epoch: 3,
avg_loss: 0.42,
steps: 100,
};
assert_eq!(stats.epoch, 3);
assert!((stats.avg_loss - 0.42).abs() < 1e-6);
assert_eq!(stats.steps, 100);
}
#[test]
fn test_compile_and_cache_roundtrip() {
let mut g = Graph::new();
let x = g.input("x", &[4, 3]);
let w = g.parameter("w", &[3, 2]);
let y = g.matmul(x, w);
let loss = g.mean_all(y);
g.set_outputs(vec![loss]);
let (plan, _) = compile_training_graph(&g);
let dir = std::env::temp_dir().join("meganeura_test_train_cache");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("train_plan.ron");
cache::save_plan(&plan, &g, &path).unwrap();
let loaded = cache::load_plan(&g, &path).unwrap().unwrap();
assert_eq!(loaded.dispatches.len(), plan.dispatches.len());
assert_eq!(loaded.buffers.len(), plan.buffers.len());
assert_eq!(loaded.param_grad_pairs.len(), plan.param_grad_pairs.len());
let _ = std::fs::remove_file(&path);
}
}