use std::collections::HashMap;
use flodl::{
Device, Tensor, Variable,
Module, NamedInputModule,
Linear, GELU, SiLU, LayerNorm, Dropout, BatchNorm,
FlowBuilder, MergeOp, Graph, modules,
SoftmaxRouter, ThresholdHalt, LearnedHalt,
Reshape, StateAdd,
Adam, Optimizer, mse_loss, clip_grad_norm,
save_checkpoint_file, load_checkpoint_file,
CosineScheduler,
no_grad,
};
use flodl::monitor::Monitor;
fn ffn_block(dim: i64) -> flodl::Result<Graph> {
FlowBuilder::from(Linear::new(dim, dim)?)
.through(GELU)
.through(LayerNorm::new(dim)?)
.build()
}
fn read_head(dim: i64) -> flodl::Result<Graph> {
FlowBuilder::from(Linear::new(dim, dim)?)
.through(LayerNorm::new(dim)?)
.build()
}
fn silu_block(dim: i64) -> flodl::Result<Graph> {
FlowBuilder::from(Linear::new(dim, dim)?)
.through(SiLU)
.through(BatchNorm::new(dim)?)
.build()
}
struct RmsNorm {
eps: f64,
}
impl RmsNorm {
fn new() -> Self {
RmsNorm { eps: 1e-6 }
}
}
impl Module for RmsNorm {
fn name(&self) -> &str { "rmsnorm" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let sq = input.pow_scalar(2.0)?; let ms = sq.mean_dim(-1, true)?; let shifted = ms.add_scalar(self.eps)?; let rms = shifted.sqrt()?; input.div(&rms) }
}
struct SoftClamp {
scale: f64,
bound: f64,
}
impl SoftClamp {
fn new(scale: f64, bound: f64) -> Self {
SoftClamp { scale, bound }
}
}
impl Module for SoftClamp {
fn name(&self) -> &str { "softclamp" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let scaled = input.mul_scalar(self.scale)?; let clamped = scaled.clamp(-self.bound, self.bound)?; clamped.abs() }
}
struct Softplus;
impl Module for Softplus {
fn name(&self) -> &str { "softplus" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let ex = input.exp()?; let shifted = ex.add_scalar(1.0)?; shifted.log() }
}
struct NegSigmoidGate;
impl Module for NegSigmoidGate {
fn name(&self) -> &str { "neg_sigmoid_gate" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let negated = input.neg()?; let gate = negated.sigmoid()?; input.mul(&gate) }
}
struct ShapeOps {
batch: i64,
dim: i64,
}
impl ShapeOps {
fn new(batch: i64, dim: i64) -> Self {
ShapeOps { batch, dim }
}
}
impl Module for ShapeOps {
fn name(&self) -> &str { "shape_ops" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let flat = input.flatten(0, -1)?; let expanded = flat.unsqueeze(0)?; let squeezed = expanded.squeeze(0)?; squeezed.reshape(&[self.batch, self.dim]) }
}
struct LogSoftmaxReduce;
impl Module for LogSoftmaxReduce {
fn name(&self) -> &str { "log_softmax_reduce" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let lsm = input.log_softmax(-1)?; lsm.sum_dim(-1, true) }
}
struct TransposeRoundTrip;
impl Module for TransposeRoundTrip {
fn name(&self) -> &str { "transpose_rt" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let t = input.transpose(0, 1)?; t.permute(&[1, 0]) }
}
struct ContextBlend;
impl Module for ContextBlend {
fn name(&self) -> &str { "context_blend" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
Ok(input.clone())
}
fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
Some(self)
}
}
impl NamedInputModule for ContextBlend {
fn forward_named(
&self,
input: &Variable,
refs: &HashMap<String, Variable>,
) -> flodl::Result<Variable> {
let ctx = &refs["ctx"];
let scaled = ctx.div_scalar(2.0)?; let gate = scaled.sigmoid()?; let modulated = input.mul(&gate)?; modulated.add(input) }
}
struct SpectralBasis;
impl Module for SpectralBasis {
fn name(&self) -> &str { "spectral_basis" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let s = input.sin()?; let c = input.cos()?; let sc = s.add(&c)?;
let r = sc.reciprocal()?; r.tanh() }
}
struct VarianceGate {
dim: i64,
}
impl VarianceGate {
fn new(dim: i64) -> Self {
VarianceGate { dim }
}
}
impl Module for VarianceGate {
fn name(&self) -> &str { "variance_gate" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let m = input.mean()?; let _v = input.var()?; let s = input.std()?; let gate_val = m.add(&s)?;
let gate = gate_val.expand(&[1, self.dim])?; input.mul(&gate) }
}
struct ChunkRecombine;
impl Module for ChunkRecombine {
fn name(&self) -> &str { "chunk_recombine" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let chunks = input.chunk(2, -1)?; let a = chunks[0].relu()?; let b = chunks[1].neg()?;
a.cat(&b, -1) }
}
struct AttentionLikeOps {
dim: i64,
}
impl AttentionLikeOps {
fn new(dim: i64) -> Self {
AttentionLikeOps { dim }
}
}
impl Module for AttentionLikeOps {
fn name(&self) -> &str { "attention_ops" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let weights = input.softmax(-1)?;
let row = input.select(0, 0)?; let row2d = row.unsqueeze(0)?;
let half_dim = self.dim / 2;
let first_half = row2d.narrow(-1, 0, half_dim)?;
let idx = Tensor::from_i64(&[0, 1], &[2], Device::CPU)?;
let selected = first_half.index_select(-1, &idx)?;
let scale = selected.mean()?; let scale_expanded = scale.expand(&[1, self.dim])?; weights.add(&scale_expanded)
}
}
struct TopKFilterOps {
dim: i64,
}
impl TopKFilterOps {
fn new(dim: i64) -> Self {
TopKFilterOps { dim }
}
}
impl Module for TopKFilterOps {
fn name(&self) -> &str { "topk_filter" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let (values, indices) = input.topk(4, -1, true, true)?;
let (sorted, _sort_idx) = values.sort(-1, false)?;
let gathered = input.gather(-1, &indices)?;
let mn = gathered.min()?; let mx = gathered.max()?; let range = mx.sub(&mn)?;
let pad_amount = self.dim - 4;
let padded = sorted.pad(&[0, pad_amount], 0.0)?;
padded.add(&range.expand(&[1, self.dim])?)
}
}
struct RepeatNarrow {
dim: i64,
}
impl RepeatNarrow {
fn new(dim: i64) -> Self {
RepeatNarrow { dim }
}
}
impl Module for RepeatNarrow {
fn name(&self) -> &str { "repeat_narrow" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
let repeated = input.repeat(&[1, 2])?; repeated.narrow(-1, 0, self.dim) }
}
struct CounterModule {
count: std::cell::Cell<u32>,
}
impl CounterModule {
fn new() -> Self {
CounterModule { count: std::cell::Cell::new(0) }
}
}
impl Module for CounterModule {
fn name(&self) -> &str { "counter" }
fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
self.count.set(self.count.get() + 1);
Ok(input.clone())
}
fn reset(&self) {
self.count.set(0);
}
}
struct HeavyPathSelector;
impl Module for HeavyPathSelector {
fn name(&self) -> &str { "heavy_path_selector" }
fn forward(&self, _input: &Variable) -> flodl::Result<Variable> {
let t = Tensor::from_f32(&[0.0], &[1], Device::CPU)?;
Ok(Variable::new(t, false))
}
fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
Some(self)
}
}
impl NamedInputModule for HeavyPathSelector {
fn forward_named(
&self,
_input: &Variable,
refs: &HashMap<String, Variable>,
) -> flodl::Result<Variable> {
let refined = &refs["refined"];
let data = refined.data().to_f32_vec()?;
let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let branch = if max_val > 5.0 { 1.0_f32 } else { 0.0 };
let t = Tensor::from_f32(&[branch], &[1], Device::CPU)?;
Ok(Variable::new(t, false))
}
}
fn spectral_monitor(dim: i64) -> flodl::Result<Graph> {
FlowBuilder::from(SpectralBasis)
.through(Linear::new(dim, dim)?)
.build()
}
fn build_showcase() -> flodl::Result<Graph> {
const B: i64 = 2; const H: i64 = 8;
FlowBuilder::from(Linear::new(2, H)?)
.input(&["ctx"])
.tag("input")
.through(GELU)
.through(LayerNorm::new(H)?)
.through(RmsNorm::new())
.through(ContextBlend)
.using(&["ctx"])
.fork(spectral_monitor(H)?)
.tag("spectral")
.split(modules![read_head(H)?, read_head(H)?])
.merge(MergeOp::Mean)
.also(Linear::new(H, H)?)
.through(Dropout::new(0.1))
.through(SoftClamp::new(0.5, 3.0))
.through(Softplus)
.through(VarianceGate::new(H))
.map(read_head(2)?)
.slices(H / 2)
.through(Reshape::new(&[B * 2, H / 2]))
.map(Linear::new(H / 2, H / 2)?)
.each()
.tag("halves")
.map(Linear::new(H / 2, H / 2)?)
.over("halves")
.map(Linear::new(H / 2, H / 2)?)
.batched()
.each()
.through(Reshape::new(&[B, H]))
.through(ShapeOps::new(B, H))
.through(NegSigmoidGate)
.through(TransposeRoundTrip)
.through(CounterModule::new())
.through(ChunkRecombine)
.through(AttentionLikeOps::new(H))
.through(TopKFilterOps::new(H))
.through(RepeatNarrow::new(H))
.loop_body(silu_block(H)?)
.for_n(2)
.tag("refined")
.gate(
SoftmaxRouter::new(H, 2)?,
modules![Linear::new(H, H)?, Linear::new(H, H)?],
)
.using(&["input"])
.switch(
HeavyPathSelector,
modules![Linear::new(H, H)?, ffn_block(H)?],
)
.using(&["refined"])
.through(StateAdd)
.using(&["memory"])
.tag("memory")
.loop_body(Linear::new(H, H)?)
.while_cond(ThresholdHalt::new(100.0), 5)
.loop_body(Linear::new(H, H)?)
.until_cond(LearnedHalt::new(H)?, 7)
.through(LogSoftmaxReduce)
.through(Linear::new(1, H)?)
.split(vec![
Box::new(Linear::new(H, H)?),
Box::new(Linear::new(H, H)?),
])
.tag_group("final_heads")
.merge(MergeOp::Add)
.through(Linear::new(H, 2)?)
.tag("output")
.build()
}
fn make_input(requires_grad: bool) -> Variable {
let t = Tensor::from_f32(&[1.0, 2.0, 0.5, -1.0], &[2, 2], Device::CPU).unwrap();
Variable::new(t, requires_grad)
}
fn make_context() -> Variable {
let t = Tensor::from_f32(
&[0.5, -0.3, 0.8, 1.2, -0.5, 0.1, 0.9, -0.7,
0.2, 0.7, -0.4, 0.6, 1.0, -0.8, 0.3, -0.1],
&[2, 8],
Device::CPU,
).unwrap();
Variable::new(t, false)
}
fn make_target() -> Variable {
let t = Tensor::from_f32(&[0.5, -0.5, -0.3, 0.8], &[2, 2], Device::CPU).unwrap();
Variable::new(t, false)
}
#[cfg(test)]
fn count_grads(params: &[flodl::Parameter]) -> usize {
params
.iter()
.filter(|p| {
p.variable.grad()
.and_then(|g| g.to_f32_vec().ok())
.is_some_and(|d| d.iter().any(|v| *v != 0.0))
})
.count()
}
fn main() {
flodl::manual_seed(42);
println!("=== floDl showcase ===\n");
println!("Building graph...");
let g = build_showcase().expect("build failed");
let n_params = g.parameters().len();
println!("Parameters: {}", n_params);
let result = g.forward_multi(&[make_input(false), make_context()])
.expect("forward failed");
println!("Output: {:?} (shape {:?})", result.data().to_f32_vec().unwrap(), result.shape());
g.reset_state();
let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v1 = r1.data().to_f32_vec().unwrap();
let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v2 = r2.data().to_f32_vec().unwrap();
println!("State drift: pass2 differs = {}", v1 != v2);
g.reset_state();
let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v3 = r3.data().to_f32_vec().unwrap();
println!("Reset restores: {}", v1 == v3);
let dot = g.dot();
println!("DOT: {} bytes", dot.len());
let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
std::fs::write(dot_path, &dot).expect("write showcase.dot");
println!("Wrote {}", dot_path);
let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
let svg = g.svg(Some(svg_path)).expect("write showcase.svg");
println!("Wrote {} ({} bytes)", svg_path, svg.len());
println!("\n--- Training (5 epochs x 4 steps) ---");
g.train();
g.reset_state();
g.enable_profiling();
let params = g.parameters();
let mut optimizer = Adam::new(¶ms, 0.001);
let num_epochs = 5;
let total_steps = num_epochs * 4;
let sched = CosineScheduler::new(0.001, 1e-5, total_steps);
let mut monitor = Monitor::new(num_epochs);
let mut step_idx = 0;
for epoch in 0..num_epochs {
let t = std::time::Instant::now();
for _ in 0..4 {
optimizer.zero_grad();
let input = make_input(true);
let ctx = make_context();
let target = make_target();
let pred = g.forward_multi(&[input, ctx]).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
clip_grad_norm(¶ms, 1.0).unwrap();
optimizer.set_lr(sched.lr(step_idx));
optimizer.step().unwrap();
step_idx += 1;
g.record_scalar("loss", loss.item().unwrap());
g.record_scalar("lr", sched.lr(step_idx - 1));
g.end_step();
}
g.end_epoch();
monitor.log(epoch, t.elapsed(), &g);
}
let trend = g.trend("loss");
println!(
"\nLoss trend: {} epochs, slope={:.4}, improving={}",
trend.len(),
trend.slope(0),
trend.improving(0),
);
let timing = g.timing_trend("input");
println!(
"Timing trend (input node): {} epochs, mean={:.1}us",
timing.len(),
timing.mean() * 1e6,
);
let profile_dot = g.dot_with_profile();
let profile_dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
std::fs::write(profile_dot_path, &profile_dot).expect("write showcase_profile.dot");
println!("Wrote {}", profile_dot_path);
let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
let profile_svg = g.svg_with_profile(Some(profile_svg_path)).expect("write showcase_profile.svg");
println!("Wrote {} ({} bytes)", profile_svg_path, profile_svg.len());
let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
g.plot_html(html_path, &["loss"]).expect("write showcase_training.html");
println!("Wrote {}", html_path);
let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
g.write_log(log_path, 5, &["loss"]).expect("write showcase_training.log");
println!("Wrote {}", log_path);
let path = "/tmp/flodl_showcase_checkpoint.fdl";
let named = g.named_parameters();
let named_bufs = g.named_buffers();
save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("save failed");
let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("load failed");
println!("\nCheckpoint save/load: OK ({} loaded)", report.loaded.len());
g.eval();
g.reset_state();
let final_out = no_grad(|| g.forward_multi(&[make_input(false), make_context()])).unwrap();
let final_vals = final_out.data().to_f32_vec().unwrap();
println!("no_grad inference: {:?}", final_vals);
assert!(final_vals.iter().all(|v| v.is_finite()), "no_grad output should be finite");
println!("\nAll showcase checks passed.");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build() {
let g = build_showcase().unwrap();
let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let vals = result.data().to_f32_vec().unwrap();
assert_eq!(vals.len(), 4, "expected 4 outputs (2x2), got {}", vals.len());
}
#[test]
fn test_forward_ref_carries_state() {
let g = build_showcase().unwrap();
let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v1 = r1.data().to_f32_vec().unwrap();
let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v2 = r2.data().to_f32_vec().unwrap();
assert_ne!(v1, v2, "pass 2 should differ from pass 1");
}
#[test]
fn test_reset_state() {
let g = build_showcase().unwrap();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.eval();
g.reset_state();
let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v1 = r1.data().to_f32_vec().unwrap();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.reset_state();
let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v3 = r3.data().to_f32_vec().unwrap();
assert_eq!(v1, v3, "after reset should match pass 1");
}
#[test]
fn test_detach_state() {
let g = build_showcase().unwrap();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.detach_state();
let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
assert_eq!(result.data().to_f32_vec().unwrap().len(), 4);
}
#[test]
fn test_backward() {
let g = build_showcase().unwrap();
let result = g.forward_multi(&[make_input(true), make_context()]).unwrap();
let loss = result.sum().unwrap();
loss.backward().unwrap();
let with_grad = count_grads(&g.parameters());
assert!(with_grad > 0, "no parameters received gradients");
}
#[test]
fn test_parameters() {
let g = build_showcase().unwrap();
let params = g.parameters();
assert!(
params.len() > 44,
"expected more than 44 params (extended graph), got {}",
params.len()
);
}
#[test]
fn test_set_training() {
let g = build_showcase().unwrap();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.set_training(false);
g.reset_state();
let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.set_training(true);
g.reset_state();
let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
assert_eq!(r1.data().to_f32_vec().unwrap().len(), 4);
assert_eq!(r2.data().to_f32_vec().unwrap().len(), 4);
}
#[test]
fn test_dot() {
let g = build_showcase().unwrap();
let dot = g.dot();
assert!(!dot.is_empty(), "DOT output is empty");
assert!(dot.contains("digraph"), "DOT should contain digraph");
}
#[test]
fn test_training_loop() {
let g = build_showcase().unwrap();
g.train();
let params = g.parameters();
let mut optimizer = Adam::new(¶ms, 0.01);
let mut losses = Vec::new();
for _ in 0..3 {
let input = make_input(true);
let ctx = make_context();
let target = make_target();
let pred = g.forward_multi(&[input, ctx]).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
losses.push(loss.item().unwrap());
loss.backward().unwrap();
clip_grad_norm(¶ms, 1.0).unwrap();
optimizer.step().unwrap();
optimizer.zero_grad();
g.end_step();
}
for (i, &l) in losses.iter().enumerate() {
assert!(l.is_finite(), "loss at step {} is not finite: {}", i, l);
}
}
#[test]
fn test_observation() {
let g = build_showcase().unwrap();
let out = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let tagged = g.tagged("output");
assert!(tagged.is_some(), "tagged 'output' not captured");
assert_eq!(tagged.unwrap().shape(), &[2, 2]);
let loss_val = out.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
g.record("test_loss", &[loss_val]);
g.flush(&["test_loss"]);
assert_eq!(g.flush_count(), 1);
let out2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let loss_val2 = out2.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
g.record("test_loss", &[loss_val2]);
g.flush(&["test_loss"]);
assert_eq!(g.flush_count(), 2);
let trend = g.trend("test_loss");
assert_eq!(trend.len(), 2, "expected 2 epochs in trend");
}
#[test]
fn test_profiling() {
let g = build_showcase().unwrap();
g.enable_profiling();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.collect_timings(&[]); g.flush_timings(&[]);
let timing = g.timing_trend("input");
assert_eq!(timing.len(), 1, "expected 1 timing epoch");
assert!(timing.latest() > 0.0, "timing should be positive");
}
#[test]
fn test_checkpoint_roundtrip() {
let g = build_showcase().unwrap();
let params = g.parameters();
let named = g.named_parameters();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
g.eval();
g.reset_state();
let path = "/tmp/flodl_showcase_test_ckpt.fdl";
let named_bufs = g.named_buffers();
save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
let before = g.forward_multi(&[make_input(false), make_context()]).unwrap();
let v_before = before.data().to_f32_vec().unwrap();
assert!(v_before.iter().all(|v| v.is_finite()), "pre-train output NaN");
let p0_before = params[0].variable.data().to_f32_vec().unwrap();
g.reset_state();
g.train();
let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
let loss = pred.sum().unwrap();
loss.backward().unwrap();
let mut opt = Adam::new(¶ms, 0.1);
opt.step().unwrap();
let p0_after = params[0].variable.data().to_f32_vec().unwrap();
assert_ne!(p0_before, p0_after, "training should change parameters");
let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
assert_eq!(report.loaded.len(), named.len());
let p0_restored = params[0].variable.data().to_f32_vec().unwrap();
assert_eq!(p0_before, p0_restored, "checkpoint restore should match original params");
let _ = std::fs::remove_file(path);
}
#[test]
fn test_no_grad() {
let g = build_showcase().unwrap();
let result = no_grad(|| g.forward_multi(&[make_input(true), make_context()])).unwrap();
let vals = result.data().to_f32_vec().unwrap();
assert_eq!(vals.len(), 4);
assert!(vals.iter().all(|v| v.is_finite()), "no_grad should produce finite values");
}
#[test]
fn test_visualization() {
let g = build_showcase().unwrap();
let dot = g.dot();
assert!(dot.contains("digraph"), "DOT should contain digraph");
assert!(dot.contains("#input"), "DOT should contain #input tag");
let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
std::fs::write(dot_path, &dot).unwrap();
let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
let svg = g.svg(Some(svg_path)).unwrap();
assert!(svg.len() > 100, "SVG should have content");
g.enable_profiling();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
let profile_dot = g.dot_with_profile();
assert!(profile_dot.contains("Forward:"), "profile DOT should show total time");
let profile_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
std::fs::write(profile_path, &profile_dot).unwrap();
let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
let profile_svg = g.svg_with_profile(Some(profile_svg_path)).unwrap();
assert!(profile_svg.len() > 100, "profile SVG should have content");
g.train();
g.reset_state();
let params = g.parameters();
let mut optimizer = Adam::new(¶ms, 0.01);
for _epoch in 0..3 {
for _ in 0..4 {
optimizer.zero_grad();
let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
let loss = mse_loss(&pred, &make_target()).unwrap();
loss.backward().unwrap();
optimizer.step().unwrap();
g.record_scalar("loss", loss.item().unwrap());
g.end_step();
}
g.end_epoch();
}
let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
g.plot_html(html_path, &["loss"]).unwrap();
let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
g.write_log(log_path, 3, &["loss"]).unwrap();
assert!(std::fs::metadata(dot_path).unwrap().len() > 100);
assert!(std::fs::metadata(svg_path).unwrap().len() > 100);
assert!(std::fs::metadata(profile_path).unwrap().len() > 100);
assert!(std::fs::metadata(profile_svg_path).unwrap().len() > 100);
assert!(std::fs::metadata(html_path).unwrap().len() > 100);
assert!(std::fs::metadata(log_path).unwrap().len() > 10);
}
#[test]
fn test_cosine_scheduler() {
let sched = CosineScheduler::new(0.01, 1e-5, 10);
let lr_start = sched.lr(0);
let lr_end = sched.lr(10);
assert!(lr_end < lr_start, "LR should decrease: {} -> {}", lr_start, lr_end);
assert!((lr_end - 1e-5).abs() < 1e-4, "LR should reach min_lr");
}
#[test]
fn test_fork_tag() {
let g = build_showcase().unwrap();
g.forward_multi(&[make_input(false), make_context()]).unwrap();
let spectral = g.tagged("spectral");
assert!(spectral.is_some(), "fork tag 'spectral' not captured");
}
}