use meganeura::{DataLoader, Graph, TrainConfig, Trainer, build_session};
use std::time::Instant;
fn main() {
env_logger::init();
let trace_path = std::env::var("MEGANEURA_TRACE").ok();
if trace_path.is_some() {
meganeura::profiler::init();
}
let batch = 64;
let img_dim = 784; let hidden = 256;
let epochs = 5;
let lr = 0.001_f32;
let n_samples = 6400;
let noise_level = 0.5_f32;
println!(
"generating {} denoising pairs (σ={})...",
n_samples, noise_level
);
let (noisy_images, noise_targets) = generate_denoising_data(n_samples, img_dim, noise_level);
let mut loader = DataLoader::new(noisy_images, noise_targets, img_dim, img_dim, batch);
println!(
"{} samples, {} batches/epoch",
loader.len(),
loader.num_batches()
);
let mut g = Graph::new();
let x = g.input("x", &[batch, img_dim]); let target = g.input("labels", &[batch, img_dim]);
let w1 = g.parameter("w1", &[img_dim, hidden]);
let b1 = g.parameter("b1", &[hidden]);
let mm1 = g.matmul(x, w1);
let ba1 = g.bias_add(mm1, b1);
let h1 = g.relu(ba1);
let w2 = g.parameter("w2", &[hidden, hidden]);
let b2 = g.parameter("b2", &[hidden]);
let mm2 = g.matmul(h1, w2);
let ba2 = g.bias_add(mm2, b2);
let h2 = g.relu(ba2);
let w3 = g.parameter("w3", &[hidden, hidden]);
let b3 = g.parameter("b3", &[hidden]);
let mm3 = g.matmul(h2, w3);
let ba3 = g.bias_add(mm3, b3);
let h3 = g.relu(ba3);
let w4 = g.parameter("w4", &[hidden, img_dim]);
let b4 = g.parameter("b4", &[img_dim]);
let mm4 = g.matmul(h3, w4);
let pred = g.bias_add(mm4, b4);
let neg_target = g.neg(target);
let diff = g.add(pred, neg_target);
let sq = g.mul(diff, diff);
let loss = g.mean_all(sq);
g.set_outputs(vec![loss]);
println!("denoising MLP: 4 layers, {} hidden units", hidden);
println!(
"parameters: {}",
2 * (img_dim * hidden + hidden) + 2 * (hidden * hidden + hidden) );
println!("building session (autodiff + egglog + compile)...");
let t0 = Instant::now();
let mut session = build_session(&g);
let compile_time = t0.elapsed();
println!(
"session ready in {:.2}s: {} buffers, {} dispatches",
compile_time.as_secs_f32(),
session.plan().buffers.len(),
session.plan().dispatches.len(),
);
init_param(&mut session, "w1", img_dim, hidden);
init_param(&mut session, "w2", hidden, hidden);
init_param(&mut session, "w3", hidden, hidden);
init_param(&mut session, "w4", hidden, img_dim);
session.set_parameter("b1", &vec![0.0; hidden]);
session.set_parameter("b2", &vec![0.0; hidden]);
session.set_parameter("b3", &vec![0.0; hidden]);
session.set_parameter("b4", &vec![0.0; img_dim]);
println!("\ntraining ({} epochs)...", epochs);
let config = TrainConfig {
learning_rate: lr,
log_interval: 20,
..TrainConfig::default()
};
let mut trainer = Trainer::new(session, config);
let t_train = Instant::now();
let history = trainer.train(&mut loader, epochs);
let elapsed = t_train.elapsed();
let total_steps: usize = history.epochs.iter().map(|e| e.steps).sum();
let samples_per_sec = (total_steps * batch) as f64 / elapsed.as_secs_f64();
println!("\n=== diffusion benchmark results ===");
println!("model: 4-layer MLP ({img_dim}→{hidden}→{hidden}→{hidden}→{img_dim})");
println!("batch size: {batch}");
println!("epochs: {epochs}");
println!("total steps: {total_steps}");
println!("compile time: {:.2}s", compile_time.as_secs_f64());
println!("train time: {:.2}s", elapsed.as_secs_f64());
println!("throughput: {:.0} samples/s", samples_per_sec);
for stats in &history.epochs {
println!(
" epoch {:>2}: avg_loss = {:.6} ({} steps)",
stats.epoch, stats.avg_loss, stats.steps,
);
}
if let Some(final_loss) = history.final_loss() {
println!("final loss: {:.6}", final_loss);
}
if let Some(ref trace_file) = trace_path {
let path = std::path::Path::new(trace_file);
meganeura::profiler::save(path).expect("failed to save profile");
println!("profile saved to {}", path.display());
}
}
fn generate_denoising_data(n: usize, dim: usize, noise_level: f32) -> (Vec<f32>, Vec<f32>) {
let side = (dim as f32).sqrt() as usize;
let mut noisy = vec![0.0_f32; n * dim];
let mut noise = vec![0.0_f32; n * dim];
let mut state: u64 = 314159265;
let mut next_uniform = || -> f32 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state >> 33) as f32 / (1u64 << 31) as f32 };
let mut next_normal = || -> f32 {
let sum = next_uniform() + next_uniform() + next_uniform() + next_uniform();
sum - 2.0 };
for i in 0..n {
let pattern = i % 5;
let base = i * dim;
let phase = (i / 5) as f32 * 0.1;
for y in 0..side {
for x in 0..side {
let px = base + y * side + x;
let fx = x as f32 / side as f32;
let fy = y as f32 / side as f32;
let clean = match pattern {
0 => fx, 1 => fy, 2 => {
let dx = fx - 0.5;
let dy = fy - 0.5;
let r = (dx * dx + dy * dy).sqrt();
if r < 0.3 + phase * 0.1 { 1.0 } else { 0.0 }
}
3 => {
let freq = 4 + (i / 5) % 4;
if (x / freq + y / freq) % 2 == 0 {
0.8
} else {
0.2
}
}
_ => {
((fx + fy) * 0.5 + phase).fract()
}
};
let n_val = next_normal() * noise_level;
noise[px] = n_val;
noisy[px] = (clean + n_val).clamp(0.0, 1.0);
}
}
}
(noisy, noise)
}
fn init_param(session: &mut meganeura::Session, name: &str, fan_in: usize, fan_out: usize) {
use std::f32::consts::PI;
let scale = (2.0 / (fan_in + fan_out) as f32).sqrt();
let data: Vec<f32> = (0..fan_in * fan_out)
.map(|i| {
let x = (i as f32 + 1.0) * 0.618_034;
(x * PI * 2.0).sin() * scale
})
.collect();
session.set_parameter(name, &data);
}