1use std::collections::HashMap;
39
40use flodl::{
41 Device, Tensor, Variable,
42 Module, NamedInputModule,
43 Linear, GELU, SiLU, LayerNorm, Dropout, BatchNorm,
44 FlowBuilder, MergeOp, Graph, modules,
45 SoftmaxRouter, ThresholdHalt, LearnedHalt,
46 Reshape, StateAdd,
47 Adam, Optimizer, mse_loss, clip_grad_norm,
48 save_checkpoint_file, load_checkpoint_file,
49 CosineScheduler,
50 no_grad,
51};
52use flodl::monitor::Monitor;
53
54fn ffn_block(dim: i64) -> flodl::Result<Graph> {
60 FlowBuilder::from(Linear::new(dim, dim)?)
61 .through(GELU)
62 .through(LayerNorm::new(dim)?)
63 .build()
64}
65
66fn read_head(dim: i64) -> flodl::Result<Graph> {
68 FlowBuilder::from(Linear::new(dim, dim)?)
69 .through(LayerNorm::new(dim)?)
70 .build()
71}
72
73fn silu_block(dim: i64) -> flodl::Result<Graph> {
75 FlowBuilder::from(Linear::new(dim, dim)?)
76 .through(SiLU)
77 .through(BatchNorm::new(dim)?)
78 .build()
79}
80
81struct RmsNorm {
88 eps: f64,
89}
90
91impl RmsNorm {
92 fn new() -> Self {
93 RmsNorm { eps: 1e-6 }
94 }
95}
96
97impl Module for RmsNorm {
98 fn name(&self) -> &str { "rmsnorm" }
99
100 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
101 let sq = input.pow_scalar(2.0)?; let ms = sq.mean_dim(-1, true)?; let shifted = ms.add_scalar(self.eps)?; let rms = shifted.sqrt()?; input.div(&rms) }
107}
108
109struct SoftClamp {
112 scale: f64,
113 bound: f64,
114}
115
116impl SoftClamp {
117 fn new(scale: f64, bound: f64) -> Self {
118 SoftClamp { scale, bound }
119 }
120}
121
122impl Module for SoftClamp {
123 fn name(&self) -> &str { "softclamp" }
124
125 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
126 let scaled = input.mul_scalar(self.scale)?; let clamped = scaled.clamp(-self.bound, self.bound)?; clamped.abs() }
130}
131
132struct Softplus;
135
136impl Module for Softplus {
137 fn name(&self) -> &str { "softplus" }
138
139 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
140 let ex = input.exp()?; let shifted = ex.add_scalar(1.0)?; shifted.log() }
144}
145
146struct NegSigmoidGate;
149
150impl Module for NegSigmoidGate {
151 fn name(&self) -> &str { "neg_sigmoid_gate" }
152
153 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
154 let negated = input.neg()?; let gate = negated.sigmoid()?; input.mul(&gate) }
158}
159
160struct ShapeOps {
165 batch: i64,
166 dim: i64,
167}
168
169impl ShapeOps {
170 fn new(batch: i64, dim: i64) -> Self {
171 ShapeOps { batch, dim }
172 }
173}
174
175impl Module for ShapeOps {
176 fn name(&self) -> &str { "shape_ops" }
177
178 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
179 let flat = input.flatten(0, -1)?; let expanded = flat.unsqueeze(0)?; let squeezed = expanded.squeeze(0)?; squeezed.reshape(&[self.batch, self.dim]) }
184}
185
186struct LogSoftmaxReduce;
189
190impl Module for LogSoftmaxReduce {
191 fn name(&self) -> &str { "log_softmax_reduce" }
192
193 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
194 let lsm = input.log_softmax(-1)?; lsm.sum_dim(-1, true) }
197}
198
199struct TransposeRoundTrip;
202
203impl Module for TransposeRoundTrip {
204 fn name(&self) -> &str { "transpose_rt" }
205
206 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
207 let t = input.transpose(0, 1)?; t.permute(&[1, 0]) }
210}
211
212struct ContextBlend;
215
216impl Module for ContextBlend {
217 fn name(&self) -> &str { "context_blend" }
218
219 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
220 Ok(input.clone())
221 }
222
223 fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
224 Some(self)
225 }
226}
227
228impl NamedInputModule for ContextBlend {
229 fn forward_named(
230 &self,
231 input: &Variable,
232 refs: &HashMap<String, Variable>,
233 ) -> flodl::Result<Variable> {
234 let ctx = &refs["ctx"];
235 let scaled = ctx.div_scalar(2.0)?; let gate = scaled.sigmoid()?; let modulated = input.mul(&gate)?; modulated.add(input) }
240}
241
242struct SpectralBasis;
246
247impl Module for SpectralBasis {
248 fn name(&self) -> &str { "spectral_basis" }
249
250 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
251 let s = input.sin()?; let c = input.cos()?; let sc = s.add(&c)?;
254 let r = sc.reciprocal()?; r.tanh() }
257}
258
259struct VarianceGate {
262 dim: i64,
263}
264
265impl VarianceGate {
266 fn new(dim: i64) -> Self {
267 VarianceGate { dim }
268 }
269}
270
271impl Module for VarianceGate {
272 fn name(&self) -> &str { "variance_gate" }
273
274 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
275 let m = input.mean()?; let _v = input.var()?; let s = input.std()?; let gate_val = m.add(&s)?;
279 let gate = gate_val.expand(&[1, self.dim])?; input.mul(&gate) }
282}
283
284struct ChunkRecombine;
287
288impl Module for ChunkRecombine {
289 fn name(&self) -> &str { "chunk_recombine" }
290
291 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
292 let chunks = input.chunk(2, -1)?; let a = chunks[0].relu()?; let b = chunks[1].neg()?;
295 a.cat(&b, -1) }
297}
298
299struct AttentionLikeOps {
302 dim: i64,
303}
304
305impl AttentionLikeOps {
306 fn new(dim: i64) -> Self {
307 AttentionLikeOps { dim }
308 }
309}
310
311impl Module for AttentionLikeOps {
312 fn name(&self) -> &str { "attention_ops" }
313
314 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
315 let weights = input.softmax(-1)?; let row = input.select(0, 0)?; let row2d = row.unsqueeze(0)?;
320
321 let half_dim = self.dim / 2;
323 let first_half = row2d.narrow(-1, 0, half_dim)?; let idx = Tensor::from_i64(&[0, 1], &[2], Device::CPU)?;
327 let selected = first_half.index_select(-1, &idx)?; let scale = selected.mean()?; let scale_expanded = scale.expand(&[1, self.dim])?; weights.add(&scale_expanded)
333 }
334}
335
336struct TopKFilterOps {
339 dim: i64,
340}
341
342impl TopKFilterOps {
343 fn new(dim: i64) -> Self {
344 TopKFilterOps { dim }
345 }
346}
347
348impl Module for TopKFilterOps {
349 fn name(&self) -> &str { "topk_filter" }
350
351 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
352 let (values, indices) = input.topk(4, -1, true, true)?; let (sorted, _sort_idx) = values.sort(-1, false)?; let gathered = input.gather(-1, &indices)?; let mn = gathered.min()?; let mx = gathered.max()?; let range = mx.sub(&mn)?;
365
366 let pad_amount = self.dim - 4;
368 let padded = sorted.pad(&[0, pad_amount], 0.0)?; padded.add(&range.expand(&[1, self.dim])?)
371 }
372}
373
374struct RepeatNarrow {
377 dim: i64,
378}
379
380impl RepeatNarrow {
381 fn new(dim: i64) -> Self {
382 RepeatNarrow { dim }
383 }
384}
385
386impl Module for RepeatNarrow {
387 fn name(&self) -> &str { "repeat_narrow" }
388
389 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
390 let repeated = input.repeat(&[1, 2])?; repeated.narrow(-1, 0, self.dim) }
393}
394
395struct CounterModule {
398 count: std::cell::Cell<u32>,
399}
400
401impl CounterModule {
402 fn new() -> Self {
403 CounterModule { count: std::cell::Cell::new(0) }
404 }
405}
406
407impl Module for CounterModule {
408 fn name(&self) -> &str { "counter" }
409
410 fn forward(&self, input: &Variable) -> flodl::Result<Variable> {
411 self.count.set(self.count.get() + 1);
412 Ok(input.clone())
413 }
414
415 fn reset(&self) {
416 self.count.set(0);
417 }
418}
419
420struct HeavyPathSelector;
427
428impl Module for HeavyPathSelector {
429 fn name(&self) -> &str { "heavy_path_selector" }
430
431 fn forward(&self, _input: &Variable) -> flodl::Result<Variable> {
432 let t = Tensor::from_f32(&[0.0], &[1], Device::CPU)?;
433 Ok(Variable::new(t, false))
434 }
435
436 fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
437 Some(self)
438 }
439}
440
441impl NamedInputModule for HeavyPathSelector {
442 fn forward_named(
443 &self,
444 _input: &Variable,
445 refs: &HashMap<String, Variable>,
446 ) -> flodl::Result<Variable> {
447 let refined = &refs["refined"];
448 let data = refined.data().to_f32_vec()?;
449 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
450
451 let branch = if max_val > 5.0 { 1.0_f32 } else { 0.0 };
452 let t = Tensor::from_f32(&[branch], &[1], Device::CPU)?;
453 Ok(Variable::new(t, false))
454 }
455}
456
457fn spectral_monitor(dim: i64) -> flodl::Result<Graph> {
464 FlowBuilder::from(SpectralBasis)
465 .through(Linear::new(dim, dim)?)
466 .build()
467}
468
469fn build_showcase() -> flodl::Result<Graph> {
516 const B: i64 = 2; const H: i64 = 8;
518
519 FlowBuilder::from(Linear::new(2, H)?)
520 .input(&["ctx"])
522
523 .tag("input")
525
526 .through(GELU)
528 .through(LayerNorm::new(H)?)
529 .through(RmsNorm::new())
530
531 .through(ContextBlend)
533 .using(&["ctx"])
534
535 .fork(spectral_monitor(H)?)
537 .tag("spectral")
538
539 .split(modules![read_head(H)?, read_head(H)?])
542 .merge(MergeOp::Mean)
543
544 .also(Linear::new(H, H)?)
546 .through(Dropout::new(0.1))
547 .through(SoftClamp::new(0.5, 3.0))
548 .through(Softplus)
549
550 .through(VarianceGate::new(H))
552
553 .map(read_head(2)?)
555 .slices(H / 2)
556
557 .through(Reshape::new(&[B * 2, H / 2]))
559
560 .map(Linear::new(H / 2, H / 2)?)
562 .each()
563 .tag("halves")
564
565 .map(Linear::new(H / 2, H / 2)?)
568 .over("halves")
569
570 .map(Linear::new(H / 2, H / 2)?)
572 .batched()
573 .each()
574
575 .through(Reshape::new(&[B, H]))
576 .through(ShapeOps::new(B, H))
577 .through(NegSigmoidGate)
578 .through(TransposeRoundTrip)
579
580 .through(CounterModule::new())
582
583 .through(ChunkRecombine)
585
586 .through(AttentionLikeOps::new(H))
588
589 .through(TopKFilterOps::new(H))
591
592 .through(RepeatNarrow::new(H))
594
595 .loop_body(silu_block(H)?)
598 .for_n(2)
599 .tag("refined")
600
601 .gate(
605 SoftmaxRouter::new(H, 2)?,
606 modules![Linear::new(H, H)?, Linear::new(H, H)?],
607 )
608 .using(&["input"])
609
610 .switch(
614 HeavyPathSelector,
615 modules![Linear::new(H, H)?, ffn_block(H)?],
616 )
617 .using(&["refined"])
618
619 .through(StateAdd)
624 .using(&["memory"])
625 .tag("memory")
626
627 .loop_body(Linear::new(H, H)?)
630 .while_cond(ThresholdHalt::new(100.0), 5)
631
632 .loop_body(Linear::new(H, H)?)
635 .until_cond(LearnedHalt::new(H)?, 7)
636
637 .through(LogSoftmaxReduce)
638 .through(Linear::new(1, H)?)
639
640 .split(vec![
642 Box::new(Linear::new(H, H)?),
643 Box::new(Linear::new(H, H)?),
644 ])
645 .tag_group("final_heads")
646 .merge(MergeOp::Add)
647
648 .through(Linear::new(H, 2)?)
650 .tag("output")
651 .build()
652}
653
654fn make_input(requires_grad: bool) -> Variable {
659 let t = Tensor::from_f32(&[1.0, 2.0, 0.5, -1.0], &[2, 2], Device::CPU).unwrap();
660 Variable::new(t, requires_grad)
661}
662
663fn make_context() -> Variable {
664 let t = Tensor::from_f32(
665 &[0.5, -0.3, 0.8, 1.2, -0.5, 0.1, 0.9, -0.7,
666 0.2, 0.7, -0.4, 0.6, 1.0, -0.8, 0.3, -0.1],
667 &[2, 8],
668 Device::CPU,
669 ).unwrap();
670 Variable::new(t, false)
671}
672
673fn make_target() -> Variable {
674 let t = Tensor::from_f32(&[0.5, -0.5, -0.3, 0.8], &[2, 2], Device::CPU).unwrap();
675 Variable::new(t, false)
676}
677
678#[cfg(test)]
679fn count_grads(params: &[flodl::Parameter]) -> usize {
680 params
681 .iter()
682 .filter(|p| {
683 p.variable.grad()
684 .and_then(|g| g.to_f32_vec().ok())
685 .is_some_and(|d| d.iter().any(|v| *v != 0.0))
686 })
687 .count()
688}
689
690fn main() {
695 println!("=== floDl showcase ===\n");
696
697 println!("Building graph...");
699 let g = build_showcase().expect("build failed");
700 let n_params = g.parameters().len();
701 println!("Parameters: {}", n_params);
702
703 let result = g.forward_multi(&[make_input(false), make_context()])
705 .expect("forward failed");
706 println!("Output: {:?} (shape {:?})", result.data().to_f32_vec().unwrap(), result.shape());
707
708 g.reset_state();
710 let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
711 let v1 = r1.data().to_f32_vec().unwrap();
712 let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
713 let v2 = r2.data().to_f32_vec().unwrap();
714 println!("State drift: pass2 differs = {}", v1 != v2);
715
716 g.reset_state();
718 let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
719 let v3 = r3.data().to_f32_vec().unwrap();
720 println!("Reset restores: {}", v1 == v3);
721
722 let dot = g.dot();
724 println!("DOT: {} bytes", dot.len());
725
726 let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
728 std::fs::write(dot_path, &dot).expect("write showcase.dot");
729 println!("Wrote {}", dot_path);
730
731 let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
733 let svg = g.svg(Some(svg_path)).expect("write showcase.svg");
734 println!("Wrote {} ({} bytes)", svg_path, svg.len());
735
736 println!("\n--- Training (5 epochs x 4 steps) ---");
738 g.train();
739 g.reset_state();
740 g.enable_profiling();
741
742 let params = g.parameters();
743 let mut optimizer = Adam::new(¶ms, 0.001);
744 let num_epochs = 5;
745 let total_steps = num_epochs * 4;
746 let sched = CosineScheduler::new(0.001, 1e-5, total_steps);
747 let mut monitor = Monitor::new(num_epochs);
748
749 let mut step_idx = 0;
750 for epoch in 0..num_epochs {
751 let t = std::time::Instant::now();
752 for _ in 0..4 {
753 optimizer.zero_grad();
754 let input = make_input(true);
755 let ctx = make_context();
756 let target = make_target();
757
758 let pred = g.forward_multi(&[input, ctx]).unwrap();
759 let loss = mse_loss(&pred, &target).unwrap();
760
761 loss.backward().unwrap();
762 clip_grad_norm(¶ms, 1.0).unwrap();
763 optimizer.set_lr(sched.lr(step_idx));
764 optimizer.step().unwrap();
765 step_idx += 1;
766
767 g.record_scalar("loss", loss.item().unwrap());
768 g.record_scalar("lr", sched.lr(step_idx - 1));
769 g.end_step();
770 }
771
772 g.end_epoch();
773 monitor.log(epoch, t.elapsed(), &g);
774 }
775
776 let trend = g.trend("loss");
778 println!(
779 "\nLoss trend: {} epochs, slope={:.4}, improving={}",
780 trend.len(),
781 trend.slope(0),
782 trend.improving(0),
783 );
784
785 let timing = g.timing_trend("input");
787 println!(
788 "Timing trend (input node): {} epochs, mean={:.1}us",
789 timing.len(),
790 timing.mean() * 1e6,
791 );
792
793 let profile_dot = g.dot_with_profile();
795 let profile_dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
796 std::fs::write(profile_dot_path, &profile_dot).expect("write showcase_profile.dot");
797 println!("Wrote {}", profile_dot_path);
798
799 let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
800 let profile_svg = g.svg_with_profile(Some(profile_svg_path)).expect("write showcase_profile.svg");
801 println!("Wrote {} ({} bytes)", profile_svg_path, profile_svg.len());
802
803 let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
805 g.plot_html(html_path, &["loss"]).expect("write showcase_training.html");
806 println!("Wrote {}", html_path);
807
808 let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
810 g.write_log(log_path, 5, &["loss"]).expect("write showcase_training.log");
811 println!("Wrote {}", log_path);
812
813 let path = "/tmp/flodl_showcase_checkpoint.fdl";
815 let named = g.named_parameters();
816 let named_bufs = g.named_buffers();
817 save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("save failed");
818 let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).expect("load failed");
819 println!("\nCheckpoint save/load: OK ({} loaded)", report.loaded.len());
820
821 g.eval();
823 g.reset_state();
824 let final_out = no_grad(|| g.forward_multi(&[make_input(false), make_context()])).unwrap();
825 let final_vals = final_out.data().to_f32_vec().unwrap();
826 println!("no_grad inference: {:?}", final_vals);
827 assert!(final_vals.iter().all(|v| v.is_finite()), "no_grad output should be finite");
828
829 println!("\nAll showcase checks passed.");
830}
831
832#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[test]
841 fn test_build() {
842 let g = build_showcase().unwrap();
843 let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
844 let vals = result.data().to_f32_vec().unwrap();
845 assert_eq!(vals.len(), 4, "expected 4 outputs (2x2), got {}", vals.len());
846 }
847
848 #[test]
849 fn test_forward_ref_carries_state() {
850 let g = build_showcase().unwrap();
851
852 let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
853 let v1 = r1.data().to_f32_vec().unwrap();
854
855 let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
856 let v2 = r2.data().to_f32_vec().unwrap();
857
858 assert_ne!(v1, v2, "pass 2 should differ from pass 1");
859 }
860
861 #[test]
862 fn test_reset_state() {
863 let g = build_showcase().unwrap();
864
865 g.forward_multi(&[make_input(false), make_context()]).unwrap();
868 g.eval();
869 g.reset_state();
870
871 let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
872 let v1 = r1.data().to_f32_vec().unwrap();
873
874 g.forward_multi(&[make_input(false), make_context()]).unwrap();
875
876 g.reset_state();
877 let r3 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
878 let v3 = r3.data().to_f32_vec().unwrap();
879
880 assert_eq!(v1, v3, "after reset should match pass 1");
881 }
882
883 #[test]
884 fn test_detach_state() {
885 let g = build_showcase().unwrap();
886
887 g.forward_multi(&[make_input(false), make_context()]).unwrap();
888 g.detach_state();
889
890 let result = g.forward_multi(&[make_input(false), make_context()]).unwrap();
891 assert_eq!(result.data().to_f32_vec().unwrap().len(), 4);
892 }
893
894 #[test]
895 fn test_backward() {
896 let g = build_showcase().unwrap();
897
898 let result = g.forward_multi(&[make_input(true), make_context()]).unwrap();
899 let loss = result.sum().unwrap();
900 loss.backward().unwrap();
901
902 let with_grad = count_grads(&g.parameters());
903 assert!(with_grad > 0, "no parameters received gradients");
904 }
905
906 #[test]
907 fn test_parameters() {
908 let g = build_showcase().unwrap();
909 let params = g.parameters();
910 assert!(
911 params.len() > 44,
912 "expected more than 44 params (extended graph), got {}",
913 params.len()
914 );
915 }
916
917 #[test]
918 fn test_set_training() {
919 let g = build_showcase().unwrap();
920
921 g.forward_multi(&[make_input(false), make_context()]).unwrap();
923
924 g.set_training(false);
926 g.reset_state();
927 let r1 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
928
929 g.set_training(true);
931 g.reset_state();
932 let r2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
933
934 assert_eq!(r1.data().to_f32_vec().unwrap().len(), 4);
935 assert_eq!(r2.data().to_f32_vec().unwrap().len(), 4);
936 }
937
938 #[test]
939 fn test_dot() {
940 let g = build_showcase().unwrap();
941 let dot = g.dot();
942 assert!(!dot.is_empty(), "DOT output is empty");
943 assert!(dot.contains("digraph"), "DOT should contain digraph");
944 }
945
946 #[test]
947 fn test_training_loop() {
948 let g = build_showcase().unwrap();
949 g.train();
950
951 let params = g.parameters();
952 let mut optimizer = Adam::new(¶ms, 0.01);
953
954 let mut losses = Vec::new();
955 for _ in 0..3 {
956 let input = make_input(true);
957 let ctx = make_context();
958 let target = make_target();
959
960 let pred = g.forward_multi(&[input, ctx]).unwrap();
961 let loss = mse_loss(&pred, &target).unwrap();
962 losses.push(loss.item().unwrap());
963
964 loss.backward().unwrap();
965 clip_grad_norm(¶ms, 1.0).unwrap();
966 optimizer.step().unwrap();
967 optimizer.zero_grad();
968 g.end_step();
969 }
970
971 for (i, &l) in losses.iter().enumerate() {
973 assert!(l.is_finite(), "loss at step {} is not finite: {}", i, l);
974 }
975 }
976
977 #[test]
978 fn test_observation() {
979 let g = build_showcase().unwrap();
980
981 let out = g.forward_multi(&[make_input(false), make_context()]).unwrap();
983
984 let tagged = g.tagged("output");
986 assert!(tagged.is_some(), "tagged 'output' not captured");
987 assert_eq!(tagged.unwrap().shape(), &[2, 2]);
988
989 let loss_val = out.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
991 g.record("test_loss", &[loss_val]);
992 g.flush(&["test_loss"]);
993 assert_eq!(g.flush_count(), 1);
994
995 let out2 = g.forward_multi(&[make_input(false), make_context()]).unwrap();
997 let loss_val2 = out2.data().to_f32_vec().unwrap().iter().map(|v| *v as f64).sum::<f64>();
998 g.record("test_loss", &[loss_val2]);
999 g.flush(&["test_loss"]);
1000 assert_eq!(g.flush_count(), 2);
1001
1002 let trend = g.trend("test_loss");
1004 assert_eq!(trend.len(), 2, "expected 2 epochs in trend");
1005 }
1006
1007 #[test]
1008 fn test_profiling() {
1009 let g = build_showcase().unwrap();
1010 g.enable_profiling();
1011
1012 g.forward_multi(&[make_input(false), make_context()]).unwrap();
1013 g.collect_timings(&[]); g.flush_timings(&[]); let timing = g.timing_trend("input");
1017 assert_eq!(timing.len(), 1, "expected 1 timing epoch");
1018 assert!(timing.latest() > 0.0, "timing should be positive");
1019 }
1020
1021 #[test]
1022 fn test_checkpoint_roundtrip() {
1023 let g = build_showcase().unwrap();
1024 let params = g.parameters();
1025 let named = g.named_parameters();
1026
1027 g.forward_multi(&[make_input(false), make_context()]).unwrap();
1029 g.eval();
1030 g.reset_state();
1031
1032 let path = "/tmp/flodl_showcase_test_ckpt.fdl";
1034 let named_bufs = g.named_buffers();
1035 save_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
1036
1037 let before = g.forward_multi(&[make_input(false), make_context()]).unwrap();
1038 let v_before = before.data().to_f32_vec().unwrap();
1039 assert!(v_before.iter().all(|v| v.is_finite()), "pre-train output NaN");
1040
1041 let p0_before = params[0].variable.data().to_f32_vec().unwrap();
1043
1044 g.reset_state();
1046 g.train();
1047 let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
1048 let loss = pred.sum().unwrap();
1049 loss.backward().unwrap();
1050 let mut opt = Adam::new(¶ms, 0.1);
1051 opt.step().unwrap();
1052
1053 let p0_after = params[0].variable.data().to_f32_vec().unwrap();
1055 assert_ne!(p0_before, p0_after, "training should change parameters");
1056
1057 let report = load_checkpoint_file(path, &named, &named_bufs, Some(g.structural_hash())).unwrap();
1059 assert_eq!(report.loaded.len(), named.len());
1060 let p0_restored = params[0].variable.data().to_f32_vec().unwrap();
1061 assert_eq!(p0_before, p0_restored, "checkpoint restore should match original params");
1062
1063 let _ = std::fs::remove_file(path);
1065 }
1066
1067 #[test]
1068 fn test_no_grad() {
1069 let g = build_showcase().unwrap();
1070
1071 let result = no_grad(|| g.forward_multi(&[make_input(true), make_context()])).unwrap();
1072 let vals = result.data().to_f32_vec().unwrap();
1073 assert_eq!(vals.len(), 4);
1074 assert!(vals.iter().all(|v| v.is_finite()), "no_grad should produce finite values");
1075 }
1076
1077 #[test]
1078 fn test_visualization() {
1079 let g = build_showcase().unwrap();
1080
1081 let dot = g.dot();
1083 assert!(dot.contains("digraph"), "DOT should contain digraph");
1084 assert!(dot.contains("#input"), "DOT should contain #input tag");
1085
1086 let dot_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.dot");
1087 std::fs::write(dot_path, &dot).unwrap();
1088
1089 let svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase.svg");
1091 let svg = g.svg(Some(svg_path)).unwrap();
1092 assert!(svg.len() > 100, "SVG should have content");
1093
1094 g.enable_profiling();
1096 g.forward_multi(&[make_input(false), make_context()]).unwrap();
1097
1098 let profile_dot = g.dot_with_profile();
1099 assert!(profile_dot.contains("Forward:"), "profile DOT should show total time");
1100
1101 let profile_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.dot");
1102 std::fs::write(profile_path, &profile_dot).unwrap();
1103
1104 let profile_svg_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_profile.svg");
1106 let profile_svg = g.svg_with_profile(Some(profile_svg_path)).unwrap();
1107 assert!(profile_svg.len() > 100, "profile SVG should have content");
1108
1109 g.train();
1111 g.reset_state();
1112 let params = g.parameters();
1113 let mut optimizer = Adam::new(¶ms, 0.01);
1114
1115 for _epoch in 0..3 {
1116 for _ in 0..4 {
1117 optimizer.zero_grad();
1118 let pred = g.forward_multi(&[make_input(true), make_context()]).unwrap();
1119 let loss = mse_loss(&pred, &make_target()).unwrap();
1120 loss.backward().unwrap();
1121 optimizer.step().unwrap();
1122
1123 g.record_scalar("loss", loss.item().unwrap());
1124 g.end_step();
1125 }
1126 g.end_epoch();
1127 }
1128
1129 let html_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.html");
1131 g.plot_html(html_path, &["loss"]).unwrap();
1132
1133 let log_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/showcase/showcase_training.log");
1135 g.write_log(log_path, 3, &["loss"]).unwrap();
1136
1137 assert!(std::fs::metadata(dot_path).unwrap().len() > 100);
1139 assert!(std::fs::metadata(svg_path).unwrap().len() > 100);
1140 assert!(std::fs::metadata(profile_path).unwrap().len() > 100);
1141 assert!(std::fs::metadata(profile_svg_path).unwrap().len() > 100);
1142 assert!(std::fs::metadata(html_path).unwrap().len() > 100);
1143 assert!(std::fs::metadata(log_path).unwrap().len() > 10);
1144 }
1145
1146 #[test]
1147 fn test_cosine_scheduler() {
1148 let sched = CosineScheduler::new(0.01, 1e-5, 10);
1149
1150 let lr_start = sched.lr(0);
1151 let lr_end = sched.lr(10);
1152
1153 assert!(lr_end < lr_start, "LR should decrease: {} -> {}", lr_start, lr_end);
1154 assert!((lr_end - 1e-5).abs() < 1e-4, "LR should reach min_lr");
1155 }
1156
1157 #[test]
1158 fn test_fork_tag() {
1159 let g = build_showcase().unwrap();
1160 g.forward_multi(&[make_input(false), make_context()]).unwrap();
1161
1162 let spectral = g.tagged("spectral");
1164 assert!(spectral.is_some(), "fork tag 'spectral' not captured");
1165 }
1166}