Skip to main content

flodl/graph/
tree.rs

1//! Graph tree: hierarchical subgraph composition with label-path addressing.
2//!
3//! When a labeled [`Graph`] is used inside a [`FlowBuilder`](super::FlowBuilder),
4//! the parent registers it as a child subgraph. Dot-separated label paths
5//! (`"encoder.scan.hidden"`) address subgraphs and tags across boundaries.
6//!
7//! All operations are build-time or explicit-query-time. The forward path is untouched.
8//!
9//! # Key methods on [`Graph`]
10//!
11//! - **Navigation**: [`tree_children()`](Graph::tree_children), [`child_graph()`](Graph::child_graph),
12//!   [`subgraph()`](Graph::subgraph), [`is_composed()`](Graph::is_composed)
13//! - **Parameters**: [`parameters_at()`](Graph::parameters_at), [`named_parameters_at()`](Graph::named_parameters_at)
14//! - **Freeze/thaw**: [`freeze()`](Graph::freeze), [`thaw()`](Graph::thaw), [`is_frozen()`](Graph::is_frozen)
15//! - **Checkpoints**: [`load_subgraph_checkpoint()`](Graph::load_subgraph_checkpoint)
16//! - **Observation**: [`tagged_at()`](Graph::tagged_at), [`collect_at()`](Graph::collect_at),
17//!   [`record_at()`](Graph::record_at), [`trend_at()`](Graph::trend_at)
18
19use std::collections::HashMap;
20use crate::autograd::Variable;
21use crate::nn::{self, Buffer, Module, Parameter};
22use crate::tensor::{Result, TensorError};
23use super::Graph;
24use super::trend::Trend;
25
26/// What a label path resolves to.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PathKind {
29    /// An entire child subgraph.
30    Subgraph,
31    /// A named tag within a graph.
32    Tag,
33}
34
35/// Internal resolution result (borrowed references, no ownership).
36#[allow(dead_code)]
37pub(crate) enum ResolvedPath<'a> {
38    /// Path resolves to an entire child subgraph.
39    Subgraph(&'a Graph),
40    /// Path resolves to a tag within a specific graph.
41    Tag { graph: &'a Graph, tag: String },
42}
43
44impl Graph {
45    // ── Path resolution ──────────────────────────────────────────────
46
47    /// Resolve a dot-separated label path to a subgraph or tag.
48    ///
49    /// **Strict dot semantics:**
50    /// - `"scan"` — local: check children first (Subgraph), then tags (Tag).
51    /// - `"letter.scan"` — child `"letter"`, then `"scan"` within it.
52    /// - `"letter.scan.location"` — child `"letter"`, child/tag `"scan"`, then `"location"`.
53    ///
54    /// Returns `Err` if any segment doesn't resolve.
55    pub(crate) fn resolve(&self, path: &str) -> Result<ResolvedPath<'_>> {
56        if path.is_empty() {
57            return Err(TensorError::new("empty label path"));
58        }
59        let segments: Vec<&str> = path.split('.').collect();
60        self.resolve_segments(&segments, path, false)
61    }
62
63    fn resolve_segments<'a>(
64        &'a self,
65        segments: &[&str],
66        full_path: &str,
67        cross_boundary: bool,
68    ) -> Result<ResolvedPath<'a>> {
69        debug_assert!(!segments.is_empty());
70        let first = segments[0];
71
72        if segments.len() == 1 {
73            // Single segment: children take priority over tags
74            if let Some(g) = self.child_graph(first) {
75                return Ok(ResolvedPath::Subgraph(g));
76            }
77            if self.tag_names.contains_key(first) {
78                // Block internal tags when accessed from outside
79                if cross_boundary && self.internal_tags.contains(first) {
80                    return Err(TensorError::new(&format!(
81                        "tag {:?} is internal and cannot be accessed from a parent graph (path: {:?})",
82                        first, full_path
83                    )));
84                }
85                return Ok(ResolvedPath::Tag { graph: self, tag: first.to_string() });
86            }
87            return Err(TensorError::new(&format!(
88                "{:?} is not a subgraph or tag of this graph (path: {:?})",
89                first, full_path
90            )));
91        }
92
93        // Multi-segment: first MUST be a child label
94        let child = self.child_graph(first).ok_or_else(|| {
95            TensorError::new(&format!(
96                "{:?} is not a subgraph of this graph (path: {:?})",
97                first, full_path
98            ))
99        })?;
100
101        // Once we cross into a child, all subsequent resolution is cross-boundary
102        child.resolve_segments(&segments[1..], full_path, true)
103    }
104
105    // ── Public navigation ────────────────────────────────────────────
106
107    /// Direct children: label -> child graph.
108    pub fn tree_children(&self) -> HashMap<&str, &Graph> {
109        self.children.iter()
110            .filter_map(|(label, &ni)| {
111                self.nodes[ni].module.as_ref()
112                    .and_then(|m| m.as_graph())
113                    .map(|g| (label.as_str(), g))
114            })
115            .collect()
116    }
117
118    /// Get a direct child graph by label (one level only).
119    pub fn child_graph(&self, label: &str) -> Option<&Graph> {
120        self.children.get(label)
121            .and_then(|&ni| self.nodes[ni].module.as_ref())
122            .and_then(|m| m.as_graph())
123    }
124
125    /// Get a subgraph at any depth via dot-path.
126    pub fn subgraph(&self, path: &str) -> Result<&Graph> {
127        match self.resolve(path)? {
128            ResolvedPath::Subgraph(g) => Ok(g),
129            ResolvedPath::Tag { .. } => Err(TensorError::new(&format!(
130                "path {:?} resolves to a tag, not a subgraph", path
131            ))),
132        }
133    }
134
135    /// Whether this graph has been composed into a parent graph.
136    pub fn is_composed(&self) -> bool {
137        self.composed.get()
138    }
139
140    /// Tags marked as internal (hidden from parent resolution).
141    pub fn internal_tags(&self) -> &std::collections::HashSet<String> {
142        &self.internal_tags
143    }
144
145    /// Validate that a path resolves, returning what it resolves to.
146    pub fn validate_path(&self, path: &str) -> Result<PathKind> {
147        match self.resolve(path)? {
148            ResolvedPath::Subgraph(_) => Ok(PathKind::Subgraph),
149            ResolvedPath::Tag { .. } => Ok(PathKind::Tag),
150        }
151    }
152
153    // ── Parameter operations ─────────────────────────────────────────
154
155    /// All parameters at a label path.
156    pub fn parameters_at(&self, path: &str) -> Result<Vec<Parameter>> {
157        match self.resolve(path)? {
158            ResolvedPath::Subgraph(g) => Ok(g.parameters()),
159            ResolvedPath::Tag { graph, ref tag } => {
160                if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
161                    if let Some(ref module) = graph.nodes[ni].module {
162                        Ok(module.parameters())
163                    } else {
164                        Ok(vec![])
165                    }
166                } else {
167                    Ok(vec![])
168                }
169            }
170        }
171    }
172
173    /// Named parameters at a label path, using the target's own namespace.
174    /// For subgraphs: delegates to the child graph's `named_parameters()`.
175    /// For tags: qualifies with the tag name as prefix.
176    pub fn named_parameters_at(&self, path: &str) -> Result<Vec<(String, Parameter)>> {
177        match self.resolve(path)? {
178            ResolvedPath::Subgraph(g) => Ok(g.named_parameters()),
179            ResolvedPath::Tag { graph, ref tag } => {
180                if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
181                    if let Some(ref module) = graph.nodes[ni].module {
182                        Ok(module.parameters().into_iter()
183                            .map(|p| (format!("{}/{}", tag, p.name), p))
184                            .collect())
185                    } else {
186                        Ok(vec![])
187                    }
188                } else {
189                    Ok(vec![])
190                }
191            }
192        }
193    }
194
195    /// Named buffers at a label path, using the target's own namespace.
196    pub fn named_buffers_at(&self, path: &str) -> Result<Vec<(String, Buffer)>> {
197        match self.resolve(path)? {
198            ResolvedPath::Subgraph(g) => Ok(g.named_buffers()),
199            ResolvedPath::Tag { graph, ref tag } => {
200                if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
201                    if let Some(ref module) = graph.nodes[ni].module {
202                        Ok(module.buffers().into_iter()
203                            .map(|b| (format!("{}/{}", tag, b.name), b))
204                            .collect())
205                    } else {
206                        Ok(vec![])
207                    }
208                } else {
209                    Ok(vec![])
210                }
211            }
212        }
213    }
214
215    // ── Freeze / thaw ────────────────────────────────────────────────
216
217    /// Freeze all parameters at the given label path.
218    pub fn freeze(&self, path: &str) -> Result<()> {
219        for p in self.parameters_at(path)? {
220            p.freeze()?;
221        }
222        Ok(())
223    }
224
225    /// Thaw (unfreeze) all parameters at the given label path.
226    pub fn thaw(&self, path: &str) -> Result<()> {
227        for p in self.parameters_at(path)? {
228            p.unfreeze()?;
229        }
230        Ok(())
231    }
232
233    /// Check if all parameters at the path are frozen.
234    /// Returns true only if there are parameters and ALL are frozen.
235    pub fn is_frozen(&self, path: &str) -> Result<bool> {
236        let params = self.parameters_at(path)?;
237        if params.is_empty() {
238            return Ok(false);
239        }
240        Ok(params.iter().all(|p| p.is_frozen()))
241    }
242
243    // ── Training mode ────────────────────────────────────────────────
244
245    // ── Checkpoint composition ────────────────────────────────────────
246
247    /// Load a checkpoint into a specific subgraph.
248    ///
249    /// The checkpoint's structural hash is validated against the target
250    /// subgraph's hash. Named parameters/buffers are matched within the
251    /// subgraph's own namespace.
252    pub fn load_subgraph_checkpoint(&self, path: &str, file: &str) -> Result<nn::LoadReport> {
253        let target = self.subgraph(path)?;
254        let params = target.named_parameters();
255        let buffers = target.named_buffers();
256        let hash = target.structural_hash();
257        nn::load_checkpoint_file(file, &params, &buffers, Some(hash))
258    }
259
260    // ── Training mode ────────────────────────────────────────────────
261
262    /// Set training mode on a specific subgraph or tagged module.
263    pub fn set_training_at(&self, path: &str, training: bool) -> Result<()> {
264        match self.resolve(path)? {
265            ResolvedPath::Subgraph(g) => {
266                g.set_training(training);
267            }
268            ResolvedPath::Tag { graph, ref tag } => {
269                if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
270                    if let Some(ref module) = graph.nodes[ni].module {
271                        crate::nn::walk_modules(module.as_ref(), &mut |m| {
272                            m.set_training(training);
273                        });
274                    }
275                }
276            }
277        }
278        Ok(())
279    }
280
281    // ── Cross-boundary observation ───────────────────────────────────
282
283    /// Get a tagged output by label path.
284    /// Returns `Err` if the path doesn't exist (null -- wiring bug).
285    /// Returns `Ok(None)` if the path exists but hasn't been computed yet (nil).
286    /// Returns `Ok(Some(v))` if the value is available.
287    pub fn tagged_at(&self, path: &str) -> Result<Option<Variable>> {
288        match self.resolve(path)? {
289            ResolvedPath::Subgraph(_) => Err(TensorError::new(&format!(
290                "path {:?} resolves to a subgraph, not a tag", path
291            ))),
292            ResolvedPath::Tag { graph, ref tag } => Ok(graph.tagged(tag)),
293        }
294    }
295
296    /// Collect metrics from label paths into observation buffers.
297    /// Each path must resolve to a tag (not a subgraph).
298    /// Metrics are stored in the target graph's batch buffer.
299    pub fn collect_at(&self, paths: &[&str]) -> Result<()> {
300        for &path in paths {
301            match self.resolve(path)? {
302                ResolvedPath::Subgraph(_) => {
303                    return Err(TensorError::new(&format!(
304                        "collect_at: {:?} resolves to a subgraph, not a tag", path
305                    )));
306                }
307                ResolvedPath::Tag { graph, ref tag } => {
308                    graph.collect(&[tag.as_str()])?;
309                }
310            }
311        }
312        Ok(())
313    }
314
315    /// Record a scalar metric at a label path.
316    /// For dotted paths, the metric is stored in the target graph's buffer
317    /// under the final segment name.
318    pub fn record_at(&self, path: &str, value: f64) -> Result<()> {
319        let segments: Vec<&str> = path.split('.').collect();
320        if segments.len() < 2 {
321            // Single segment: record into self
322            self.record_scalar(path, value);
323            return Ok(());
324        }
325        // Multi-segment: resolve parent graph, record under last segment
326        let parent_path = segments[..segments.len() - 1].join(".");
327        let tag = segments[segments.len() - 1];
328        let target = self.subgraph(&parent_path)?;
329        target.record_scalar(tag, value);
330        Ok(())
331    }
332
333    /// Get trend for a label-path metric.
334    /// For dotted paths, reads from the target graph's epoch history.
335    pub fn trend_at(&self, path: &str) -> Result<Trend> {
336        let segments: Vec<&str> = path.split('.').collect();
337        if segments.len() < 2 {
338            return Ok(self.trend(path));
339        }
340        let parent_path = segments[..segments.len() - 1].join(".");
341        let tag = segments[segments.len() - 1];
342        let target = self.subgraph(&parent_path)?;
343        Ok(target.trend(tag))
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use crate::autograd::Variable;
350    use crate::graph::FlowBuilder;
351    use crate::nn::{Linear, Module};
352    use crate::nn::ReLU;
353    use crate::tensor::{test_device, test_opts, Tensor};
354    use super::PathKind;
355
356    #[test]
357    fn test_unlabeled_graph_no_children() {
358        let dev = test_device();
359
360        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
361            .through(ReLU::new())
362            .build()
363            .unwrap();
364
365        let outer = FlowBuilder::from(inner)
366            .through(Linear::on_device(4, 2, dev).unwrap())
367            .build()
368            .unwrap();
369
370        // Unlabeled child is NOT registered
371        assert!(outer.tree_children().is_empty());
372        // But parameters are still collected (backward compat)
373        assert_eq!(outer.parameters().len(), 4); // 2 from inner Linear + 2 from outer Linear
374    }
375
376    #[test]
377    fn test_labeled_child_registered() {
378        let dev = test_device();
379
380        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
381            .through(ReLU::new())
382            .label("encoder")
383            .build()
384            .unwrap();
385
386        let outer = FlowBuilder::from(inner)
387            .through(Linear::on_device(4, 2, dev).unwrap())
388            .build()
389            .unwrap();
390
391        assert_eq!(outer.tree_children().len(), 1);
392        assert!(outer.tree_children().contains_key("encoder"));
393        assert!(outer.child_graph("encoder").is_some());
394        assert_eq!(outer.child_graph("encoder").unwrap().label(), Some("encoder"));
395    }
396
397    #[test]
398    fn test_composed_flag() {
399        let dev = test_device();
400
401        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
402            .label("child")
403            .build()
404            .unwrap();
405
406        // Standalone: not composed
407        assert!(!inner.is_composed());
408
409        let outer = FlowBuilder::from(inner)
410            .through(Linear::on_device(4, 2, dev).unwrap())
411            .build()
412            .unwrap();
413
414        // After composition: child is composed
415        let child = outer.child_graph("child").unwrap();
416        assert!(child.is_composed());
417        // Parent is not composed
418        assert!(!outer.is_composed());
419    }
420
421    #[test]
422    fn test_label_collision_error() {
423        let dev = test_device();
424
425        let a = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
426            .label("dupe")
427            .build()
428            .unwrap();
429        let b = FlowBuilder::from(Linear::on_device(4, 2, dev).unwrap())
430            .label("dupe")
431            .build()
432            .unwrap();
433
434        let result = FlowBuilder::from(a)
435            .through(b)
436            .build();
437
438        let msg = result.err().expect("should be Err").to_string();
439        assert!(msg.contains("duplicate child graph label"), "got: {}", msg);
440    }
441
442    #[test]
443    fn test_dot_in_label_error() {
444        let dev = test_device();
445
446        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
447            .label("a.b")
448            .build()
449            .unwrap();
450
451        let result = FlowBuilder::from(inner)
452            .through(Linear::on_device(4, 2, dev).unwrap())
453            .build();
454
455        let msg = result.err().expect("should be Err").to_string();
456        assert!(msg.contains("contains a dot"), "got: {}", msg);
457    }
458
459    #[test]
460    fn test_label_tag_same_node_ok() {
461        let dev = test_device();
462
463        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
464            .label("encoder")
465            .build()
466            .unwrap();
467
468        // Tag the same node as the child graph label
469        let outer = FlowBuilder::from(inner)
470            .tag("encoder")
471            .through(Linear::on_device(4, 2, dev).unwrap())
472            .build();
473
474        assert!(outer.is_ok());
475    }
476
477    #[test]
478    fn test_resolve_single_segment_child() {
479        let dev = test_device();
480
481        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
482            .label("encoder")
483            .build()
484            .unwrap();
485
486        let outer = FlowBuilder::from(inner)
487            .through(Linear::on_device(4, 2, dev).unwrap())
488            .build()
489            .unwrap();
490
491        assert_eq!(outer.validate_path("encoder").unwrap(), PathKind::Subgraph);
492    }
493
494    #[test]
495    fn test_resolve_single_segment_tag() {
496        let dev = test_device();
497
498        let outer = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
499            .tag("hidden")
500            .through(Linear::on_device(4, 2, dev).unwrap())
501            .build()
502            .unwrap();
503
504        assert_eq!(outer.validate_path("hidden").unwrap(), PathKind::Tag);
505    }
506
507    #[test]
508    fn test_resolve_multi_segment() {
509        let dev = test_device();
510
511        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
512            .tag("hidden")
513            .through(Linear::on_device(4, 2, dev).unwrap())
514            .label("encoder")
515            .build()
516            .unwrap();
517
518        let outer = FlowBuilder::from(inner)
519            .through(Linear::on_device(2, 1, dev).unwrap())
520            .build()
521            .unwrap();
522
523        assert_eq!(outer.validate_path("encoder.hidden").unwrap(), PathKind::Tag);
524    }
525
526    #[test]
527    fn test_resolve_multi_level() {
528        let dev = test_device();
529
530        let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
531            .label("read")
532            .build()
533            .unwrap();
534        let middle = FlowBuilder::from(innermost)
535            .through(Linear::on_device(4, 2, dev).unwrap())
536            .label("letter")
537            .build()
538            .unwrap();
539        let outer = FlowBuilder::from(middle)
540            .through(Linear::on_device(2, 1, dev).unwrap())
541            .build()
542            .unwrap();
543
544        assert_eq!(outer.validate_path("letter").unwrap(), PathKind::Subgraph);
545        assert_eq!(outer.validate_path("letter.read").unwrap(), PathKind::Subgraph);
546    }
547
548    #[test]
549    fn test_resolve_invalid_path_error() {
550        let dev = test_device();
551
552        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
553            .label("encoder")
554            .build()
555            .unwrap();
556
557        let outer = FlowBuilder::from(inner)
558            .through(Linear::on_device(4, 2, dev).unwrap())
559            .build()
560            .unwrap();
561
562        // Non-existent single segment
563        assert!(outer.validate_path("nonexistent").is_err());
564        // Non-existent dotted path
565        assert!(outer.validate_path("encoder.nonexistent").is_err());
566        // Dotting into non-child first segment
567        assert!(outer.validate_path("nonexistent.foo").is_err());
568    }
569
570    #[test]
571    fn test_subgraph_returns_graph() {
572        let dev = test_device();
573
574        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
575            .label("encoder")
576            .build()
577            .unwrap();
578
579        let outer = FlowBuilder::from(inner)
580            .through(Linear::on_device(4, 2, dev).unwrap())
581            .build()
582            .unwrap();
583
584        let sub = outer.subgraph("encoder").unwrap();
585        assert_eq!(sub.label(), Some("encoder"));
586        assert_eq!(sub.parameters().len(), 2); // 1 Linear: weight + bias
587    }
588
589    #[test]
590    fn test_forward_still_works_with_tree() {
591        let dev = test_device();
592        let opts = test_opts();
593
594        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
595            .through(ReLU::new())
596            .label("encoder")
597            .build()
598            .unwrap();
599
600        let outer = FlowBuilder::from(inner)
601            .through(Linear::on_device(4, 2, dev).unwrap())
602            .build()
603            .unwrap();
604
605        let x = Variable::new(
606            Tensor::randn(&[1, 3], opts).unwrap(),
607            false,
608        );
609        let y = outer.forward(&x).unwrap();
610        assert_eq!(y.shape(), vec![1, 2]);
611    }
612
613    // ── Phase B: Training control ────────────────────────────────────
614
615    #[test]
616    fn test_parameters_at_subgraph() {
617        let dev = test_device();
618        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
619            .through(Linear::on_device(4, 2, dev).unwrap())
620            .label("encoder")
621            .build()
622            .unwrap();
623
624        let outer = FlowBuilder::from(inner)
625            .through(Linear::on_device(2, 1, dev).unwrap())
626            .build()
627            .unwrap();
628
629        // Child has 2 Linear layers = 4 params (2 weight + 2 bias)
630        let params = outer.parameters_at("encoder").unwrap();
631        assert_eq!(params.len(), 4);
632        // Outer total = 4 (child) + 2 (outer Linear) = 6
633        assert_eq!(outer.parameters().len(), 6);
634    }
635
636    #[test]
637    fn test_parameters_at_tag() {
638        let dev = test_device();
639        let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
640            .tag("first")
641            .through(Linear::on_device(4, 2, dev).unwrap())
642            .build()
643            .unwrap();
644
645        let params = g.parameters_at("first").unwrap();
646        assert_eq!(params.len(), 2); // 1 Linear: weight + bias
647    }
648
649    #[test]
650    fn test_freeze_thaw_roundtrip() {
651        let dev = test_device();
652        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
653            .label("encoder")
654            .build()
655            .unwrap();
656
657        let outer = FlowBuilder::from(inner)
658            .through(Linear::on_device(4, 2, dev).unwrap())
659            .build()
660            .unwrap();
661
662        // Initially not frozen
663        assert!(!outer.is_frozen("encoder").unwrap());
664
665        // Freeze child
666        outer.freeze("encoder").unwrap();
667        assert!(outer.is_frozen("encoder").unwrap());
668        // All child params should have requires_grad = false
669        for p in outer.parameters_at("encoder").unwrap() {
670            assert!(p.is_frozen());
671        }
672        // Outer params still trainable
673        let outer_params = outer.parameters();
674        let outer_only: Vec<_> = outer_params.iter()
675            .filter(|p| !p.is_frozen())
676            .collect();
677        assert_eq!(outer_only.len(), 2); // outer Linear: weight + bias
678
679        // Thaw child
680        outer.thaw("encoder").unwrap();
681        assert!(!outer.is_frozen("encoder").unwrap());
682        for p in outer.parameters_at("encoder").unwrap() {
683            assert!(!p.is_frozen());
684        }
685    }
686
687    #[test]
688    fn test_freeze_deep_path() {
689        let dev = test_device();
690        let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
691            .label("read")
692            .build()
693            .unwrap();
694        let middle = FlowBuilder::from(innermost)
695            .through(Linear::on_device(4, 2, dev).unwrap())
696            .label("letter")
697            .build()
698            .unwrap();
699        let outer = FlowBuilder::from(middle)
700            .through(Linear::on_device(2, 1, dev).unwrap())
701            .build()
702            .unwrap();
703
704        // Freeze only the innermost
705        outer.freeze("letter.read").unwrap();
706        assert!(outer.is_frozen("letter.read").unwrap());
707        // "letter" overall is NOT fully frozen (it has its own Linear too)
708        assert!(!outer.is_frozen("letter").unwrap());
709    }
710
711    #[test]
712    fn test_named_parameters_at_uses_target_namespace() {
713        let dev = test_device();
714        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
715            .tag("hidden")
716            .through(Linear::on_device(4, 2, dev).unwrap())
717            .label("encoder")
718            .build()
719            .unwrap();
720
721        let outer = FlowBuilder::from(inner)
722            .through(Linear::on_device(2, 1, dev).unwrap())
723            .build()
724            .unwrap();
725
726        // Subgraph: uses child's own namespace
727        let named = outer.named_parameters_at("encoder").unwrap();
728        assert_eq!(named.len(), 4);
729        // Names should use child-local prefixes (tag "hidden" and node id)
730        assert!(named.iter().any(|(n, _)| n.starts_with("hidden/")));
731    }
732
733    #[test]
734    fn test_freeze_invalid_path_error() {
735        let dev = test_device();
736        let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
737            .build()
738            .unwrap();
739
740        assert!(g.freeze("nonexistent").is_err());
741        assert!(g.thaw("nonexistent").is_err());
742        assert!(g.is_frozen("nonexistent").is_err());
743        assert!(g.parameters_at("nonexistent").is_err());
744    }
745
746    #[test]
747    fn test_set_training_at() {
748        let dev = test_device();
749        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
750            .through(crate::nn::Dropout::new(0.5))
751            .label("encoder")
752            .build()
753            .unwrap();
754
755        let outer = FlowBuilder::from(inner)
756            .through(Linear::on_device(4, 2, dev).unwrap())
757            .build()
758            .unwrap();
759
760        // Set child to eval mode
761        outer.set_training_at("encoder", false).unwrap();
762        // Set child back to training mode
763        outer.set_training_at("encoder", true).unwrap();
764        // Invalid path errors
765        assert!(outer.set_training_at("nonexistent", false).is_err());
766    }
767
768    // ── Phase C: Checkpoint composition ──────────────────────────────
769
770    #[test]
771    fn test_subgraph_checkpoint_roundtrip() {
772        let dev = test_device();
773        // Build and "train" a child graph standalone
774        let child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
775            .through(ReLU::new())
776            .through(Linear::on_device(4, 2, dev).unwrap())
777            .label("encoder")
778            .build()
779            .unwrap();
780
781        // Save child checkpoint
782        let dir = std::env::temp_dir().join("flodl_test_subgraph_ckpt");
783        std::fs::create_dir_all(&dir).unwrap();
784        let ckpt_path = dir.join("encoder.fdl");
785        child.save_checkpoint(ckpt_path.to_str().unwrap()).unwrap();
786
787        // Build parent with a fresh (randomly initialized) child of same architecture
788        let fresh_child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
789            .through(ReLU::new())
790            .through(Linear::on_device(4, 2, dev).unwrap())
791            .label("encoder")
792            .build()
793            .unwrap();
794
795        let parent = FlowBuilder::from(fresh_child)
796            .through(Linear::on_device(2, 1, dev).unwrap())
797            .build()
798            .unwrap();
799
800        // Load child checkpoint into parent's subgraph
801        let report = parent.load_subgraph_checkpoint("encoder", ckpt_path.to_str().unwrap()).unwrap();
802        assert!(report.loaded.len() >= 4); // At least weight+bias from 2 Linears
803        assert!(report.missing.is_empty());
804
805        // Clean up
806        let _ = std::fs::remove_dir_all(&dir);
807    }
808
809    #[test]
810    fn test_subgraph_checkpoint_preserves_parent_params() {
811        let dev = test_device();
812        let child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
813            .label("encoder")
814            .build()
815            .unwrap();
816
817        let dir = std::env::temp_dir().join("flodl_test_preserve_parent");
818        std::fs::create_dir_all(&dir).unwrap();
819        let ckpt_path = dir.join("encoder.fdl");
820        child.save_checkpoint(ckpt_path.to_str().unwrap()).unwrap();
821
822        // Build parent with fresh child
823        let fresh_child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
824            .label("encoder")
825            .build()
826            .unwrap();
827        let parent = FlowBuilder::from(fresh_child)
828            .through(Linear::on_device(4, 2, dev).unwrap())
829            .build()
830            .unwrap();
831
832        // Snapshot parent-level param data
833        let parent_w = parent.parameters().last().unwrap().variable.data().clone();
834
835        // Load child checkpoint
836        parent.load_subgraph_checkpoint("encoder", ckpt_path.to_str().unwrap()).unwrap();
837
838        // Parent param unchanged
839        let parent_w_after = parent.parameters().last().unwrap().variable.data().clone();
840        let diff = parent_w.sub(&parent_w_after).unwrap().abs().unwrap().sum().unwrap().item().unwrap();
841        assert!(diff < 1e-10, "parent params should be unchanged, diff={}", diff);
842
843        let _ = std::fs::remove_dir_all(&dir);
844    }
845
846    // ── Phase D: Cross-boundary observation ──────────────────────────
847
848    #[test]
849    fn test_tagged_at_returns_value_after_forward() {
850        let dev = test_device();
851        let opts = test_opts();
852        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
853            .tag("hidden")
854            .through(Linear::on_device(4, 2, dev).unwrap())
855            .label("encoder")
856            .build()
857            .unwrap();
858
859        let outer = FlowBuilder::from(inner)
860            .through(Linear::on_device(2, 1, dev).unwrap())
861            .build()
862            .unwrap();
863
864        let x = Variable::new(Tensor::randn(&[1, 3], opts).unwrap(), false);
865        outer.forward(&x).unwrap();
866
867        let val = outer.tagged_at("encoder.hidden").unwrap();
868        assert!(val.is_some());
869        assert_eq!(val.unwrap().shape(), vec![1, 4]);
870    }
871
872    #[test]
873    fn test_tagged_at_before_forward_returns_none() {
874        let dev = test_device();
875        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
876            .tag("hidden")
877            .through(Linear::on_device(4, 2, dev).unwrap())
878            .label("encoder")
879            .build()
880            .unwrap();
881
882        let outer = FlowBuilder::from(inner)
883            .through(Linear::on_device(2, 1, dev).unwrap())
884            .build()
885            .unwrap();
886
887        // Before forward: path exists but no value computed
888        let val = outer.tagged_at("encoder.hidden").unwrap();
889        assert!(val.is_none());
890    }
891
892    #[test]
893    fn test_tagged_at_invalid_path_returns_err() {
894        let dev = test_device();
895        let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
896            .build()
897            .unwrap();
898
899        assert!(g.tagged_at("nonexistent.tag").is_err());
900    }
901
902    #[test]
903    fn test_record_at_and_trend_at() {
904        let dev = test_device();
905        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
906            .label("encoder")
907            .build()
908            .unwrap();
909
910        let outer = FlowBuilder::from(inner)
911            .through(Linear::on_device(4, 2, dev).unwrap())
912            .build()
913            .unwrap();
914
915        // Record into child's buffer
916        outer.record_at("encoder.loss", 0.5).unwrap();
917        outer.record_at("encoder.loss", 0.3).unwrap();
918
919        // Flush child's buffers to see the trend
920        let child = outer.child_graph("encoder").unwrap();
921        child.flush(&[]);
922
923        let trend = outer.trend_at("encoder.loss").unwrap();
924        assert_eq!(trend.len(), 1); // one epoch flushed
925    }
926
927    // ── Phase E: Developer experience ────────────────────────────────
928
929    #[test]
930    fn test_internal_tag_hidden_from_parent() {
931        let dev = test_device();
932        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
933            .tag("_plumbing")
934            .through(Linear::on_device(4, 2, dev).unwrap())
935            .tag("output")
936            .label("encoder")
937            .build()
938            .unwrap();
939
940        let outer = FlowBuilder::from(inner)
941            .through(Linear::on_device(2, 1, dev).unwrap())
942            .build()
943            .unwrap();
944
945        // Auto-internal: _plumbing starts with underscore
946        assert!(outer.child_graph("encoder").unwrap().internal_tags().contains("_plumbing"));
947        // Internal tag blocked from parent
948        assert!(outer.tagged_at("encoder._plumbing").is_err());
949        // Non-internal tag accessible
950        assert_eq!(outer.validate_path("encoder.output").unwrap(), PathKind::Tag);
951    }
952
953    #[test]
954    fn test_explicit_internal_tag() {
955        let dev = test_device();
956        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
957            .tag("intermediate")
958            .internal("intermediate")
959            .through(Linear::on_device(4, 2, dev).unwrap())
960            .label("encoder")
961            .build()
962            .unwrap();
963
964        let outer = FlowBuilder::from(inner)
965            .through(Linear::on_device(2, 1, dev).unwrap())
966            .build()
967            .unwrap();
968
969        // Explicitly internal: blocked from parent
970        assert!(outer.tagged_at("encoder.intermediate").is_err());
971    }
972
973    #[test]
974    fn test_tree_summary_output() {
975        let dev = test_device();
976        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
977            .tag("hidden")
978            .through(Linear::on_device(4, 2, dev).unwrap())
979            .label("encoder")
980            .build()
981            .unwrap();
982
983        let outer = FlowBuilder::from(inner)
984            .through(Linear::on_device(2, 1, dev).unwrap())
985            .build()
986            .unwrap();
987
988        let summary = outer.tree_summary();
989        assert!(summary.contains("Graph Tree"), "missing header:\n{}", summary);
990        assert!(summary.contains("encoder"), "missing child label:\n{}", summary);
991        assert!(summary.contains("Parameter Summary"), "missing param summary:\n{}", summary);
992    }
993
994    #[test]
995    fn test_param_summary_output() {
996        let dev = test_device();
997        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
998            .label("encoder")
999            .build()
1000            .unwrap();
1001
1002        let outer = FlowBuilder::from(inner)
1003            .through(Linear::on_device(4, 2, dev).unwrap())
1004            .build()
1005            .unwrap();
1006
1007        let summary = outer.param_summary();
1008        assert!(summary.contains("encoder"), "missing child:\n{}", summary);
1009        assert!(summary.contains("(own)"), "missing own params:\n{}", summary);
1010        assert!(summary.contains("trainable"), "missing trainable:\n{}", summary);
1011    }
1012
1013    // ── Phase F: Tree-aware observation ──────────────────────────────
1014
1015    #[test]
1016    fn test_flush_recurses_into_children() {
1017        let dev = test_device();
1018        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1019            .label("encoder")
1020            .build()
1021            .unwrap();
1022
1023        let outer = FlowBuilder::from(inner)
1024            .through(Linear::on_device(4, 2, dev).unwrap())
1025            .build()
1026            .unwrap();
1027
1028        // Record into child via tree path
1029        outer.record_at("encoder.loss", 0.5).unwrap();
1030        outer.record_at("encoder.loss", 0.3).unwrap();
1031        // Record into parent
1032        outer.record_scalar("parent_loss", 1.0);
1033
1034        // Single flush on parent should flush both
1035        outer.flush(&[]);
1036
1037        // Parent flushed
1038        assert_eq!(outer.flush_count(), 1);
1039        assert_eq!(outer.trend("parent_loss").len(), 1);
1040
1041        // Child also flushed
1042        let child = outer.child_graph("encoder").unwrap();
1043        assert_eq!(child.flush_count(), 1);
1044        assert_eq!(child.trend("loss").len(), 1);
1045    }
1046
1047    #[test]
1048    fn test_latest_metrics_includes_children() {
1049        let dev = test_device();
1050        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1051            .label("encoder")
1052            .build()
1053            .unwrap();
1054
1055        let outer = FlowBuilder::from(inner)
1056            .through(Linear::on_device(4, 2, dev).unwrap())
1057            .build()
1058            .unwrap();
1059
1060        // Record and flush
1061        outer.record_at("encoder.ce", 0.5).unwrap();
1062        outer.record_scalar("total_loss", 1.0);
1063        outer.flush(&[]);
1064
1065        let metrics = outer.latest_metrics();
1066        let names: Vec<&str> = metrics.iter().map(|(n, _)| n.as_str()).collect();
1067
1068        // Parent metric present
1069        assert!(names.contains(&"total_loss"), "missing parent metric: {:?}", names);
1070        // Child metric present with dotted prefix
1071        assert!(names.contains(&"encoder.ce"), "missing child metric: {:?}", names);
1072    }
1073
1074    #[test]
1075    fn test_latest_metrics_local_excludes_children() {
1076        let dev = test_device();
1077        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1078            .label("encoder")
1079            .build()
1080            .unwrap();
1081
1082        let outer = FlowBuilder::from(inner)
1083            .through(Linear::on_device(4, 2, dev).unwrap())
1084            .build()
1085            .unwrap();
1086
1087        outer.record_at("encoder.ce", 0.5).unwrap();
1088        outer.record_scalar("total_loss", 1.0);
1089        outer.flush(&[]);
1090
1091        let local = outer.latest_metrics_local();
1092        let names: Vec<&str> = local.iter().map(|(n, _)| n.as_str()).collect();
1093
1094        assert!(names.contains(&"total_loss"));
1095        assert!(!names.contains(&"encoder.ce"), "local should not include children: {:?}", names);
1096    }
1097
1098    #[test]
1099    fn test_double_flush_is_safe() {
1100        let dev = test_device();
1101        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1102            .label("encoder")
1103            .build()
1104            .unwrap();
1105
1106        let outer = FlowBuilder::from(inner)
1107            .through(Linear::on_device(4, 2, dev).unwrap())
1108            .build()
1109            .unwrap();
1110
1111        outer.record_at("encoder.loss", 0.5).unwrap();
1112
1113        // Flush child explicitly first
1114        let child = outer.child_graph("encoder").unwrap();
1115        child.flush(&[]);
1116        assert_eq!(child.flush_count(), 1);
1117
1118        // Parent flush recurses — child buffer already empty, no double epoch
1119        outer.flush(&[]);
1120        assert_eq!(child.flush_count(), 1); // still 1, not 2
1121        assert_eq!(child.trend("loss").len(), 1); // one epoch, not two
1122    }
1123
1124    #[test]
1125    fn test_flush_local_skips_children() {
1126        let dev = test_device();
1127        let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1128            .label("encoder")
1129            .build()
1130            .unwrap();
1131
1132        let outer = FlowBuilder::from(inner)
1133            .through(Linear::on_device(4, 2, dev).unwrap())
1134            .build()
1135            .unwrap();
1136
1137        outer.record_at("encoder.loss", 0.5).unwrap();
1138        outer.record_scalar("parent_loss", 1.0);
1139
1140        // flush_local: only parent
1141        outer.flush_local(&[]);
1142
1143        assert_eq!(outer.flush_count(), 1);
1144        assert_eq!(outer.trend("parent_loss").len(), 1);
1145
1146        // Child NOT flushed — data still in batch buffer
1147        let child = outer.child_graph("encoder").unwrap();
1148        assert_eq!(child.flush_count(), 0);
1149        assert_eq!(child.collected("loss").len(), 1); // still in batch buffer
1150    }
1151
1152    #[test]
1153    fn test_flush_recurses_multi_level() {
1154        let dev = test_device();
1155        let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1156            .label("read")
1157            .build()
1158            .unwrap();
1159        let middle = FlowBuilder::from(innermost)
1160            .through(Linear::on_device(4, 2, dev).unwrap())
1161            .label("letter")
1162            .build()
1163            .unwrap();
1164        let outer = FlowBuilder::from(middle)
1165            .through(Linear::on_device(2, 1, dev).unwrap())
1166            .build()
1167            .unwrap();
1168
1169        // Record into deepest child
1170        outer.record_at("letter.read.hidden_loss", 0.7).unwrap();
1171        // Record into middle child
1172        outer.record_at("letter.mid_loss", 0.4).unwrap();
1173
1174        outer.flush(&[]);
1175
1176        // All levels flushed
1177        let metrics = outer.latest_metrics();
1178        let names: Vec<&str> = metrics.iter().map(|(n, _)| n.as_str()).collect();
1179        assert!(names.contains(&"letter.mid_loss"), "missing middle: {:?}", names);
1180        assert!(names.contains(&"letter.read.hidden_loss"), "missing deep: {:?}", names);
1181    }
1182
1183    #[test]
1184    fn test_metrics_no_children_unchanged() {
1185        // Verify single-graph behavior is identical (no regression)
1186        let dev = test_device();
1187        let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1188            .build()
1189            .unwrap();
1190
1191        g.record_scalar("loss", 0.5);
1192        g.record_scalar("loss", 0.3);
1193        g.flush(&[]);
1194
1195        let metrics = g.latest_metrics();
1196        assert_eq!(metrics.len(), 1);
1197        assert_eq!(metrics[0].0, "loss");
1198        assert!((metrics[0].1 - 0.4).abs() < 1e-10); // mean of 0.5 and 0.3
1199    }
1200}