Skip to main content

flodl/graph/
mod.rs

1//! Computation graph: fluent builder, parallel execution, observation, profiling,
2//! visualization, and hierarchical composition.
3//!
4//! Build graphs with [`FlowBuilder`], execute via the [`Module`] trait.
5//! Label subgraphs for tree features: selective freeze/thaw, subgraph
6//! checkpoints, cross-boundary observation, and per-subgraph optimizer groups.
7//!
8//! ```ignore
9//! let encoder = FlowBuilder::from(Linear::new(4, 8)?)
10//!     .through(GELU)
11//!     .label("encoder")
12//!     .build()?;
13//!
14//! let model = FlowBuilder::from(encoder)
15//!     .through(Linear::new(8, 2)?)
16//!     .build()?;
17//!
18//! let y = model.forward(&x)?;
19//! model.freeze("encoder")?;  // freeze by label path
20//! ```
21
22pub mod node;
23pub mod flow;
24pub mod loop_node;
25pub mod switch;
26pub mod gate;
27pub mod map;
28pub mod observe;
29pub mod trend;
30pub mod profile;
31pub mod dot;
32pub mod plot;
33pub mod router;
34pub mod halt;
35pub mod reshape;
36pub mod state;
37pub mod snapshot;
38pub mod tree;
39pub mod verbose;
40
41use std::cell::{Cell, OnceCell, RefCell};
42use std::collections::{BTreeSet, HashMap, HashSet};
43use std::rc::Rc;
44use std::time::Instant;
45
46use indexmap::IndexMap;
47use hmac_sha256::Hash as Sha256;
48
49use node::*;
50use crate::autograd::Variable;
51use crate::nn::{Buffer, Module, Parameter};
52use crate::tensor::{Result, Tensor, TensorError};
53
54pub use flow::FlowBuilder;
55pub use loop_node::LoopBuilder;
56pub use map::MapBuilder;
57pub use trend::{Trend, TrendGroup};
58pub use profile::{Profile, NodeTiming, LevelTiming};
59pub use plot::format_duration;
60pub use router::{SoftmaxRouter, SigmoidRouter, FixedSelector, ArgmaxSelector};
61pub use halt::{ThresholdHalt, LearnedHalt};
62pub use reshape::Reshape;
63pub use state::StateAdd;
64pub use observe::Reduce;
65pub use tree::PathKind;
66pub use snapshot::ModelSnapshot;
67
68/// Merge operation for combining split branches.
69pub enum MergeOp {
70    /// Element-wise sum of all branches.
71    Add,
72    /// Element-wise mean of all branches.
73    Mean,
74}
75
76/// Pre-computed route from one node's output port to another node's input port.
77/// Replaces HashMap-based edge routing in forward_impl for O(1) access.
78#[derive(Clone)]
79struct Route {
80    from_port_idx: usize,
81    to_node_idx: usize,
82    to_port_idx: usize,
83}
84
85/// Pre-computed graph input → node input slot mapping.
86struct InputRoute {
87    node_idx: usize,
88    port_idx: usize,
89}
90
91/// Forward-reference state buffer. Persists across `forward()` calls.
92struct StateEntry {
93    writer_ni: usize,
94    value: Rc<RefCell<Option<Variable>>>,
95}
96
97/// An executable computation graph. Implements `Module` for composability.
98///
99/// Built via [`FlowBuilder`]. Supports parallel execution of independent nodes,
100/// observation of tagged outputs, profiling, and DOT/SVG visualization.
101///
102/// ```ignore
103/// let g = FlowBuilder::from(Linear::new(4, 8)?)
104///     .through(GELU)
105///     .through(Linear::new(8, 2)?)
106///     .build()?;
107///
108/// // Forward pass (graph implements Module)
109/// let y = g.forward(&x)?;
110///
111/// // Observation
112/// g.end_step();
113/// g.end_epoch();
114/// let loss_trend = g.trend("loss");
115///
116/// // Visualization
117/// let dot = g.dot();
118/// g.svg(Some("graph.svg"))?;
119/// ```
120pub struct Graph {
121    nodes: Vec<Node>,
122    node_index: HashMap<String, usize>,
123    levels: Vec<Vec<usize>>,
124    edges: Vec<Edge>,
125    #[allow(dead_code)] // kept for DOT/debug introspection
126    edges_from: HashMap<usize, Vec<usize>>,
127    inputs: Vec<ExposedPort>,
128    outputs: Vec<ExposedPort>,
129    order: Vec<usize>,
130    state: Vec<StateEntry>,
131    // State writer lookup: node_idx → [(state_entry_idx, output_port_idx)]
132    state_writers: HashMap<usize, Vec<(usize, usize)>>,
133    // Tag groups: group name → suffixed tag names
134    tag_groups: HashMap<String, Vec<String>>,
135    // Observation: tag mapping (immutable after build)
136    tag_names: HashMap<String, (usize, usize)>,           // tag name → (node_idx, port_idx)
137    tag_capture: HashMap<usize, Vec<(String, usize)>>,     // node_idx → [(tag_name, port_idx)]
138    // Observation: mutable state (RefCell/Cell for &self methods)
139    tagged_outputs: RefCell<HashMap<String, Variable>>,
140    batch_buffer: RefCell<HashMap<String, Vec<f64>>>,
141    epoch_history: RefCell<HashMap<String, Vec<f64>>>,
142    metric_order: RefCell<Vec<String>>,
143    flush_count: Cell<usize>,
144    // Profiling
145    profiling: Cell<bool>,
146    last_profile: RefCell<Option<profile::Profile>>,
147    timing_buffer: RefCell<HashMap<String, Vec<f64>>>,
148    timing_history: RefCell<HashMap<String, Vec<f64>>>,
149    // Flush timestamps (seconds since first forward — for ETA in write_log)
150    flush_times: RefCell<Vec<f64>>,
151    training_start: Cell<f64>,
152    // Step/epoch counters
153    step_count: Cell<usize>,
154    epoch_count: Cell<usize>,
155    // Identity: label + structural hash
156    label: Option<String>,
157    structural_hash_cache: OnceCell<String>,
158    // Graph tree: hierarchical composition
159    children: HashMap<String, usize>,
160    composed: Cell<bool>,
161    internal_tags: HashSet<String>,
162    // Pre-computed execution plan (built once, used every forward call)
163    routes_from: Vec<Vec<Route>>,
164    input_routes: Vec<InputRoute>,
165    output_node_idx: usize,
166    output_port_idx: usize,
167    node_input_count: Vec<usize>,
168    // Cached execution buffers (reused across forward calls, avoids re-allocation)
169    exec_slots: RefCell<Vec<Vec<Option<Variable>>>>,
170}
171
172impl Graph {
173    #[allow(clippy::too_many_arguments)]
174    pub(crate) fn build(
175        mut node_map: IndexMap<String, Node>,
176        edges: Vec<Edge>,
177        inputs: Vec<ExposedPort>,
178        outputs: Vec<ExposedPort>,
179        tags: HashMap<String, NodeRef>,
180        forward_refs: Vec<ForwardRefSpec>,
181        tag_groups: HashMap<String, Vec<String>>,
182        label: Option<String>,
183        mut internal_tags: HashSet<String>,
184        verbose: bool,
185    ) -> Result<Self> {
186        // Set up forward-reference state buffers and wire state read nodes
187        let mut state = Vec::with_capacity(forward_refs.len());
188        for fr in &forward_refs {
189            let value: Rc<RefCell<Option<Variable>>> = Rc::new(RefCell::new(None));
190            let reader_value = value.clone();
191
192            // Wire the state read node to return the buffer value
193            if let Some(node) = node_map.get_mut(&fr.reader_id) {
194                node.run = Box::new(move |_: &[Variable]| {
195                    match reader_value.borrow().as_ref() {
196                        Some(v) => Ok(vec![v.clone()]),
197                        None => Ok(vec![]), // empty = no state yet
198                    }
199                });
200            }
201
202            state.push(StateEntry {
203                writer_ni: 0, // resolved after node indexing
204                value,
205            });
206        }
207
208        // Convert to indexed storage
209        let mut nodes = Vec::with_capacity(node_map.len());
210        let mut node_index = HashMap::with_capacity(node_map.len());
211
212        for (_key, node) in node_map {
213            let idx = nodes.len();
214            node_index.insert(node.id.clone(), idx);
215            nodes.push(node);
216        }
217
218        // Validate edges
219        for edge in &edges {
220            if !node_index.contains_key(&edge.from_node) {
221                return Err(TensorError::new(&format!(
222                    "unknown source node: {}",
223                    edge.from_node
224                )));
225            }
226            if !node_index.contains_key(&edge.to_node) {
227                return Err(TensorError::new(&format!(
228                    "unknown target node: {}",
229                    edge.to_node
230                )));
231            }
232        }
233
234        // Build edges_from lookup
235        let mut edges_from: HashMap<usize, Vec<usize>> = HashMap::new();
236        for (ei, edge) in edges.iter().enumerate() {
237            let from_idx = node_index[&edge.from_node];
238            edges_from.entry(from_idx).or_default().push(ei);
239        }
240
241        // Topological levels (Kahn's algorithm)
242        let levels = topological_levels(&nodes, &node_index, &edges)?;
243        let order: Vec<usize> = levels.iter().flat_map(|l| l.iter().copied()).collect();
244
245        // Build tag capture indices for observation
246        let mut tag_names_map: HashMap<String, (usize, usize)> = HashMap::new();
247        let mut tag_capture: HashMap<usize, Vec<(String, usize)>> = HashMap::new();
248        for (name, node_ref) in &tags {
249            if let Some(&ni) = node_index.get(&node_ref.node_id) {
250                let port_idx = nodes[ni]
251                    .output_ports
252                    .iter()
253                    .position(|p| p == &node_ref.port)
254                    .unwrap_or(0);
255                tag_names_map.insert(name.clone(), (ni, port_idx));
256                tag_capture
257                    .entry(ni)
258                    .or_default()
259                    .push((name.clone(), port_idx));
260            }
261        }
262
263        // Detect child subgraphs: labeled Graphs become tree children
264        let mut children: HashMap<String, usize> = HashMap::new();
265        for (idx, node) in nodes.iter().enumerate() {
266            if let Some(ref module) = node.module {
267                if let Some(child_graph) = module.as_graph() {
268                    if let Some(child_label) = child_graph.label() {
269                        if child_label.contains('.') {
270                            return Err(TensorError::new(&format!(
271                                "child graph label {:?} contains a dot — \
272                                 dots are reserved for path separators",
273                                child_label
274                            )));
275                        }
276                        if children.contains_key(child_label) {
277                            return Err(TensorError::new(&format!(
278                                "duplicate child graph label {:?} at the same tree level",
279                                child_label
280                            )));
281                        }
282                        // Validate: label doesn't shadow a tag on a different node
283                        if let Some(&(tag_ni, _)) = tag_names_map.get(child_label) {
284                            if tag_ni != idx {
285                                return Err(TensorError::new(&format!(
286                                    "child graph label {:?} collides with a tag \
287                                     on a different node",
288                                    child_label
289                                )));
290                            }
291                        }
292                        children.insert(child_label.to_string(), idx);
293                        child_graph.composed.set(true);
294                    }
295                    // Unlabeled graphs: not registered, no tree features, no error
296                }
297            }
298        }
299
300        // Auto-internal inference: underscore-prefixed tags
301        for name in tag_names_map.keys() {
302            if name.starts_with('_') {
303                internal_tags.insert(name.clone());
304            }
305        }
306
307        // Build state writer lookup: node_idx → [(state_entry_idx, port_idx)]
308        // Also resolve writer_ni on each state entry for DOT rendering.
309        let mut state_writers: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
310        for (si, fr) in forward_refs.iter().enumerate() {
311            if let Some(&ni) = node_index.get(&fr.writer_id) {
312                state[si].writer_ni = ni;
313                let port_idx = nodes[ni]
314                    .output_ports
315                    .iter()
316                    .position(|p| p == &fr.writer_port)
317                    .unwrap_or(0);
318                state_writers.entry(ni).or_default().push((si, port_idx));
319            }
320        }
321
322        // Pre-compute routing table: flat Vec lookups replace HashMap edge routing
323        let n = nodes.len();
324        let mut routes_from: Vec<Vec<Route>> = vec![Vec::new(); n];
325        for edge in &edges {
326            let from_ni = node_index[&edge.from_node];
327            let to_ni = node_index[&edge.to_node];
328            let from_port_idx = nodes[from_ni]
329                .output_ports
330                .iter()
331                .position(|p| p == &edge.from_port)
332                .unwrap_or(0);
333            let to_port_idx = nodes[to_ni]
334                .input_ports
335                .iter()
336                .position(|p| p == &edge.to_port)
337                .unwrap_or(0);
338            routes_from[from_ni].push(Route {
339                from_port_idx,
340                to_node_idx: to_ni,
341                to_port_idx,
342            });
343        }
344
345        // Pre-compute graph input → slot mapping
346        let input_routes: Vec<InputRoute> = inputs
347            .iter()
348            .map(|ep| {
349                let ni = node_index[&ep.node_id];
350                let port_idx = nodes[ni]
351                    .input_ports
352                    .iter()
353                    .position(|p| p == &ep.port)
354                    .unwrap_or(0);
355                InputRoute {
356                    node_idx: ni,
357                    port_idx,
358                }
359            })
360            .collect();
361
362        // Pre-compute output location
363        let output_node_idx = node_index[&outputs[0].node_id];
364        let output_port_idx = nodes[output_node_idx]
365            .output_ports
366            .iter()
367            .position(|p| p == &outputs[0].port)
368            .unwrap_or(0);
369
370        // Pre-compute input port counts and allocate execution buffers
371        let node_input_count: Vec<usize> = nodes.iter().map(|nd| nd.input_ports.len()).collect();
372        let exec_slots = RefCell::new(
373            node_input_count.iter().map(|&c| vec![None; c]).collect(),
374        );
375
376        let graph = Ok(Graph {
377            nodes,
378            node_index,
379            levels,
380            edges,
381            edges_from,
382            inputs,
383            outputs,
384            order,
385            state,
386            state_writers,
387            tag_groups,
388            tag_names: tag_names_map,
389            tag_capture,
390            tagged_outputs: RefCell::new(HashMap::new()),
391            batch_buffer: RefCell::new(HashMap::new()),
392            epoch_history: RefCell::new(HashMap::new()),
393            metric_order: RefCell::new(Vec::new()),
394            flush_count: Cell::new(0),
395            profiling: Cell::new(false),
396            last_profile: RefCell::new(None),
397            timing_buffer: RefCell::new(HashMap::new()),
398            timing_history: RefCell::new(HashMap::new()),
399            flush_times: RefCell::new(Vec::new()),
400            training_start: Cell::new(0.0),
401            step_count: Cell::new(0),
402            epoch_count: Cell::new(0),
403            label,
404            structural_hash_cache: OnceCell::new(),
405            children,
406            composed: Cell::new(false),
407            internal_tags,
408            routes_from,
409            input_routes,
410            output_node_idx,
411            output_port_idx,
412            node_input_count,
413            exec_slots,
414        });
415
416        if verbose {
417            if let Ok(ref g) = graph {
418                eprintln!("{}", g.tree_summary());
419            }
420        }
421
422        graph
423    }
424
425    fn forward_impl(&self, graph_inputs: &[Variable]) -> Result<Variable> {
426        if graph_inputs.len() != self.inputs.len() {
427            return Err(TensorError::new(&format!(
428                "expected {} inputs, got {}",
429                self.inputs.len(),
430                graph_inputs.len()
431            )));
432        }
433
434        // Record training start on first forward (for ETA).
435        if self.training_start.get() == 0.0 {
436            self.training_start.set(instant_secs());
437        }
438
439        let is_profiling = self.profiling.get();
440        let forward_start = if is_profiling { Some(Instant::now()) } else { None };
441        let mut prof_nodes: Vec<profile::NodeTiming> = Vec::new();
442        let mut prof_levels: Vec<profile::LevelTiming> = Vec::new();
443
444        // Build reverse tag lookup for profiling: node_idx → first tag name
445        let tags_by_node: HashMap<usize, String> = if is_profiling {
446            let mut m = HashMap::new();
447            for (name, &(ni, _)) in &self.tag_names {
448                m.entry(ni).or_insert_with(|| name.clone());
449            }
450            m
451        } else {
452            HashMap::new()
453        };
454
455        let has_tags = !self.tag_capture.is_empty();
456
457        // Reuse cached execution buffers (Vec-indexed, no HashMap overhead)
458        let mut slots = self.exec_slots.borrow_mut();
459
460        // Clear previous values (drops old Variables, reuses allocations)
461        for node_slots in slots.iter_mut() {
462            for slot in node_slots.iter_mut() {
463                *slot = None;
464            }
465        }
466
467        // Clear tagged outputs
468        if has_tags {
469            self.tagged_outputs.borrow_mut().clear();
470        }
471
472        // Route graph inputs via pre-computed index mapping
473        for (i, route) in self.input_routes.iter().enumerate() {
474            slots[route.node_idx][route.port_idx] = Some(graph_inputs[i].clone());
475        }
476
477        // Will hold the output node's results until we can extract the final value
478        let mut final_output: Option<Vec<Variable>> = None;
479
480        // Execute levels sequentially
481        for (level_idx, level) in self.levels.iter().enumerate() {
482            let level_start = if is_profiling { Some(Instant::now()) } else { None };
483            let mut level_sum_ns: u64 = 0;
484
485            for &ni in level {
486                let node = &self.nodes[ni];
487                let input_count = self.node_input_count[ni];
488
489                // Collect inputs from pre-indexed slots (no HashMap lookups)
490                let inputs: Vec<Variable> = (0..input_count)
491                    .map(|i| {
492                        match slots[ni][i].as_ref() {
493                            Some(v) => Ok(v.clone()),
494                            None if i > 0 => {
495                                // Zero fill for unconnected ref ports (forward refs)
496                                let first = slots[ni][0].as_ref().ok_or_else(|| {
497                                    TensorError::new(&format!(
498                                        "node '{}': ref port {} has no data and primary input \
499                                         is also missing — check that all inputs are connected",
500                                        node.id, i
501                                    ))
502                                })?;
503                                Ok(Variable::new(
504                                    Tensor::zeros_like(&first.data())?,
505                                    false,
506                                ))
507                            }
508                            _ => Err(TensorError::new(&format!(
509                                "node '{}': missing primary input (port {}) — check that all \
510                                 inputs to this node are connected in the graph builder",
511                                node.id, i
512                            ))),
513                        }
514                    })
515                    .collect::<Result<Vec<Variable>>>()?;
516
517                // Release input slots early (frees Rc references)
518                for slot in slots[ni].iter_mut() {
519                    *slot = None;
520                }
521
522                // Execute node (with optional per-node timing)
523                let node_start = if is_profiling { Some(Instant::now()) } else { None };
524                let node_outputs = (node.run)(&inputs)?;
525                if is_profiling {
526                    let elapsed = node_start.unwrap().elapsed();
527                    level_sum_ns += elapsed.as_nanos() as u64;
528                    prof_nodes.push(profile::NodeTiming {
529                        id: node.id.clone(),
530                        tag: tags_by_node.get(&ni).cloned().unwrap_or_default(),
531                        duration: elapsed,
532                        level: level_idx,
533                    });
534                }
535
536                // Route outputs via pre-computed routing table (no HashMap, no String ops)
537                for route in &self.routes_from[ni] {
538                    let value = if route.from_port_idx < node_outputs.len() {
539                        Some(node_outputs[route.from_port_idx].clone())
540                    } else {
541                        None
542                    };
543                    slots[route.to_node_idx][route.to_port_idx] = value;
544                }
545
546                // Capture state: if this node is a state writer, store its output
547                if let Some(writers) = self.state_writers.get(&ni) {
548                    for &(si, port_idx) in writers {
549                        if port_idx < node_outputs.len() {
550                            *self.state[si].value.borrow_mut() =
551                                Some(node_outputs[port_idx].clone());
552                        }
553                    }
554                }
555
556                // Capture tagged outputs for observation
557                if has_tags {
558                    if let Some(captures) = self.tag_capture.get(&ni) {
559                        let mut tagged = self.tagged_outputs.borrow_mut();
560                        for (tag_name, port_idx) in captures {
561                            if *port_idx < node_outputs.len() {
562                                tagged.insert(
563                                    tag_name.clone(),
564                                    node_outputs[*port_idx].clone(),
565                                );
566                            }
567                        }
568                    }
569                }
570
571                // Keep output node's results; all others drop here (early release)
572                if ni == self.output_node_idx {
573                    final_output = Some(node_outputs);
574                }
575            }
576
577            // Record level timing
578            if is_profiling {
579                prof_levels.push(profile::LevelTiming {
580                    index: level_idx,
581                    wall_clock: level_start.unwrap().elapsed(),
582                    sum_nodes: std::time::Duration::from_nanos(level_sum_ns),
583                    num_nodes: level.len(),
584                });
585            }
586        }
587
588        // Drop the borrow before storing profile (which also borrows RefCells)
589        drop(slots);
590
591        // Store profile
592        if is_profiling {
593            *self.last_profile.borrow_mut() = Some(profile::Profile {
594                total: forward_start.unwrap().elapsed(),
595                levels: prof_levels,
596                nodes: prof_nodes,
597            });
598        }
599
600        // Extract graph output
601        final_output
602            .and_then(|o| o.into_iter().nth(self.output_port_idx))
603            .ok_or_else(|| TensorError::new("graph produced no output"))
604    }
605}
606
607impl Graph {
608    /// Clear all forward-reference state buffers to None.
609    /// Call when starting inference on a new sequence.
610    pub fn reset_state(&self) {
611        for entry in &self.state {
612            *entry.value.borrow_mut() = None;
613        }
614    }
615
616    /// Break gradient chain on forward-reference state buffers and module state.
617    /// Call between training steps to prevent unbounded graph growth.
618    pub fn detach_state(&self) {
619        // Detach graph-level state buffers (forward references).
620        for entry in &self.state {
621            let mut val = entry.value.borrow_mut();
622            if let Some(ref v) = *val {
623                *val = Some(v.detach());
624            }
625        }
626        // Detach tagged outputs — these hold Variables from the forward
627        // pass whose grad_fn chains reference the C++ autograd graph.
628        // Without this, the Node objects persist until the next forward
629        // pass replaces tagged_outputs.
630        {
631            let mut tagged = self.tagged_outputs.borrow_mut();
632            for var in tagged.values_mut() {
633                *var = var.detach();
634            }
635        }
636        // Propagate detach to modules that hold internal state.
637        for node in &self.nodes {
638            if let Some(ref module) = node.module {
639                module.detach_state();
640            }
641        }
642    }
643
644    /// Returns true if this graph has forward-reference state.
645    pub fn has_state(&self) -> bool {
646        !self.state.is_empty()
647    }
648
649    /// End-of-step housekeeping: detach state (cut gradient chain but
650    /// preserve values for the next forward), collect timings,
651    /// increment step counter.
652    ///
653    /// For recurrent models this implements truncated BPTT — state carries
654    /// over between steps but gradients don't flow across step boundaries.
655    /// Call [`end_sequence`](Self::end_sequence) to fully wipe state
656    /// when starting a new independent sequence.
657    ///
658    /// ```ignore
659    /// for token in sequence {
660    ///     let y = graph.forward(&token)?;
661    ///     // ... backward, optimize ...
662    ///     graph.end_step();       // keep state, cut gradients
663    /// }
664    /// graph.end_sequence();       // wipe state for next sequence
665    /// ```
666    pub fn end_step(&self) {
667        self.detach_state();
668        if self.profiling.get() {
669            self.collect_timings(&[]);
670        }
671        self.step_count.set(self.step_count.get() + 1);
672    }
673
674    /// End-of-sequence housekeeping: fully reset state buffers to None.
675    /// Call between independent sequences so the model starts fresh.
676    ///
677    /// For non-recurrent graphs (no forward refs) this is a no-op.
678    pub fn end_sequence(&self) {
679        self.reset_state();
680    }
681
682    /// End-of-epoch housekeeping: flush all observation and timing buffers,
683    /// increment epoch counter.
684    pub fn end_epoch(&self) {
685        self.flush(&[]);
686        if self.profiling.get() {
687            self.flush_timings(&[]);
688        }
689        self.epoch_count.set(self.epoch_count.get() + 1);
690    }
691
692    /// Number of completed training steps.
693    pub fn step_count(&self) -> usize {
694        self.step_count.get()
695    }
696
697    /// Number of completed training epochs.
698    pub fn epoch_count(&self) -> usize {
699        self.epoch_count.get()
700    }
701
702    /// Get member tags of a tag group, or None if not registered.
703    pub fn tag_group(&self, name: &str) -> Option<&[String]> {
704        self.tag_groups.get(name).map(|v| v.as_slice())
705    }
706
707    /// Forward with multiple inputs (for graphs with Input ports).
708    /// Inputs are in declaration order: From entry first, then each Input.
709    pub fn forward_multi(&self, inputs: &[Variable]) -> Result<Variable> {
710        self.forward_impl(inputs)
711    }
712
713    /// Move all parameters, state buffers, and module buffers to a device.
714    pub fn set_device(&self, device: crate::tensor::Device) {
715        // Move parameters — detach first so the moved tensor is a fresh leaf,
716        // not a non-leaf with CopyBackward from native autograd.
717        for p in self.parameters() {
718            if p.variable.data().device() != device
719                && let Ok(t) = p.variable.data().detach()
720                    .and_then(|d| d.to_device(device))
721            {
722                p.variable.set_data(t);
723            }
724        }
725        // Move state buffers
726        for entry in &self.state {
727            let mut val = entry.value.borrow_mut();
728            if let Some(ref v) = *val
729                && v.data().device() != device
730                && let Ok(t) = v.data().to_device(device)
731            {
732                *val = Some(Variable::new(t, false));
733            }
734        }
735        // Walk modules for move_to_device (BatchNorm running stats, etc.)
736        let mut visited = HashSet::new();
737        for &ni in &self.order {
738            if let Some(ref module) = self.nodes[ni].module {
739                crate::nn::walk_modules_visited(
740                    module.as_ref(),
741                    &mut visited,
742                    &mut |m: &dyn crate::nn::Module| m.move_to_device(device),
743                );
744            }
745        }
746    }
747
748    /// Return parameters with qualified names: `"prefix/param_name"`.
749    ///
750    /// The prefix is the tag name if the node is tagged, otherwise the node ID
751    /// (e.g. `"linear_1"`). When a node has multiple parameters with the same
752    /// name, suffixes `_0`, `_1`, ... are appended to disambiguate.
753    pub fn named_parameters(&self) -> Vec<(String, Parameter)> {
754        // Build reverse map: node_idx → tag name
755        let mut idx_to_tag: HashMap<usize, String> = HashMap::new();
756        for (tag, &(ni, _)) in &self.tag_names {
757            // First tag wins (deterministic because we only need one prefix)
758            idx_to_tag.entry(ni).or_insert_with(|| tag.clone());
759        }
760
761        let mut result = Vec::new();
762        let mut seen = HashSet::new();
763
764        for &ni in &self.order {
765            if let Some(ref module) = self.nodes[ni].module {
766                let prefix = idx_to_tag.get(&ni)
767                    .cloned()
768                    .unwrap_or_else(|| self.nodes[ni].id.clone());
769
770                let params = module.parameters();
771                // Check for duplicate param names within this node
772                let mut name_counts: HashMap<String, usize> = HashMap::new();
773                for p in &params {
774                    *name_counts.entry(p.name.clone()).or_insert(0) += 1;
775                }
776
777                let mut name_idx: HashMap<String, usize> = HashMap::new();
778                for p in params {
779                    let ptr = Rc::as_ptr(&p.variable.inner) as usize;
780                    if !seen.insert(ptr) {
781                        continue;
782                    }
783
784                    let qualified = if name_counts[&p.name] > 1 {
785                        let idx = name_idx.entry(p.name.clone()).or_insert(0);
786                        let q = format!("{}/{}_{}", prefix, p.name, idx);
787                        *idx += 1;
788                        q
789                    } else {
790                        format!("{}/{}", prefix, p.name)
791                    };
792
793                    result.push((qualified, p));
794                }
795            }
796        }
797
798        result
799    }
800
801    /// Return buffers with qualified names, using the same prefix logic
802    /// as `named_parameters()`.
803    pub fn named_buffers(&self) -> Vec<(String, Buffer)> {
804        let mut idx_to_tag: HashMap<usize, String> = HashMap::new();
805        for (tag, &(ni, _)) in &self.tag_names {
806            idx_to_tag.entry(ni).or_insert_with(|| tag.clone());
807        }
808
809        let mut result = Vec::new();
810        let mut seen = HashSet::new();
811
812        for &ni in &self.order {
813            if let Some(ref module) = self.nodes[ni].module {
814                let prefix = idx_to_tag.get(&ni)
815                    .cloned()
816                    .unwrap_or_else(|| self.nodes[ni].id.clone());
817
818                let bufs = module.buffers();
819                let mut name_counts: HashMap<String, usize> = HashMap::new();
820                for b in &bufs {
821                    *name_counts.entry(b.name.clone()).or_insert(0) += 1;
822                }
823
824                let mut name_idx: HashMap<String, usize> = HashMap::new();
825                for b in bufs {
826                    let ptr = Rc::as_ptr(&b.inner) as usize;
827                    if !seen.insert(ptr) {
828                        continue;
829                    }
830
831                    let qualified = if name_counts[&b.name] > 1 {
832                        let idx = name_idx.entry(b.name.clone()).or_insert(0);
833                        let q = format!("{}/{}_{}", prefix, b.name, idx);
834                        *idx += 1;
835                        q
836                    } else {
837                        format!("{}/{}", prefix, b.name)
838                    };
839
840                    result.push((qualified, b));
841                }
842            }
843        }
844
845        result
846    }
847
848    /// Human-readable label set via `FlowBuilder::label()`.
849    pub fn label(&self) -> Option<&str> {
850        self.label.as_deref()
851    }
852
853    /// Full 64-character hex structural hash (computed lazily, cached).
854    pub fn structural_hash(&self) -> &str {
855        self.structural_hash_cache.get_or_init(|| self.compute_structural_hash())
856    }
857
858    /// First 8 characters of the structural hash.
859    pub fn short_hash(&self) -> &str {
860        &self.structural_hash()[..8]
861    }
862
863    /// Save all parameters and buffers to a checkpoint file.
864    ///
865    /// Embeds the structural hash for architecture validation on load.
866    /// Supports `.gz` extension for gzip compression.
867    pub fn save_checkpoint(&self, path: &str) -> Result<()> {
868        let params = self.named_parameters();
869        let buffers = self.named_buffers();
870        let hash = self.structural_hash();
871        crate::nn::save_checkpoint_file(path, &params, &buffers, Some(hash))
872    }
873
874    /// Load parameters and buffers from a checkpoint file.
875    ///
876    /// Validates the structural hash against this graph's architecture.
877    /// Returns a [`LoadReport`](crate::nn::LoadReport) describing what was
878    /// loaded, skipped, or missing.
879    pub fn load_checkpoint(&self, path: &str) -> Result<crate::nn::LoadReport> {
880        let params = self.named_parameters();
881        let buffers = self.named_buffers();
882        let hash = self.structural_hash();
883        crate::nn::load_checkpoint_file(path, &params, &buffers, Some(hash))
884    }
885
886    fn compute_structural_hash(&self) -> String {
887        let mut hasher = Sha256::new();
888
889        // 1. Nodes in topological order
890        for &ni in &self.order {
891            let node = &self.nodes[ni];
892            hasher.update(node.id.as_bytes());
893            hasher.update(b"\0");
894
895            if let Some(ref module) = node.module {
896                hasher.update(module.name().as_bytes());
897                hasher.update(b"\0");
898
899                // Sorted parameters
900                let mut params: Vec<_> = module.parameters().into_iter()
901                    .map(|p| (p.name.clone(), p.variable.shape()))
902                    .collect();
903                params.sort_by(|a, b| a.0.cmp(&b.0));
904                for (name, shape) in &params {
905                    hasher.update(b"P");
906                    hasher.update(name.as_bytes());
907                    hasher.update(b"\0");
908                    for &dim in shape {
909                        hasher.update(dim.to_le_bytes());
910                    }
911                }
912
913                // Sorted buffers
914                let mut bufs: Vec<_> = module.buffers().into_iter()
915                    .map(|b| (b.name.clone(), b.shape()))
916                    .collect();
917                bufs.sort_by(|a, b| a.0.cmp(&b.0));
918                for (name, shape) in &bufs {
919                    hasher.update(b"B");
920                    hasher.update(name.as_bytes());
921                    hasher.update(b"\0");
922                    for &dim in shape {
923                        hasher.update(dim.to_le_bytes());
924                    }
925                }
926
927                // Nested graph hash
928                if let Some(nested_hash) = module.structural_hash() {
929                    hasher.update(b"G");
930                    hasher.update(nested_hash.as_bytes());
931                }
932            }
933        }
934
935        // 2. Edges
936        hasher.update(b"EDGES");
937        for edge in &self.edges {
938            hasher.update(edge.from_node.as_bytes());
939            hasher.update(b"\0");
940            hasher.update(edge.from_port.as_bytes());
941            hasher.update(b"\0");
942            hasher.update(edge.to_node.as_bytes());
943            hasher.update(b"\0");
944            hasher.update(edge.to_port.as_bytes());
945            hasher.update(b"\0");
946        }
947
948        // 3. Tags (sorted)
949        hasher.update(b"TAGS");
950        let mut tags: Vec<_> = self.tag_names.iter().collect();
951        tags.sort_by(|a, b| a.0.cmp(b.0));
952        for (name, (node_idx, port_idx)) in &tags {
953            hasher.update(name.as_bytes());
954            hasher.update(b"\0");
955            hasher.update((*node_idx as u64).to_le_bytes());
956            hasher.update((*port_idx as u64).to_le_bytes());
957        }
958
959        // 4. Input/output ports
960        hasher.update(b"INPUTS");
961        for port in &self.inputs {
962            hasher.update(port.name.as_bytes());
963            hasher.update(b"\0");
964            hasher.update(port.node_id.as_bytes());
965            hasher.update(b"\0");
966            hasher.update(port.port.as_bytes());
967            hasher.update(b"\0");
968        }
969        hasher.update(b"OUTPUTS");
970        for port in &self.outputs {
971            hasher.update(port.name.as_bytes());
972            hasher.update(b"\0");
973            hasher.update(port.node_id.as_bytes());
974            hasher.update(b"\0");
975            hasher.update(port.port.as_bytes());
976            hasher.update(b"\0");
977        }
978
979        hasher.finalize().iter().map(|b| format!("{b:02x}")).collect()
980    }
981}
982
983impl Module for Graph {
984    fn name(&self) -> &str { "graph" }
985
986    fn as_graph(&self) -> Option<&Graph> { Some(self) }
987
988    fn structural_hash(&self) -> Option<String> {
989        Some(self.structural_hash().to_string())
990    }
991
992    fn forward(&self, input: &Variable) -> Result<Variable> {
993        self.forward_impl(std::slice::from_ref(input))
994    }
995
996    fn parameters(&self) -> Vec<Parameter> {
997        let mut params = Vec::new();
998        let mut seen = HashSet::new();
999
1000        for &ni in &self.order {
1001            if let Some(ref module) = self.nodes[ni].module {
1002                for p in module.parameters() {
1003                    let ptr = Rc::as_ptr(&p.variable.inner) as usize;
1004                    if seen.insert(ptr) {
1005                        params.push(p);
1006                    }
1007                }
1008            }
1009        }
1010
1011        params
1012    }
1013
1014    fn set_training(&self, training: bool) {
1015        let mut visited = HashSet::new();
1016        for &ni in &self.order {
1017            if let Some(ref module) = self.nodes[ni].module {
1018                crate::nn::walk_modules_visited(
1019                    module.as_ref(),
1020                    &mut visited,
1021                    &mut |m: &dyn crate::nn::Module| m.set_training(training),
1022                );
1023            }
1024        }
1025    }
1026
1027    fn move_to_device(&self, device: crate::tensor::Device) {
1028        self.set_device(device);
1029    }
1030}
1031
1032/// Current time as seconds since epoch (monotonic approximation for ETA).
1033fn instant_secs() -> f64 {
1034    use std::time::SystemTime;
1035    SystemTime::now()
1036        .duration_since(SystemTime::UNIX_EPOCH)
1037        .unwrap_or_default()
1038        .as_secs_f64()
1039}
1040
1041/// Kahn's algorithm with level grouping for parallel execution.
1042fn topological_levels(
1043    nodes: &[Node],
1044    node_index: &HashMap<String, usize>,
1045    edges: &[Edge],
1046) -> Result<Vec<Vec<usize>>> {
1047    let n = nodes.len();
1048
1049    // Build unique dependency sets (node-level, not edge-level).
1050    // dependents uses BTreeSet so iteration follows node index order,
1051    // making the topological sort deterministic across runs.
1052    let mut deps: Vec<HashSet<usize>> = vec![HashSet::new(); n];
1053    let mut dependents: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
1054
1055    for edge in edges {
1056        let from_ni = node_index[&edge.from_node];
1057        let to_ni = node_index[&edge.to_node];
1058        deps[to_ni].insert(from_ni);
1059        dependents[from_ni].insert(to_ni);
1060    }
1061
1062    let mut in_degree: Vec<usize> = deps.iter().map(|d| d.len()).collect();
1063
1064    // Seed with zero in-degree nodes
1065    let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
1066    let mut levels = Vec::new();
1067    let mut visited = 0;
1068
1069    while !queue.is_empty() {
1070        levels.push(queue.clone());
1071        visited += queue.len();
1072
1073        let mut next_queue = Vec::new();
1074        for &ni in &queue {
1075            for &dep in &dependents[ni] {
1076                in_degree[dep] -= 1;
1077                if in_degree[dep] == 0 {
1078                    next_queue.push(dep);
1079                }
1080            }
1081        }
1082        queue = next_queue;
1083    }
1084
1085    if visited != n {
1086        return Err(TensorError::new("cycle detected in graph"));
1087    }
1088
1089    Ok(levels)
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094    use super::*;
1095    use crate::autograd::Variable;
1096    use crate::nn::{Linear, NamedInputModule, ReLU, Sigmoid, mse_loss, Optimizer, SGD};
1097    use crate::tensor::Tensor;
1098    use std::collections::HashMap;
1099
1100    fn from_f32(data: &[f32], shape: &[i64]) -> Tensor {
1101        Tensor::from_f32(data, shape, crate::tensor::test_device()).unwrap()
1102    }
1103
1104    // --- Helper modules for testing ---
1105
1106    /// Doubles the input: forward(x) = 2*x
1107    struct Doubler;
1108    impl Module for Doubler {
1109        fn forward(&self, input: &Variable) -> Result<Variable> {
1110            input.add(input)
1111        }
1112    }
1113
1114    /// Adds a learnable bias at each step (for gradient accumulation testing).
1115    struct BiasStep {
1116        bias: Parameter,
1117    }
1118    impl BiasStep {
1119        fn new(size: i64) -> Result<Self> {
1120            let data = Tensor::zeros(&[size], crate::tensor::test_opts())?;
1121            let var = Variable::new(data, true);
1122            Ok(BiasStep {
1123                bias: Parameter {
1124                    variable: var,
1125                    name: "loop_bias".to_string(),
1126                },
1127            })
1128        }
1129    }
1130    impl Module for BiasStep {
1131        fn forward(&self, input: &Variable) -> Result<Variable> {
1132            input.add(&self.bias.variable)
1133        }
1134        fn parameters(&self) -> Vec<Parameter> {
1135            vec![self.bias.clone()]
1136        }
1137    }
1138
1139    /// Module that adds a tagged ref to the stream (for Using tests).
1140    struct AddRefModule;
1141    impl Module for AddRefModule {
1142        fn forward(&self, input: &Variable) -> Result<Variable> {
1143            Ok(input.clone())
1144        }
1145        fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
1146    }
1147    impl NamedInputModule for AddRefModule {
1148        fn forward_named(
1149            &self,
1150            input: &Variable,
1151            refs: &HashMap<String, Variable>,
1152        ) -> Result<Variable> {
1153            if let Some(ctx) = refs.get("ctx") {
1154                input.add(ctx)
1155            } else {
1156                Ok(input.clone())
1157            }
1158        }
1159    }
1160
1161    // --- Core graph tests (from before) ---
1162
1163    #[test]
1164    fn test_single_module() {
1165        let l = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
1166        let graph = FlowBuilder::from(l).build().unwrap();
1167
1168        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1169        let y = graph.forward(&x).unwrap();
1170        assert_eq!(y.shape(), vec![1, 2]);
1171    }
1172
1173    #[test]
1174    fn test_linear_chain() {
1175        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1176            .through(ReLU::new())
1177            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1178            .build()
1179            .unwrap();
1180
1181        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1182        let y = graph.forward(&x).unwrap();
1183        assert_eq!(y.shape(), vec![1, 2]);
1184    }
1185
1186    #[test]
1187    fn test_also_residual() {
1188        let l1 = Linear::on_device(3, 3, crate::tensor::test_device()).unwrap();
1189        l1.weight.variable.set_data(from_f32(
1190            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1191            &[3, 3],
1192        ));
1193        l1.bias
1194            .as_ref()
1195            .unwrap()
1196            .variable
1197            .set_data(from_f32(&[0.0, 0.0, 0.0], &[3]));
1198
1199        let l2 = Linear::on_device(3, 3, crate::tensor::test_device()).unwrap();
1200        l2.weight.variable.set_data(from_f32(
1201            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1202            &[3, 3],
1203        ));
1204        l2.bias
1205            .as_ref()
1206            .unwrap()
1207            .variable
1208            .set_data(from_f32(&[1.0, 1.0, 1.0], &[3]));
1209
1210        // l1(x) + l2(l1(x)) = x + (x + 1) = 2x + 1
1211        let graph = FlowBuilder::from(l1).also(l2).build().unwrap();
1212
1213        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1214        let y = graph.forward(&x).unwrap();
1215        let data = y.data().to_f32_vec().unwrap();
1216
1217        assert!((data[0] - 3.0).abs() < 1e-5);
1218        assert!((data[1] - 5.0).abs() < 1e-5);
1219        assert!((data[2] - 7.0).abs() < 1e-5);
1220    }
1221
1222    // --- Fork tests ---
1223
1224    #[test]
1225    fn test_fork_basic() {
1226        // Fork runs a side module but main stream continues unchanged.
1227        // identity(x) → fork(linear) tagged "side" → through(ReLU)
1228        // Main stream: ReLU(identity(x)) = ReLU(x)
1229        // Side output: linear(x) accessible via tagged("side")
1230        let l = Linear::on_device(2, 3, crate::tensor::test_device()).unwrap();
1231
1232        let graph = FlowBuilder::from(Identity)
1233            .fork(l)
1234            .tag("side")
1235            .through(ReLU::new())
1236            .build()
1237            .unwrap();
1238
1239        let x = Variable::new(from_f32(&[1.0, -2.0], &[1, 2]), false);
1240        let y = graph.forward(&x).unwrap();
1241
1242        // Main stream went through ReLU(identity(x)) → shape [1, 2]
1243        assert_eq!(y.shape(), vec![1, 2]);
1244        let data = y.data().to_f32_vec().unwrap();
1245        assert!((data[0] - 1.0).abs() < 1e-5);
1246        assert!((data[1] - 0.0).abs() < 1e-5); // ReLU(-2) = 0
1247
1248        // Side output is linear(x) → shape [1, 3]
1249        let side = graph.tagged("side").unwrap();
1250        assert_eq!(side.shape(), vec![1, 3]);
1251    }
1252
1253    #[test]
1254    fn test_fork_multiple() {
1255        // Two forks from the same stream: letter_head and case_head pattern
1256        let head_a = Linear::on_device(4, 3, crate::tensor::test_device()).unwrap();
1257        let head_b = Linear::on_device(4, 2, crate::tensor::test_device()).unwrap();
1258
1259        let graph = FlowBuilder::from(Linear::on_device(2, 4, crate::tensor::test_device()).unwrap())
1260            .tag("latent")
1261            .fork(head_a)
1262            .tag("head_a")
1263            .fork(head_b)
1264            .tag("head_b")
1265            .build()
1266            .unwrap();
1267
1268        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1269        let y = graph.forward(&x).unwrap();
1270
1271        // Main stream is still the linear(2→4) output
1272        assert_eq!(y.shape(), vec![1, 4]);
1273
1274        // Both forks produced their outputs
1275        let a = graph.tagged("head_a").unwrap();
1276        assert_eq!(a.shape(), vec![1, 3]);
1277        let b = graph.tagged("head_b").unwrap();
1278        assert_eq!(b.shape(), vec![1, 2]);
1279    }
1280
1281    #[test]
1282    fn test_fork_backward() {
1283        // Gradients flow through both forks and the main stream
1284        let graph = FlowBuilder::from(Linear::on_device(2, 4, crate::tensor::test_device()).unwrap())
1285            .fork(Linear::on_device(4, 3, crate::tensor::test_device()).unwrap())
1286            .tag("side")
1287            .through(Linear::on_device(4, 1, crate::tensor::test_device()).unwrap())
1288            .build()
1289            .unwrap();
1290
1291        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1292        let y = graph.forward(&x).unwrap();
1293
1294        // Loss from main stream + side output
1295        let side = graph.tagged("side").unwrap();
1296        let loss = y.sum().unwrap().add(&side.sum().unwrap()).unwrap();
1297        loss.backward().unwrap();
1298
1299        assert!(x.grad().is_some(), "input should have gradient");
1300        for p in graph.parameters() {
1301            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1302        }
1303    }
1304
1305    // --- Split/Merge tests ---
1306
1307    #[test]
1308    fn test_split_merge_add() {
1309        let graph = FlowBuilder::from(Linear::on_device(3, 3, crate::tensor::test_device()).unwrap())
1310            .split(vec![Box::new(ReLU::new()), Box::new(Sigmoid::new())])
1311            .merge(MergeOp::Add)
1312            .build()
1313            .unwrap();
1314
1315        let x = Variable::new(from_f32(&[1.0, -1.0, 2.0], &[1, 3]), false);
1316        let y = graph.forward(&x).unwrap();
1317        assert_eq!(y.shape(), vec![1, 3]);
1318    }
1319
1320    #[test]
1321    fn test_split_merge_mean() {
1322        let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1323        l.weight
1324            .variable
1325            .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1326        l.bias
1327            .as_ref()
1328            .unwrap()
1329            .variable
1330            .set_data(from_f32(&[0.0, 0.0], &[2]));
1331
1332        let b1 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1333        b1.weight
1334            .variable
1335            .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1336        b1.bias
1337            .as_ref()
1338            .unwrap()
1339            .variable
1340            .set_data(from_f32(&[0.0, 0.0], &[2]));
1341        let b2 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1342        b2.weight
1343            .variable
1344            .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1345        b2.bias
1346            .as_ref()
1347            .unwrap()
1348            .variable
1349            .set_data(from_f32(&[0.0, 0.0], &[2]));
1350
1351        let graph = FlowBuilder::from(l)
1352            .split(vec![Box::new(b1), Box::new(b2)])
1353            .merge(MergeOp::Mean)
1354            .build()
1355            .unwrap();
1356
1357        let x = Variable::new(from_f32(&[3.0, 7.0], &[1, 2]), false);
1358        let y = graph.forward(&x).unwrap();
1359        let data = y.data().to_f32_vec().unwrap();
1360
1361        assert!((data[0] - 3.0).abs() < 1e-5);
1362        assert!((data[1] - 7.0).abs() < 1e-5);
1363    }
1364
1365    #[test]
1366    fn test_parameters() {
1367        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1368            .through(ReLU::new())
1369            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1370            .build()
1371            .unwrap();
1372
1373        let params = graph.parameters();
1374        assert_eq!(params.len(), 4);
1375    }
1376
1377    #[test]
1378    fn test_graph_backward() {
1379        let l1 = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
1380        let l2 = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
1381
1382        let graph = FlowBuilder::from(l1)
1383            .through(ReLU::new())
1384            .through(l2)
1385            .build()
1386            .unwrap();
1387
1388        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), true);
1389        let y = graph.forward(&x).unwrap();
1390        let loss = y.sum().unwrap();
1391        loss.backward().unwrap();
1392
1393        for p in graph.parameters() {
1394            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1395        }
1396        assert!(x.grad().is_some());
1397    }
1398
1399    #[test]
1400    fn test_graph_as_module() {
1401        let inner = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1402            .through(ReLU::new())
1403            .build()
1404            .unwrap();
1405
1406        let outer = FlowBuilder::from(inner)
1407            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1408            .build()
1409            .unwrap();
1410
1411        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1412        let y = outer.forward(&x).unwrap();
1413        assert_eq!(y.shape(), vec![1, 2]);
1414        assert_eq!(outer.parameters().len(), 4);
1415    }
1416
1417    #[test]
1418    fn test_training_loop() {
1419        let graph = FlowBuilder::from(Linear::on_device(1, 1, crate::tensor::test_device()).unwrap())
1420            .build()
1421            .unwrap();
1422
1423        let params = graph.parameters();
1424        let mut optim = SGD::new(&params, 0.01, 0.0);
1425
1426        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1]), false);
1427        let target = Variable::new(from_f32(&[3.0, 5.0, 7.0, 9.0], &[4, 1]), false);
1428
1429        let mut last_loss = f64::MAX;
1430        for _ in 0..800 {
1431            optim.zero_grad();
1432            let pred = graph.forward(&x).unwrap();
1433            let loss = mse_loss(&pred, &target).unwrap();
1434            last_loss = loss.item().unwrap();
1435            loss.backward().unwrap();
1436            optim.step().unwrap();
1437        }
1438
1439        assert!(last_loss < 0.01, "got loss={}", last_loss);
1440    }
1441
1442    #[test]
1443    fn test_also_backward() {
1444        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1445            .also(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1446            .build()
1447            .unwrap();
1448
1449        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1450        let y = graph.forward(&x).unwrap();
1451        let loss = y.sum().unwrap();
1452        loss.backward().unwrap();
1453
1454        assert!(x.grad().is_some());
1455        for p in graph.parameters() {
1456            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1457        }
1458    }
1459
1460    #[test]
1461    fn test_split_merge_backward() {
1462        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1463            .split(vec![
1464                Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
1465                Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
1466            ])
1467            .merge(MergeOp::Add)
1468            .build()
1469            .unwrap();
1470
1471        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1472        let y = graph.forward(&x).unwrap();
1473        let loss = y.sum().unwrap();
1474        loss.backward().unwrap();
1475
1476        assert!(x.grad().is_some());
1477        for p in graph.parameters() {
1478            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1479        }
1480    }
1481
1482    #[test]
1483    fn test_build_error_open_streams() {
1484        let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1485            .split(vec![Box::new(ReLU::new()), Box::new(Sigmoid::new())])
1486            .build();
1487        assert!(result.is_err());
1488    }
1489
1490    #[test]
1491    fn test_build_error_duplicate_tag() {
1492        let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1493            .tag("features")
1494            .through(ReLU::new())
1495            .tag("features")
1496            .build();
1497        assert!(result.is_err());
1498    }
1499
1500    // --- Using tests ---
1501
1502    #[test]
1503    fn test_using_backward_ref() {
1504        // Tag a point, then use it downstream
1505        // Graph: linear(x) → tag("ctx") → through(AddRef).using("ctx")
1506        // AddRef adds ctx to stream: stream + ctx = 2 * linear(x)
1507        let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1508        l.weight
1509            .variable
1510            .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1511        l.bias
1512            .as_ref()
1513            .unwrap()
1514            .variable
1515            .set_data(from_f32(&[0.0, 0.0], &[2]));
1516
1517        let graph = FlowBuilder::from(l)
1518            .tag("ctx")
1519            .through(AddRefModule)
1520            .using(&["ctx"])
1521            .build()
1522            .unwrap();
1523
1524        let x = Variable::new(from_f32(&[3.0, 5.0], &[1, 2]), false);
1525        let y = graph.forward(&x).unwrap();
1526        let data = y.data().to_f32_vec().unwrap();
1527
1528        // identity(x) = [3, 5], then AddRef adds ctx ([3, 5]) = [6, 10]
1529        assert!((data[0] - 6.0).abs() < 1e-5);
1530        assert!((data[1] - 10.0).abs() < 1e-5);
1531    }
1532
1533    #[test]
1534    fn test_using_backward_gradients() {
1535        let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1536        let graph = FlowBuilder::from(l)
1537            .tag("ctx")
1538            .through(AddRefModule)
1539            .using(&["ctx"])
1540            .build()
1541            .unwrap();
1542
1543        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1544        let y = graph.forward(&x).unwrap();
1545        let loss = y.sum().unwrap();
1546        loss.backward().unwrap();
1547
1548        assert!(x.grad().is_some());
1549        for p in graph.parameters() {
1550            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1551        }
1552    }
1553
1554    #[test]
1555    fn test_using_error_plain_module() {
1556        // Using on a plain module (not NamedInputModule) should error
1557        let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1558            .tag("ctx")
1559            .through(ReLU::new())
1560            .using(&["ctx"])
1561            .build();
1562        assert!(result.is_err());
1563    }
1564
1565    #[test]
1566    fn test_using_error_unknown_tag() {
1567        let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1568            .through(AddRefModule)
1569            .using(&["nonexistent"])
1570            .build();
1571        assert!(result.is_err());
1572    }
1573
1574    // --- Loop tests ---
1575
1576    #[test]
1577    fn test_loop_for() {
1578        // Doubler × 3 iterations: [1, 2] → [8, 16]
1579        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1580            .loop_body(Doubler)
1581            .for_n(3)
1582            .build()
1583            .unwrap();
1584
1585        // Set linear to identity
1586        let params = graph.parameters();
1587        params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1588        params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1589
1590        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1591        let y = graph.forward(&x).unwrap();
1592        let data = y.data().to_f32_vec().unwrap();
1593
1594        assert!((data[0] - 8.0).abs() < 1e-5, "1*2^3=8, got {}", data[0]);
1595        assert!((data[1] - 16.0).abs() < 1e-5, "2*2^3=16, got {}", data[1]);
1596    }
1597
1598    #[test]
1599    fn test_loop_for_backward() {
1600        // Loop with a learnable bias — gradient should accumulate across iterations
1601        let bias_step = BiasStep::new(2).unwrap();
1602        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1603            .loop_body(bias_step)
1604            .for_n(3)
1605            .build()
1606            .unwrap();
1607
1608        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1609        let y = graph.forward(&x).unwrap();
1610        let loss = y.sum().unwrap();
1611        loss.backward().unwrap();
1612
1613        // All parameters should have gradients
1614        for p in graph.parameters() {
1615            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1616        }
1617
1618        // The bias gradient should be 3 (accumulated from 3 iterations)
1619        // dL/db = 1 per iteration, 3 iterations → grad = [3, 3]
1620        // (because sum reduces to scalar, dL/d_each_element = 1, and bias contributes at each step)
1621        let all_params = graph.parameters();
1622        // Find the loop_bias parameter (from BiasStep, not Linear's "bias")
1623        let bias_param = all_params.iter().find(|p| p.name == "loop_bias").unwrap();
1624        let grad = bias_param.variable.grad().unwrap().to_f32_vec().unwrap();
1625        assert!(
1626            (grad[0] - 3.0).abs() < 1e-5,
1627            "bias grad should be 3, got {}",
1628            grad[0]
1629        );
1630    }
1631
1632    #[test]
1633    fn test_loop_while() {
1634        // While max < 10: double. Input [1, 2] → double until max >= 10
1635        // Iter 0: check [1,2] max=2 < 10 → double → [2,4]
1636        // Iter 1: check [2,4] max=4 < 10 → double → [4,8]
1637        // Iter 2: check [4,8] max=8 < 10 → double → [8,16]
1638        // Iter 3: check [8,16] max=16 >= 10 → halt
1639        // Result: [8, 16]
1640        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1641            .loop_body(Doubler)
1642            .while_cond(ThresholdHalt::new(10.0), 20)
1643            .build()
1644            .unwrap();
1645
1646        let params = graph.parameters();
1647        params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1648        params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1649
1650        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1651        let y = graph.forward(&x).unwrap();
1652        let data = y.data().to_f32_vec().unwrap();
1653
1654        assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1655        assert!((data[1] - 16.0).abs() < 1e-5, "got {}", data[1]);
1656    }
1657
1658    #[test]
1659    fn test_loop_while_immediate_halt() {
1660        // Threshold 0.5 — input [1, 2] max=2 > 0.5, halt immediately
1661        // While checks before body, so body never runs
1662        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1663            .loop_body(Doubler)
1664            .while_cond(ThresholdHalt::new(0.5), 20)
1665            .build()
1666            .unwrap();
1667
1668        let params = graph.parameters();
1669        params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1670        params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1671
1672        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1673        let y = graph.forward(&x).unwrap();
1674        let data = y.data().to_f32_vec().unwrap();
1675
1676        // Body never ran — output = input
1677        assert!((data[0] - 1.0).abs() < 1e-5);
1678        assert!((data[1] - 2.0).abs() < 1e-5);
1679    }
1680
1681    #[test]
1682    fn test_loop_until() {
1683        // Until max > 10: double. Body runs at least once.
1684        // Input [1, 2]
1685        // Iter 0: double → [2, 4], check max=4 <= 10 → continue
1686        // Iter 1: double → [4, 8], check max=8 <= 10 → continue
1687        // Iter 2: double → [8, 16], check max=16 > 10 → halt
1688        // Result: [8, 16]
1689        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1690            .loop_body(Doubler)
1691            .until_cond(ThresholdHalt::new(10.0), 20)
1692            .build()
1693            .unwrap();
1694
1695        let params = graph.parameters();
1696        params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1697        params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1698
1699        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1700        let y = graph.forward(&x).unwrap();
1701        let data = y.data().to_f32_vec().unwrap();
1702
1703        assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1704        assert!((data[1] - 16.0).abs() < 1e-5, "got {}", data[1]);
1705    }
1706
1707    #[test]
1708    fn test_loop_until_at_least_once() {
1709        // Until with threshold 0.5 — input [1, 2] would halt immediately in While,
1710        // but Until always runs body at least once
1711        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1712            .loop_body(Doubler)
1713            .until_cond(ThresholdHalt::new(0.5), 20)
1714            .build()
1715            .unwrap();
1716
1717        let params = graph.parameters();
1718        params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1719        params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1720
1721        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1722        let y = graph.forward(&x).unwrap();
1723        let data = y.data().to_f32_vec().unwrap();
1724
1725        // Body ran once: [2, 4]
1726        assert!((data[0] - 2.0).abs() < 1e-5, "got {}", data[0]);
1727        assert!((data[1] - 4.0).abs() < 1e-5, "got {}", data[1]);
1728    }
1729
1730    #[test]
1731    fn test_loop_parameters() {
1732        // Loop with learnable body — parameters should include body params
1733        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1734            .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1735            .for_n(3)
1736            .build()
1737            .unwrap();
1738
1739        let params = graph.parameters();
1740        // From module: weight + bias = 2, loop body Linear: weight + bias = 2
1741        assert_eq!(params.len(), 4);
1742    }
1743
1744    #[test]
1745    fn test_loop_while_parameters() {
1746        // While loop with body + condition — both contribute parameters
1747        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1748            .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1749            .while_cond(Linear::on_device(2, 1, crate::tensor::test_device()).unwrap(), 10)
1750            .build()
1751            .unwrap();
1752
1753        let params = graph.parameters();
1754        // From module: 2, loop body: 2, condition: 2 = 6
1755        assert_eq!(params.len(), 6);
1756    }
1757
1758    #[test]
1759    fn test_loop_in_chain() {
1760        // Linear → Loop(ReLU) × 3 → Linear
1761        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1762            .loop_body(ReLU::new())
1763            .for_n(3)
1764            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1765            .build()
1766            .unwrap();
1767
1768        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1769        let y = graph.forward(&x).unwrap();
1770        assert_eq!(y.shape(), vec![1, 2]);
1771    }
1772
1773    #[test]
1774    fn test_loop_using_backward_ref() {
1775        // Tag a tensor, then use it inside a loop body via .using()
1776        // Graph: identity → tag("ctx") → loop_body(AddRefModule).for_n(3).using("ctx")
1777        // Each iteration: state = state + ctx
1778        // So after 3 iterations: state = x + 3*x = 4*x
1779        let graph = FlowBuilder::from(Identity)
1780            .tag("ctx")
1781            .loop_body(AddRefModule)
1782            .for_n(3)
1783            .using(&["ctx"])
1784            .build()
1785            .unwrap();
1786
1787        let x = Variable::new(from_f32(&[2.0, 3.0], &[1, 2]), false);
1788        let y = graph.forward(&x).unwrap();
1789        let data = y.data().to_f32_vec().unwrap();
1790
1791        // x = [2, 3], after 3 iterations of (state + ctx): [8, 12]
1792        assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1793        assert!((data[1] - 12.0).abs() < 1e-5, "got {}", data[1]);
1794    }
1795
1796    #[test]
1797    fn test_loop_using_backward_gradients() {
1798        // Ensure gradients flow through loop+using
1799        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1800            .tag("ctx")
1801            .loop_body(AddRefModule)
1802            .for_n(2)
1803            .using(&["ctx"])
1804            .build()
1805            .unwrap();
1806
1807        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1808        let y = graph.forward(&x).unwrap();
1809        let loss = y.sum().unwrap();
1810        loss.backward().unwrap();
1811
1812        assert!(x.grad().is_some(), "input should have gradient");
1813        for p in graph.parameters() {
1814            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1815        }
1816    }
1817
1818    // --- Forward reference tests ---
1819
1820    /// Nil-safe add: skips nil inputs, adds rest. For forward ref state accumulation.
1821    struct NilSafeAdd;
1822    impl Module for NilSafeAdd {
1823        fn forward(&self, input: &Variable) -> Result<Variable> {
1824            Ok(input.clone())
1825        }
1826        fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
1827    }
1828    impl NamedInputModule for NilSafeAdd {
1829        fn forward_named(
1830            &self,
1831            input: &Variable,
1832            refs: &HashMap<String, Variable>,
1833        ) -> Result<Variable> {
1834            if let Some(memory) = refs.get("memory") {
1835                input.add(memory)
1836            } else {
1837                Ok(input.clone())
1838            }
1839        }
1840    }
1841
1842    use crate::nn::Identity;
1843
1844    #[test]
1845    fn test_flowbuilder_new() {
1846        // FlowBuilder::new() starts with implicit Identity
1847        let graph = FlowBuilder::new()
1848            .tag("input")
1849            .through(Linear::on_device(3, 2, crate::tensor::test_device()).unwrap())
1850            .build()
1851            .unwrap();
1852
1853        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1854        let y = graph.forward(&x).unwrap();
1855        assert_eq!(y.shape(), vec![1, 2]);
1856    }
1857
1858    #[test]
1859    fn test_forward_ref() {
1860        // Forward reference: using() before tag(). State carries between forward() calls.
1861        // Graph: entry → NilSafeAdd.Using("memory") → Identity.Tag("memory")
1862        // Pass 1: add gets [stream, zeros] (memory is nil/zeroed) → Identity → state captured
1863        // Pass 2: add gets [stream, prev_output] → sum → Identity → state captured
1864        let graph = FlowBuilder::from(Identity)
1865            .through(NilSafeAdd)
1866            .using(&["memory"])
1867            .through(Identity)
1868            .tag("memory")
1869            .build()
1870            .unwrap();
1871
1872        assert!(graph.has_state());
1873
1874        // Pass 1: [1,2] + zeros → [1,2]
1875        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1876        let y1 = graph.forward(&x).unwrap();
1877        let d1 = y1.data().to_f32_vec().unwrap();
1878        assert!((d1[0] - 1.0).abs() < 1e-5, "pass1[0]: got {}", d1[0]);
1879        assert!((d1[1] - 2.0).abs() < 1e-5, "pass1[1]: got {}", d1[1]);
1880
1881        // Pass 2: [1,2] + [1,2] → [2,4]
1882        let y2 = graph.forward(&x).unwrap();
1883        let d2 = y2.data().to_f32_vec().unwrap();
1884        assert!((d2[0] - 2.0).abs() < 1e-5, "pass2[0]: got {}", d2[0]);
1885        assert!((d2[1] - 4.0).abs() < 1e-5, "pass2[1]: got {}", d2[1]);
1886
1887        // Pass 3: [1,2] + [2,4] → [3,6]
1888        let y3 = graph.forward(&x).unwrap();
1889        let d3 = y3.data().to_f32_vec().unwrap();
1890        assert!((d3[0] - 3.0).abs() < 1e-5, "pass3[0]: got {}", d3[0]);
1891        assert!((d3[1] - 6.0).abs() < 1e-5, "pass3[1]: got {}", d3[1]);
1892    }
1893
1894    #[test]
1895    fn test_forward_ref_reset_state() {
1896        let graph = FlowBuilder::from(Identity)
1897            .through(NilSafeAdd)
1898            .using(&["memory"])
1899            .through(Identity)
1900            .tag("memory")
1901            .build()
1902            .unwrap();
1903
1904        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1905
1906        // Build up state
1907        graph.forward(&x).unwrap();
1908        graph.forward(&x).unwrap();
1909        let y_before = graph.forward(&x).unwrap();
1910        let d_before = y_before.data().to_f32_vec().unwrap();
1911        assert!((d_before[0] - 3.0).abs() < 1e-5);
1912
1913        // Reset and verify state is cleared
1914        graph.reset_state();
1915        let y_after = graph.forward(&x).unwrap();
1916        let d_after = y_after.data().to_f32_vec().unwrap();
1917        assert!((d_after[0] - 1.0).abs() < 1e-5, "after reset: got {}", d_after[0]);
1918    }
1919
1920    #[test]
1921    fn test_forward_ref_detach_state() {
1922        let graph = FlowBuilder::from(Identity)
1923            .through(NilSafeAdd)
1924            .using(&["memory"])
1925            .through(Identity)
1926            .tag("memory")
1927            .build()
1928            .unwrap();
1929
1930        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1931
1932        // Run forward, accumulate state
1933        let y1 = graph.forward(&x).unwrap();
1934        let _ = y1.sum().unwrap();
1935
1936        // Detach state — values preserved but gradient chain broken
1937        graph.detach_state();
1938
1939        // State should still have values (not reset)
1940        let y2 = graph.forward(&x).unwrap();
1941        let d2 = y2.data().to_f32_vec().unwrap();
1942        assert!((d2[0] - 2.0).abs() < 1e-5, "detach preserves values: got {}", d2[0]);
1943    }
1944
1945    #[test]
1946    fn test_forward_ref_backward() {
1947        // Gradients should flow through forward-ref connections
1948        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1949            .through(NilSafeAdd)
1950            .using(&["memory"])
1951            .through(Identity)
1952            .tag("memory")
1953            .build()
1954            .unwrap();
1955
1956        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1957        let y = graph.forward(&x).unwrap();
1958        let loss = y.sum().unwrap();
1959        loss.backward().unwrap();
1960
1961        assert!(x.grad().is_some(), "input should have gradient");
1962        for p in graph.parameters() {
1963            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1964        }
1965    }
1966
1967    #[test]
1968    fn test_forward_ref_unresolved_error() {
1969        // Using a tag that is never defined should error at build
1970        let result = FlowBuilder::from(Identity)
1971            .through(NilSafeAdd)
1972            .using(&["nonexistent"])
1973            .build();
1974        assert!(result.is_err());
1975    }
1976
1977    #[test]
1978    fn test_forward_ref_mixed_refs() {
1979        // Mix backward ref (tag before using) and forward ref (using before tag)
1980        // "ctx" is backward (AddRefModule expects "ctx"), "memory" is forward (NilSafeAdd expects "memory")
1981        let graph = FlowBuilder::from(Identity)
1982            .tag("ctx")
1983            .through(AddRefModule)
1984            .using(&["ctx"])
1985            .through(NilSafeAdd)
1986            .using(&["memory"])
1987            .through(Identity)
1988            .tag("memory")
1989            .build()
1990            .unwrap();
1991
1992        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1993
1994        // Pass 1: entry=[1,2], AddRef adds ctx=[1,2] → [2,4], NilSafeAdd +zeros → [2,4]
1995        let y1 = graph.forward(&x).unwrap();
1996        let d1 = y1.data().to_f32_vec().unwrap();
1997        assert!((d1[0] - 2.0).abs() < 1e-5, "mixed pass1[0]: got {}", d1[0]);
1998
1999        // Pass 2: entry=[1,2], AddRef adds ctx=[1,2] → [2,4], NilSafeAdd +[2,4] → [4,8]
2000        let y2 = graph.forward(&x).unwrap();
2001        let d2 = y2.data().to_f32_vec().unwrap();
2002        assert!((d2[0] - 4.0).abs() < 1e-5, "mixed pass2[0]: got {}", d2[0]);
2003    }
2004
2005    // --- Switch tests ---
2006
2007    /// Triples input.
2008    struct Tripler;
2009    impl Module for Tripler {
2010        fn forward(&self, input: &Variable) -> Result<Variable> {
2011            input.add(&input.add(input)?)
2012        }
2013        fn parameters(&self) -> Vec<Parameter> { vec![] }
2014    }
2015
2016    #[test]
2017    fn test_switch_selects_branch() {
2018        // Branch 0: double, Branch 1: triple. Router selects branch 1.
2019        let graph = FlowBuilder::from(Identity)
2020            .switch(FixedSelector::new(1), vec![Box::new(Doubler), Box::new(Tripler)])
2021            .build()
2022            .unwrap();
2023
2024        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2025        let y = graph.forward(&x).unwrap();
2026        let data = y.data().to_f32_vec().unwrap();
2027        assert!((data[0] - 3.0).abs() < 1e-5, "triple [1]=3, got {}", data[0]);
2028        assert!((data[1] - 6.0).abs() < 1e-5, "triple [2]=6, got {}", data[1]);
2029    }
2030
2031    #[test]
2032    fn test_switch_branch0() {
2033        let graph = FlowBuilder::from(Identity)
2034            .switch(FixedSelector::new(0), vec![Box::new(Doubler), Box::new(Tripler)])
2035            .build()
2036            .unwrap();
2037
2038        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2039        let y = graph.forward(&x).unwrap();
2040        let data = y.data().to_f32_vec().unwrap();
2041        assert!((data[0] - 2.0).abs() < 1e-5, "double [1]=2, got {}", data[0]);
2042        assert!((data[1] - 4.0).abs() < 1e-5, "double [2]=4, got {}", data[1]);
2043    }
2044
2045    #[test]
2046    fn test_switch_backward() {
2047        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2048            .switch(FixedSelector::new(0), vec![
2049                Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2050                Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2051            ])
2052            .build()
2053            .unwrap();
2054
2055        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2056        let y = graph.forward(&x).unwrap();
2057        let loss = y.sum().unwrap();
2058        loss.backward().unwrap();
2059
2060        assert!(x.grad().is_some());
2061        // Only entry + selected branch params should have gradients
2062        // (router has no params, unselected branch wasn't executed)
2063    }
2064
2065    #[test]
2066    fn test_switch_parameters() {
2067        let graph = FlowBuilder::from(Identity)
2068            .switch(
2069                Linear::on_device(2, 1, crate::tensor::test_device()).unwrap(),
2070                vec![
2071                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2072                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2073                ],
2074            )
2075            .build()
2076            .unwrap();
2077
2078        let params = graph.parameters();
2079        // Router: 2, Branch0: 2, Branch1: 2 = 6
2080        assert_eq!(params.len(), 6);
2081    }
2082
2083    // --- Gate tests ---
2084
2085    /// Router that outputs equal weights for all experts.
2086    struct EqualRouter(usize);
2087    impl Module for EqualRouter {
2088        fn forward(&self, input: &Variable) -> Result<Variable> {
2089            let batch = input.shape()[0];
2090            let w = 1.0 / self.0 as f32;
2091            let data = vec![w; batch as usize * self.0];
2092            Ok(Variable::new(
2093                Tensor::from_f32(&data, &[batch, self.0 as i64], crate::tensor::test_device())?,
2094                false,
2095            ))
2096        }
2097        fn parameters(&self) -> Vec<Parameter> { vec![] }
2098    }
2099
2100    #[test]
2101    fn test_gate_equal_weights() {
2102        // Equal weights: output = mean of expert outputs
2103        let graph = FlowBuilder::from(Identity)
2104            .gate(EqualRouter(2), vec![Box::new(Doubler), Box::new(Tripler)])
2105            .build()
2106            .unwrap();
2107
2108        let x = Variable::new(from_f32(&[2.0, 4.0], &[1, 2]), false);
2109        let y = graph.forward(&x).unwrap();
2110        let data = y.data().to_f32_vec().unwrap();
2111        // double=[4,8], triple=[6,12], mean = [5, 10]
2112        assert!((data[0] - 5.0).abs() < 1e-5, "gate[0]=5, got {}", data[0]);
2113        assert!((data[1] - 10.0).abs() < 1e-5, "gate[1]=10, got {}", data[1]);
2114    }
2115
2116    #[test]
2117    fn test_gate_backward() {
2118        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2119            .gate(
2120                Linear::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2121                vec![
2122                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2123                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2124                ],
2125            )
2126            .build()
2127            .unwrap();
2128
2129        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2130        let y = graph.forward(&x).unwrap();
2131        let loss = y.sum().unwrap();
2132        loss.backward().unwrap();
2133
2134        assert!(x.grad().is_some());
2135        for p in graph.parameters() {
2136            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
2137        }
2138    }
2139
2140    #[test]
2141    fn test_gate_parameters() {
2142        let graph = FlowBuilder::from(Identity)
2143            .gate(
2144                Linear::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2145                vec![
2146                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2147                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2148                ],
2149            )
2150            .build()
2151            .unwrap();
2152
2153        let params = graph.parameters();
2154        // Router: 2, Expert0: 2, Expert1: 2 = 6
2155        assert_eq!(params.len(), 6);
2156    }
2157
2158    // --- Map tests ---
2159
2160    #[test]
2161    fn test_map_each() {
2162        // Map doubler over 3 elements along dim 0
2163        let graph = FlowBuilder::from(Identity)
2164            .map(Doubler)
2165            .each()
2166            .build()
2167            .unwrap();
2168
2169        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]), false);
2170        let y = graph.forward(&x).unwrap();
2171        let data = y.data().to_f32_vec().unwrap();
2172
2173        assert_eq!(y.shape(), vec![3, 2]);
2174        assert!((data[0] - 2.0).abs() < 1e-5);
2175        assert!((data[5] - 12.0).abs() < 1e-5);
2176    }
2177
2178    #[test]
2179    fn test_map_batched() {
2180        // Batched: pass full tensor, skip element-wise
2181        let graph = FlowBuilder::from(Identity)
2182            .map(Doubler)
2183            .batched()
2184            .each()
2185            .build()
2186            .unwrap();
2187
2188        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), false);
2189        let y = graph.forward(&x).unwrap();
2190        let data = y.data().to_f32_vec().unwrap();
2191
2192        assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
2193    }
2194
2195    #[test]
2196    fn test_map_backward() {
2197        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2198            .map(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2199            .each()
2200            .build()
2201            .unwrap();
2202
2203        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), true);
2204        let y = graph.forward(&x).unwrap();
2205        let loss = y.sum().unwrap();
2206        loss.backward().unwrap();
2207
2208        assert!(x.grad().is_some());
2209        for p in graph.parameters() {
2210            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
2211        }
2212    }
2213
2214    // --- Observation tests ---
2215
2216    /// Scalar output module: sum all elements to a single value.
2217    struct ScalarSum;
2218    impl Module for ScalarSum {
2219        fn forward(&self, input: &Variable) -> Result<Variable> {
2220            input.sum()
2221        }
2222    }
2223
2224    #[test]
2225    fn test_tagged_capture() {
2226        // Tag intermediate output and retrieve it after forward
2227        let graph = FlowBuilder::from(Identity)
2228            .tag("features")
2229            .through(Doubler)
2230            .build()
2231            .unwrap();
2232
2233        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2234        let _ = graph.forward(&x).unwrap();
2235
2236        // Tagged value should be the identity output (before doubling)
2237        let features = graph.tagged("features").unwrap();
2238        let data = features.data().to_f32_vec().unwrap();
2239        assert!((data[0] - 1.0).abs() < 1e-5);
2240        assert!((data[1] - 2.0).abs() < 1e-5);
2241
2242        assert!(graph.tagged("nonexistent").is_none());
2243    }
2244
2245    #[test]
2246    fn test_tagged_updates_each_forward() {
2247        let graph = FlowBuilder::from(Doubler)
2248            .tag("doubled")
2249            .build()
2250            .unwrap();
2251
2252        let x1 = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2253        let _ = graph.forward(&x1).unwrap();
2254        let v1 = graph.tagged("doubled").unwrap().item().unwrap();
2255        assert!((v1 - 2.0).abs() < 1e-5);
2256
2257        let x2 = Variable::new(from_f32(&[5.0], &[1, 1]), false);
2258        let _ = graph.forward(&x2).unwrap();
2259        let v2 = graph.tagged("doubled").unwrap().item().unwrap();
2260        assert!((v2 - 10.0).abs() < 1e-5);
2261    }
2262
2263    #[test]
2264    fn test_tag_names() {
2265        let graph = FlowBuilder::from(Identity)
2266            .tag("a")
2267            .through(Identity)
2268            .tag("b")
2269            .build()
2270            .unwrap();
2271
2272        let mut names = graph.tag_names();
2273        names.sort();
2274        assert_eq!(names, vec!["a", "b"]);
2275    }
2276
2277    #[test]
2278    fn test_collect_flush_trend() {
2279        // Simulate a training loop with collect → flush → trend
2280        let graph = FlowBuilder::from(ScalarSum)
2281            .tag("loss")
2282            .build()
2283            .unwrap();
2284
2285        // Epoch 1: 3 batches with different inputs
2286        for val in &[1.0f32, 2.0, 3.0] {
2287            let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2288            let _ = graph.forward(&x).unwrap();
2289            graph.collect(&["loss"]).unwrap();
2290        }
2291        // batch buffer should have [1, 2, 3]
2292        let collected = graph.collected("loss");
2293        assert_eq!(collected.len(), 3);
2294
2295        graph.flush(&["loss"]);
2296        assert_eq!(graph.flush_count(), 1);
2297
2298        // Epoch 2: 3 batches
2299        for val in &[0.5f32, 0.3, 0.2] {
2300            let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2301            let _ = graph.forward(&x).unwrap();
2302            graph.collect(&["loss"]).unwrap();
2303        }
2304        graph.flush(&["loss"]);
2305        assert_eq!(graph.flush_count(), 2);
2306
2307        // Trend should show decrease: epoch1 mean=2.0, epoch2 mean≈0.333
2308        let trend = graph.trend("loss");
2309        assert_eq!(trend.len(), 2);
2310        assert!((trend.values()[0] - 2.0).abs() < 1e-5);
2311        assert!((trend.values()[1] - (1.0 / 3.0)).abs() < 1e-5);
2312        assert!(trend.improving(0));
2313    }
2314
2315    #[test]
2316    fn test_record_external_values() {
2317        let graph = FlowBuilder::from(Identity).build().unwrap();
2318
2319        graph.record("external_loss", &[0.5, 0.4, 0.3]);
2320        graph.flush(&["external_loss"]);
2321
2322        graph.record("external_loss", &[0.1, 0.05]);
2323        graph.flush(&["external_loss"]);
2324
2325        let trend = graph.trend("external_loss");
2326        assert_eq!(trend.len(), 2);
2327        assert!((trend.values()[0] - 0.4).abs() < 1e-5); // mean(0.5, 0.4, 0.3)
2328        assert!((trend.values()[1] - 0.075).abs() < 1e-5); // mean(0.1, 0.05)
2329        assert!(trend.improving(0));
2330    }
2331
2332    #[test]
2333    fn test_flush_all() {
2334        let graph = FlowBuilder::from(Identity).build().unwrap();
2335
2336        graph.record("a", &[1.0, 2.0]);
2337        graph.record("b", &[3.0, 4.0]);
2338        graph.flush(&[]); // flush all
2339
2340        assert_eq!(graph.trend("a").len(), 1);
2341        assert_eq!(graph.trend("b").len(), 1);
2342    }
2343
2344    #[test]
2345    fn test_reset_trend() {
2346        let graph = FlowBuilder::from(Identity).build().unwrap();
2347
2348        graph.record("loss", &[1.0]);
2349        graph.flush(&[]);
2350        assert_eq!(graph.trend("loss").len(), 1);
2351
2352        graph.reset_trend(&["loss"]);
2353        assert_eq!(graph.trend("loss").len(), 0);
2354    }
2355
2356    #[test]
2357    fn test_trends_group() {
2358        let graph = FlowBuilder::from(Identity).build().unwrap();
2359
2360        // Two decreasing series
2361        for epoch in &[10.0, 8.0, 6.0, 4.0] {
2362            graph.record("a", &[*epoch]);
2363            graph.record("b", &[*epoch * 0.5]);
2364            graph.flush(&[]);
2365        }
2366
2367        let tg = graph.trends(&["a", "b"]);
2368        assert_eq!(tg.len(), 2);
2369        assert!(tg.all_improving(0));
2370    }
2371
2372    // --- TagGroup tests ---
2373
2374    #[test]
2375    fn test_tag_group() {
2376        // Split into 3 branches with tag_group, then merge
2377        let graph = FlowBuilder::from(Identity)
2378            .split(vec![
2379                Box::new(Doubler),
2380                Box::new(Tripler),
2381                Box::new(Identity),
2382            ])
2383            .tag_group("branch")
2384            .merge(MergeOp::Add)
2385            .build()
2386            .unwrap();
2387
2388        // Check group registration
2389        let members = graph.tag_group("branch").unwrap();
2390        assert_eq!(members, &["branch_0", "branch_1", "branch_2"]);
2391
2392        // Non-existent group returns None
2393        assert!(graph.tag_group("nonexistent").is_none());
2394
2395        // Tags work for observation
2396        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2397        let _ = graph.forward(&x).unwrap();
2398
2399        let b0 = graph.tagged("branch_0").unwrap();
2400        let b0_data = b0.data().to_f32_vec().unwrap();
2401        assert!((b0_data[0] - 2.0).abs() < 1e-5, "doubler: got {}", b0_data[0]);
2402
2403        let b1 = graph.tagged("branch_1").unwrap();
2404        let b1_data = b1.data().to_f32_vec().unwrap();
2405        assert!((b1_data[0] - 3.0).abs() < 1e-5, "tripler: got {}", b1_data[0]);
2406    }
2407
2408    #[test]
2409    fn test_tag_group_observation() {
2410        // Tag group with collect/flush and trends expansion
2411        let graph = FlowBuilder::from(Identity)
2412            .split(vec![Box::new(ScalarSum), Box::new(ScalarSum)])
2413            .tag_group("head")
2414            .merge(MergeOp::Add)
2415            .build()
2416            .unwrap();
2417
2418        // Run a few epochs
2419        for epoch in &[1.0f32, 2.0, 3.0] {
2420            let x = Variable::new(from_f32(&[*epoch], &[1, 1]), false);
2421            let _ = graph.forward(&x).unwrap();
2422            graph.collect(&["head_0", "head_1"]).unwrap();
2423            graph.flush(&["head_0", "head_1"]);
2424        }
2425
2426        // Trends with group expansion
2427        let tg = graph.trends(&["head"]);
2428        assert_eq!(tg.len(), 2); // head_0 and head_1
2429    }
2430
2431    #[test]
2432    fn test_tag_group_errors() {
2433        // tag_group on single stream should error
2434        let result = FlowBuilder::from(Identity)
2435            .tag_group("bad")
2436            .build();
2437        assert!(result.is_err());
2438
2439        // Duplicate group name
2440        let result = FlowBuilder::from(Identity)
2441            .split(vec![Box::new(Doubler), Box::new(Tripler)])
2442            .tag_group("x")
2443            .merge(MergeOp::Add)
2444            .split(vec![Box::new(Doubler), Box::new(Tripler)])
2445            .tag_group("x")
2446            .merge(MergeOp::Add)
2447            .build();
2448        assert!(result.is_err());
2449    }
2450
2451    // --- Input tests ---
2452
2453    /// Module that adds all refs to input (for multi-input testing).
2454    struct SumRefs;
2455    impl Module for SumRefs {
2456        fn forward(&self, input: &Variable) -> Result<Variable> {
2457            Ok(input.clone())
2458        }
2459        fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
2460    }
2461    impl NamedInputModule for SumRefs {
2462        fn forward_named(
2463            &self,
2464            input: &Variable,
2465            refs: &HashMap<String, Variable>,
2466        ) -> Result<Variable> {
2467            let mut result = input.clone();
2468            for v in refs.values() {
2469                result = result.add(v)?;
2470            }
2471            Ok(result)
2472        }
2473    }
2474
2475    #[test]
2476    fn test_input_auxiliary() {
2477        // Graph with auxiliary inputs: From(identity) + Input("ctx")
2478        // Downstream: through(SumRefs).using("ctx")
2479        let graph = FlowBuilder::from(Identity)
2480            .input(&["ctx"])
2481            .through(SumRefs)
2482            .using(&["ctx"])
2483            .build()
2484            .unwrap();
2485
2486        let main = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2487        let ctx = Variable::new(from_f32(&[10.0, 20.0], &[1, 2]), false);
2488
2489        let y = graph.forward_multi(&[main, ctx]).unwrap();
2490        let data = y.data().to_f32_vec().unwrap();
2491        // SumRefs adds ctx to main: [1+10, 2+20] = [11, 22]
2492        assert!((data[0] - 11.0).abs() < 1e-5, "got {}", data[0]);
2493        assert!((data[1] - 22.0).abs() < 1e-5, "got {}", data[1]);
2494    }
2495
2496    #[test]
2497    fn test_input_multiple() {
2498        // Graph with two auxiliary inputs
2499        let graph = FlowBuilder::from(Identity)
2500            .input(&["a", "b"])
2501            .through(SumRefs)
2502            .using(&["a", "b"])
2503            .build()
2504            .unwrap();
2505
2506        let main = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2507        let a = Variable::new(from_f32(&[10.0], &[1, 1]), false);
2508        let b = Variable::new(from_f32(&[100.0], &[1, 1]), false);
2509
2510        let y = graph.forward_multi(&[main, a, b]).unwrap();
2511        let data = y.data().to_f32_vec().unwrap();
2512        // 1 + 10 + 100 = 111
2513        assert!((data[0] - 111.0).abs() < 1e-5, "got {}", data[0]);
2514    }
2515
2516    #[test]
2517    fn test_input_error_count_mismatch() {
2518        let graph = FlowBuilder::from(Identity)
2519            .input(&["ctx"])
2520            .build()
2521            .unwrap();
2522
2523        // forward() with single input should fail (expects 2: main + ctx)
2524        let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2525        assert!(graph.forward(&x).is_err());
2526    }
2527
2528    // --- Graph set_training test ---
2529
2530    #[test]
2531    fn test_graph_set_training() {
2532        use crate::nn::Dropout;
2533
2534        let graph = FlowBuilder::from(Linear::on_device(3, 3, crate::tensor::test_device()).unwrap())
2535            .through(Dropout::new(0.5))
2536            .build()
2537            .unwrap();
2538
2539        // Training mode: dropout is active
2540        let x = Variable::new(from_f32(&[1.0; 12], &[4, 3]), false);
2541        let y1 = graph.forward(&x).unwrap();
2542        assert_eq!(y1.shape(), vec![4, 3]);
2543
2544        // Set eval via graph
2545        graph.set_training(false);
2546        let y2 = graph.forward(&x).unwrap();
2547        let y3 = graph.forward(&x).unwrap();
2548        assert_eq!(y2.shape(), vec![4, 3]);
2549
2550        // In eval: dropout is identity, so repeated forward gives same output
2551        let d2 = y2.data().to_f32_vec().unwrap();
2552        let d3 = y3.data().to_f32_vec().unwrap();
2553        let same = d2.iter().zip(d3.iter()).all(|(a, b)| (a - b).abs() < 1e-6);
2554        assert!(same, "eval mode should be deterministic (no dropout)");
2555    }
2556
2557    // --- walk_modules test ---
2558
2559    #[test]
2560    fn test_walk_modules() {
2561        use crate::nn::walk_modules;
2562
2563        let l1 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
2564        let mut count = 0;
2565        walk_modules(&l1, &mut |_| count += 1);
2566        assert_eq!(count, 1); // leaf module, no children
2567    }
2568
2569    // --- Profiling tests ---
2570
2571    #[test]
2572    fn test_profiling_basic() {
2573        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2574            .tag("encoder")
2575            .through(ReLU::new())
2576            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2577            .tag("decoder")
2578            .build()
2579            .unwrap();
2580
2581        // No profiling by default
2582        assert!(!graph.profiling());
2583        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
2584        graph.forward(&x).unwrap();
2585        assert!(graph.profile().is_none());
2586
2587        // Enable profiling
2588        graph.enable_profiling();
2589        assert!(graph.profiling());
2590        graph.forward(&x).unwrap();
2591
2592        let p = graph.profile().unwrap();
2593        assert!(p.total.as_nanos() > 0, "total should be nonzero");
2594        assert!(!p.nodes.is_empty(), "should have node timings");
2595        assert!(!p.levels.is_empty(), "should have level timings");
2596
2597        // Tagged node timing
2598        let enc_dur = p.timing("encoder");
2599        assert!(enc_dur.as_nanos() > 0, "encoder timing should be nonzero");
2600        let dec_dur = p.timing("decoder");
2601        assert!(dec_dur.as_nanos() > 0, "decoder timing should be nonzero");
2602        assert!(p.timing("nonexistent").is_zero());
2603
2604        // Graph-level timing shortcut
2605        assert!(graph.timing("encoder").as_nanos() > 0);
2606
2607        // Display
2608        let s = p.to_string();
2609        assert!(s.contains("Forward:"));
2610        assert!(s.contains("Level"));
2611
2612        // Disable
2613        graph.disable_profiling();
2614        assert!(!graph.profiling());
2615        graph.forward(&x).unwrap();
2616        assert!(graph.profile().is_none());
2617    }
2618
2619    #[test]
2620    fn test_profiling_timing_trend() {
2621        let graph = FlowBuilder::from(ScalarSum)
2622            .tag("loss")
2623            .build()
2624            .unwrap();
2625
2626        graph.enable_profiling();
2627
2628        // Simulate 2 epochs, 3 batches each
2629        for _ in 0..2 {
2630            for val in &[1.0f32, 2.0, 3.0] {
2631                let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2632                graph.forward(&x).unwrap();
2633                graph.collect_timings(&["loss"]);
2634            }
2635            graph.flush_timings(&[]);
2636        }
2637
2638        let trend = graph.timing_trend("loss");
2639        assert_eq!(trend.len(), 2, "2 epochs flushed");
2640        assert!(trend.values()[0] > 0.0, "timing values should be positive");
2641
2642        // Reset
2643        graph.reset_timing_trend(&["loss"]);
2644        assert_eq!(graph.timing_trend("loss").len(), 0);
2645    }
2646
2647    // --- DOT tests ---
2648
2649    #[test]
2650    fn test_dot_basic() {
2651        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2652            .tag("enc")
2653            .through(ReLU::new())
2654            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2655            .build()
2656            .unwrap();
2657
2658        let dot = graph.dot();
2659        assert!(dot.contains("digraph G"));
2660        assert!(dot.contains("level 0"));
2661        assert!(dot.contains("#enc"));
2662        assert!(dot.contains("->"));
2663    }
2664
2665    #[test]
2666    fn test_dot_with_profile() {
2667        let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2668            .tag("enc")
2669            .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2670            .build()
2671            .unwrap();
2672
2673        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
2674
2675        // Without profiling: dot_with_profile falls back to structural
2676        let dot1 = graph.dot_with_profile();
2677        assert!(dot1.contains("digraph G"));
2678
2679        // With profiling: includes timing annotations
2680        graph.enable_profiling();
2681        graph.forward(&x).unwrap();
2682        let dot2 = graph.dot_with_profile();
2683        assert!(dot2.contains("digraph G"));
2684        assert!(dot2.contains("Forward:"));
2685    }
2686
2687    // --- Traced tests ---
2688
2689    /// A loop body that implements trace() — captures per-iteration side data.
2690    struct TracingDoubler {
2691        last_output: RefCell<Option<Variable>>,
2692    }
2693    impl TracingDoubler {
2694        fn new() -> Self {
2695            TracingDoubler {
2696                last_output: RefCell::new(None),
2697            }
2698        }
2699    }
2700    impl Module for TracingDoubler {
2701        fn forward(&self, input: &Variable) -> Result<Variable> {
2702            let out = input.add(input)?;
2703            *self.last_output.borrow_mut() = Some(out.clone());
2704            Ok(out)
2705        }
2706        fn trace(&self) -> Option<Variable> {
2707            self.last_output.borrow().clone()
2708        }
2709    }
2710
2711    #[test]
2712    fn test_loop_traces() {
2713        // Loop(TracingDoubler) × 3: [1,2] → [2,4] → [4,8] → [8,16]
2714        // traces should capture [2,4], [4,8], [8,16]
2715        let graph = FlowBuilder::from(Identity)
2716            .loop_body(TracingDoubler::new())
2717            .for_n(3)
2718            .build()
2719            .unwrap();
2720
2721        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2722        let y = graph.forward(&x).unwrap();
2723        let data = y.data().to_f32_vec().unwrap();
2724        assert!((data[0] - 8.0).abs() < 1e-5);
2725
2726        // Get traces — should find them on the loop node
2727        let traces = graph.traces("any").unwrap();
2728        assert_eq!(traces.len(), 3, "3 iterations = 3 traces");
2729
2730        let t0 = traces[0].data().to_f32_vec().unwrap();
2731        assert!((t0[0] - 2.0).abs() < 1e-5, "iter0: [2,4], got {}", t0[0]);
2732
2733        let t1 = traces[1].data().to_f32_vec().unwrap();
2734        assert!((t1[0] - 4.0).abs() < 1e-5, "iter1: [4,8], got {}", t1[0]);
2735
2736        let t2 = traces[2].data().to_f32_vec().unwrap();
2737        assert!((t2[0] - 8.0).abs() < 1e-5, "iter2: [8,16], got {}", t2[0]);
2738    }
2739
2740    #[test]
2741    fn test_loop_traces_cleared_each_forward() {
2742        let graph = FlowBuilder::from(Identity)
2743            .loop_body(TracingDoubler::new())
2744            .for_n(2)
2745            .build()
2746            .unwrap();
2747
2748        let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2749        graph.forward(&x).unwrap();
2750        let traces1 = graph.traces("any").unwrap();
2751        assert_eq!(traces1.len(), 2);
2752
2753        // Second forward should clear and re-populate
2754        graph.forward(&x).unwrap();
2755        let traces2 = graph.traces("any").unwrap();
2756        assert_eq!(traces2.len(), 2);
2757    }
2758
2759    #[test]
2760    fn test_loop_no_traces_without_trace_impl() {
2761        // Doubler doesn't implement trace() (returns None by default)
2762        let graph = FlowBuilder::from(Identity)
2763            .loop_body(Doubler)
2764            .for_n(3)
2765            .build()
2766            .unwrap();
2767
2768        let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2769        graph.forward(&x).unwrap();
2770
2771        // No traces since Doubler's trace() returns None
2772        assert!(graph.traces("any").is_none());
2773    }
2774
2775    // --- Router tests ---
2776
2777    #[test]
2778    fn test_softmax_router_gate() {
2779        // SoftmaxRouter with 2 experts: double + triple, weights from learned router
2780        let graph = FlowBuilder::from(Identity)
2781            .gate(
2782                SoftmaxRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2783                vec![Box::new(Doubler), Box::new(Tripler)],
2784            )
2785            .build()
2786            .unwrap();
2787
2788        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2789        let y = graph.forward(&x).unwrap();
2790        // Output should be a weighted combination — just verify it runs and has correct shape
2791        assert_eq!(y.shape(), vec![1, 2]);
2792        // Router has 2 params (weight + bias), experts have 0
2793        let params = graph.parameters();
2794        assert_eq!(params.len(), 2);
2795    }
2796
2797    #[test]
2798    fn test_softmax_router_backward() {
2799        let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2800            .gate(
2801                SoftmaxRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2802                vec![
2803                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2804                    Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2805                ],
2806            )
2807            .build()
2808            .unwrap();
2809
2810        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2811        let y = graph.forward(&x).unwrap();
2812        let loss = y.sum().unwrap();
2813        loss.backward().unwrap();
2814
2815        assert!(x.grad().is_some());
2816        for p in graph.parameters() {
2817            assert!(p.variable.grad().is_some(), "{} missing gradient", p.name);
2818        }
2819    }
2820
2821    #[test]
2822    fn test_sigmoid_router_gate() {
2823        let graph = FlowBuilder::from(Identity)
2824            .gate(
2825                SigmoidRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2826                vec![Box::new(Doubler), Box::new(Tripler)],
2827            )
2828            .build()
2829            .unwrap();
2830
2831        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2832        let y = graph.forward(&x).unwrap();
2833        assert_eq!(y.shape(), vec![1, 2]);
2834    }
2835
2836    #[test]
2837    fn test_fixed_selector_switch() {
2838        // FixedSelector(1) always picks branch 1 (Tripler)
2839        let graph = FlowBuilder::from(Identity)
2840            .switch(FixedSelector::new(1), vec![Box::new(Doubler), Box::new(Tripler)])
2841            .build()
2842            .unwrap();
2843
2844        let x = Variable::new(from_f32(&[2.0, 3.0], &[1, 2]), false);
2845        let y = graph.forward(&x).unwrap();
2846        let data = y.data().to_f32_vec().unwrap();
2847        assert!((data[0] - 6.0).abs() < 1e-5, "triple 2=6, got {}", data[0]);
2848        assert!((data[1] - 9.0).abs() < 1e-5, "triple 3=9, got {}", data[1]);
2849    }
2850
2851    #[test]
2852    fn test_argmax_selector_switch() {
2853        let graph = FlowBuilder::from(Identity)
2854            .switch(
2855                ArgmaxSelector::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2856                vec![Box::new(Doubler), Box::new(Tripler)],
2857            )
2858            .build()
2859            .unwrap();
2860
2861        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2862        let y = graph.forward(&x).unwrap();
2863        // Should select one branch — just verify it runs and has correct shape
2864        assert_eq!(y.shape(), vec![1, 2]);
2865        // ArgmaxSelector has params from its Linear projection
2866        assert_eq!(graph.parameters().len(), 2);
2867    }
2868
2869    // --- Halt tests ---
2870
2871    #[test]
2872    fn test_threshold_halt_while() {
2873        // body = Doubler, halt when max > 10
2874        // input [1,2] → iter1 [2,4] → iter2 [4,8] → iter3 [8,16] halt (16 > 10)
2875        let graph = FlowBuilder::from(Identity)
2876            .loop_body(Doubler)
2877            .while_cond(ThresholdHalt::new(10.0), 20)
2878            .build()
2879            .unwrap();
2880
2881        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2882        let y = graph.forward(&x).unwrap();
2883        let data = y.data().to_f32_vec().unwrap();
2884        // Should stop at [8, 16] (max=16 > 10)
2885        assert!((data[0] - 8.0).abs() < 1e-5, "expected 8, got {}", data[0]);
2886        assert!((data[1] - 16.0).abs() < 1e-5, "expected 16, got {}", data[1]);
2887    }
2888
2889    #[test]
2890    fn test_threshold_halt_until() {
2891        // Until: body runs first, then check
2892        // input [1,2] → iter1 body [2,4] check (max=4 < 10 continue)
2893        //             → iter2 body [4,8] check (max=8 < 10 continue)
2894        //             → iter3 body [8,16] check (max=16 > 10 halt)
2895        let graph = FlowBuilder::from(Identity)
2896            .loop_body(Doubler)
2897            .until_cond(ThresholdHalt::new(10.0), 20)
2898            .build()
2899            .unwrap();
2900
2901        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2902        let y = graph.forward(&x).unwrap();
2903        let data = y.data().to_f32_vec().unwrap();
2904        // Should stop at [8, 16] (max=16 > 10)
2905        assert!((data[0] - 8.0).abs() < 1e-5, "expected 8, got {}", data[0]);
2906        assert!((data[1] - 16.0).abs() < 1e-5, "expected 16, got {}", data[1]);
2907    }
2908
2909    #[test]
2910    fn test_threshold_halt_immediate() {
2911        // Threshold already exceeded: while should not iterate
2912        let graph = FlowBuilder::from(Identity)
2913            .loop_body(Doubler)
2914            .while_cond(ThresholdHalt::new(0.5), 20)
2915            .build()
2916            .unwrap();
2917
2918        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2919        let y = graph.forward(&x).unwrap();
2920        let data = y.data().to_f32_vec().unwrap();
2921        // max=2.0 > 0.5 → halt immediately, input passes through
2922        assert!((data[0] - 1.0).abs() < 1e-5, "expected 1, got {}", data[0]);
2923        assert!((data[1] - 2.0).abs() < 1e-5, "expected 2, got {}", data[1]);
2924    }
2925
2926    #[test]
2927    fn test_learned_halt_parameters() {
2928        let graph = FlowBuilder::from(Identity)
2929            .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2930            .until_cond(LearnedHalt::on_device(2, crate::tensor::test_device()).unwrap(), 5)
2931            .build()
2932            .unwrap();
2933
2934        // Body Linear: 2 params, LearnedHalt Linear(2→1): 2 params = 4
2935        let params = graph.parameters();
2936        assert_eq!(params.len(), 4);
2937    }
2938
2939    #[test]
2940    fn test_named_parameters_unique() {
2941        let graph = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2942            .through(ReLU::new())
2943            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2944            .build()
2945            .unwrap();
2946
2947        let named = graph.named_parameters();
2948        // Two Linear layers: 2 params each (weight + bias) = 4
2949        assert_eq!(named.len(), 4);
2950
2951        // All names should be unique
2952        let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
2953        let unique: std::collections::HashSet<&str> = names.iter().copied().collect();
2954        assert_eq!(names.len(), unique.len(), "duplicate names: {:?}", names);
2955    }
2956
2957    #[test]
2958    fn test_named_parameters_tagged_prefix() {
2959        let graph = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2960            .tag("encoder")
2961            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2962            .build()
2963            .unwrap();
2964
2965        let named = graph.named_parameters();
2966        // First Linear is tagged "encoder", second is untagged
2967        let encoder_params: Vec<&str> = named.iter()
2968            .filter(|(n, _)| n.starts_with("encoder/"))
2969            .map(|(n, _)| n.as_str())
2970            .collect();
2971        assert_eq!(encoder_params.len(), 2, "tagged node should have 2 params with 'encoder/' prefix");
2972
2973        // Untagged node uses its node_id (like "linear_2")
2974        let untagged: Vec<&str> = named.iter()
2975            .filter(|(n, _)| !n.starts_with("encoder/"))
2976            .map(|(n, _)| n.as_str())
2977            .collect();
2978        assert_eq!(untagged.len(), 2, "untagged node should have 2 params");
2979        assert!(untagged[0].contains('/'), "should have prefix/name format: {}", untagged[0]);
2980    }
2981
2982    // --- Structural hash tests ---
2983
2984    #[test]
2985    fn test_structural_hash_deterministic() {
2986        let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2987            .through(ReLU::new())
2988            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2989            .build()
2990            .unwrap();
2991
2992        let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2993            .through(ReLU::new())
2994            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2995            .build()
2996            .unwrap();
2997
2998        assert_eq!(g1.structural_hash(), g2.structural_hash());
2999    }
3000
3001    #[test]
3002    fn test_structural_hash_differs() {
3003        let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3004            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3005            .build()
3006            .unwrap();
3007
3008        // Different architecture: different hidden size
3009        let g2 = FlowBuilder::from(Linear::on_device(4, 16, crate::tensor::test_device()).unwrap())
3010            .through(Linear::on_device(16, 2, crate::tensor::test_device()).unwrap())
3011            .build()
3012            .unwrap();
3013
3014        assert_ne!(g1.structural_hash(), g2.structural_hash());
3015    }
3016
3017    #[test]
3018    fn test_short_hash_length() {
3019        let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3020            .build()
3021            .unwrap();
3022
3023        assert_eq!(g.structural_hash().len(), 64);
3024        assert_eq!(g.short_hash().len(), 8);
3025        assert!(g.structural_hash().starts_with(g.short_hash()));
3026    }
3027
3028    #[test]
3029    fn test_label_default_none() {
3030        let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3031            .build()
3032            .unwrap();
3033        assert!(g.label().is_none());
3034    }
3035
3036    #[test]
3037    fn test_label_set() {
3038        let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3039            .label("my-model")
3040            .build()
3041            .unwrap();
3042        assert_eq!(g.label(), Some("my-model"));
3043    }
3044
3045    #[test]
3046    fn test_label_does_not_affect_hash() {
3047        let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3048            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3049            .build()
3050            .unwrap();
3051
3052        let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3053            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3054            .label("different-label")
3055            .build()
3056            .unwrap();
3057
3058        assert_eq!(g1.structural_hash(), g2.structural_hash());
3059    }
3060
3061    #[test]
3062    fn test_graph_save_load_checkpoint() {
3063        let g = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3064            .tag("enc")
3065            .through(ReLU::new())
3066            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3067            .tag("dec")
3068            .build()
3069            .unwrap();
3070
3071        let dir = std::env::temp_dir();
3072        let path = dir.join("test_graph_ckpt.fdl");
3073        let path_str = path.to_str().unwrap();
3074
3075        // Save
3076        g.save_checkpoint(path_str).unwrap();
3077
3078        // Build identical architecture, load into it
3079        let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3080            .tag("enc")
3081            .through(ReLU::new())
3082            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3083            .tag("dec")
3084            .build()
3085            .unwrap();
3086
3087        let report = g2.load_checkpoint(path_str).unwrap();
3088        assert_eq!(report.loaded.len(), 4); // 2 Linear × (weight + bias)
3089        assert!(report.skipped.is_empty());
3090        assert!(report.missing.is_empty());
3091
3092        // Verify weights match
3093        for ((n1, p1), (n2, p2)) in g.named_parameters().iter().zip(g2.named_parameters().iter()) {
3094            assert_eq!(n1, n2);
3095            assert_eq!(p1.variable.data().to_f32_vec().unwrap(),
3096                       p2.variable.data().to_f32_vec().unwrap());
3097        }
3098
3099        std::fs::remove_file(path_str).ok();
3100    }
3101
3102    #[test]
3103    fn test_graph_checkpoint_hash_mismatch() {
3104        let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3105            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3106            .build()
3107            .unwrap();
3108
3109        let dir = std::env::temp_dir();
3110        let path = dir.join("test_graph_ckpt_mismatch.fdl");
3111        let path_str = path.to_str().unwrap();
3112
3113        g1.save_checkpoint(path_str).unwrap();
3114
3115        // Different architecture
3116        let g2 = FlowBuilder::from(Linear::on_device(4, 16, crate::tensor::test_device()).unwrap())
3117            .through(Linear::on_device(16, 2, crate::tensor::test_device()).unwrap())
3118            .build()
3119            .unwrap();
3120
3121        let result = g2.load_checkpoint(path_str);
3122        assert!(result.is_err());
3123        assert!(format!("{}", result.unwrap_err()).contains("architecture mismatch"));
3124
3125        std::fs::remove_file(path_str).ok();
3126    }
3127
3128    #[test]
3129    fn test_graph_checkpoint_gz() {
3130        let g = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3131            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3132            .build()
3133            .unwrap();
3134
3135        let dir = std::env::temp_dir();
3136        let path = dir.join("test_graph_ckpt.fdl.gz");
3137        let path_str = path.to_str().unwrap();
3138
3139        g.save_checkpoint(path_str).unwrap();
3140
3141        let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3142            .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3143            .build()
3144            .unwrap();
3145
3146        let report = g2.load_checkpoint(path_str).unwrap();
3147        assert_eq!(report.loaded.len(), 4);
3148
3149        std::fs::remove_file(path_str).ok();
3150    }
3151
3152    // --- collect_with reduction tests ---
3153
3154    #[test]
3155    fn test_collect_with_sum_reduction() {
3156        // Non-scalar tagged output reduced via Sum
3157        let graph = FlowBuilder::from(Identity)
3158            .tag("features")
3159            .build()
3160            .unwrap();
3161
3162        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
3163        let _ = graph.forward(&x).unwrap();
3164        graph.collect_with(&["features"], Reduce::Sum).unwrap();
3165
3166        let collected = graph.collected("features");
3167        assert_eq!(collected.len(), 1);
3168        assert!((collected[0] - 6.0).abs() < 1e-5, "sum([1,2,3]) = 6, got {}", collected[0]);
3169    }
3170
3171    #[test]
3172    fn test_collect_with_mean_reduction() {
3173        let graph = FlowBuilder::from(Identity)
3174            .tag("out")
3175            .build()
3176            .unwrap();
3177
3178        let x = Variable::new(from_f32(&[2.0, 4.0, 6.0], &[1, 3]), false);
3179        let _ = graph.forward(&x).unwrap();
3180        graph.collect_with(&["out"], Reduce::Mean).unwrap();
3181
3182        let collected = graph.collected("out");
3183        assert!((collected[0] - 4.0).abs() < 1e-5, "mean([2,4,6]) = 4, got {}", collected[0]);
3184    }
3185
3186    #[test]
3187    fn test_collect_with_max_reduction() {
3188        let graph = FlowBuilder::from(Identity)
3189            .tag("out")
3190            .build()
3191            .unwrap();
3192
3193        let x = Variable::new(from_f32(&[1.0, 5.0, 3.0], &[1, 3]), false);
3194        let _ = graph.forward(&x).unwrap();
3195        graph.collect_with(&["out"], Reduce::Max).unwrap();
3196
3197        let collected = graph.collected("out");
3198        assert!((collected[0] - 5.0).abs() < 1e-5, "max([1,5,3]) = 5, got {}", collected[0]);
3199    }
3200
3201    #[test]
3202    fn test_collect_with_min_reduction() {
3203        let graph = FlowBuilder::from(Identity)
3204            .tag("out")
3205            .build()
3206            .unwrap();
3207
3208        let x = Variable::new(from_f32(&[-2.0, 0.0, 3.0], &[1, 3]), false);
3209        let _ = graph.forward(&x).unwrap();
3210        graph.collect_with(&["out"], Reduce::Min).unwrap();
3211
3212        let collected = graph.collected("out");
3213        assert!((collected[0] - (-2.0)).abs() < 1e-5, "min([-2,0,3]) = -2, got {}", collected[0]);
3214    }
3215
3216    #[test]
3217    fn test_collect_with_norm_reduction() {
3218        let graph = FlowBuilder::from(Identity)
3219            .tag("out")
3220            .build()
3221            .unwrap();
3222
3223        let x = Variable::new(from_f32(&[3.0, 4.0], &[1, 2]), false);
3224        let _ = graph.forward(&x).unwrap();
3225        graph.collect_with(&["out"], Reduce::Norm).unwrap();
3226
3227        let collected = graph.collected("out");
3228        // L2 norm of [3, 4] = 5
3229        assert!((collected[0] - 5.0).abs() < 1e-4, "norm([3,4]) = 5, got {}", collected[0]);
3230    }
3231
3232    #[test]
3233    fn test_collect_rejects_non_scalar() {
3234        // Plain collect() should reject non-scalar outputs
3235        let graph = FlowBuilder::from(Identity)
3236            .tag("out")
3237            .build()
3238            .unwrap();
3239
3240        let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
3241        let _ = graph.forward(&x).unwrap();
3242        assert!(graph.collect(&["out"]).is_err());
3243    }
3244
3245    #[test]
3246    fn test_collect_with_scalar_passthrough() {
3247        // collect_with on already-scalar output should work without reduction
3248        let graph = FlowBuilder::from(ScalarSum)
3249            .tag("loss")
3250            .build()
3251            .unwrap();
3252
3253        let x = Variable::new(from_f32(&[3.0, 7.0], &[1, 2]), false);
3254        let _ = graph.forward(&x).unwrap();
3255        graph.collect_with(&["loss"], Reduce::Max).unwrap();
3256
3257        let collected = graph.collected("loss");
3258        // ScalarSum yields 10.0 (scalar), so it should pass through directly
3259        assert!((collected[0] - 10.0).abs() < 1e-5);
3260    }
3261
3262    #[test]
3263    fn test_collect_with_flush_trend_pipeline() {
3264        // Full pipeline: non-scalar → reduce → flush → trend
3265        let graph = FlowBuilder::from(Identity)
3266            .tag("h")
3267            .build()
3268            .unwrap();
3269
3270        // Epoch 1: two batches with decreasing norms
3271        let x1 = Variable::new(from_f32(&[3.0, 4.0], &[1, 2]), false);
3272        let _ = graph.forward(&x1).unwrap();
3273        graph.collect_with(&["h"], Reduce::Norm).unwrap();
3274
3275        let x2 = Variable::new(from_f32(&[1.0, 0.0], &[1, 2]), false);
3276        let _ = graph.forward(&x2).unwrap();
3277        graph.collect_with(&["h"], Reduce::Norm).unwrap();
3278
3279        graph.flush(&["h"]);
3280
3281        // Epoch 2
3282        let x3 = Variable::new(from_f32(&[0.5, 0.5], &[1, 2]), false);
3283        let _ = graph.forward(&x3).unwrap();
3284        graph.collect_with(&["h"], Reduce::Norm).unwrap();
3285        graph.flush(&["h"]);
3286
3287        let trend = graph.trend("h");
3288        assert_eq!(trend.len(), 2);
3289        // Epoch 1 mean: (5.0 + 1.0) / 2 = 3.0
3290        assert!((trend.values()[0] - 3.0).abs() < 1e-4);
3291        assert!(trend.improving(0)); // norms should be decreasing
3292    }
3293
3294    // --- Map.over and Map.slices tests ---
3295
3296    #[test]
3297    fn test_map_over_tag() {
3298        // Tag a tensor, then map over it from a different stream position
3299        let graph = FlowBuilder::from(Identity)
3300            .tag("features")
3301            .through(Doubler)        // stream is now 2x
3302            .map(Doubler)
3303            .over("features")        // map over original (1x), not current stream (2x)
3304            .build()
3305            .unwrap();
3306
3307        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), false);
3308        let y = graph.forward(&x).unwrap();
3309        let data = y.data().to_f32_vec().unwrap();
3310        // .over("features") maps Doubler over the tagged value (original x)
3311        // Doubler: x + x = 2x, applied element-wise along dim 0
3312        assert_eq!(y.shape(), vec![2, 2]);
3313        assert!((data[0] - 2.0).abs() < 1e-5);  // 1.0 * 2
3314        assert!((data[1] - 4.0).abs() < 1e-5);  // 2.0 * 2
3315        assert!((data[2] - 6.0).abs() < 1e-5);  // 3.0 * 2
3316        assert!((data[3] - 8.0).abs() < 1e-5);  // 4.0 * 2
3317    }
3318
3319    #[test]
3320    fn test_map_over_unknown_tag_error() {
3321        let result = FlowBuilder::from(Identity)
3322            .map(Doubler)
3323            .over("nonexistent")
3324            .build();
3325        assert!(result.is_err());
3326    }
3327
3328    #[test]
3329    fn test_map_slices() {
3330        // Input [2, 4], slices(2): decompose → [4, 2], map Doubler, recompose → [2, 4]
3331        let graph = FlowBuilder::from(Identity)
3332            .map(Doubler)
3333            .slices(2)
3334            .build()
3335            .unwrap();
3336
3337        let x = Variable::new(
3338            from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]),
3339            false,
3340        );
3341        let y = graph.forward(&x).unwrap();
3342        let data = y.data().to_f32_vec().unwrap();
3343
3344        // Each element doubled
3345        assert_eq!(y.shape(), vec![2, 4]);
3346        assert!((data[0] - 2.0).abs() < 1e-5);
3347        assert!((data[7] - 16.0).abs() < 1e-5);
3348    }
3349
3350    #[test]
3351    fn test_map_slices_batched() {
3352        // Same as above but with batched fast path
3353        let graph = FlowBuilder::from(Identity)
3354            .map(Doubler)
3355            .batched()
3356            .slices(2)
3357            .build()
3358            .unwrap();
3359
3360        let x = Variable::new(
3361            from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]),
3362            false,
3363        );
3364        let y = graph.forward(&x).unwrap();
3365        let data = y.data().to_f32_vec().unwrap();
3366
3367        assert_eq!(y.shape(), vec![2, 4]);
3368        assert!((data[0] - 2.0).abs() < 1e-5);
3369        assert!((data[7] - 16.0).abs() < 1e-5);
3370    }
3371
3372    #[test]
3373    fn test_map_slices_gradient() {
3374        // Input [2, 4] → slices(2) decomposes to [4, 2] → Linear(2, 3) → [4, 3] → recompose [2, 6]
3375        let graph = FlowBuilder::from(Identity)
3376            .map(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3377            .slices(2)
3378            .build()
3379            .unwrap();
3380
3381        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]), true);
3382        let y = graph.forward(&x).unwrap();
3383        assert_eq!(y.shape(), vec![2, 6]); // 3 * 2 slices = 6
3384        let loss = y.sum().unwrap();
3385        loss.backward().unwrap();
3386
3387        assert!(x.grad().is_some());
3388        for p in graph.parameters() {
3389            assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
3390        }
3391    }
3392
3393    #[test]
3394    fn test_map_slices_not_divisible_error() {
3395        let graph = FlowBuilder::from(Identity)
3396            .map(Doubler)
3397            .slices(3)
3398            .build()
3399            .unwrap();
3400
3401        // [2, 4] with slices(3) — 4 not divisible by 3
3402        let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]), false);
3403        assert!(graph.forward(&x).is_err());
3404    }
3405}