Skip to main content

showcase/
main.rs

1//! Showcase: exercises every method of the graph fluent builder API
2//! in a single coherent graph, plus training tools and observation.
3//!
4//! Builder methods exercised:
5//!
6//!     From, Through, Tag, Using (backward ref), Using (forward ref),
7//!     Split, Merge, Also, Map.Slices, Map.Each, Map.Over, Map.Batched,
8//!     Loop.For, Loop.While, Loop.Until,
9//!     Gate, Gate.Using, Switch, Switch.Using,
10//!     Fork, Input, TagGroup, Build
11//!
12//! Graph methods exercised:
13//!
14//!     forward, forward_multi, parameters, train, eval, set_training, reset_state,
15//!     detach_state, dot, tagged, flush, trend, enable_profiling,
16//!     flush_timings, timing_trend
17//!
18//! Variable ops exercised (via custom modules):
19//!
20//!     add, sub, mul, div, matmul, mul_scalar, add_scalar, div_scalar,
21//!     relu, sigmoid, tanh, gelu, silu,
22//!     sum, mean, sum_dim, mean_dim,
23//!     exp, log, neg, abs, sqrt, pow_scalar, clamp,
24//!     softmax, log_softmax,
25//!     transpose, reshape, narrow, flatten, squeeze, unsqueeze, expand,
26//!     cat, select, index_select, permute,
27//!     sin, cos, reciprocal, var, std, gather, chunk, repeat, pad,
28//!     topk, sort, min, max
29//!
30//! Training tools exercised:
31//!
32//!     Adam optimizer, CosineScheduler, mse_loss, clip_grad_norm,
33//!     save/load checkpoint, no_grad, observation, profiling, trends
34//!
35//! Also demonstrates Graph-as-Module: sub-graphs used as reusable blocks
36//! inside Split branches and Loop bodies.
37
38use std::collections::HashMap;
39
40use flodl::{
41    Device, Tensor, Variable,
42    Module, NamedInputModule,
43    Linear, GELU, SiLU, LayerNorm, Dropout, BatchNorm,
44    FlowBuilder, MergeOp, Graph, modules,
45    SoftmaxRouter, ThresholdHalt, LearnedHalt,
46    Reshape, StateAdd,
47    Adam, Optimizer, mse_loss, clip_grad_norm,
48    save_checkpoint_file, load_checkpoint_file,
49    CosineScheduler,
50    no_grad,
51};
52use flodl::monitor::Monitor;
53
54// ---------------------------------------------------------------------------
55// Reusable sub-graph builders
56// ---------------------------------------------------------------------------
57
58/// Feed-forward block: Linear -> GELU -> LayerNorm.
59fn ffn_block(dim: i64) -> flodl::Result<Graph> {
60    FlowBuilder::from(Linear::new(dim, dim)?)
61        .through(GELU)
62        .through(LayerNorm::new(dim)?)
63        .build()
64}
65
66/// Projection head: Linear -> LayerNorm.
67fn read_head(dim: i64) -> flodl::Result<Graph> {
68    FlowBuilder::from(Linear::new(dim, dim)?)
69        .through(LayerNorm::new(dim)?)
70        .build()
71}
72
73/// SiLU block: Linear -> SiLU -> BatchNorm.
74fn silu_block(dim: i64) -> flodl::Result<Graph> {
75    FlowBuilder::from(Linear::new(dim, dim)?)
76        .through(SiLU)
77        .through(BatchNorm::new(dim)?)
78        .build()
79}
80
81// ---------------------------------------------------------------------------
82// Custom modules exercising Variable ops
83// ---------------------------------------------------------------------------
84
85/// RMS normalization: x / sqrt(mean(x^2) + eps).
86/// Exercises: pow_scalar, mean_dim, add_scalar, sqrt, div.
87struct RmsNorm {
88    eps: f64,
89}
90
91impl RmsNorm {
92    fn new() -> Self {
93        RmsNorm { eps: 1e-6 }
94    }
95}
96
97impl Module for RmsNorm {
98    fn name(&self) -> &str { "rmsnorm" }
99
100    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
101        let sq = input.pow_scalar(2.0)?;                           // pow_scalar
102        let ms = sq.mean_dim(-1, true)?;                           // mean_dim
103        let shifted = ms.add_scalar(self.eps)?;                    // add_scalar
104        let rms = shifted.sqrt()?;                                 // sqrt
105        input.div(&rms)                                            // div
106    }
107}
108
109/// Soft clamping: clamp(x * scale, -bound, bound) then abs.
110/// Exercises: mul_scalar, clamp, abs.
111struct SoftClamp {
112    scale: f64,
113    bound: f64,
114}
115
116impl SoftClamp {
117    fn new(scale: f64, bound: f64) -> Self {
118        SoftClamp { scale, bound }
119    }
120}
121
122impl Module for SoftClamp {
123    fn name(&self) -> &str { "softclamp" }
124
125    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
126        let scaled = input.mul_scalar(self.scale)?;                // mul_scalar
127        let clamped = scaled.clamp(-self.bound, self.bound)?;      // clamp
128        clamped.abs()                                              // abs
129    }
130}
131
132/// Log-space transform: log(exp(x) + 1) (softplus).
133/// Exercises: exp, add_scalar (as +1), log.
134struct Softplus;
135
136impl Module for Softplus {
137    fn name(&self) -> &str { "softplus" }
138
139    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
140        let ex = input.exp()?;                                     // exp
141        let shifted = ex.add_scalar(1.0)?;                         // add_scalar (+1)
142        shifted.log()                                              // log
143    }
144}
145
146/// Negated sigmoid gate: sigmoid(-x) * x.
147/// Exercises: neg, sigmoid (direct op), mul.
148struct NegSigmoidGate;
149
150impl Module for NegSigmoidGate {
151    fn name(&self) -> &str { "neg_sigmoid_gate" }
152
153    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
154        let negated = input.neg()?;                                // neg
155        let gate = negated.sigmoid()?;                             // sigmoid
156        input.mul(&gate)                                           // mul
157    }
158}
159
160/// Shape gymnastics: flatten -> unsqueeze -> squeeze -> transpose.
161/// Input [B, D] -> flatten [B*D] -> unsqueeze [1, B*D] -> squeeze [B*D]
162/// -> reshape back to [B, D].
163/// Exercises: flatten, unsqueeze, squeeze, reshape.
164struct ShapeOps {
165    batch: i64,
166    dim: i64,
167}
168
169impl ShapeOps {
170    fn new(batch: i64, dim: i64) -> Self {
171        ShapeOps { batch, dim }
172    }
173}
174
175impl Module for ShapeOps {
176    fn name(&self) -> &str { "shape_ops" }
177
178    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
179        let flat = input.flatten(0, -1)?;                          // flatten
180        let expanded = flat.unsqueeze(0)?;                         // unsqueeze
181        let squeezed = expanded.squeeze(0)?;                       // squeeze
182        squeezed.reshape(&[self.batch, self.dim])                  // reshape (back)
183    }
184}
185
186/// Log-softmax along last dim, then sum_dim to scalar per batch.
187/// Exercises: log_softmax, sum_dim.
188struct LogSoftmaxReduce;
189
190impl Module for LogSoftmaxReduce {
191    fn name(&self) -> &str { "log_softmax_reduce" }
192
193    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
194        let lsm = input.log_softmax(-1)?;                         // log_softmax
195        lsm.sum_dim(-1, true)                                     // sum_dim (keepdim)
196    }
197}
198
199/// Transpose dim 0 and dim 1 (exercises transpose + permute).
200/// For [A, B] input: transpose(0,1) -> [B, A], permute back to [A, B].
201struct TransposeRoundTrip;
202
203impl Module for TransposeRoundTrip {
204    fn name(&self) -> &str { "transpose_rt" }
205
206    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
207        let t = input.transpose(0, 1)?;                           // transpose
208        t.permute(&[1, 0])                                        // permute (back)
209    }
210}
211
212/// Context blending: uses auxiliary input to modulate the stream.
213/// Exercises: div_scalar, sigmoid, mul, add (via NamedInputModule).
214struct ContextBlend;
215
216impl Module for ContextBlend {
217    fn name(&self) -> &str { "context_blend" }
218
219    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
220        Ok(input.clone())
221    }
222
223    fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
224        Some(self)
225    }
226}
227
228impl NamedInputModule for ContextBlend {
229    fn forward_named(
230        &self,
231        input: &Variable,
232        refs: &HashMap<String, Variable>,
233    ) -> flodl::Result<Variable> {
234        let ctx = &refs["ctx"];
235        let scaled = ctx.div_scalar(2.0)?;                        // div_scalar
236        let gate = scaled.sigmoid()?;                              // sigmoid
237        let modulated = input.mul(&gate)?;                         // mul
238        modulated.add(input)                                       // add
239    }
240}
241
242/// Spectral basis: sin/cos/reciprocal projections.
243/// Used as a fork side-branch (output captured, stream continues unchanged).
244/// Exercises: sin, cos, reciprocal, tanh.
245struct SpectralBasis;
246
247impl Module for SpectralBasis {
248    fn name(&self) -> &str { "spectral_basis" }
249
250    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
251        let s = input.sin()?;                                      // sin
252        let c = input.cos()?;                                      // cos
253        let sc = s.add(&c)?;
254        let r = sc.reciprocal()?;                                  // reciprocal
255        r.tanh()                                                   // tanh
256    }
257}
258
259/// Variance gate: gate stream by normalized variance.
260/// Exercises: mean (scalar), var, std, expand.
261struct VarianceGate {
262    dim: i64,
263}
264
265impl VarianceGate {
266    fn new(dim: i64) -> Self {
267        VarianceGate { dim }
268    }
269}
270
271impl Module for VarianceGate {
272    fn name(&self) -> &str { "variance_gate" }
273
274    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
275        let m = input.mean()?;                                     // mean (scalar)
276        let _v = input.var()?;                                     // var (scalar)
277        let s = input.std()?;                                      // std (scalar)
278        let gate_val = m.add(&s)?;
279        let gate = gate_val.expand(&[1, self.dim])?;               // expand
280        input.mul(&gate)                                           // mul
281    }
282}
283
284/// Chunk-recombine: split along last dim, process, cat back.
285/// Exercises: chunk, relu (Variable op), cat.
286struct ChunkRecombine;
287
288impl Module for ChunkRecombine {
289    fn name(&self) -> &str { "chunk_recombine" }
290
291    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
292        let chunks = input.chunk(2, -1)?;                          // chunk
293        let a = chunks[0].relu()?;                                 // relu (Variable op)
294        let b = chunks[1].neg()?;
295        a.cat(&b, -1)                                              // cat
296    }
297}
298
299/// Attention-like op exercise: softmax, select, narrow, index_select.
300/// Input [1, D] -> exercises each op then returns [1, D].
301struct AttentionLikeOps {
302    dim: i64,
303}
304
305impl AttentionLikeOps {
306    fn new(dim: i64) -> Self {
307        AttentionLikeOps { dim }
308    }
309}
310
311impl Module for AttentionLikeOps {
312    fn name(&self) -> &str { "attention_ops" }
313
314    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
315        let weights = input.softmax(-1)?;                          // softmax
316
317        // select dim 0, index 0 -> [D]
318        let row = input.select(0, 0)?;                             // select
319        let row2d = row.unsqueeze(0)?;
320
321        // narrow: take first half along last dim
322        let half_dim = self.dim / 2;
323        let first_half = row2d.narrow(-1, 0, half_dim)?;           // narrow
324
325        // index_select: pick specific indices from first_half [1, half_dim]
326        let idx = Tensor::from_i64(&[0, 1], &[2], Device::CPU)?;
327        let selected = first_half.index_select(-1, &idx)?;         // index_select
328
329        // Combine: scale weights by mean of selected values
330        let scale = selected.mean()?;                              // scalar
331        let scale_expanded = scale.expand(&[1, self.dim])?;        // expand (scalar -> [1,D])
332        weights.add(&scale_expanded)
333    }
334}
335
336/// TopK/sort/gather/min/max/pad exercise.
337/// Input [1, D] -> exercises each op then returns [1, D].
338struct TopKFilterOps {
339    dim: i64,
340}
341
342impl TopKFilterOps {
343    fn new(dim: i64) -> Self {
344        TopKFilterOps { dim }
345    }
346}
347
348impl Module for TopKFilterOps {
349    fn name(&self) -> &str { "topk_filter" }
350
351    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
352        // topk: get top 4 values
353        let (values, indices) = input.topk(4, -1, true, true)?;    // topk
354
355        // sort the top values
356        let (sorted, _sort_idx) = values.sort(-1, false)?;         // sort
357
358        // gather: use indices to rearrange
359        let gathered = input.gather(-1, &indices)?;                 // gather
360
361        // min/max as scalar ops
362        let mn = gathered.min()?;                                   // min
363        let mx = gathered.max()?;                                   // max
364        let range = mx.sub(&mn)?;
365
366        // pad: pad sorted [1,4] to [1, D] with zeros on the right
367        let pad_amount = self.dim - 4;
368        let padded = sorted.pad(&[0, pad_amount], 0.0)?;           // pad
369
370        padded.add(&range.expand(&[1, self.dim])?)
371    }
372}
373
374/// Repeat exercise: repeat tensor along dims.
375/// Input [1, D] -> repeat [1, 2] -> [1, 2D] -> narrow back to [1, D].
376struct RepeatNarrow {
377    dim: i64,
378}
379
380impl RepeatNarrow {
381    fn new(dim: i64) -> Self {
382        RepeatNarrow { dim }
383    }
384}
385
386impl Module for RepeatNarrow {
387    fn name(&self) -> &str { "repeat_narrow" }
388
389    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
390        let repeated = input.repeat(&[1, 2])?;                    // repeat
391        repeated.narrow(-1, 0, self.dim)                           // narrow (trim back)
392    }
393}
394
395/// Resettable module: exercises Module::reset() auto-detection by loops.
396/// Accumulates a call counter, reset clears it.
397struct CounterModule {
398    count: std::cell::Cell<u32>,
399}
400
401impl CounterModule {
402    fn new() -> Self {
403        CounterModule { count: std::cell::Cell::new(0) }
404    }
405}
406
407impl Module for CounterModule {
408    fn name(&self) -> &str { "counter" }
409
410    fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
411        self.count.set(self.count.get() + 1);
412        Ok(input.clone())
413    }
414
415    fn reset(&self) {
416        self.count.set(0);
417    }
418}
419
420// ---------------------------------------------------------------------------
421// Custom Switch selector (user-defined NamedInputModule)
422// ---------------------------------------------------------------------------
423
424/// Picks branch 0 (lightweight) or branch 1 (heavy) based on
425/// activation magnitude of the "refined" reference.
426struct HeavyPathSelector;
427
428impl Module for HeavyPathSelector {
429    fn name(&self) -> &str { "heavy_path_selector" }
430
431    fn forward(&self, _input: &Variable) -> flodl::Result<Variable> {
432        let t = Tensor::from_f32(&[0.0], &[1], Device::CPU)?;
433        Ok(Variable::new(t, false))
434    }
435
436    fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
437        Some(self)
438    }
439}
440
441impl NamedInputModule for HeavyPathSelector {
442    fn forward_named(
443        &self,
444        _input: &Variable,
445        refs: &HashMap<String, Variable>,
446    ) -> flodl::Result<Variable> {
447        let refined = &refs["refined"];
448        let data = refined.data().to_f32_vec()?;
449        let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
450
451        let branch = if max_val > 5.0 { 1.0_f32 } else { 0.0 };
452        let t = Tensor::from_f32(&[branch], &[1], Device::CPU)?;
453        Ok(Variable::new(t, false))
454    }
455}
456
457// ---------------------------------------------------------------------------
458// Sub-graph builder for fork side-branch
459// ---------------------------------------------------------------------------
460
461/// Spectral monitor sub-graph: SpectralBasis -> Linear projection.
462/// Used as a fork target — its output gets tagged but doesn't affect the main stream.
463fn spectral_monitor(dim: i64) -> flodl::Result<Graph> {
464    FlowBuilder::from(SpectralBasis)
465        .through(Linear::new(dim, dim)?)
466        .build()
467}
468
469// ---------------------------------------------------------------------------
470// Graph construction
471// ---------------------------------------------------------------------------
472
473/// Build the extended showcase graph.
474///
475/// Data flow `[B,2]` -> `[B,2]` (B=2 for BatchNorm compatibility):
476///
477/// ```text
478/// Input("ctx")                                     auxiliary input port
479/// From(Linear 2->8)                                Tag("input")
480/// Through(GELU) -> Through(LayerNorm)
481/// Through(RmsNorm)                                 pow_scalar, mean_dim, add_scalar, sqrt, div
482/// Through(ContextBlend).Using("ctx")               div_scalar, sigmoid, mul, add (NamedInputModule)
483/// Fork(spectral_monitor).Tag("spectral")           sin, cos, reciprocal, tanh (side branch)
484/// Split(read_head, read_head) -> Mean()            multi-head read
485/// Also(Linear 8->8)                                residual
486/// Through(Dropout(0.1))
487/// Through(SoftClamp(0.5, 3.0))                     mul_scalar, clamp, abs
488/// Through(Softplus)                                exp, add_scalar, log
489/// Through(VarianceGate)                            mean, var, std, expand
490/// Map(read_head(2)).Slices(4)                      per-position processing
491/// Through(Reshape [2,4])
492/// Map(Linear 4->4).Each()                          Tag("halves")
493/// Map(Linear 4->4).Over("halves")                  refine tagged halves
494/// Map(Linear 4->4).Batched().Each()                batched fast-path map
495/// Through(Reshape [1,8])
496/// Through(ShapeOps)                                flatten, unsqueeze, squeeze, reshape
497/// Through(NegSigmoidGate)                          neg, sigmoid, mul
498/// Through(TransposeRoundTrip)                      transpose, permute
499/// Through(CounterModule)                           exercises Module::reset()
500/// Through(ChunkRecombine)                          chunk, relu, cat
501/// Through(AttentionLikeOps)                        softmax, select, narrow, index_select
502/// Through(TopKFilterOps)                           topk, sort, gather, min, max, pad
503/// Through(RepeatNarrow)                            repeat
504/// Loop(silu_block).For(2)                          SiLU + BatchNorm, Tag("refined")
505/// Gate(SoftmaxRouter, Linear, Linear)              .Using("input")
506/// Switch(HeavyPathSelector, Linear, ffn_block)     .Using("refined")
507/// Through(StateAdd).Using("memory").Tag("memory")  forward ref
508/// Loop(Linear).While(ThresholdHalt(100), 5)
509/// Loop(Linear).Until(LearnedHalt(8), 7)
510/// Through(LogSoftmaxReduce)                        log_softmax, sum_dim
511/// Through(Linear 1->8)                             widen back
512/// Split(Linear, Linear).TagGroup("final_heads") -> Add()
513/// Through(Linear 8->2)                             output projection   Tag("output")
514/// ```
515fn build_showcase() -> flodl::Result<Graph> {
516    const B: i64 = 2;  // batch size (>= 2 for BatchNorm training mode)
517    const H: i64 = 8;
518
519    FlowBuilder::from(Linear::new(2, H)?)
520        // input() declares auxiliary graph inputs — forward_multi receives them
521        .input(&["ctx"])
522
523        // Tag names a position in the stream for later reference via .using()
524        .tag("input")
525
526        // .through() chains modules sequentially: stream -> module -> stream
527        .through(GELU)
528        .through(LayerNorm::new(H)?)
529        .through(RmsNorm::new())
530
531        // ContextBlend is a NamedInputModule that reads the "ctx" auxiliary input
532        .through(ContextBlend)
533        .using(&["ctx"])
534
535        // .fork() runs a side-branch: output can be tagged, main stream unchanged
536        .fork(spectral_monitor(H)?)
537        .tag("spectral")
538
539        // .split() forks the stream into parallel branches, .merge() recombines.
540        // modules![] is shorthand for vec![Box::new(...) as Box<dyn Module>, ...]
541        .split(modules![read_head(H)?, read_head(H)?])
542        .merge(MergeOp::Mean)
543
544        // .also() adds a residual connection: output = stream + module(stream)
545        .also(Linear::new(H, H)?)
546        .through(Dropout::new(0.1))
547        .through(SoftClamp::new(0.5, 3.0))
548        .through(Softplus)
549
550        // VarianceGate exercises mean/var/std/expand
551        .through(VarianceGate::new(H))
552
553        // .map().slices(n) decomposes [B,D] -> [B*n,D/n], applies body, recomposes
554        .map(read_head(2)?)
555        .slices(H / 2)
556
557        // Reshape changes tensor dimensions without copying data
558        .through(Reshape::new(&[B * 2, H / 2]))
559
560        // .map().each() applies body independently to each element in a multi-stream
561        .map(Linear::new(H / 2, H / 2)?)
562        .each()
563        .tag("halves")
564
565        // .map().over(tag) iterates over a tagged tensor (backward ref) instead
566        // of the current stream — useful for refining previously computed features
567        .map(Linear::new(H / 2, H / 2)?)
568        .over("halves")
569
570        // .map().batched().each() — fast path: full batch in one call
571        .map(Linear::new(H / 2, H / 2)?)
572        .batched()
573        .each()
574
575        .through(Reshape::new(&[B, H]))
576        .through(ShapeOps::new(B, H))
577        .through(NegSigmoidGate)
578        .through(TransposeRoundTrip)
579
580        // CounterModule overrides reset() — loops auto-call it before iterating
581        .through(CounterModule::new())
582
583        // ChunkRecombine: chunk, relu (Variable op), cat
584        .through(ChunkRecombine)
585
586        // AttentionLikeOps: softmax, select, narrow, index_select
587        .through(AttentionLikeOps::new(H))
588
589        // TopKFilterOps: topk, sort, gather, min, max, pad
590        .through(TopKFilterOps::new(H))
591
592        // RepeatNarrow: repeat
593        .through(RepeatNarrow::new(H))
594
595        // .loop_body().for_n(n) repeats the body n times, feeding output back as input.
596        // silu_block is a sub-graph (Graph implements Module) — graphs compose freely.
597        .loop_body(silu_block(H)?)
598        .for_n(2)
599        .tag("refined")
600
601        // .gate() is soft routing (mixture of experts): all experts run, router
602        // produces weights, outputs are combined. .using() feeds the tagged "input"
603        // stream to the router as a backward reference.
604        .gate(
605            SoftmaxRouter::new(H, 2)?,
606            modules![Linear::new(H, H)?, Linear::new(H, H)?],
607        )
608        .using(&["input"])
609
610        // .switch() is hard routing: router picks one branch, others are skipped.
611        // HeavyPathSelector is a custom NamedInputModule — it receives the "refined"
612        // ref via forward_named() and decides which branch to activate.
613        .switch(
614            HeavyPathSelector,
615            modules![Linear::new(H, H)?, ffn_block(H)?],
616        )
617        .using(&["refined"])
618
619        // Forward reference: .using("memory") reads a tag that doesn't exist yet —
620        // the framework creates a state buffer. .tag("memory") writes to it.
621        // On the first pass, the state is zero (pass-through). On subsequent passes,
622        // StateAdd accumulates: stream + previous_memory.
623        .through(StateAdd)
624        .using(&["memory"])
625        .tag("memory")
626
627        // .while_cond() repeats until the halt module signals stop (or max iterations).
628        // ThresholdHalt stops when the stream's L2 norm exceeds the threshold.
629        .loop_body(Linear::new(H, H)?)
630        .while_cond(ThresholdHalt::new(100.0), 5)
631
632        // .until_cond() is the inverse: repeats until halt signals true.
633        // LearnedHalt has trainable parameters — it learns when to stop.
634        .loop_body(Linear::new(H, H)?)
635        .until_cond(LearnedHalt::new(H)?, 7)
636
637        .through(LogSoftmaxReduce)
638        .through(Linear::new(1, H)?)
639
640        // Split with tag_group: names each branch ("final_heads_0", "final_heads_1")
641        .split(vec![
642            Box::new(Linear::new(H, H)?),
643            Box::new(Linear::new(H, H)?),
644        ])
645        .tag_group("final_heads")
646        .merge(MergeOp::Add)
647
648        // Final projection and output tag for observation
649        .through(Linear::new(H, 2)?)
650        .tag("output")
651        .build()
652}
653
654// ---------------------------------------------------------------------------
655// Helpers
656// ---------------------------------------------------------------------------
657
658fn make_input(requires_grad: bool) -> Variable {
659    let t = Tensor::from_f32(&[1.0, 2.0, 0.5, -1.0], &[2, 2], Device::CPU).unwrap();
660    Variable::new(t, requires_grad)
661}
662
663fn make_context() -> Variable {
664    let t = Tensor::from_f32(
665        &[0.5, -0.3, 0.8, 1.2, -0.5, 0.1, 0.9, -0.7,
666          0.2, 0.7, -0.4, 0.6, 1.0, -0.8, 0.3, -0.1],
667        &[2, 8],
668        Device::CPU,
669    ).unwrap();
670    Variable::new(t, false)
671}
672
673fn make_target() -> Variable {
674    let t = Tensor::from_f32(&[0.5, -0.5, -0.3, 0.8], &[2, 2], Device::CPU).unwrap();
675    Variable::new(t, false)
676}
677
678#[cfg(test)]
679fn count_grads(params: &[flodl::Parameter]) -> usize {
680    params
681        .iter()
682        .filter(|p| {
683            p.variable.grad()
684                .and_then(|g| g.to_f32_vec().ok())
685                .is_some_and(|d| d.iter().any(|v| *v != 0.0))
686        })
687        .count()
688}
689
690// ---------------------------------------------------------------------------
691// Main: demo run
692// ---------------------------------------------------------------------------
693
694fn main() {
695    println!("=== floDl showcase ===\n");
696
697    // -- Build --
698    println!("Building graph...");
699    let g = build_showcase().expect("build failed");
700    let n_params = g.parameters().len();
701    println!("Parameters: {}", n_params);
702
703    // -- Forward (with auxiliary input) --
704    let result = g.forward_multi(&[make_input(false), make_context()])
705        .expect("forward failed");
706    println!("Output: {:?} (shape {:?})", result.data().to_f32_vec().unwrap(), result.shape());
707
708    // -- Forward ref carries state --
709    g.reset_state();
710    let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
711    let v1 = r1.data().to_f32_vec().unwrap();
712    let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
713    let v2 = r2.data().to_f32_vec().unwrap();
714    println!("State drift: pass2 differs = {}", v1 != v2);
715
716    // -- Reset --
717    g.reset_state();
718    let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
719    let v3 = r3.data().to_f32_vec().unwrap();
720    println!("Reset restores: {}", v1 == v3);
721
722    // -- DOT + SVG (structural) --
723    let dot = g.dot();
724    println!("DOT: {} bytes", dot.len());
725
726    // Write structural DOT
727    let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
728    std::fs::write(dot_path, &dot).expect("write showcase.dot");
729    println!("Wrote {}", dot_path);
730
731    // Write structural SVG
732    let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
733    let svg = g.svg(Some(svg_path)).expect("write showcase.svg");
734    println!("Wrote {} ({} bytes)", svg_path, svg.len());
735
736    // -- Training loop with observation + profiling + monitor --
737    println!("\n--- Training (5 epochs x 4 steps) ---");
738    g.train();
739    g.reset_state();
740    g.enable_profiling();
741
742    let params = g.parameters();
743    let mut optimizer = Adam::new(&params, 0.001);
744    let num_epochs = 5;
745    let total_steps = num_epochs * 4;
746    let sched = CosineScheduler::new(0.001, 1e-5, total_steps);
747    let mut monitor = Monitor::new(num_epochs);
748
749    let mut step_idx = 0;
750    for epoch in 0..num_epochs {
751        let t = std::time::Instant::now();
752        for _ in 0..4 {
753            optimizer.zero_grad();
754            let input = make_input(true);
755            let ctx = make_context();
756            let target = make_target();
757
758            let pred = g.forward_multi(&[input, ctx]).unwrap();
759            let loss = mse_loss(&pred, &target).unwrap();
760
761            loss.backward().unwrap();
762            clip_grad_norm(&params, 1.0).unwrap();
763            optimizer.set_lr(sched.lr(step_idx));
764            optimizer.step().unwrap();
765            step_idx += 1;
766
767            g.record_scalar("loss", loss.item().unwrap());
768            g.record_scalar("lr", sched.lr(step_idx - 1));
769            g.end_step();
770        }
771
772        g.end_epoch();
773        monitor.log(epoch, t.elapsed(), &g);
774    }
775
776    // -- Trends --
777    let trend = g.trend("loss");
778    println!(
779        "\nLoss trend: {} epochs, slope={:.4}, improving={}",
780        trend.len(),
781        trend.slope(0),
782        trend.improving(0),
783    );
784
785    // Timing trends use node IDs — pick the first tagged one
786    let timing = g.timing_trend("input");
787    println!(
788        "Timing trend (input node): {} epochs, mean={:.1}us",
789        timing.len(),
790        timing.mean() * 1e6,
791    );
792
793    // -- Write profiling DOT + SVG --
794    let profile_dot = g.dot_with_profile();
795    let profile_dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
796    std::fs::write(profile_dot_path, &profile_dot).expect("write showcase_profile.dot");
797    println!("Wrote {}", profile_dot_path);
798
799    let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
800    let profile_svg = g.svg_with_profile(Some(profile_svg_path)).expect("write showcase_profile.svg");
801    println!("Wrote {} ({} bytes)", profile_svg_path, profile_svg.len());
802
803    // -- Write training HTML --
804    let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
805    g.plot_html(html_path, &["loss"]).expect("write showcase_training.html");
806    println!("Wrote {}", html_path);
807
808    // -- Write training log --
809    let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
810    g.write_log(log_path, 5, &["loss"]).expect("write showcase_training.log");
811    println!("Wrote {}", log_path);
812
813    // -- Checkpoint round-trip --
814    let path = "/tmp/flodl_showcase_checkpoint.fdl";
815    let named = g.named_parameters();
816    let named_bufs = g.named_buffers();
817    save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("save failed");
818    let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("load failed");
819    println!("\nCheckpoint save/load: OK ({} loaded)", report.loaded.len());
820
821    // -- no_grad inference (eval mode works now — BatchNorm has running stats from training) --
822    g.eval();
823    g.reset_state();
824    let final_out = no_grad(|| g.forward_multi(&[make_input(false), make_context()])).unwrap();
825    let final_vals = final_out.data().to_f32_vec().unwrap();
826    println!("no_grad inference: {:?}", final_vals);
827    assert!(final_vals.iter().all(|v| v.is_finite()), "no_grad output should be finite");
828
829    println!("\nAll showcase checks passed.");
830}
831
832// ---------------------------------------------------------------------------
833// Tests
834// ---------------------------------------------------------------------------
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    #[test]
841    fn test_build() {
842        let g = build_showcase().unwrap();
843        let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
844        let vals = result.data().to_f32_vec().unwrap();
845        assert_eq!(vals.len(), 4, "expected 4 outputs (2x2), got {}", vals.len());
846    }
847
848    #[test]
849    fn test_forward_ref_carries_state() {
850        let g = build_showcase().unwrap();
851
852        let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
853        let v1 = r1.data().to_f32_vec().unwrap();
854
855        let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
856        let v2 = r2.data().to_f32_vec().unwrap();
857
858        assert_ne!(v1, v2, "pass 2 should differ from pass 1");
859    }
860
861    #[test]
862    fn test_reset_state() {
863        let g = build_showcase().unwrap();
864
865        // Populate BatchNorm running stats, then switch to eval mode
866        // so forward passes don't update running stats (deterministic).
867        g.forward_multi(&[make_input(false), make_context()]).unwrap();
868        g.eval();
869        g.reset_state();
870
871        let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
872        let v1 = r1.data().to_f32_vec().unwrap();
873
874        g.forward_multi(&[make_input(false), make_context()]).unwrap();
875
876        g.reset_state();
877        let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
878        let v3 = r3.data().to_f32_vec().unwrap();
879
880        assert_eq!(v1, v3, "after reset should match pass 1");
881    }
882
883    #[test]
884    fn test_detach_state() {
885        let g = build_showcase().unwrap();
886
887        g.forward_multi(&[make_input(false), make_context()]).unwrap();
888        g.detach_state();
889
890        let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
891        assert_eq!(result.data().to_f32_vec().unwrap().len(), 4);
892    }
893
894    #[test]
895    fn test_backward() {
896        let g = build_showcase().unwrap();
897
898        let result = g.forward_multi(&[make_input(true), make_context()]).unwrap();
899        let loss = result.sum().unwrap();
900        loss.backward().unwrap();
901
902        let with_grad = count_grads(&g.parameters());
903        assert!(with_grad > 0, "no parameters received gradients");
904    }
905
906    #[test]
907    fn test_parameters() {
908        let g = build_showcase().unwrap();
909        let params = g.parameters();
910        assert!(
911            params.len() > 44,
912            "expected more than 44 params (extended graph), got {}",
913            params.len()
914        );
915    }
916
917    #[test]
918    fn test_set_training() {
919        let g = build_showcase().unwrap();
920
921        // Run one training pass to populate BatchNorm running stats
922        g.forward_multi(&[make_input(false), make_context()]).unwrap();
923
924        // Now eval mode works (running stats populated)
925        g.set_training(false);
926        g.reset_state();
927        let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
928
929        // Switch back to training
930        g.set_training(true);
931        g.reset_state();
932        let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
933
934        assert_eq!(r1.data().to_f32_vec().unwrap().len(), 4);
935        assert_eq!(r2.data().to_f32_vec().unwrap().len(), 4);
936    }
937
938    #[test]
939    fn test_dot() {
940        let g = build_showcase().unwrap();
941        let dot = g.dot();
942        assert!(!dot.is_empty(), "DOT output is empty");
943        assert!(dot.contains("digraph"), "DOT should contain digraph");
944    }
945
946    #[test]
947    fn test_training_loop() {
948        let g = build_showcase().unwrap();
949        g.train();
950
951        let params = g.parameters();
952        let mut optimizer = Adam::new(&params, 0.01);
953
954        let mut losses = Vec::new();
955        for _ in 0..3 {
956            let input = make_input(true);
957            let ctx = make_context();
958            let target = make_target();
959
960            let pred = g.forward_multi(&[input, ctx]).unwrap();
961            let loss = mse_loss(&pred, &target).unwrap();
962            losses.push(loss.item().unwrap());
963
964            loss.backward().unwrap();
965            clip_grad_norm(&params, 1.0).unwrap();
966            optimizer.step().unwrap();
967            optimizer.zero_grad();
968            g.end_step();
969        }
970
971        // Loss should be finite across all steps
972        for (i, &l) in losses.iter().enumerate() {
973            assert!(l.is_finite(), "loss at step {} is not finite: {}", i, l);
974        }
975    }
976
977    #[test]
978    fn test_observation() {
979        let g = build_showcase().unwrap();
980
981        // Run forward — tagged outputs should be captured
982        let out = g.forward_multi(&[make_input(false), make_context()]).unwrap();
983
984        // "output" tag should have a captured value
985        let tagged = g.tagged("output");
986        assert!(tagged.is_some(), "tagged 'output' not captured");
987        assert_eq!(tagged.unwrap().shape(), &[2, 2]);
988
989        // Record a scalar metric manually (output is [1,2], not scalar)
990        let loss_val = out.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
991        g.record("test_loss", &[loss_val]);
992        g.flush(&["test_loss"]);
993        assert_eq!(g.flush_count(), 1);
994
995        // Run another epoch
996        let out2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
997        let loss_val2 = out2.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
998        g.record("test_loss", &[loss_val2]);
999        g.flush(&["test_loss"]);
1000        assert_eq!(g.flush_count(), 2);
1001
1002        // Trend should have 2 epochs
1003        let trend = g.trend("test_loss");
1004        assert_eq!(trend.len(), 2, "expected 2 epochs in trend");
1005    }
1006
1007    #[test]
1008    fn test_profiling() {
1009        let g = build_showcase().unwrap();
1010        g.enable_profiling();
1011
1012        g.forward_multi(&[make_input(false), make_context()]).unwrap();
1013        g.collect_timings(&[]);  // snapshot node timings to buffer
1014        g.flush_timings(&[]);    // flush buffer to epoch history
1015
1016        let timing = g.timing_trend("input");
1017        assert_eq!(timing.len(), 1, "expected 1 timing epoch");
1018        assert!(timing.latest() > 0.0, "timing should be positive");
1019    }
1020
1021    #[test]
1022    fn test_checkpoint_roundtrip() {
1023        let g = build_showcase().unwrap();
1024        let params = g.parameters();
1025        let named = g.named_parameters();
1026
1027        // Populate BatchNorm running stats, then use eval mode for deterministic output
1028        g.forward_multi(&[make_input(false), make_context()]).unwrap();
1029        g.eval();
1030        g.reset_state();
1031
1032        // Save checkpoint and capture baseline output
1033        let path = "/tmp/flodl_showcase_test_ckpt.fdl";
1034        let named_bufs = g.named_buffers();
1035        save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
1036
1037        let before = g.forward_multi(&[make_input(false), make_context()]).unwrap();
1038        let v_before = before.data().to_f32_vec().unwrap();
1039        assert!(v_before.iter().all(|v| v.is_finite()), "pre-train output NaN");
1040
1041        // Capture first parameter tensor for direct comparison
1042        let p0_before = params[0].variable.data().to_f32_vec().unwrap();
1043
1044        // Mutate parameters via optimizer step
1045        g.reset_state();
1046        g.train();
1047        let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
1048        let loss = pred.sum().unwrap();
1049        loss.backward().unwrap();
1050        let mut opt = Adam::new(&params, 0.1);
1051        opt.step().unwrap();
1052
1053        // Verify parameters actually changed
1054        let p0_after = params[0].variable.data().to_f32_vec().unwrap();
1055        assert_ne!(p0_before, p0_after, "training should change parameters");
1056
1057        // Restore checkpoint and verify parameters match original
1058        let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
1059        assert_eq!(report.loaded.len(), named.len());
1060        let p0_restored = params[0].variable.data().to_f32_vec().unwrap();
1061        assert_eq!(p0_before, p0_restored, "checkpoint restore should match original params");
1062
1063        // Cleanup
1064        let _ = std::fs::remove_file(path);
1065    }
1066
1067    #[test]
1068    fn test_no_grad() {
1069        let g = build_showcase().unwrap();
1070
1071        let result = no_grad(|| g.forward_multi(&[make_input(true), make_context()])).unwrap();
1072        let vals = result.data().to_f32_vec().unwrap();
1073        assert_eq!(vals.len(), 4);
1074        assert!(vals.iter().all(|v| v.is_finite()), "no_grad should produce finite values");
1075    }
1076
1077    #[test]
1078    fn test_visualization() {
1079        let g = build_showcase().unwrap();
1080
1081        // Structural DOT
1082        let dot = g.dot();
1083        assert!(dot.contains("digraph"), "DOT should contain digraph");
1084        assert!(dot.contains("#input"), "DOT should contain #input tag");
1085
1086        let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
1087        std::fs::write(dot_path, &dot).unwrap();
1088
1089        // Structural SVG
1090        let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
1091        let svg = g.svg(Some(svg_path)).unwrap();
1092        assert!(svg.len() > 100, "SVG should have content");
1093
1094        // Run a forward pass with profiling for timing DOT
1095        g.enable_profiling();
1096        g.forward_multi(&[make_input(false), make_context()]).unwrap();
1097
1098        let profile_dot = g.dot_with_profile();
1099        assert!(profile_dot.contains("Forward:"), "profile DOT should show total time");
1100
1101        let profile_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
1102        std::fs::write(profile_path, &profile_dot).unwrap();
1103
1104        // Profile SVG
1105        let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
1106        let profile_svg = g.svg_with_profile(Some(profile_svg_path)).unwrap();
1107        assert!(profile_svg.len() > 100, "profile SVG should have content");
1108
1109        // Training with observation for HTML + log
1110        g.train();
1111        g.reset_state();
1112        let params = g.parameters();
1113        let mut optimizer = Adam::new(&params, 0.01);
1114
1115        for _epoch in 0..3 {
1116            for _ in 0..4 {
1117                optimizer.zero_grad();
1118                let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
1119                let loss = mse_loss(&pred, &make_target()).unwrap();
1120                loss.backward().unwrap();
1121                optimizer.step().unwrap();
1122
1123                g.record_scalar("loss", loss.item().unwrap());
1124                g.end_step();
1125            }
1126            g.end_epoch();
1127        }
1128
1129        // Training HTML plot
1130        let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
1131        g.plot_html(html_path, &["loss"]).unwrap();
1132
1133        // Training log
1134        let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
1135        g.write_log(log_path, 3, &["loss"]).unwrap();
1136
1137        // Verify files exist and have content
1138        assert!(std::fs::metadata(dot_path).unwrap().len() > 100);
1139        assert!(std::fs::metadata(svg_path).unwrap().len() > 100);
1140        assert!(std::fs::metadata(profile_path).unwrap().len() > 100);
1141        assert!(std::fs::metadata(profile_svg_path).unwrap().len() > 100);
1142        assert!(std::fs::metadata(html_path).unwrap().len() > 100);
1143        assert!(std::fs::metadata(log_path).unwrap().len() > 10);
1144    }
1145
1146    #[test]
1147    fn test_cosine_scheduler() {
1148        let sched = CosineScheduler::new(0.01, 1e-5, 10);
1149
1150        let lr_start = sched.lr(0);
1151        let lr_end = sched.lr(10);
1152
1153        assert!(lr_end < lr_start, "LR should decrease: {} -> {}", lr_start, lr_end);
1154        assert!((lr_end - 1e-5).abs() < 1e-4, "LR should reach min_lr");
1155    }
1156
1157    #[test]
1158    fn test_fork_tag() {
1159        let g = build_showcase().unwrap();
1160        g.forward_multi(&[make_input(false), make_context()]).unwrap();
1161
1162        // Fork output should be captured via tag
1163        let spectral = g.tagged("spectral");
1164        assert!(spectral.is_some(), "fork tag 'spectral' not captured");
1165    }
1166}