1use std::collections::HashMap;
20use crate::autograd::Variable;
21use crate::nn::{self, Buffer, Module, Parameter};
22use crate::tensor::{Result, TensorError};
23use super::Graph;
24use super::trend::Trend;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PathKind {
29 Subgraph,
31 Tag,
33}
34
35#[allow(dead_code)]
37pub(crate) enum ResolvedPath<'a> {
38 Subgraph(&'a Graph),
40 Tag { graph: &'a Graph, tag: String },
42}
43
44impl Graph {
45 pub(crate) fn resolve(&self, path: &str) -> Result<ResolvedPath<'_>> {
56 if path.is_empty() {
57 return Err(TensorError::new("empty label path"));
58 }
59 let segments: Vec<&str> = path.split('.').collect();
60 self.resolve_segments(&segments, path, false)
61 }
62
63 fn resolve_segments<'a>(
64 &'a self,
65 segments: &[&str],
66 full_path: &str,
67 cross_boundary: bool,
68 ) -> Result<ResolvedPath<'a>> {
69 debug_assert!(!segments.is_empty());
70 let first = segments[0];
71
72 if segments.len() == 1 {
73 if let Some(g) = self.child_graph(first) {
75 return Ok(ResolvedPath::Subgraph(g));
76 }
77 if self.tag_names.contains_key(first) {
78 if cross_boundary && self.internal_tags.contains(first) {
80 return Err(TensorError::new(&format!(
81 "tag {:?} is internal and cannot be accessed from a parent graph (path: {:?})",
82 first, full_path
83 )));
84 }
85 return Ok(ResolvedPath::Tag { graph: self, tag: first.to_string() });
86 }
87 return Err(TensorError::new(&format!(
88 "{:?} is not a subgraph or tag of this graph (path: {:?})",
89 first, full_path
90 )));
91 }
92
93 let child = self.child_graph(first).ok_or_else(|| {
95 TensorError::new(&format!(
96 "{:?} is not a subgraph of this graph (path: {:?})",
97 first, full_path
98 ))
99 })?;
100
101 child.resolve_segments(&segments[1..], full_path, true)
103 }
104
105 pub fn tree_children(&self) -> HashMap<&str, &Graph> {
109 self.children.iter()
110 .filter_map(|(label, &ni)| {
111 self.nodes[ni].module.as_ref()
112 .and_then(|m| m.as_graph())
113 .map(|g| (label.as_str(), g))
114 })
115 .collect()
116 }
117
118 pub fn child_graph(&self, label: &str) -> Option<&Graph> {
120 self.children.get(label)
121 .and_then(|&ni| self.nodes[ni].module.as_ref())
122 .and_then(|m| m.as_graph())
123 }
124
125 pub fn subgraph(&self, path: &str) -> Result<&Graph> {
127 match self.resolve(path)? {
128 ResolvedPath::Subgraph(g) => Ok(g),
129 ResolvedPath::Tag { .. } => Err(TensorError::new(&format!(
130 "path {:?} resolves to a tag, not a subgraph", path
131 ))),
132 }
133 }
134
135 pub fn is_composed(&self) -> bool {
137 self.composed.get()
138 }
139
140 pub fn internal_tags(&self) -> &std::collections::HashSet<String> {
142 &self.internal_tags
143 }
144
145 pub fn validate_path(&self, path: &str) -> Result<PathKind> {
147 match self.resolve(path)? {
148 ResolvedPath::Subgraph(_) => Ok(PathKind::Subgraph),
149 ResolvedPath::Tag { .. } => Ok(PathKind::Tag),
150 }
151 }
152
153 pub fn parameters_at(&self, path: &str) -> Result<Vec<Parameter>> {
157 match self.resolve(path)? {
158 ResolvedPath::Subgraph(g) => Ok(g.parameters()),
159 ResolvedPath::Tag { graph, ref tag } => {
160 if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
161 if let Some(ref module) = graph.nodes[ni].module {
162 Ok(module.parameters())
163 } else {
164 Ok(vec![])
165 }
166 } else {
167 Ok(vec![])
168 }
169 }
170 }
171 }
172
173 pub fn named_parameters_at(&self, path: &str) -> Result<Vec<(String, Parameter)>> {
177 match self.resolve(path)? {
178 ResolvedPath::Subgraph(g) => Ok(g.named_parameters()),
179 ResolvedPath::Tag { graph, ref tag } => {
180 if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
181 if let Some(ref module) = graph.nodes[ni].module {
182 Ok(module.parameters().into_iter()
183 .map(|p| (format!("{}/{}", tag, p.name), p))
184 .collect())
185 } else {
186 Ok(vec![])
187 }
188 } else {
189 Ok(vec![])
190 }
191 }
192 }
193 }
194
195 pub fn named_buffers_at(&self, path: &str) -> Result<Vec<(String, Buffer)>> {
197 match self.resolve(path)? {
198 ResolvedPath::Subgraph(g) => Ok(g.named_buffers()),
199 ResolvedPath::Tag { graph, ref tag } => {
200 if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
201 if let Some(ref module) = graph.nodes[ni].module {
202 Ok(module.buffers().into_iter()
203 .map(|b| (format!("{}/{}", tag, b.name), b))
204 .collect())
205 } else {
206 Ok(vec![])
207 }
208 } else {
209 Ok(vec![])
210 }
211 }
212 }
213 }
214
215 pub fn freeze(&self, path: &str) -> Result<()> {
219 for p in self.parameters_at(path)? {
220 p.freeze()?;
221 }
222 Ok(())
223 }
224
225 pub fn thaw(&self, path: &str) -> Result<()> {
227 for p in self.parameters_at(path)? {
228 p.unfreeze()?;
229 }
230 Ok(())
231 }
232
233 pub fn is_frozen(&self, path: &str) -> Result<bool> {
236 let params = self.parameters_at(path)?;
237 if params.is_empty() {
238 return Ok(false);
239 }
240 Ok(params.iter().all(|p| p.is_frozen()))
241 }
242
243 pub fn load_subgraph_checkpoint(&self, path: &str, file: &str) -> Result<nn::LoadReport> {
253 let target = self.subgraph(path)?;
254 let params = target.named_parameters();
255 let buffers = target.named_buffers();
256 let hash = target.structural_hash();
257 nn::load_checkpoint_file(file, ¶ms, &buffers, Some(hash))
258 }
259
260 pub fn set_training_at(&self, path: &str, training: bool) -> Result<()> {
264 match self.resolve(path)? {
265 ResolvedPath::Subgraph(g) => {
266 g.set_training(training);
267 }
268 ResolvedPath::Tag { graph, ref tag } => {
269 if let Some(&(ni, _)) = graph.tag_names.get(tag.as_str()) {
270 if let Some(ref module) = graph.nodes[ni].module {
271 crate::nn::walk_modules(module.as_ref(), &mut |m| {
272 m.set_training(training);
273 });
274 }
275 }
276 }
277 }
278 Ok(())
279 }
280
281 pub fn tagged_at(&self, path: &str) -> Result<Option<Variable>> {
288 match self.resolve(path)? {
289 ResolvedPath::Subgraph(_) => Err(TensorError::new(&format!(
290 "path {:?} resolves to a subgraph, not a tag", path
291 ))),
292 ResolvedPath::Tag { graph, ref tag } => Ok(graph.tagged(tag)),
293 }
294 }
295
296 pub fn collect_at(&self, paths: &[&str]) -> Result<()> {
300 for &path in paths {
301 match self.resolve(path)? {
302 ResolvedPath::Subgraph(_) => {
303 return Err(TensorError::new(&format!(
304 "collect_at: {:?} resolves to a subgraph, not a tag", path
305 )));
306 }
307 ResolvedPath::Tag { graph, ref tag } => {
308 graph.collect(&[tag.as_str()])?;
309 }
310 }
311 }
312 Ok(())
313 }
314
315 pub fn record_at(&self, path: &str, value: f64) -> Result<()> {
319 let segments: Vec<&str> = path.split('.').collect();
320 if segments.len() < 2 {
321 self.record_scalar(path, value);
323 return Ok(());
324 }
325 let parent_path = segments[..segments.len() - 1].join(".");
327 let tag = segments[segments.len() - 1];
328 let target = self.subgraph(&parent_path)?;
329 target.record_scalar(tag, value);
330 Ok(())
331 }
332
333 pub fn trend_at(&self, path: &str) -> Result<Trend> {
336 let segments: Vec<&str> = path.split('.').collect();
337 if segments.len() < 2 {
338 return Ok(self.trend(path));
339 }
340 let parent_path = segments[..segments.len() - 1].join(".");
341 let tag = segments[segments.len() - 1];
342 let target = self.subgraph(&parent_path)?;
343 Ok(target.trend(tag))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use crate::autograd::Variable;
350 use crate::graph::FlowBuilder;
351 use crate::nn::{Linear, Module};
352 use crate::nn::ReLU;
353 use crate::tensor::{test_device, test_opts, Tensor};
354 use super::PathKind;
355
356 #[test]
357 fn test_unlabeled_graph_no_children() {
358 let dev = test_device();
359
360 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
361 .through(ReLU::new())
362 .build()
363 .unwrap();
364
365 let outer = FlowBuilder::from(inner)
366 .through(Linear::on_device(4, 2, dev).unwrap())
367 .build()
368 .unwrap();
369
370 assert!(outer.tree_children().is_empty());
372 assert_eq!(outer.parameters().len(), 4); }
375
376 #[test]
377 fn test_labeled_child_registered() {
378 let dev = test_device();
379
380 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
381 .through(ReLU::new())
382 .label("encoder")
383 .build()
384 .unwrap();
385
386 let outer = FlowBuilder::from(inner)
387 .through(Linear::on_device(4, 2, dev).unwrap())
388 .build()
389 .unwrap();
390
391 assert_eq!(outer.tree_children().len(), 1);
392 assert!(outer.tree_children().contains_key("encoder"));
393 assert!(outer.child_graph("encoder").is_some());
394 assert_eq!(outer.child_graph("encoder").unwrap().label(), Some("encoder"));
395 }
396
397 #[test]
398 fn test_composed_flag() {
399 let dev = test_device();
400
401 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
402 .label("child")
403 .build()
404 .unwrap();
405
406 assert!(!inner.is_composed());
408
409 let outer = FlowBuilder::from(inner)
410 .through(Linear::on_device(4, 2, dev).unwrap())
411 .build()
412 .unwrap();
413
414 let child = outer.child_graph("child").unwrap();
416 assert!(child.is_composed());
417 assert!(!outer.is_composed());
419 }
420
421 #[test]
422 fn test_label_collision_error() {
423 let dev = test_device();
424
425 let a = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
426 .label("dupe")
427 .build()
428 .unwrap();
429 let b = FlowBuilder::from(Linear::on_device(4, 2, dev).unwrap())
430 .label("dupe")
431 .build()
432 .unwrap();
433
434 let result = FlowBuilder::from(a)
435 .through(b)
436 .build();
437
438 let msg = result.err().expect("should be Err").to_string();
439 assert!(msg.contains("duplicate child graph label"), "got: {}", msg);
440 }
441
442 #[test]
443 fn test_dot_in_label_error() {
444 let dev = test_device();
445
446 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
447 .label("a.b")
448 .build()
449 .unwrap();
450
451 let result = FlowBuilder::from(inner)
452 .through(Linear::on_device(4, 2, dev).unwrap())
453 .build();
454
455 let msg = result.err().expect("should be Err").to_string();
456 assert!(msg.contains("contains a dot"), "got: {}", msg);
457 }
458
459 #[test]
460 fn test_label_tag_same_node_ok() {
461 let dev = test_device();
462
463 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
464 .label("encoder")
465 .build()
466 .unwrap();
467
468 let outer = FlowBuilder::from(inner)
470 .tag("encoder")
471 .through(Linear::on_device(4, 2, dev).unwrap())
472 .build();
473
474 assert!(outer.is_ok());
475 }
476
477 #[test]
478 fn test_resolve_single_segment_child() {
479 let dev = test_device();
480
481 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
482 .label("encoder")
483 .build()
484 .unwrap();
485
486 let outer = FlowBuilder::from(inner)
487 .through(Linear::on_device(4, 2, dev).unwrap())
488 .build()
489 .unwrap();
490
491 assert_eq!(outer.validate_path("encoder").unwrap(), PathKind::Subgraph);
492 }
493
494 #[test]
495 fn test_resolve_single_segment_tag() {
496 let dev = test_device();
497
498 let outer = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
499 .tag("hidden")
500 .through(Linear::on_device(4, 2, dev).unwrap())
501 .build()
502 .unwrap();
503
504 assert_eq!(outer.validate_path("hidden").unwrap(), PathKind::Tag);
505 }
506
507 #[test]
508 fn test_resolve_multi_segment() {
509 let dev = test_device();
510
511 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
512 .tag("hidden")
513 .through(Linear::on_device(4, 2, dev).unwrap())
514 .label("encoder")
515 .build()
516 .unwrap();
517
518 let outer = FlowBuilder::from(inner)
519 .through(Linear::on_device(2, 1, dev).unwrap())
520 .build()
521 .unwrap();
522
523 assert_eq!(outer.validate_path("encoder.hidden").unwrap(), PathKind::Tag);
524 }
525
526 #[test]
527 fn test_resolve_multi_level() {
528 let dev = test_device();
529
530 let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
531 .label("read")
532 .build()
533 .unwrap();
534 let middle = FlowBuilder::from(innermost)
535 .through(Linear::on_device(4, 2, dev).unwrap())
536 .label("letter")
537 .build()
538 .unwrap();
539 let outer = FlowBuilder::from(middle)
540 .through(Linear::on_device(2, 1, dev).unwrap())
541 .build()
542 .unwrap();
543
544 assert_eq!(outer.validate_path("letter").unwrap(), PathKind::Subgraph);
545 assert_eq!(outer.validate_path("letter.read").unwrap(), PathKind::Subgraph);
546 }
547
548 #[test]
549 fn test_resolve_invalid_path_error() {
550 let dev = test_device();
551
552 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
553 .label("encoder")
554 .build()
555 .unwrap();
556
557 let outer = FlowBuilder::from(inner)
558 .through(Linear::on_device(4, 2, dev).unwrap())
559 .build()
560 .unwrap();
561
562 assert!(outer.validate_path("nonexistent").is_err());
564 assert!(outer.validate_path("encoder.nonexistent").is_err());
566 assert!(outer.validate_path("nonexistent.foo").is_err());
568 }
569
570 #[test]
571 fn test_subgraph_returns_graph() {
572 let dev = test_device();
573
574 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
575 .label("encoder")
576 .build()
577 .unwrap();
578
579 let outer = FlowBuilder::from(inner)
580 .through(Linear::on_device(4, 2, dev).unwrap())
581 .build()
582 .unwrap();
583
584 let sub = outer.subgraph("encoder").unwrap();
585 assert_eq!(sub.label(), Some("encoder"));
586 assert_eq!(sub.parameters().len(), 2); }
588
589 #[test]
590 fn test_forward_still_works_with_tree() {
591 let dev = test_device();
592 let opts = test_opts();
593
594 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
595 .through(ReLU::new())
596 .label("encoder")
597 .build()
598 .unwrap();
599
600 let outer = FlowBuilder::from(inner)
601 .through(Linear::on_device(4, 2, dev).unwrap())
602 .build()
603 .unwrap();
604
605 let x = Variable::new(
606 Tensor::randn(&[1, 3], opts).unwrap(),
607 false,
608 );
609 let y = outer.forward(&x).unwrap();
610 assert_eq!(y.shape(), vec![1, 2]);
611 }
612
613 #[test]
616 fn test_parameters_at_subgraph() {
617 let dev = test_device();
618 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
619 .through(Linear::on_device(4, 2, dev).unwrap())
620 .label("encoder")
621 .build()
622 .unwrap();
623
624 let outer = FlowBuilder::from(inner)
625 .through(Linear::on_device(2, 1, dev).unwrap())
626 .build()
627 .unwrap();
628
629 let params = outer.parameters_at("encoder").unwrap();
631 assert_eq!(params.len(), 4);
632 assert_eq!(outer.parameters().len(), 6);
634 }
635
636 #[test]
637 fn test_parameters_at_tag() {
638 let dev = test_device();
639 let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
640 .tag("first")
641 .through(Linear::on_device(4, 2, dev).unwrap())
642 .build()
643 .unwrap();
644
645 let params = g.parameters_at("first").unwrap();
646 assert_eq!(params.len(), 2); }
648
649 #[test]
650 fn test_freeze_thaw_roundtrip() {
651 let dev = test_device();
652 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
653 .label("encoder")
654 .build()
655 .unwrap();
656
657 let outer = FlowBuilder::from(inner)
658 .through(Linear::on_device(4, 2, dev).unwrap())
659 .build()
660 .unwrap();
661
662 assert!(!outer.is_frozen("encoder").unwrap());
664
665 outer.freeze("encoder").unwrap();
667 assert!(outer.is_frozen("encoder").unwrap());
668 for p in outer.parameters_at("encoder").unwrap() {
670 assert!(p.is_frozen());
671 }
672 let outer_params = outer.parameters();
674 let outer_only: Vec<_> = outer_params.iter()
675 .filter(|p| !p.is_frozen())
676 .collect();
677 assert_eq!(outer_only.len(), 2); outer.thaw("encoder").unwrap();
681 assert!(!outer.is_frozen("encoder").unwrap());
682 for p in outer.parameters_at("encoder").unwrap() {
683 assert!(!p.is_frozen());
684 }
685 }
686
687 #[test]
688 fn test_freeze_deep_path() {
689 let dev = test_device();
690 let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
691 .label("read")
692 .build()
693 .unwrap();
694 let middle = FlowBuilder::from(innermost)
695 .through(Linear::on_device(4, 2, dev).unwrap())
696 .label("letter")
697 .build()
698 .unwrap();
699 let outer = FlowBuilder::from(middle)
700 .through(Linear::on_device(2, 1, dev).unwrap())
701 .build()
702 .unwrap();
703
704 outer.freeze("letter.read").unwrap();
706 assert!(outer.is_frozen("letter.read").unwrap());
707 assert!(!outer.is_frozen("letter").unwrap());
709 }
710
711 #[test]
712 fn test_named_parameters_at_uses_target_namespace() {
713 let dev = test_device();
714 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
715 .tag("hidden")
716 .through(Linear::on_device(4, 2, dev).unwrap())
717 .label("encoder")
718 .build()
719 .unwrap();
720
721 let outer = FlowBuilder::from(inner)
722 .through(Linear::on_device(2, 1, dev).unwrap())
723 .build()
724 .unwrap();
725
726 let named = outer.named_parameters_at("encoder").unwrap();
728 assert_eq!(named.len(), 4);
729 assert!(named.iter().any(|(n, _)| n.starts_with("hidden/")));
731 }
732
733 #[test]
734 fn test_freeze_invalid_path_error() {
735 let dev = test_device();
736 let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
737 .build()
738 .unwrap();
739
740 assert!(g.freeze("nonexistent").is_err());
741 assert!(g.thaw("nonexistent").is_err());
742 assert!(g.is_frozen("nonexistent").is_err());
743 assert!(g.parameters_at("nonexistent").is_err());
744 }
745
746 #[test]
747 fn test_set_training_at() {
748 let dev = test_device();
749 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
750 .through(crate::nn::Dropout::new(0.5))
751 .label("encoder")
752 .build()
753 .unwrap();
754
755 let outer = FlowBuilder::from(inner)
756 .through(Linear::on_device(4, 2, dev).unwrap())
757 .build()
758 .unwrap();
759
760 outer.set_training_at("encoder", false).unwrap();
762 outer.set_training_at("encoder", true).unwrap();
764 assert!(outer.set_training_at("nonexistent", false).is_err());
766 }
767
768 #[test]
771 fn test_subgraph_checkpoint_roundtrip() {
772 let dev = test_device();
773 let child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
775 .through(ReLU::new())
776 .through(Linear::on_device(4, 2, dev).unwrap())
777 .label("encoder")
778 .build()
779 .unwrap();
780
781 let dir = std::env::temp_dir().join("flodl_test_subgraph_ckpt");
783 std::fs::create_dir_all(&dir).unwrap();
784 let ckpt_path = dir.join("encoder.fdl");
785 child.save_checkpoint(ckpt_path.to_str().unwrap()).unwrap();
786
787 let fresh_child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
789 .through(ReLU::new())
790 .through(Linear::on_device(4, 2, dev).unwrap())
791 .label("encoder")
792 .build()
793 .unwrap();
794
795 let parent = FlowBuilder::from(fresh_child)
796 .through(Linear::on_device(2, 1, dev).unwrap())
797 .build()
798 .unwrap();
799
800 let report = parent.load_subgraph_checkpoint("encoder", ckpt_path.to_str().unwrap()).unwrap();
802 assert!(report.loaded.len() >= 4); assert!(report.missing.is_empty());
804
805 let _ = std::fs::remove_dir_all(&dir);
807 }
808
809 #[test]
810 fn test_subgraph_checkpoint_preserves_parent_params() {
811 let dev = test_device();
812 let child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
813 .label("encoder")
814 .build()
815 .unwrap();
816
817 let dir = std::env::temp_dir().join("flodl_test_preserve_parent");
818 std::fs::create_dir_all(&dir).unwrap();
819 let ckpt_path = dir.join("encoder.fdl");
820 child.save_checkpoint(ckpt_path.to_str().unwrap()).unwrap();
821
822 let fresh_child = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
824 .label("encoder")
825 .build()
826 .unwrap();
827 let parent = FlowBuilder::from(fresh_child)
828 .through(Linear::on_device(4, 2, dev).unwrap())
829 .build()
830 .unwrap();
831
832 let parent_w = parent.parameters().last().unwrap().variable.data().clone();
834
835 parent.load_subgraph_checkpoint("encoder", ckpt_path.to_str().unwrap()).unwrap();
837
838 let parent_w_after = parent.parameters().last().unwrap().variable.data().clone();
840 let diff = parent_w.sub(&parent_w_after).unwrap().abs().unwrap().sum().unwrap().item().unwrap();
841 assert!(diff < 1e-10, "parent params should be unchanged, diff={}", diff);
842
843 let _ = std::fs::remove_dir_all(&dir);
844 }
845
846 #[test]
849 fn test_tagged_at_returns_value_after_forward() {
850 let dev = test_device();
851 let opts = test_opts();
852 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
853 .tag("hidden")
854 .through(Linear::on_device(4, 2, dev).unwrap())
855 .label("encoder")
856 .build()
857 .unwrap();
858
859 let outer = FlowBuilder::from(inner)
860 .through(Linear::on_device(2, 1, dev).unwrap())
861 .build()
862 .unwrap();
863
864 let x = Variable::new(Tensor::randn(&[1, 3], opts).unwrap(), false);
865 outer.forward(&x).unwrap();
866
867 let val = outer.tagged_at("encoder.hidden").unwrap();
868 assert!(val.is_some());
869 assert_eq!(val.unwrap().shape(), vec![1, 4]);
870 }
871
872 #[test]
873 fn test_tagged_at_before_forward_returns_none() {
874 let dev = test_device();
875 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
876 .tag("hidden")
877 .through(Linear::on_device(4, 2, dev).unwrap())
878 .label("encoder")
879 .build()
880 .unwrap();
881
882 let outer = FlowBuilder::from(inner)
883 .through(Linear::on_device(2, 1, dev).unwrap())
884 .build()
885 .unwrap();
886
887 let val = outer.tagged_at("encoder.hidden").unwrap();
889 assert!(val.is_none());
890 }
891
892 #[test]
893 fn test_tagged_at_invalid_path_returns_err() {
894 let dev = test_device();
895 let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
896 .build()
897 .unwrap();
898
899 assert!(g.tagged_at("nonexistent.tag").is_err());
900 }
901
902 #[test]
903 fn test_record_at_and_trend_at() {
904 let dev = test_device();
905 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
906 .label("encoder")
907 .build()
908 .unwrap();
909
910 let outer = FlowBuilder::from(inner)
911 .through(Linear::on_device(4, 2, dev).unwrap())
912 .build()
913 .unwrap();
914
915 outer.record_at("encoder.loss", 0.5).unwrap();
917 outer.record_at("encoder.loss", 0.3).unwrap();
918
919 let child = outer.child_graph("encoder").unwrap();
921 child.flush(&[]);
922
923 let trend = outer.trend_at("encoder.loss").unwrap();
924 assert_eq!(trend.len(), 1); }
926
927 #[test]
930 fn test_internal_tag_hidden_from_parent() {
931 let dev = test_device();
932 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
933 .tag("_plumbing")
934 .through(Linear::on_device(4, 2, dev).unwrap())
935 .tag("output")
936 .label("encoder")
937 .build()
938 .unwrap();
939
940 let outer = FlowBuilder::from(inner)
941 .through(Linear::on_device(2, 1, dev).unwrap())
942 .build()
943 .unwrap();
944
945 assert!(outer.child_graph("encoder").unwrap().internal_tags().contains("_plumbing"));
947 assert!(outer.tagged_at("encoder._plumbing").is_err());
949 assert_eq!(outer.validate_path("encoder.output").unwrap(), PathKind::Tag);
951 }
952
953 #[test]
954 fn test_explicit_internal_tag() {
955 let dev = test_device();
956 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
957 .tag("intermediate")
958 .internal("intermediate")
959 .through(Linear::on_device(4, 2, dev).unwrap())
960 .label("encoder")
961 .build()
962 .unwrap();
963
964 let outer = FlowBuilder::from(inner)
965 .through(Linear::on_device(2, 1, dev).unwrap())
966 .build()
967 .unwrap();
968
969 assert!(outer.tagged_at("encoder.intermediate").is_err());
971 }
972
973 #[test]
974 fn test_tree_summary_output() {
975 let dev = test_device();
976 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
977 .tag("hidden")
978 .through(Linear::on_device(4, 2, dev).unwrap())
979 .label("encoder")
980 .build()
981 .unwrap();
982
983 let outer = FlowBuilder::from(inner)
984 .through(Linear::on_device(2, 1, dev).unwrap())
985 .build()
986 .unwrap();
987
988 let summary = outer.tree_summary();
989 assert!(summary.contains("Graph Tree"), "missing header:\n{}", summary);
990 assert!(summary.contains("encoder"), "missing child label:\n{}", summary);
991 assert!(summary.contains("Parameter Summary"), "missing param summary:\n{}", summary);
992 }
993
994 #[test]
995 fn test_param_summary_output() {
996 let dev = test_device();
997 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
998 .label("encoder")
999 .build()
1000 .unwrap();
1001
1002 let outer = FlowBuilder::from(inner)
1003 .through(Linear::on_device(4, 2, dev).unwrap())
1004 .build()
1005 .unwrap();
1006
1007 let summary = outer.param_summary();
1008 assert!(summary.contains("encoder"), "missing child:\n{}", summary);
1009 assert!(summary.contains("(own)"), "missing own params:\n{}", summary);
1010 assert!(summary.contains("trainable"), "missing trainable:\n{}", summary);
1011 }
1012
1013 #[test]
1016 fn test_flush_recurses_into_children() {
1017 let dev = test_device();
1018 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1019 .label("encoder")
1020 .build()
1021 .unwrap();
1022
1023 let outer = FlowBuilder::from(inner)
1024 .through(Linear::on_device(4, 2, dev).unwrap())
1025 .build()
1026 .unwrap();
1027
1028 outer.record_at("encoder.loss", 0.5).unwrap();
1030 outer.record_at("encoder.loss", 0.3).unwrap();
1031 outer.record_scalar("parent_loss", 1.0);
1033
1034 outer.flush(&[]);
1036
1037 assert_eq!(outer.flush_count(), 1);
1039 assert_eq!(outer.trend("parent_loss").len(), 1);
1040
1041 let child = outer.child_graph("encoder").unwrap();
1043 assert_eq!(child.flush_count(), 1);
1044 assert_eq!(child.trend("loss").len(), 1);
1045 }
1046
1047 #[test]
1048 fn test_latest_metrics_includes_children() {
1049 let dev = test_device();
1050 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1051 .label("encoder")
1052 .build()
1053 .unwrap();
1054
1055 let outer = FlowBuilder::from(inner)
1056 .through(Linear::on_device(4, 2, dev).unwrap())
1057 .build()
1058 .unwrap();
1059
1060 outer.record_at("encoder.ce", 0.5).unwrap();
1062 outer.record_scalar("total_loss", 1.0);
1063 outer.flush(&[]);
1064
1065 let metrics = outer.latest_metrics();
1066 let names: Vec<&str> = metrics.iter().map(|(n, _)| n.as_str()).collect();
1067
1068 assert!(names.contains(&"total_loss"), "missing parent metric: {:?}", names);
1070 assert!(names.contains(&"encoder.ce"), "missing child metric: {:?}", names);
1072 }
1073
1074 #[test]
1075 fn test_latest_metrics_local_excludes_children() {
1076 let dev = test_device();
1077 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1078 .label("encoder")
1079 .build()
1080 .unwrap();
1081
1082 let outer = FlowBuilder::from(inner)
1083 .through(Linear::on_device(4, 2, dev).unwrap())
1084 .build()
1085 .unwrap();
1086
1087 outer.record_at("encoder.ce", 0.5).unwrap();
1088 outer.record_scalar("total_loss", 1.0);
1089 outer.flush(&[]);
1090
1091 let local = outer.latest_metrics_local();
1092 let names: Vec<&str> = local.iter().map(|(n, _)| n.as_str()).collect();
1093
1094 assert!(names.contains(&"total_loss"));
1095 assert!(!names.contains(&"encoder.ce"), "local should not include children: {:?}", names);
1096 }
1097
1098 #[test]
1099 fn test_double_flush_is_safe() {
1100 let dev = test_device();
1101 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1102 .label("encoder")
1103 .build()
1104 .unwrap();
1105
1106 let outer = FlowBuilder::from(inner)
1107 .through(Linear::on_device(4, 2, dev).unwrap())
1108 .build()
1109 .unwrap();
1110
1111 outer.record_at("encoder.loss", 0.5).unwrap();
1112
1113 let child = outer.child_graph("encoder").unwrap();
1115 child.flush(&[]);
1116 assert_eq!(child.flush_count(), 1);
1117
1118 outer.flush(&[]);
1120 assert_eq!(child.flush_count(), 1); assert_eq!(child.trend("loss").len(), 1); }
1123
1124 #[test]
1125 fn test_flush_local_skips_children() {
1126 let dev = test_device();
1127 let inner = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1128 .label("encoder")
1129 .build()
1130 .unwrap();
1131
1132 let outer = FlowBuilder::from(inner)
1133 .through(Linear::on_device(4, 2, dev).unwrap())
1134 .build()
1135 .unwrap();
1136
1137 outer.record_at("encoder.loss", 0.5).unwrap();
1138 outer.record_scalar("parent_loss", 1.0);
1139
1140 outer.flush_local(&[]);
1142
1143 assert_eq!(outer.flush_count(), 1);
1144 assert_eq!(outer.trend("parent_loss").len(), 1);
1145
1146 let child = outer.child_graph("encoder").unwrap();
1148 assert_eq!(child.flush_count(), 0);
1149 assert_eq!(child.collected("loss").len(), 1); }
1151
1152 #[test]
1153 fn test_flush_recurses_multi_level() {
1154 let dev = test_device();
1155 let innermost = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1156 .label("read")
1157 .build()
1158 .unwrap();
1159 let middle = FlowBuilder::from(innermost)
1160 .through(Linear::on_device(4, 2, dev).unwrap())
1161 .label("letter")
1162 .build()
1163 .unwrap();
1164 let outer = FlowBuilder::from(middle)
1165 .through(Linear::on_device(2, 1, dev).unwrap())
1166 .build()
1167 .unwrap();
1168
1169 outer.record_at("letter.read.hidden_loss", 0.7).unwrap();
1171 outer.record_at("letter.mid_loss", 0.4).unwrap();
1173
1174 outer.flush(&[]);
1175
1176 let metrics = outer.latest_metrics();
1178 let names: Vec<&str> = metrics.iter().map(|(n, _)| n.as_str()).collect();
1179 assert!(names.contains(&"letter.mid_loss"), "missing middle: {:?}", names);
1180 assert!(names.contains(&"letter.read.hidden_loss"), "missing deep: {:?}", names);
1181 }
1182
1183 #[test]
1184 fn test_metrics_no_children_unchanged() {
1185 let dev = test_device();
1187 let g = FlowBuilder::from(Linear::on_device(3, 4, dev).unwrap())
1188 .build()
1189 .unwrap();
1190
1191 g.record_scalar("loss", 0.5);
1192 g.record_scalar("loss", 0.3);
1193 g.flush(&[]);
1194
1195 let metrics = g.latest_metrics();
1196 assert_eq!(metrics.len(), 1);
1197 assert_eq!(metrics[0].0, "loss");
1198 assert!((metrics[0].1 - 0.4).abs() < 1e-10); }
1200}