1pub mod node;
23pub mod flow;
24pub mod loop_node;
25pub mod switch;
26pub mod gate;
27pub mod map;
28pub mod observe;
29pub mod trend;
30pub mod profile;
31pub mod dot;
32pub mod plot;
33pub mod router;
34pub mod halt;
35pub mod reshape;
36pub mod state;
37pub mod snapshot;
38pub mod tree;
39pub mod verbose;
40
41use std::cell::{Cell, OnceCell, RefCell};
42use std::collections::{BTreeSet, HashMap, HashSet};
43use std::rc::Rc;
44use std::time::Instant;
45
46use indexmap::IndexMap;
47use hmac_sha256::Hash as Sha256;
48
49use node::*;
50use crate::autograd::Variable;
51use crate::nn::{Buffer, Module, Parameter};
52use crate::tensor::{Result, Tensor, TensorError};
53
54pub use flow::FlowBuilder;
55pub use loop_node::LoopBuilder;
56pub use map::MapBuilder;
57pub use trend::{Trend, TrendGroup};
58pub use profile::{Profile, NodeTiming, LevelTiming};
59pub use plot::format_duration;
60pub use router::{SoftmaxRouter, SigmoidRouter, FixedSelector, ArgmaxSelector};
61pub use halt::{ThresholdHalt, LearnedHalt};
62pub use reshape::Reshape;
63pub use state::StateAdd;
64pub use observe::Reduce;
65pub use tree::PathKind;
66pub use snapshot::ModelSnapshot;
67
68pub enum MergeOp {
70 Add,
72 Mean,
74}
75
76#[derive(Clone)]
79struct Route {
80 from_port_idx: usize,
81 to_node_idx: usize,
82 to_port_idx: usize,
83}
84
85struct InputRoute {
87 node_idx: usize,
88 port_idx: usize,
89}
90
91struct StateEntry {
93 writer_ni: usize,
94 value: Rc<RefCell<Option<Variable>>>,
95}
96
97pub struct Graph {
121 nodes: Vec<Node>,
122 node_index: HashMap<String, usize>,
123 levels: Vec<Vec<usize>>,
124 edges: Vec<Edge>,
125 #[allow(dead_code)] edges_from: HashMap<usize, Vec<usize>>,
127 inputs: Vec<ExposedPort>,
128 outputs: Vec<ExposedPort>,
129 order: Vec<usize>,
130 state: Vec<StateEntry>,
131 state_writers: HashMap<usize, Vec<(usize, usize)>>,
133 tag_groups: HashMap<String, Vec<String>>,
135 tag_names: HashMap<String, (usize, usize)>, tag_capture: HashMap<usize, Vec<(String, usize)>>, tagged_outputs: RefCell<HashMap<String, Variable>>,
140 batch_buffer: RefCell<HashMap<String, Vec<f64>>>,
141 epoch_history: RefCell<HashMap<String, Vec<f64>>>,
142 metric_order: RefCell<Vec<String>>,
143 flush_count: Cell<usize>,
144 profiling: Cell<bool>,
146 last_profile: RefCell<Option<profile::Profile>>,
147 timing_buffer: RefCell<HashMap<String, Vec<f64>>>,
148 timing_history: RefCell<HashMap<String, Vec<f64>>>,
149 flush_times: RefCell<Vec<f64>>,
151 training_start: Cell<f64>,
152 step_count: Cell<usize>,
154 epoch_count: Cell<usize>,
155 label: Option<String>,
157 structural_hash_cache: OnceCell<String>,
158 children: HashMap<String, usize>,
160 composed: Cell<bool>,
161 internal_tags: HashSet<String>,
162 routes_from: Vec<Vec<Route>>,
164 input_routes: Vec<InputRoute>,
165 output_node_idx: usize,
166 output_port_idx: usize,
167 node_input_count: Vec<usize>,
168 exec_slots: RefCell<Vec<Vec<Option<Variable>>>>,
170}
171
172impl Graph {
173 #[allow(clippy::too_many_arguments)]
174 pub(crate) fn build(
175 mut node_map: IndexMap<String, Node>,
176 edges: Vec<Edge>,
177 inputs: Vec<ExposedPort>,
178 outputs: Vec<ExposedPort>,
179 tags: HashMap<String, NodeRef>,
180 forward_refs: Vec<ForwardRefSpec>,
181 tag_groups: HashMap<String, Vec<String>>,
182 label: Option<String>,
183 mut internal_tags: HashSet<String>,
184 verbose: bool,
185 ) -> Result<Self> {
186 let mut state = Vec::with_capacity(forward_refs.len());
188 for fr in &forward_refs {
189 let value: Rc<RefCell<Option<Variable>>> = Rc::new(RefCell::new(None));
190 let reader_value = value.clone();
191
192 if let Some(node) = node_map.get_mut(&fr.reader_id) {
194 node.run = Box::new(move |_: &[Variable]| {
195 match reader_value.borrow().as_ref() {
196 Some(v) => Ok(vec![v.clone()]),
197 None => Ok(vec![]), }
199 });
200 }
201
202 state.push(StateEntry {
203 writer_ni: 0, value,
205 });
206 }
207
208 let mut nodes = Vec::with_capacity(node_map.len());
210 let mut node_index = HashMap::with_capacity(node_map.len());
211
212 for (_key, node) in node_map {
213 let idx = nodes.len();
214 node_index.insert(node.id.clone(), idx);
215 nodes.push(node);
216 }
217
218 for edge in &edges {
220 if !node_index.contains_key(&edge.from_node) {
221 return Err(TensorError::new(&format!(
222 "unknown source node: {}",
223 edge.from_node
224 )));
225 }
226 if !node_index.contains_key(&edge.to_node) {
227 return Err(TensorError::new(&format!(
228 "unknown target node: {}",
229 edge.to_node
230 )));
231 }
232 }
233
234 let mut edges_from: HashMap<usize, Vec<usize>> = HashMap::new();
236 for (ei, edge) in edges.iter().enumerate() {
237 let from_idx = node_index[&edge.from_node];
238 edges_from.entry(from_idx).or_default().push(ei);
239 }
240
241 let levels = topological_levels(&nodes, &node_index, &edges)?;
243 let order: Vec<usize> = levels.iter().flat_map(|l| l.iter().copied()).collect();
244
245 let mut tag_names_map: HashMap<String, (usize, usize)> = HashMap::new();
247 let mut tag_capture: HashMap<usize, Vec<(String, usize)>> = HashMap::new();
248 for (name, node_ref) in &tags {
249 if let Some(&ni) = node_index.get(&node_ref.node_id) {
250 let port_idx = nodes[ni]
251 .output_ports
252 .iter()
253 .position(|p| p == &node_ref.port)
254 .unwrap_or(0);
255 tag_names_map.insert(name.clone(), (ni, port_idx));
256 tag_capture
257 .entry(ni)
258 .or_default()
259 .push((name.clone(), port_idx));
260 }
261 }
262
263 let mut children: HashMap<String, usize> = HashMap::new();
265 for (idx, node) in nodes.iter().enumerate() {
266 if let Some(ref module) = node.module {
267 if let Some(child_graph) = module.as_graph() {
268 if let Some(child_label) = child_graph.label() {
269 if child_label.contains('.') {
270 return Err(TensorError::new(&format!(
271 "child graph label {:?} contains a dot — \
272 dots are reserved for path separators",
273 child_label
274 )));
275 }
276 if children.contains_key(child_label) {
277 return Err(TensorError::new(&format!(
278 "duplicate child graph label {:?} at the same tree level",
279 child_label
280 )));
281 }
282 if let Some(&(tag_ni, _)) = tag_names_map.get(child_label) {
284 if tag_ni != idx {
285 return Err(TensorError::new(&format!(
286 "child graph label {:?} collides with a tag \
287 on a different node",
288 child_label
289 )));
290 }
291 }
292 children.insert(child_label.to_string(), idx);
293 child_graph.composed.set(true);
294 }
295 }
297 }
298 }
299
300 for name in tag_names_map.keys() {
302 if name.starts_with('_') {
303 internal_tags.insert(name.clone());
304 }
305 }
306
307 let mut state_writers: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
310 for (si, fr) in forward_refs.iter().enumerate() {
311 if let Some(&ni) = node_index.get(&fr.writer_id) {
312 state[si].writer_ni = ni;
313 let port_idx = nodes[ni]
314 .output_ports
315 .iter()
316 .position(|p| p == &fr.writer_port)
317 .unwrap_or(0);
318 state_writers.entry(ni).or_default().push((si, port_idx));
319 }
320 }
321
322 let n = nodes.len();
324 let mut routes_from: Vec<Vec<Route>> = vec![Vec::new(); n];
325 for edge in &edges {
326 let from_ni = node_index[&edge.from_node];
327 let to_ni = node_index[&edge.to_node];
328 let from_port_idx = nodes[from_ni]
329 .output_ports
330 .iter()
331 .position(|p| p == &edge.from_port)
332 .unwrap_or(0);
333 let to_port_idx = nodes[to_ni]
334 .input_ports
335 .iter()
336 .position(|p| p == &edge.to_port)
337 .unwrap_or(0);
338 routes_from[from_ni].push(Route {
339 from_port_idx,
340 to_node_idx: to_ni,
341 to_port_idx,
342 });
343 }
344
345 let input_routes: Vec<InputRoute> = inputs
347 .iter()
348 .map(|ep| {
349 let ni = node_index[&ep.node_id];
350 let port_idx = nodes[ni]
351 .input_ports
352 .iter()
353 .position(|p| p == &ep.port)
354 .unwrap_or(0);
355 InputRoute {
356 node_idx: ni,
357 port_idx,
358 }
359 })
360 .collect();
361
362 let output_node_idx = node_index[&outputs[0].node_id];
364 let output_port_idx = nodes[output_node_idx]
365 .output_ports
366 .iter()
367 .position(|p| p == &outputs[0].port)
368 .unwrap_or(0);
369
370 let node_input_count: Vec<usize> = nodes.iter().map(|nd| nd.input_ports.len()).collect();
372 let exec_slots = RefCell::new(
373 node_input_count.iter().map(|&c| vec![None; c]).collect(),
374 );
375
376 let graph = Ok(Graph {
377 nodes,
378 node_index,
379 levels,
380 edges,
381 edges_from,
382 inputs,
383 outputs,
384 order,
385 state,
386 state_writers,
387 tag_groups,
388 tag_names: tag_names_map,
389 tag_capture,
390 tagged_outputs: RefCell::new(HashMap::new()),
391 batch_buffer: RefCell::new(HashMap::new()),
392 epoch_history: RefCell::new(HashMap::new()),
393 metric_order: RefCell::new(Vec::new()),
394 flush_count: Cell::new(0),
395 profiling: Cell::new(false),
396 last_profile: RefCell::new(None),
397 timing_buffer: RefCell::new(HashMap::new()),
398 timing_history: RefCell::new(HashMap::new()),
399 flush_times: RefCell::new(Vec::new()),
400 training_start: Cell::new(0.0),
401 step_count: Cell::new(0),
402 epoch_count: Cell::new(0),
403 label,
404 structural_hash_cache: OnceCell::new(),
405 children,
406 composed: Cell::new(false),
407 internal_tags,
408 routes_from,
409 input_routes,
410 output_node_idx,
411 output_port_idx,
412 node_input_count,
413 exec_slots,
414 });
415
416 if verbose {
417 if let Ok(ref g) = graph {
418 eprintln!("{}", g.tree_summary());
419 }
420 }
421
422 graph
423 }
424
425 fn forward_impl(&self, graph_inputs: &[Variable]) -> Result<Variable> {
426 if graph_inputs.len() != self.inputs.len() {
427 return Err(TensorError::new(&format!(
428 "expected {} inputs, got {}",
429 self.inputs.len(),
430 graph_inputs.len()
431 )));
432 }
433
434 if self.training_start.get() == 0.0 {
436 self.training_start.set(instant_secs());
437 }
438
439 let is_profiling = self.profiling.get();
440 let forward_start = if is_profiling { Some(Instant::now()) } else { None };
441 let mut prof_nodes: Vec<profile::NodeTiming> = Vec::new();
442 let mut prof_levels: Vec<profile::LevelTiming> = Vec::new();
443
444 let tags_by_node: HashMap<usize, String> = if is_profiling {
446 let mut m = HashMap::new();
447 for (name, &(ni, _)) in &self.tag_names {
448 m.entry(ni).or_insert_with(|| name.clone());
449 }
450 m
451 } else {
452 HashMap::new()
453 };
454
455 let has_tags = !self.tag_capture.is_empty();
456
457 let mut slots = self.exec_slots.borrow_mut();
459
460 for node_slots in slots.iter_mut() {
462 for slot in node_slots.iter_mut() {
463 *slot = None;
464 }
465 }
466
467 if has_tags {
469 self.tagged_outputs.borrow_mut().clear();
470 }
471
472 for (i, route) in self.input_routes.iter().enumerate() {
474 slots[route.node_idx][route.port_idx] = Some(graph_inputs[i].clone());
475 }
476
477 let mut final_output: Option<Vec<Variable>> = None;
479
480 for (level_idx, level) in self.levels.iter().enumerate() {
482 let level_start = if is_profiling { Some(Instant::now()) } else { None };
483 let mut level_sum_ns: u64 = 0;
484
485 for &ni in level {
486 let node = &self.nodes[ni];
487 let input_count = self.node_input_count[ni];
488
489 let inputs: Vec<Variable> = (0..input_count)
491 .map(|i| {
492 match slots[ni][i].as_ref() {
493 Some(v) => Ok(v.clone()),
494 None if i > 0 => {
495 let first = slots[ni][0].as_ref().ok_or_else(|| {
497 TensorError::new(&format!(
498 "node '{}': ref port {} has no data and primary input \
499 is also missing — check that all inputs are connected",
500 node.id, i
501 ))
502 })?;
503 Ok(Variable::new(
504 Tensor::zeros_like(&first.data())?,
505 false,
506 ))
507 }
508 _ => Err(TensorError::new(&format!(
509 "node '{}': missing primary input (port {}) — check that all \
510 inputs to this node are connected in the graph builder",
511 node.id, i
512 ))),
513 }
514 })
515 .collect::<Result<Vec<Variable>>>()?;
516
517 for slot in slots[ni].iter_mut() {
519 *slot = None;
520 }
521
522 let node_start = if is_profiling { Some(Instant::now()) } else { None };
524 let node_outputs = (node.run)(&inputs)?;
525 if is_profiling {
526 let elapsed = node_start.unwrap().elapsed();
527 level_sum_ns += elapsed.as_nanos() as u64;
528 prof_nodes.push(profile::NodeTiming {
529 id: node.id.clone(),
530 tag: tags_by_node.get(&ni).cloned().unwrap_or_default(),
531 duration: elapsed,
532 level: level_idx,
533 });
534 }
535
536 for route in &self.routes_from[ni] {
538 let value = if route.from_port_idx < node_outputs.len() {
539 Some(node_outputs[route.from_port_idx].clone())
540 } else {
541 None
542 };
543 slots[route.to_node_idx][route.to_port_idx] = value;
544 }
545
546 if let Some(writers) = self.state_writers.get(&ni) {
548 for &(si, port_idx) in writers {
549 if port_idx < node_outputs.len() {
550 *self.state[si].value.borrow_mut() =
551 Some(node_outputs[port_idx].clone());
552 }
553 }
554 }
555
556 if has_tags {
558 if let Some(captures) = self.tag_capture.get(&ni) {
559 let mut tagged = self.tagged_outputs.borrow_mut();
560 for (tag_name, port_idx) in captures {
561 if *port_idx < node_outputs.len() {
562 tagged.insert(
563 tag_name.clone(),
564 node_outputs[*port_idx].clone(),
565 );
566 }
567 }
568 }
569 }
570
571 if ni == self.output_node_idx {
573 final_output = Some(node_outputs);
574 }
575 }
576
577 if is_profiling {
579 prof_levels.push(profile::LevelTiming {
580 index: level_idx,
581 wall_clock: level_start.unwrap().elapsed(),
582 sum_nodes: std::time::Duration::from_nanos(level_sum_ns),
583 num_nodes: level.len(),
584 });
585 }
586 }
587
588 drop(slots);
590
591 if is_profiling {
593 *self.last_profile.borrow_mut() = Some(profile::Profile {
594 total: forward_start.unwrap().elapsed(),
595 levels: prof_levels,
596 nodes: prof_nodes,
597 });
598 }
599
600 final_output
602 .and_then(|o| o.into_iter().nth(self.output_port_idx))
603 .ok_or_else(|| TensorError::new("graph produced no output"))
604 }
605}
606
607impl Graph {
608 pub fn reset_state(&self) {
611 for entry in &self.state {
612 *entry.value.borrow_mut() = None;
613 }
614 }
615
616 pub fn detach_state(&self) {
619 for entry in &self.state {
621 let mut val = entry.value.borrow_mut();
622 if let Some(ref v) = *val {
623 *val = Some(v.detach());
624 }
625 }
626 {
631 let mut tagged = self.tagged_outputs.borrow_mut();
632 for var in tagged.values_mut() {
633 *var = var.detach();
634 }
635 }
636 for node in &self.nodes {
638 if let Some(ref module) = node.module {
639 module.detach_state();
640 }
641 }
642 }
643
644 pub fn has_state(&self) -> bool {
646 !self.state.is_empty()
647 }
648
649 pub fn end_step(&self) {
667 self.detach_state();
668 if self.profiling.get() {
669 self.collect_timings(&[]);
670 }
671 self.step_count.set(self.step_count.get() + 1);
672 }
673
674 pub fn end_sequence(&self) {
679 self.reset_state();
680 }
681
682 pub fn end_epoch(&self) {
685 self.flush(&[]);
686 if self.profiling.get() {
687 self.flush_timings(&[]);
688 }
689 self.epoch_count.set(self.epoch_count.get() + 1);
690 }
691
692 pub fn step_count(&self) -> usize {
694 self.step_count.get()
695 }
696
697 pub fn epoch_count(&self) -> usize {
699 self.epoch_count.get()
700 }
701
702 pub fn tag_group(&self, name: &str) -> Option<&[String]> {
704 self.tag_groups.get(name).map(|v| v.as_slice())
705 }
706
707 pub fn forward_multi(&self, inputs: &[Variable]) -> Result<Variable> {
710 self.forward_impl(inputs)
711 }
712
713 pub fn set_device(&self, device: crate::tensor::Device) {
715 for p in self.parameters() {
718 if p.variable.data().device() != device
719 && let Ok(t) = p.variable.data().detach()
720 .and_then(|d| d.to_device(device))
721 {
722 p.variable.set_data(t);
723 }
724 }
725 for entry in &self.state {
727 let mut val = entry.value.borrow_mut();
728 if let Some(ref v) = *val
729 && v.data().device() != device
730 && let Ok(t) = v.data().to_device(device)
731 {
732 *val = Some(Variable::new(t, false));
733 }
734 }
735 let mut visited = HashSet::new();
737 for &ni in &self.order {
738 if let Some(ref module) = self.nodes[ni].module {
739 crate::nn::walk_modules_visited(
740 module.as_ref(),
741 &mut visited,
742 &mut |m: &dyn crate::nn::Module| m.move_to_device(device),
743 );
744 }
745 }
746 }
747
748 pub fn named_parameters(&self) -> Vec<(String, Parameter)> {
754 let mut idx_to_tag: HashMap<usize, String> = HashMap::new();
756 for (tag, &(ni, _)) in &self.tag_names {
757 idx_to_tag.entry(ni).or_insert_with(|| tag.clone());
759 }
760
761 let mut result = Vec::new();
762 let mut seen = HashSet::new();
763
764 for &ni in &self.order {
765 if let Some(ref module) = self.nodes[ni].module {
766 let prefix = idx_to_tag.get(&ni)
767 .cloned()
768 .unwrap_or_else(|| self.nodes[ni].id.clone());
769
770 let params = module.parameters();
771 let mut name_counts: HashMap<String, usize> = HashMap::new();
773 for p in ¶ms {
774 *name_counts.entry(p.name.clone()).or_insert(0) += 1;
775 }
776
777 let mut name_idx: HashMap<String, usize> = HashMap::new();
778 for p in params {
779 let ptr = Rc::as_ptr(&p.variable.inner) as usize;
780 if !seen.insert(ptr) {
781 continue;
782 }
783
784 let qualified = if name_counts[&p.name] > 1 {
785 let idx = name_idx.entry(p.name.clone()).or_insert(0);
786 let q = format!("{}/{}_{}", prefix, p.name, idx);
787 *idx += 1;
788 q
789 } else {
790 format!("{}/{}", prefix, p.name)
791 };
792
793 result.push((qualified, p));
794 }
795 }
796 }
797
798 result
799 }
800
801 pub fn named_buffers(&self) -> Vec<(String, Buffer)> {
804 let mut idx_to_tag: HashMap<usize, String> = HashMap::new();
805 for (tag, &(ni, _)) in &self.tag_names {
806 idx_to_tag.entry(ni).or_insert_with(|| tag.clone());
807 }
808
809 let mut result = Vec::new();
810 let mut seen = HashSet::new();
811
812 for &ni in &self.order {
813 if let Some(ref module) = self.nodes[ni].module {
814 let prefix = idx_to_tag.get(&ni)
815 .cloned()
816 .unwrap_or_else(|| self.nodes[ni].id.clone());
817
818 let bufs = module.buffers();
819 let mut name_counts: HashMap<String, usize> = HashMap::new();
820 for b in &bufs {
821 *name_counts.entry(b.name.clone()).or_insert(0) += 1;
822 }
823
824 let mut name_idx: HashMap<String, usize> = HashMap::new();
825 for b in bufs {
826 let ptr = Rc::as_ptr(&b.inner) as usize;
827 if !seen.insert(ptr) {
828 continue;
829 }
830
831 let qualified = if name_counts[&b.name] > 1 {
832 let idx = name_idx.entry(b.name.clone()).or_insert(0);
833 let q = format!("{}/{}_{}", prefix, b.name, idx);
834 *idx += 1;
835 q
836 } else {
837 format!("{}/{}", prefix, b.name)
838 };
839
840 result.push((qualified, b));
841 }
842 }
843 }
844
845 result
846 }
847
848 pub fn label(&self) -> Option<&str> {
850 self.label.as_deref()
851 }
852
853 pub fn structural_hash(&self) -> &str {
855 self.structural_hash_cache.get_or_init(|| self.compute_structural_hash())
856 }
857
858 pub fn short_hash(&self) -> &str {
860 &self.structural_hash()[..8]
861 }
862
863 pub fn save_checkpoint(&self, path: &str) -> Result<()> {
868 let params = self.named_parameters();
869 let buffers = self.named_buffers();
870 let hash = self.structural_hash();
871 crate::nn::save_checkpoint_file(path, ¶ms, &buffers, Some(hash))
872 }
873
874 pub fn load_checkpoint(&self, path: &str) -> Result<crate::nn::LoadReport> {
880 let params = self.named_parameters();
881 let buffers = self.named_buffers();
882 let hash = self.structural_hash();
883 crate::nn::load_checkpoint_file(path, ¶ms, &buffers, Some(hash))
884 }
885
886 fn compute_structural_hash(&self) -> String {
887 let mut hasher = Sha256::new();
888
889 for &ni in &self.order {
891 let node = &self.nodes[ni];
892 hasher.update(node.id.as_bytes());
893 hasher.update(b"\0");
894
895 if let Some(ref module) = node.module {
896 hasher.update(module.name().as_bytes());
897 hasher.update(b"\0");
898
899 let mut params: Vec<_> = module.parameters().into_iter()
901 .map(|p| (p.name.clone(), p.variable.shape()))
902 .collect();
903 params.sort_by(|a, b| a.0.cmp(&b.0));
904 for (name, shape) in ¶ms {
905 hasher.update(b"P");
906 hasher.update(name.as_bytes());
907 hasher.update(b"\0");
908 for &dim in shape {
909 hasher.update(dim.to_le_bytes());
910 }
911 }
912
913 let mut bufs: Vec<_> = module.buffers().into_iter()
915 .map(|b| (b.name.clone(), b.shape()))
916 .collect();
917 bufs.sort_by(|a, b| a.0.cmp(&b.0));
918 for (name, shape) in &bufs {
919 hasher.update(b"B");
920 hasher.update(name.as_bytes());
921 hasher.update(b"\0");
922 for &dim in shape {
923 hasher.update(dim.to_le_bytes());
924 }
925 }
926
927 if let Some(nested_hash) = module.structural_hash() {
929 hasher.update(b"G");
930 hasher.update(nested_hash.as_bytes());
931 }
932 }
933 }
934
935 hasher.update(b"EDGES");
937 for edge in &self.edges {
938 hasher.update(edge.from_node.as_bytes());
939 hasher.update(b"\0");
940 hasher.update(edge.from_port.as_bytes());
941 hasher.update(b"\0");
942 hasher.update(edge.to_node.as_bytes());
943 hasher.update(b"\0");
944 hasher.update(edge.to_port.as_bytes());
945 hasher.update(b"\0");
946 }
947
948 hasher.update(b"TAGS");
950 let mut tags: Vec<_> = self.tag_names.iter().collect();
951 tags.sort_by(|a, b| a.0.cmp(b.0));
952 for (name, (node_idx, port_idx)) in &tags {
953 hasher.update(name.as_bytes());
954 hasher.update(b"\0");
955 hasher.update((*node_idx as u64).to_le_bytes());
956 hasher.update((*port_idx as u64).to_le_bytes());
957 }
958
959 hasher.update(b"INPUTS");
961 for port in &self.inputs {
962 hasher.update(port.name.as_bytes());
963 hasher.update(b"\0");
964 hasher.update(port.node_id.as_bytes());
965 hasher.update(b"\0");
966 hasher.update(port.port.as_bytes());
967 hasher.update(b"\0");
968 }
969 hasher.update(b"OUTPUTS");
970 for port in &self.outputs {
971 hasher.update(port.name.as_bytes());
972 hasher.update(b"\0");
973 hasher.update(port.node_id.as_bytes());
974 hasher.update(b"\0");
975 hasher.update(port.port.as_bytes());
976 hasher.update(b"\0");
977 }
978
979 hasher.finalize().iter().map(|b| format!("{b:02x}")).collect()
980 }
981}
982
983impl Module for Graph {
984 fn name(&self) -> &str { "graph" }
985
986 fn as_graph(&self) -> Option<&Graph> { Some(self) }
987
988 fn structural_hash(&self) -> Option<String> {
989 Some(self.structural_hash().to_string())
990 }
991
992 fn forward(&self, input: &Variable) -> Result<Variable> {
993 self.forward_impl(std::slice::from_ref(input))
994 }
995
996 fn parameters(&self) -> Vec<Parameter> {
997 let mut params = Vec::new();
998 let mut seen = HashSet::new();
999
1000 for &ni in &self.order {
1001 if let Some(ref module) = self.nodes[ni].module {
1002 for p in module.parameters() {
1003 let ptr = Rc::as_ptr(&p.variable.inner) as usize;
1004 if seen.insert(ptr) {
1005 params.push(p);
1006 }
1007 }
1008 }
1009 }
1010
1011 params
1012 }
1013
1014 fn set_training(&self, training: bool) {
1015 let mut visited = HashSet::new();
1016 for &ni in &self.order {
1017 if let Some(ref module) = self.nodes[ni].module {
1018 crate::nn::walk_modules_visited(
1019 module.as_ref(),
1020 &mut visited,
1021 &mut |m: &dyn crate::nn::Module| m.set_training(training),
1022 );
1023 }
1024 }
1025 }
1026
1027 fn move_to_device(&self, device: crate::tensor::Device) {
1028 self.set_device(device);
1029 }
1030}
1031
1032fn instant_secs() -> f64 {
1034 use std::time::SystemTime;
1035 SystemTime::now()
1036 .duration_since(SystemTime::UNIX_EPOCH)
1037 .unwrap_or_default()
1038 .as_secs_f64()
1039}
1040
1041fn topological_levels(
1043 nodes: &[Node],
1044 node_index: &HashMap<String, usize>,
1045 edges: &[Edge],
1046) -> Result<Vec<Vec<usize>>> {
1047 let n = nodes.len();
1048
1049 let mut deps: Vec<HashSet<usize>> = vec![HashSet::new(); n];
1053 let mut dependents: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
1054
1055 for edge in edges {
1056 let from_ni = node_index[&edge.from_node];
1057 let to_ni = node_index[&edge.to_node];
1058 deps[to_ni].insert(from_ni);
1059 dependents[from_ni].insert(to_ni);
1060 }
1061
1062 let mut in_degree: Vec<usize> = deps.iter().map(|d| d.len()).collect();
1063
1064 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
1066 let mut levels = Vec::new();
1067 let mut visited = 0;
1068
1069 while !queue.is_empty() {
1070 levels.push(queue.clone());
1071 visited += queue.len();
1072
1073 let mut next_queue = Vec::new();
1074 for &ni in &queue {
1075 for &dep in &dependents[ni] {
1076 in_degree[dep] -= 1;
1077 if in_degree[dep] == 0 {
1078 next_queue.push(dep);
1079 }
1080 }
1081 }
1082 queue = next_queue;
1083 }
1084
1085 if visited != n {
1086 return Err(TensorError::new("cycle detected in graph"));
1087 }
1088
1089 Ok(levels)
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095 use crate::autograd::Variable;
1096 use crate::nn::{Linear, NamedInputModule, ReLU, Sigmoid, mse_loss, Optimizer, SGD};
1097 use crate::tensor::Tensor;
1098 use std::collections::HashMap;
1099
1100 fn from_f32(data: &[f32], shape: &[i64]) -> Tensor {
1101 Tensor::from_f32(data, shape, crate::tensor::test_device()).unwrap()
1102 }
1103
1104 struct Doubler;
1108 impl Module for Doubler {
1109 fn forward(&self, input: &Variable) -> Result<Variable> {
1110 input.add(input)
1111 }
1112 }
1113
1114 struct BiasStep {
1116 bias: Parameter,
1117 }
1118 impl BiasStep {
1119 fn new(size: i64) -> Result<Self> {
1120 let data = Tensor::zeros(&[size], crate::tensor::test_opts())?;
1121 let var = Variable::new(data, true);
1122 Ok(BiasStep {
1123 bias: Parameter {
1124 variable: var,
1125 name: "loop_bias".to_string(),
1126 },
1127 })
1128 }
1129 }
1130 impl Module for BiasStep {
1131 fn forward(&self, input: &Variable) -> Result<Variable> {
1132 input.add(&self.bias.variable)
1133 }
1134 fn parameters(&self) -> Vec<Parameter> {
1135 vec![self.bias.clone()]
1136 }
1137 }
1138
1139 struct AddRefModule;
1141 impl Module for AddRefModule {
1142 fn forward(&self, input: &Variable) -> Result<Variable> {
1143 Ok(input.clone())
1144 }
1145 fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
1146 }
1147 impl NamedInputModule for AddRefModule {
1148 fn forward_named(
1149 &self,
1150 input: &Variable,
1151 refs: &HashMap<String, Variable>,
1152 ) -> Result<Variable> {
1153 if let Some(ctx) = refs.get("ctx") {
1154 input.add(ctx)
1155 } else {
1156 Ok(input.clone())
1157 }
1158 }
1159 }
1160
1161 #[test]
1164 fn test_single_module() {
1165 let l = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
1166 let graph = FlowBuilder::from(l).build().unwrap();
1167
1168 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1169 let y = graph.forward(&x).unwrap();
1170 assert_eq!(y.shape(), vec![1, 2]);
1171 }
1172
1173 #[test]
1174 fn test_linear_chain() {
1175 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1176 .through(ReLU::new())
1177 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1178 .build()
1179 .unwrap();
1180
1181 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1182 let y = graph.forward(&x).unwrap();
1183 assert_eq!(y.shape(), vec![1, 2]);
1184 }
1185
1186 #[test]
1187 fn test_also_residual() {
1188 let l1 = Linear::on_device(3, 3, crate::tensor::test_device()).unwrap();
1189 l1.weight.variable.set_data(from_f32(
1190 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1191 &[3, 3],
1192 ));
1193 l1.bias
1194 .as_ref()
1195 .unwrap()
1196 .variable
1197 .set_data(from_f32(&[0.0, 0.0, 0.0], &[3]));
1198
1199 let l2 = Linear::on_device(3, 3, crate::tensor::test_device()).unwrap();
1200 l2.weight.variable.set_data(from_f32(
1201 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
1202 &[3, 3],
1203 ));
1204 l2.bias
1205 .as_ref()
1206 .unwrap()
1207 .variable
1208 .set_data(from_f32(&[1.0, 1.0, 1.0], &[3]));
1209
1210 let graph = FlowBuilder::from(l1).also(l2).build().unwrap();
1212
1213 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1214 let y = graph.forward(&x).unwrap();
1215 let data = y.data().to_f32_vec().unwrap();
1216
1217 assert!((data[0] - 3.0).abs() < 1e-5);
1218 assert!((data[1] - 5.0).abs() < 1e-5);
1219 assert!((data[2] - 7.0).abs() < 1e-5);
1220 }
1221
1222 #[test]
1225 fn test_fork_basic() {
1226 let l = Linear::on_device(2, 3, crate::tensor::test_device()).unwrap();
1231
1232 let graph = FlowBuilder::from(Identity)
1233 .fork(l)
1234 .tag("side")
1235 .through(ReLU::new())
1236 .build()
1237 .unwrap();
1238
1239 let x = Variable::new(from_f32(&[1.0, -2.0], &[1, 2]), false);
1240 let y = graph.forward(&x).unwrap();
1241
1242 assert_eq!(y.shape(), vec![1, 2]);
1244 let data = y.data().to_f32_vec().unwrap();
1245 assert!((data[0] - 1.0).abs() < 1e-5);
1246 assert!((data[1] - 0.0).abs() < 1e-5); let side = graph.tagged("side").unwrap();
1250 assert_eq!(side.shape(), vec![1, 3]);
1251 }
1252
1253 #[test]
1254 fn test_fork_multiple() {
1255 let head_a = Linear::on_device(4, 3, crate::tensor::test_device()).unwrap();
1257 let head_b = Linear::on_device(4, 2, crate::tensor::test_device()).unwrap();
1258
1259 let graph = FlowBuilder::from(Linear::on_device(2, 4, crate::tensor::test_device()).unwrap())
1260 .tag("latent")
1261 .fork(head_a)
1262 .tag("head_a")
1263 .fork(head_b)
1264 .tag("head_b")
1265 .build()
1266 .unwrap();
1267
1268 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1269 let y = graph.forward(&x).unwrap();
1270
1271 assert_eq!(y.shape(), vec![1, 4]);
1273
1274 let a = graph.tagged("head_a").unwrap();
1276 assert_eq!(a.shape(), vec![1, 3]);
1277 let b = graph.tagged("head_b").unwrap();
1278 assert_eq!(b.shape(), vec![1, 2]);
1279 }
1280
1281 #[test]
1282 fn test_fork_backward() {
1283 let graph = FlowBuilder::from(Linear::on_device(2, 4, crate::tensor::test_device()).unwrap())
1285 .fork(Linear::on_device(4, 3, crate::tensor::test_device()).unwrap())
1286 .tag("side")
1287 .through(Linear::on_device(4, 1, crate::tensor::test_device()).unwrap())
1288 .build()
1289 .unwrap();
1290
1291 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1292 let y = graph.forward(&x).unwrap();
1293
1294 let side = graph.tagged("side").unwrap();
1296 let loss = y.sum().unwrap().add(&side.sum().unwrap()).unwrap();
1297 loss.backward().unwrap();
1298
1299 assert!(x.grad().is_some(), "input should have gradient");
1300 for p in graph.parameters() {
1301 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1302 }
1303 }
1304
1305 #[test]
1308 fn test_split_merge_add() {
1309 let graph = FlowBuilder::from(Linear::on_device(3, 3, crate::tensor::test_device()).unwrap())
1310 .split(vec![Box::new(ReLU::new()), Box::new(Sigmoid::new())])
1311 .merge(MergeOp::Add)
1312 .build()
1313 .unwrap();
1314
1315 let x = Variable::new(from_f32(&[1.0, -1.0, 2.0], &[1, 3]), false);
1316 let y = graph.forward(&x).unwrap();
1317 assert_eq!(y.shape(), vec![1, 3]);
1318 }
1319
1320 #[test]
1321 fn test_split_merge_mean() {
1322 let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1323 l.weight
1324 .variable
1325 .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1326 l.bias
1327 .as_ref()
1328 .unwrap()
1329 .variable
1330 .set_data(from_f32(&[0.0, 0.0], &[2]));
1331
1332 let b1 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1333 b1.weight
1334 .variable
1335 .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1336 b1.bias
1337 .as_ref()
1338 .unwrap()
1339 .variable
1340 .set_data(from_f32(&[0.0, 0.0], &[2]));
1341 let b2 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1342 b2.weight
1343 .variable
1344 .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1345 b2.bias
1346 .as_ref()
1347 .unwrap()
1348 .variable
1349 .set_data(from_f32(&[0.0, 0.0], &[2]));
1350
1351 let graph = FlowBuilder::from(l)
1352 .split(vec![Box::new(b1), Box::new(b2)])
1353 .merge(MergeOp::Mean)
1354 .build()
1355 .unwrap();
1356
1357 let x = Variable::new(from_f32(&[3.0, 7.0], &[1, 2]), false);
1358 let y = graph.forward(&x).unwrap();
1359 let data = y.data().to_f32_vec().unwrap();
1360
1361 assert!((data[0] - 3.0).abs() < 1e-5);
1362 assert!((data[1] - 7.0).abs() < 1e-5);
1363 }
1364
1365 #[test]
1366 fn test_parameters() {
1367 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1368 .through(ReLU::new())
1369 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1370 .build()
1371 .unwrap();
1372
1373 let params = graph.parameters();
1374 assert_eq!(params.len(), 4);
1375 }
1376
1377 #[test]
1378 fn test_graph_backward() {
1379 let l1 = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
1380 let l2 = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
1381
1382 let graph = FlowBuilder::from(l1)
1383 .through(ReLU::new())
1384 .through(l2)
1385 .build()
1386 .unwrap();
1387
1388 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), true);
1389 let y = graph.forward(&x).unwrap();
1390 let loss = y.sum().unwrap();
1391 loss.backward().unwrap();
1392
1393 for p in graph.parameters() {
1394 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1395 }
1396 assert!(x.grad().is_some());
1397 }
1398
1399 #[test]
1400 fn test_graph_as_module() {
1401 let inner = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1402 .through(ReLU::new())
1403 .build()
1404 .unwrap();
1405
1406 let outer = FlowBuilder::from(inner)
1407 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1408 .build()
1409 .unwrap();
1410
1411 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1412 let y = outer.forward(&x).unwrap();
1413 assert_eq!(y.shape(), vec![1, 2]);
1414 assert_eq!(outer.parameters().len(), 4);
1415 }
1416
1417 #[test]
1418 fn test_training_loop() {
1419 let graph = FlowBuilder::from(Linear::on_device(1, 1, crate::tensor::test_device()).unwrap())
1420 .build()
1421 .unwrap();
1422
1423 let params = graph.parameters();
1424 let mut optim = SGD::new(¶ms, 0.01, 0.0);
1425
1426 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1]), false);
1427 let target = Variable::new(from_f32(&[3.0, 5.0, 7.0, 9.0], &[4, 1]), false);
1428
1429 let mut last_loss = f64::MAX;
1430 for _ in 0..800 {
1431 optim.zero_grad();
1432 let pred = graph.forward(&x).unwrap();
1433 let loss = mse_loss(&pred, &target).unwrap();
1434 last_loss = loss.item().unwrap();
1435 loss.backward().unwrap();
1436 optim.step().unwrap();
1437 }
1438
1439 assert!(last_loss < 0.01, "got loss={}", last_loss);
1440 }
1441
1442 #[test]
1443 fn test_also_backward() {
1444 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1445 .also(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1446 .build()
1447 .unwrap();
1448
1449 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1450 let y = graph.forward(&x).unwrap();
1451 let loss = y.sum().unwrap();
1452 loss.backward().unwrap();
1453
1454 assert!(x.grad().is_some());
1455 for p in graph.parameters() {
1456 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1457 }
1458 }
1459
1460 #[test]
1461 fn test_split_merge_backward() {
1462 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1463 .split(vec![
1464 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
1465 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
1466 ])
1467 .merge(MergeOp::Add)
1468 .build()
1469 .unwrap();
1470
1471 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1472 let y = graph.forward(&x).unwrap();
1473 let loss = y.sum().unwrap();
1474 loss.backward().unwrap();
1475
1476 assert!(x.grad().is_some());
1477 for p in graph.parameters() {
1478 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1479 }
1480 }
1481
1482 #[test]
1483 fn test_build_error_open_streams() {
1484 let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1485 .split(vec![Box::new(ReLU::new()), Box::new(Sigmoid::new())])
1486 .build();
1487 assert!(result.is_err());
1488 }
1489
1490 #[test]
1491 fn test_build_error_duplicate_tag() {
1492 let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1493 .tag("features")
1494 .through(ReLU::new())
1495 .tag("features")
1496 .build();
1497 assert!(result.is_err());
1498 }
1499
1500 #[test]
1503 fn test_using_backward_ref() {
1504 let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1508 l.weight
1509 .variable
1510 .set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1511 l.bias
1512 .as_ref()
1513 .unwrap()
1514 .variable
1515 .set_data(from_f32(&[0.0, 0.0], &[2]));
1516
1517 let graph = FlowBuilder::from(l)
1518 .tag("ctx")
1519 .through(AddRefModule)
1520 .using(&["ctx"])
1521 .build()
1522 .unwrap();
1523
1524 let x = Variable::new(from_f32(&[3.0, 5.0], &[1, 2]), false);
1525 let y = graph.forward(&x).unwrap();
1526 let data = y.data().to_f32_vec().unwrap();
1527
1528 assert!((data[0] - 6.0).abs() < 1e-5);
1530 assert!((data[1] - 10.0).abs() < 1e-5);
1531 }
1532
1533 #[test]
1534 fn test_using_backward_gradients() {
1535 let l = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
1536 let graph = FlowBuilder::from(l)
1537 .tag("ctx")
1538 .through(AddRefModule)
1539 .using(&["ctx"])
1540 .build()
1541 .unwrap();
1542
1543 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1544 let y = graph.forward(&x).unwrap();
1545 let loss = y.sum().unwrap();
1546 loss.backward().unwrap();
1547
1548 assert!(x.grad().is_some());
1549 for p in graph.parameters() {
1550 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1551 }
1552 }
1553
1554 #[test]
1555 fn test_using_error_plain_module() {
1556 let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1558 .tag("ctx")
1559 .through(ReLU::new())
1560 .using(&["ctx"])
1561 .build();
1562 assert!(result.is_err());
1563 }
1564
1565 #[test]
1566 fn test_using_error_unknown_tag() {
1567 let result = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1568 .through(AddRefModule)
1569 .using(&["nonexistent"])
1570 .build();
1571 assert!(result.is_err());
1572 }
1573
1574 #[test]
1577 fn test_loop_for() {
1578 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1580 .loop_body(Doubler)
1581 .for_n(3)
1582 .build()
1583 .unwrap();
1584
1585 let params = graph.parameters();
1587 params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1588 params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1589
1590 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1591 let y = graph.forward(&x).unwrap();
1592 let data = y.data().to_f32_vec().unwrap();
1593
1594 assert!((data[0] - 8.0).abs() < 1e-5, "1*2^3=8, got {}", data[0]);
1595 assert!((data[1] - 16.0).abs() < 1e-5, "2*2^3=16, got {}", data[1]);
1596 }
1597
1598 #[test]
1599 fn test_loop_for_backward() {
1600 let bias_step = BiasStep::new(2).unwrap();
1602 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1603 .loop_body(bias_step)
1604 .for_n(3)
1605 .build()
1606 .unwrap();
1607
1608 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1609 let y = graph.forward(&x).unwrap();
1610 let loss = y.sum().unwrap();
1611 loss.backward().unwrap();
1612
1613 for p in graph.parameters() {
1615 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1616 }
1617
1618 let all_params = graph.parameters();
1622 let bias_param = all_params.iter().find(|p| p.name == "loop_bias").unwrap();
1624 let grad = bias_param.variable.grad().unwrap().to_f32_vec().unwrap();
1625 assert!(
1626 (grad[0] - 3.0).abs() < 1e-5,
1627 "bias grad should be 3, got {}",
1628 grad[0]
1629 );
1630 }
1631
1632 #[test]
1633 fn test_loop_while() {
1634 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1641 .loop_body(Doubler)
1642 .while_cond(ThresholdHalt::new(10.0), 20)
1643 .build()
1644 .unwrap();
1645
1646 let params = graph.parameters();
1647 params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1648 params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1649
1650 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1651 let y = graph.forward(&x).unwrap();
1652 let data = y.data().to_f32_vec().unwrap();
1653
1654 assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1655 assert!((data[1] - 16.0).abs() < 1e-5, "got {}", data[1]);
1656 }
1657
1658 #[test]
1659 fn test_loop_while_immediate_halt() {
1660 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1663 .loop_body(Doubler)
1664 .while_cond(ThresholdHalt::new(0.5), 20)
1665 .build()
1666 .unwrap();
1667
1668 let params = graph.parameters();
1669 params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1670 params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1671
1672 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1673 let y = graph.forward(&x).unwrap();
1674 let data = y.data().to_f32_vec().unwrap();
1675
1676 assert!((data[0] - 1.0).abs() < 1e-5);
1678 assert!((data[1] - 2.0).abs() < 1e-5);
1679 }
1680
1681 #[test]
1682 fn test_loop_until() {
1683 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1690 .loop_body(Doubler)
1691 .until_cond(ThresholdHalt::new(10.0), 20)
1692 .build()
1693 .unwrap();
1694
1695 let params = graph.parameters();
1696 params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1697 params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1698
1699 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1700 let y = graph.forward(&x).unwrap();
1701 let data = y.data().to_f32_vec().unwrap();
1702
1703 assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1704 assert!((data[1] - 16.0).abs() < 1e-5, "got {}", data[1]);
1705 }
1706
1707 #[test]
1708 fn test_loop_until_at_least_once() {
1709 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1712 .loop_body(Doubler)
1713 .until_cond(ThresholdHalt::new(0.5), 20)
1714 .build()
1715 .unwrap();
1716
1717 let params = graph.parameters();
1718 params[0].variable.set_data(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]));
1719 params[1].variable.set_data(from_f32(&[0.0, 0.0], &[2]));
1720
1721 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1722 let y = graph.forward(&x).unwrap();
1723 let data = y.data().to_f32_vec().unwrap();
1724
1725 assert!((data[0] - 2.0).abs() < 1e-5, "got {}", data[0]);
1727 assert!((data[1] - 4.0).abs() < 1e-5, "got {}", data[1]);
1728 }
1729
1730 #[test]
1731 fn test_loop_parameters() {
1732 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1734 .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1735 .for_n(3)
1736 .build()
1737 .unwrap();
1738
1739 let params = graph.parameters();
1740 assert_eq!(params.len(), 4);
1742 }
1743
1744 #[test]
1745 fn test_loop_while_parameters() {
1746 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1748 .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1749 .while_cond(Linear::on_device(2, 1, crate::tensor::test_device()).unwrap(), 10)
1750 .build()
1751 .unwrap();
1752
1753 let params = graph.parameters();
1754 assert_eq!(params.len(), 6);
1756 }
1757
1758 #[test]
1759 fn test_loop_in_chain() {
1760 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
1762 .loop_body(ReLU::new())
1763 .for_n(3)
1764 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
1765 .build()
1766 .unwrap();
1767
1768 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1769 let y = graph.forward(&x).unwrap();
1770 assert_eq!(y.shape(), vec![1, 2]);
1771 }
1772
1773 #[test]
1774 fn test_loop_using_backward_ref() {
1775 let graph = FlowBuilder::from(Identity)
1780 .tag("ctx")
1781 .loop_body(AddRefModule)
1782 .for_n(3)
1783 .using(&["ctx"])
1784 .build()
1785 .unwrap();
1786
1787 let x = Variable::new(from_f32(&[2.0, 3.0], &[1, 2]), false);
1788 let y = graph.forward(&x).unwrap();
1789 let data = y.data().to_f32_vec().unwrap();
1790
1791 assert!((data[0] - 8.0).abs() < 1e-5, "got {}", data[0]);
1793 assert!((data[1] - 12.0).abs() < 1e-5, "got {}", data[1]);
1794 }
1795
1796 #[test]
1797 fn test_loop_using_backward_gradients() {
1798 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1800 .tag("ctx")
1801 .loop_body(AddRefModule)
1802 .for_n(2)
1803 .using(&["ctx"])
1804 .build()
1805 .unwrap();
1806
1807 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1808 let y = graph.forward(&x).unwrap();
1809 let loss = y.sum().unwrap();
1810 loss.backward().unwrap();
1811
1812 assert!(x.grad().is_some(), "input should have gradient");
1813 for p in graph.parameters() {
1814 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1815 }
1816 }
1817
1818 struct NilSafeAdd;
1822 impl Module for NilSafeAdd {
1823 fn forward(&self, input: &Variable) -> Result<Variable> {
1824 Ok(input.clone())
1825 }
1826 fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
1827 }
1828 impl NamedInputModule for NilSafeAdd {
1829 fn forward_named(
1830 &self,
1831 input: &Variable,
1832 refs: &HashMap<String, Variable>,
1833 ) -> Result<Variable> {
1834 if let Some(memory) = refs.get("memory") {
1835 input.add(memory)
1836 } else {
1837 Ok(input.clone())
1838 }
1839 }
1840 }
1841
1842 use crate::nn::Identity;
1843
1844 #[test]
1845 fn test_flowbuilder_new() {
1846 let graph = FlowBuilder::new()
1848 .tag("input")
1849 .through(Linear::on_device(3, 2, crate::tensor::test_device()).unwrap())
1850 .build()
1851 .unwrap();
1852
1853 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
1854 let y = graph.forward(&x).unwrap();
1855 assert_eq!(y.shape(), vec![1, 2]);
1856 }
1857
1858 #[test]
1859 fn test_forward_ref() {
1860 let graph = FlowBuilder::from(Identity)
1865 .through(NilSafeAdd)
1866 .using(&["memory"])
1867 .through(Identity)
1868 .tag("memory")
1869 .build()
1870 .unwrap();
1871
1872 assert!(graph.has_state());
1873
1874 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1876 let y1 = graph.forward(&x).unwrap();
1877 let d1 = y1.data().to_f32_vec().unwrap();
1878 assert!((d1[0] - 1.0).abs() < 1e-5, "pass1[0]: got {}", d1[0]);
1879 assert!((d1[1] - 2.0).abs() < 1e-5, "pass1[1]: got {}", d1[1]);
1880
1881 let y2 = graph.forward(&x).unwrap();
1883 let d2 = y2.data().to_f32_vec().unwrap();
1884 assert!((d2[0] - 2.0).abs() < 1e-5, "pass2[0]: got {}", d2[0]);
1885 assert!((d2[1] - 4.0).abs() < 1e-5, "pass2[1]: got {}", d2[1]);
1886
1887 let y3 = graph.forward(&x).unwrap();
1889 let d3 = y3.data().to_f32_vec().unwrap();
1890 assert!((d3[0] - 3.0).abs() < 1e-5, "pass3[0]: got {}", d3[0]);
1891 assert!((d3[1] - 6.0).abs() < 1e-5, "pass3[1]: got {}", d3[1]);
1892 }
1893
1894 #[test]
1895 fn test_forward_ref_reset_state() {
1896 let graph = FlowBuilder::from(Identity)
1897 .through(NilSafeAdd)
1898 .using(&["memory"])
1899 .through(Identity)
1900 .tag("memory")
1901 .build()
1902 .unwrap();
1903
1904 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1905
1906 graph.forward(&x).unwrap();
1908 graph.forward(&x).unwrap();
1909 let y_before = graph.forward(&x).unwrap();
1910 let d_before = y_before.data().to_f32_vec().unwrap();
1911 assert!((d_before[0] - 3.0).abs() < 1e-5);
1912
1913 graph.reset_state();
1915 let y_after = graph.forward(&x).unwrap();
1916 let d_after = y_after.data().to_f32_vec().unwrap();
1917 assert!((d_after[0] - 1.0).abs() < 1e-5, "after reset: got {}", d_after[0]);
1918 }
1919
1920 #[test]
1921 fn test_forward_ref_detach_state() {
1922 let graph = FlowBuilder::from(Identity)
1923 .through(NilSafeAdd)
1924 .using(&["memory"])
1925 .through(Identity)
1926 .tag("memory")
1927 .build()
1928 .unwrap();
1929
1930 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1931
1932 let y1 = graph.forward(&x).unwrap();
1934 let _ = y1.sum().unwrap();
1935
1936 graph.detach_state();
1938
1939 let y2 = graph.forward(&x).unwrap();
1941 let d2 = y2.data().to_f32_vec().unwrap();
1942 assert!((d2[0] - 2.0).abs() < 1e-5, "detach preserves values: got {}", d2[0]);
1943 }
1944
1945 #[test]
1946 fn test_forward_ref_backward() {
1947 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
1949 .through(NilSafeAdd)
1950 .using(&["memory"])
1951 .through(Identity)
1952 .tag("memory")
1953 .build()
1954 .unwrap();
1955
1956 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
1957 let y = graph.forward(&x).unwrap();
1958 let loss = y.sum().unwrap();
1959 loss.backward().unwrap();
1960
1961 assert!(x.grad().is_some(), "input should have gradient");
1962 for p in graph.parameters() {
1963 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
1964 }
1965 }
1966
1967 #[test]
1968 fn test_forward_ref_unresolved_error() {
1969 let result = FlowBuilder::from(Identity)
1971 .through(NilSafeAdd)
1972 .using(&["nonexistent"])
1973 .build();
1974 assert!(result.is_err());
1975 }
1976
1977 #[test]
1978 fn test_forward_ref_mixed_refs() {
1979 let graph = FlowBuilder::from(Identity)
1982 .tag("ctx")
1983 .through(AddRefModule)
1984 .using(&["ctx"])
1985 .through(NilSafeAdd)
1986 .using(&["memory"])
1987 .through(Identity)
1988 .tag("memory")
1989 .build()
1990 .unwrap();
1991
1992 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
1993
1994 let y1 = graph.forward(&x).unwrap();
1996 let d1 = y1.data().to_f32_vec().unwrap();
1997 assert!((d1[0] - 2.0).abs() < 1e-5, "mixed pass1[0]: got {}", d1[0]);
1998
1999 let y2 = graph.forward(&x).unwrap();
2001 let d2 = y2.data().to_f32_vec().unwrap();
2002 assert!((d2[0] - 4.0).abs() < 1e-5, "mixed pass2[0]: got {}", d2[0]);
2003 }
2004
2005 struct Tripler;
2009 impl Module for Tripler {
2010 fn forward(&self, input: &Variable) -> Result<Variable> {
2011 input.add(&input.add(input)?)
2012 }
2013 fn parameters(&self) -> Vec<Parameter> { vec![] }
2014 }
2015
2016 #[test]
2017 fn test_switch_selects_branch() {
2018 let graph = FlowBuilder::from(Identity)
2020 .switch(FixedSelector::new(1), vec![Box::new(Doubler), Box::new(Tripler)])
2021 .build()
2022 .unwrap();
2023
2024 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2025 let y = graph.forward(&x).unwrap();
2026 let data = y.data().to_f32_vec().unwrap();
2027 assert!((data[0] - 3.0).abs() < 1e-5, "triple [1]=3, got {}", data[0]);
2028 assert!((data[1] - 6.0).abs() < 1e-5, "triple [2]=6, got {}", data[1]);
2029 }
2030
2031 #[test]
2032 fn test_switch_branch0() {
2033 let graph = FlowBuilder::from(Identity)
2034 .switch(FixedSelector::new(0), vec![Box::new(Doubler), Box::new(Tripler)])
2035 .build()
2036 .unwrap();
2037
2038 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2039 let y = graph.forward(&x).unwrap();
2040 let data = y.data().to_f32_vec().unwrap();
2041 assert!((data[0] - 2.0).abs() < 1e-5, "double [1]=2, got {}", data[0]);
2042 assert!((data[1] - 4.0).abs() < 1e-5, "double [2]=4, got {}", data[1]);
2043 }
2044
2045 #[test]
2046 fn test_switch_backward() {
2047 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2048 .switch(FixedSelector::new(0), vec![
2049 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2050 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2051 ])
2052 .build()
2053 .unwrap();
2054
2055 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2056 let y = graph.forward(&x).unwrap();
2057 let loss = y.sum().unwrap();
2058 loss.backward().unwrap();
2059
2060 assert!(x.grad().is_some());
2061 }
2064
2065 #[test]
2066 fn test_switch_parameters() {
2067 let graph = FlowBuilder::from(Identity)
2068 .switch(
2069 Linear::on_device(2, 1, crate::tensor::test_device()).unwrap(),
2070 vec![
2071 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2072 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2073 ],
2074 )
2075 .build()
2076 .unwrap();
2077
2078 let params = graph.parameters();
2079 assert_eq!(params.len(), 6);
2081 }
2082
2083 struct EqualRouter(usize);
2087 impl Module for EqualRouter {
2088 fn forward(&self, input: &Variable) -> Result<Variable> {
2089 let batch = input.shape()[0];
2090 let w = 1.0 / self.0 as f32;
2091 let data = vec![w; batch as usize * self.0];
2092 Ok(Variable::new(
2093 Tensor::from_f32(&data, &[batch, self.0 as i64], crate::tensor::test_device())?,
2094 false,
2095 ))
2096 }
2097 fn parameters(&self) -> Vec<Parameter> { vec![] }
2098 }
2099
2100 #[test]
2101 fn test_gate_equal_weights() {
2102 let graph = FlowBuilder::from(Identity)
2104 .gate(EqualRouter(2), vec![Box::new(Doubler), Box::new(Tripler)])
2105 .build()
2106 .unwrap();
2107
2108 let x = Variable::new(from_f32(&[2.0, 4.0], &[1, 2]), false);
2109 let y = graph.forward(&x).unwrap();
2110 let data = y.data().to_f32_vec().unwrap();
2111 assert!((data[0] - 5.0).abs() < 1e-5, "gate[0]=5, got {}", data[0]);
2113 assert!((data[1] - 10.0).abs() < 1e-5, "gate[1]=10, got {}", data[1]);
2114 }
2115
2116 #[test]
2117 fn test_gate_backward() {
2118 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2119 .gate(
2120 Linear::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2121 vec![
2122 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2123 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2124 ],
2125 )
2126 .build()
2127 .unwrap();
2128
2129 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2130 let y = graph.forward(&x).unwrap();
2131 let loss = y.sum().unwrap();
2132 loss.backward().unwrap();
2133
2134 assert!(x.grad().is_some());
2135 for p in graph.parameters() {
2136 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
2137 }
2138 }
2139
2140 #[test]
2141 fn test_gate_parameters() {
2142 let graph = FlowBuilder::from(Identity)
2143 .gate(
2144 Linear::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2145 vec![
2146 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2147 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2148 ],
2149 )
2150 .build()
2151 .unwrap();
2152
2153 let params = graph.parameters();
2154 assert_eq!(params.len(), 6);
2156 }
2157
2158 #[test]
2161 fn test_map_each() {
2162 let graph = FlowBuilder::from(Identity)
2164 .map(Doubler)
2165 .each()
2166 .build()
2167 .unwrap();
2168
2169 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]), false);
2170 let y = graph.forward(&x).unwrap();
2171 let data = y.data().to_f32_vec().unwrap();
2172
2173 assert_eq!(y.shape(), vec![3, 2]);
2174 assert!((data[0] - 2.0).abs() < 1e-5);
2175 assert!((data[5] - 12.0).abs() < 1e-5);
2176 }
2177
2178 #[test]
2179 fn test_map_batched() {
2180 let graph = FlowBuilder::from(Identity)
2182 .map(Doubler)
2183 .batched()
2184 .each()
2185 .build()
2186 .unwrap();
2187
2188 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), false);
2189 let y = graph.forward(&x).unwrap();
2190 let data = y.data().to_f32_vec().unwrap();
2191
2192 assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
2193 }
2194
2195 #[test]
2196 fn test_map_backward() {
2197 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2198 .map(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2199 .each()
2200 .build()
2201 .unwrap();
2202
2203 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), true);
2204 let y = graph.forward(&x).unwrap();
2205 let loss = y.sum().unwrap();
2206 loss.backward().unwrap();
2207
2208 assert!(x.grad().is_some());
2209 for p in graph.parameters() {
2210 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
2211 }
2212 }
2213
2214 struct ScalarSum;
2218 impl Module for ScalarSum {
2219 fn forward(&self, input: &Variable) -> Result<Variable> {
2220 input.sum()
2221 }
2222 }
2223
2224 #[test]
2225 fn test_tagged_capture() {
2226 let graph = FlowBuilder::from(Identity)
2228 .tag("features")
2229 .through(Doubler)
2230 .build()
2231 .unwrap();
2232
2233 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2234 let _ = graph.forward(&x).unwrap();
2235
2236 let features = graph.tagged("features").unwrap();
2238 let data = features.data().to_f32_vec().unwrap();
2239 assert!((data[0] - 1.0).abs() < 1e-5);
2240 assert!((data[1] - 2.0).abs() < 1e-5);
2241
2242 assert!(graph.tagged("nonexistent").is_none());
2243 }
2244
2245 #[test]
2246 fn test_tagged_updates_each_forward() {
2247 let graph = FlowBuilder::from(Doubler)
2248 .tag("doubled")
2249 .build()
2250 .unwrap();
2251
2252 let x1 = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2253 let _ = graph.forward(&x1).unwrap();
2254 let v1 = graph.tagged("doubled").unwrap().item().unwrap();
2255 assert!((v1 - 2.0).abs() < 1e-5);
2256
2257 let x2 = Variable::new(from_f32(&[5.0], &[1, 1]), false);
2258 let _ = graph.forward(&x2).unwrap();
2259 let v2 = graph.tagged("doubled").unwrap().item().unwrap();
2260 assert!((v2 - 10.0).abs() < 1e-5);
2261 }
2262
2263 #[test]
2264 fn test_tag_names() {
2265 let graph = FlowBuilder::from(Identity)
2266 .tag("a")
2267 .through(Identity)
2268 .tag("b")
2269 .build()
2270 .unwrap();
2271
2272 let mut names = graph.tag_names();
2273 names.sort();
2274 assert_eq!(names, vec!["a", "b"]);
2275 }
2276
2277 #[test]
2278 fn test_collect_flush_trend() {
2279 let graph = FlowBuilder::from(ScalarSum)
2281 .tag("loss")
2282 .build()
2283 .unwrap();
2284
2285 for val in &[1.0f32, 2.0, 3.0] {
2287 let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2288 let _ = graph.forward(&x).unwrap();
2289 graph.collect(&["loss"]).unwrap();
2290 }
2291 let collected = graph.collected("loss");
2293 assert_eq!(collected.len(), 3);
2294
2295 graph.flush(&["loss"]);
2296 assert_eq!(graph.flush_count(), 1);
2297
2298 for val in &[0.5f32, 0.3, 0.2] {
2300 let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2301 let _ = graph.forward(&x).unwrap();
2302 graph.collect(&["loss"]).unwrap();
2303 }
2304 graph.flush(&["loss"]);
2305 assert_eq!(graph.flush_count(), 2);
2306
2307 let trend = graph.trend("loss");
2309 assert_eq!(trend.len(), 2);
2310 assert!((trend.values()[0] - 2.0).abs() < 1e-5);
2311 assert!((trend.values()[1] - (1.0 / 3.0)).abs() < 1e-5);
2312 assert!(trend.improving(0));
2313 }
2314
2315 #[test]
2316 fn test_record_external_values() {
2317 let graph = FlowBuilder::from(Identity).build().unwrap();
2318
2319 graph.record("external_loss", &[0.5, 0.4, 0.3]);
2320 graph.flush(&["external_loss"]);
2321
2322 graph.record("external_loss", &[0.1, 0.05]);
2323 graph.flush(&["external_loss"]);
2324
2325 let trend = graph.trend("external_loss");
2326 assert_eq!(trend.len(), 2);
2327 assert!((trend.values()[0] - 0.4).abs() < 1e-5); assert!((trend.values()[1] - 0.075).abs() < 1e-5); assert!(trend.improving(0));
2330 }
2331
2332 #[test]
2333 fn test_flush_all() {
2334 let graph = FlowBuilder::from(Identity).build().unwrap();
2335
2336 graph.record("a", &[1.0, 2.0]);
2337 graph.record("b", &[3.0, 4.0]);
2338 graph.flush(&[]); assert_eq!(graph.trend("a").len(), 1);
2341 assert_eq!(graph.trend("b").len(), 1);
2342 }
2343
2344 #[test]
2345 fn test_reset_trend() {
2346 let graph = FlowBuilder::from(Identity).build().unwrap();
2347
2348 graph.record("loss", &[1.0]);
2349 graph.flush(&[]);
2350 assert_eq!(graph.trend("loss").len(), 1);
2351
2352 graph.reset_trend(&["loss"]);
2353 assert_eq!(graph.trend("loss").len(), 0);
2354 }
2355
2356 #[test]
2357 fn test_trends_group() {
2358 let graph = FlowBuilder::from(Identity).build().unwrap();
2359
2360 for epoch in &[10.0, 8.0, 6.0, 4.0] {
2362 graph.record("a", &[*epoch]);
2363 graph.record("b", &[*epoch * 0.5]);
2364 graph.flush(&[]);
2365 }
2366
2367 let tg = graph.trends(&["a", "b"]);
2368 assert_eq!(tg.len(), 2);
2369 assert!(tg.all_improving(0));
2370 }
2371
2372 #[test]
2375 fn test_tag_group() {
2376 let graph = FlowBuilder::from(Identity)
2378 .split(vec![
2379 Box::new(Doubler),
2380 Box::new(Tripler),
2381 Box::new(Identity),
2382 ])
2383 .tag_group("branch")
2384 .merge(MergeOp::Add)
2385 .build()
2386 .unwrap();
2387
2388 let members = graph.tag_group("branch").unwrap();
2390 assert_eq!(members, &["branch_0", "branch_1", "branch_2"]);
2391
2392 assert!(graph.tag_group("nonexistent").is_none());
2394
2395 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2397 let _ = graph.forward(&x).unwrap();
2398
2399 let b0 = graph.tagged("branch_0").unwrap();
2400 let b0_data = b0.data().to_f32_vec().unwrap();
2401 assert!((b0_data[0] - 2.0).abs() < 1e-5, "doubler: got {}", b0_data[0]);
2402
2403 let b1 = graph.tagged("branch_1").unwrap();
2404 let b1_data = b1.data().to_f32_vec().unwrap();
2405 assert!((b1_data[0] - 3.0).abs() < 1e-5, "tripler: got {}", b1_data[0]);
2406 }
2407
2408 #[test]
2409 fn test_tag_group_observation() {
2410 let graph = FlowBuilder::from(Identity)
2412 .split(vec![Box::new(ScalarSum), Box::new(ScalarSum)])
2413 .tag_group("head")
2414 .merge(MergeOp::Add)
2415 .build()
2416 .unwrap();
2417
2418 for epoch in &[1.0f32, 2.0, 3.0] {
2420 let x = Variable::new(from_f32(&[*epoch], &[1, 1]), false);
2421 let _ = graph.forward(&x).unwrap();
2422 graph.collect(&["head_0", "head_1"]).unwrap();
2423 graph.flush(&["head_0", "head_1"]);
2424 }
2425
2426 let tg = graph.trends(&["head"]);
2428 assert_eq!(tg.len(), 2); }
2430
2431 #[test]
2432 fn test_tag_group_errors() {
2433 let result = FlowBuilder::from(Identity)
2435 .tag_group("bad")
2436 .build();
2437 assert!(result.is_err());
2438
2439 let result = FlowBuilder::from(Identity)
2441 .split(vec![Box::new(Doubler), Box::new(Tripler)])
2442 .tag_group("x")
2443 .merge(MergeOp::Add)
2444 .split(vec![Box::new(Doubler), Box::new(Tripler)])
2445 .tag_group("x")
2446 .merge(MergeOp::Add)
2447 .build();
2448 assert!(result.is_err());
2449 }
2450
2451 struct SumRefs;
2455 impl Module for SumRefs {
2456 fn forward(&self, input: &Variable) -> Result<Variable> {
2457 Ok(input.clone())
2458 }
2459 fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
2460 }
2461 impl NamedInputModule for SumRefs {
2462 fn forward_named(
2463 &self,
2464 input: &Variable,
2465 refs: &HashMap<String, Variable>,
2466 ) -> Result<Variable> {
2467 let mut result = input.clone();
2468 for v in refs.values() {
2469 result = result.add(v)?;
2470 }
2471 Ok(result)
2472 }
2473 }
2474
2475 #[test]
2476 fn test_input_auxiliary() {
2477 let graph = FlowBuilder::from(Identity)
2480 .input(&["ctx"])
2481 .through(SumRefs)
2482 .using(&["ctx"])
2483 .build()
2484 .unwrap();
2485
2486 let main = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2487 let ctx = Variable::new(from_f32(&[10.0, 20.0], &[1, 2]), false);
2488
2489 let y = graph.forward_multi(&[main, ctx]).unwrap();
2490 let data = y.data().to_f32_vec().unwrap();
2491 assert!((data[0] - 11.0).abs() < 1e-5, "got {}", data[0]);
2493 assert!((data[1] - 22.0).abs() < 1e-5, "got {}", data[1]);
2494 }
2495
2496 #[test]
2497 fn test_input_multiple() {
2498 let graph = FlowBuilder::from(Identity)
2500 .input(&["a", "b"])
2501 .through(SumRefs)
2502 .using(&["a", "b"])
2503 .build()
2504 .unwrap();
2505
2506 let main = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2507 let a = Variable::new(from_f32(&[10.0], &[1, 1]), false);
2508 let b = Variable::new(from_f32(&[100.0], &[1, 1]), false);
2509
2510 let y = graph.forward_multi(&[main, a, b]).unwrap();
2511 let data = y.data().to_f32_vec().unwrap();
2512 assert!((data[0] - 111.0).abs() < 1e-5, "got {}", data[0]);
2514 }
2515
2516 #[test]
2517 fn test_input_error_count_mismatch() {
2518 let graph = FlowBuilder::from(Identity)
2519 .input(&["ctx"])
2520 .build()
2521 .unwrap();
2522
2523 let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2525 assert!(graph.forward(&x).is_err());
2526 }
2527
2528 #[test]
2531 fn test_graph_set_training() {
2532 use crate::nn::Dropout;
2533
2534 let graph = FlowBuilder::from(Linear::on_device(3, 3, crate::tensor::test_device()).unwrap())
2535 .through(Dropout::new(0.5))
2536 .build()
2537 .unwrap();
2538
2539 let x = Variable::new(from_f32(&[1.0; 12], &[4, 3]), false);
2541 let y1 = graph.forward(&x).unwrap();
2542 assert_eq!(y1.shape(), vec![4, 3]);
2543
2544 graph.set_training(false);
2546 let y2 = graph.forward(&x).unwrap();
2547 let y3 = graph.forward(&x).unwrap();
2548 assert_eq!(y2.shape(), vec![4, 3]);
2549
2550 let d2 = y2.data().to_f32_vec().unwrap();
2552 let d3 = y3.data().to_f32_vec().unwrap();
2553 let same = d2.iter().zip(d3.iter()).all(|(a, b)| (a - b).abs() < 1e-6);
2554 assert!(same, "eval mode should be deterministic (no dropout)");
2555 }
2556
2557 #[test]
2560 fn test_walk_modules() {
2561 use crate::nn::walk_modules;
2562
2563 let l1 = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
2564 let mut count = 0;
2565 walk_modules(&l1, &mut |_| count += 1);
2566 assert_eq!(count, 1); }
2568
2569 #[test]
2572 fn test_profiling_basic() {
2573 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2574 .tag("encoder")
2575 .through(ReLU::new())
2576 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2577 .tag("decoder")
2578 .build()
2579 .unwrap();
2580
2581 assert!(!graph.profiling());
2583 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
2584 graph.forward(&x).unwrap();
2585 assert!(graph.profile().is_none());
2586
2587 graph.enable_profiling();
2589 assert!(graph.profiling());
2590 graph.forward(&x).unwrap();
2591
2592 let p = graph.profile().unwrap();
2593 assert!(p.total.as_nanos() > 0, "total should be nonzero");
2594 assert!(!p.nodes.is_empty(), "should have node timings");
2595 assert!(!p.levels.is_empty(), "should have level timings");
2596
2597 let enc_dur = p.timing("encoder");
2599 assert!(enc_dur.as_nanos() > 0, "encoder timing should be nonzero");
2600 let dec_dur = p.timing("decoder");
2601 assert!(dec_dur.as_nanos() > 0, "decoder timing should be nonzero");
2602 assert!(p.timing("nonexistent").is_zero());
2603
2604 assert!(graph.timing("encoder").as_nanos() > 0);
2606
2607 let s = p.to_string();
2609 assert!(s.contains("Forward:"));
2610 assert!(s.contains("Level"));
2611
2612 graph.disable_profiling();
2614 assert!(!graph.profiling());
2615 graph.forward(&x).unwrap();
2616 assert!(graph.profile().is_none());
2617 }
2618
2619 #[test]
2620 fn test_profiling_timing_trend() {
2621 let graph = FlowBuilder::from(ScalarSum)
2622 .tag("loss")
2623 .build()
2624 .unwrap();
2625
2626 graph.enable_profiling();
2627
2628 for _ in 0..2 {
2630 for val in &[1.0f32, 2.0, 3.0] {
2631 let x = Variable::new(from_f32(&[*val], &[1, 1]), false);
2632 graph.forward(&x).unwrap();
2633 graph.collect_timings(&["loss"]);
2634 }
2635 graph.flush_timings(&[]);
2636 }
2637
2638 let trend = graph.timing_trend("loss");
2639 assert_eq!(trend.len(), 2, "2 epochs flushed");
2640 assert!(trend.values()[0] > 0.0, "timing values should be positive");
2641
2642 graph.reset_timing_trend(&["loss"]);
2644 assert_eq!(graph.timing_trend("loss").len(), 0);
2645 }
2646
2647 #[test]
2650 fn test_dot_basic() {
2651 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2652 .tag("enc")
2653 .through(ReLU::new())
2654 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2655 .build()
2656 .unwrap();
2657
2658 let dot = graph.dot();
2659 assert!(dot.contains("digraph G"));
2660 assert!(dot.contains("level 0"));
2661 assert!(dot.contains("#enc"));
2662 assert!(dot.contains("->"));
2663 }
2664
2665 #[test]
2666 fn test_dot_with_profile() {
2667 let graph = FlowBuilder::from(Linear::on_device(3, 4, crate::tensor::test_device()).unwrap())
2668 .tag("enc")
2669 .through(Linear::on_device(4, 2, crate::tensor::test_device()).unwrap())
2670 .build()
2671 .unwrap();
2672
2673 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
2674
2675 let dot1 = graph.dot_with_profile();
2677 assert!(dot1.contains("digraph G"));
2678
2679 graph.enable_profiling();
2681 graph.forward(&x).unwrap();
2682 let dot2 = graph.dot_with_profile();
2683 assert!(dot2.contains("digraph G"));
2684 assert!(dot2.contains("Forward:"));
2685 }
2686
2687 struct TracingDoubler {
2691 last_output: RefCell<Option<Variable>>,
2692 }
2693 impl TracingDoubler {
2694 fn new() -> Self {
2695 TracingDoubler {
2696 last_output: RefCell::new(None),
2697 }
2698 }
2699 }
2700 impl Module for TracingDoubler {
2701 fn forward(&self, input: &Variable) -> Result<Variable> {
2702 let out = input.add(input)?;
2703 *self.last_output.borrow_mut() = Some(out.clone());
2704 Ok(out)
2705 }
2706 fn trace(&self) -> Option<Variable> {
2707 self.last_output.borrow().clone()
2708 }
2709 }
2710
2711 #[test]
2712 fn test_loop_traces() {
2713 let graph = FlowBuilder::from(Identity)
2716 .loop_body(TracingDoubler::new())
2717 .for_n(3)
2718 .build()
2719 .unwrap();
2720
2721 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2722 let y = graph.forward(&x).unwrap();
2723 let data = y.data().to_f32_vec().unwrap();
2724 assert!((data[0] - 8.0).abs() < 1e-5);
2725
2726 let traces = graph.traces("any").unwrap();
2728 assert_eq!(traces.len(), 3, "3 iterations = 3 traces");
2729
2730 let t0 = traces[0].data().to_f32_vec().unwrap();
2731 assert!((t0[0] - 2.0).abs() < 1e-5, "iter0: [2,4], got {}", t0[0]);
2732
2733 let t1 = traces[1].data().to_f32_vec().unwrap();
2734 assert!((t1[0] - 4.0).abs() < 1e-5, "iter1: [4,8], got {}", t1[0]);
2735
2736 let t2 = traces[2].data().to_f32_vec().unwrap();
2737 assert!((t2[0] - 8.0).abs() < 1e-5, "iter2: [8,16], got {}", t2[0]);
2738 }
2739
2740 #[test]
2741 fn test_loop_traces_cleared_each_forward() {
2742 let graph = FlowBuilder::from(Identity)
2743 .loop_body(TracingDoubler::new())
2744 .for_n(2)
2745 .build()
2746 .unwrap();
2747
2748 let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2749 graph.forward(&x).unwrap();
2750 let traces1 = graph.traces("any").unwrap();
2751 assert_eq!(traces1.len(), 2);
2752
2753 graph.forward(&x).unwrap();
2755 let traces2 = graph.traces("any").unwrap();
2756 assert_eq!(traces2.len(), 2);
2757 }
2758
2759 #[test]
2760 fn test_loop_no_traces_without_trace_impl() {
2761 let graph = FlowBuilder::from(Identity)
2763 .loop_body(Doubler)
2764 .for_n(3)
2765 .build()
2766 .unwrap();
2767
2768 let x = Variable::new(from_f32(&[1.0], &[1, 1]), false);
2769 graph.forward(&x).unwrap();
2770
2771 assert!(graph.traces("any").is_none());
2773 }
2774
2775 #[test]
2778 fn test_softmax_router_gate() {
2779 let graph = FlowBuilder::from(Identity)
2781 .gate(
2782 SoftmaxRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2783 vec![Box::new(Doubler), Box::new(Tripler)],
2784 )
2785 .build()
2786 .unwrap();
2787
2788 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2789 let y = graph.forward(&x).unwrap();
2790 assert_eq!(y.shape(), vec![1, 2]);
2792 let params = graph.parameters();
2794 assert_eq!(params.len(), 2);
2795 }
2796
2797 #[test]
2798 fn test_softmax_router_backward() {
2799 let graph = FlowBuilder::from(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2800 .gate(
2801 SoftmaxRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2802 vec![
2803 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2804 Box::new(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap()),
2805 ],
2806 )
2807 .build()
2808 .unwrap();
2809
2810 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), true);
2811 let y = graph.forward(&x).unwrap();
2812 let loss = y.sum().unwrap();
2813 loss.backward().unwrap();
2814
2815 assert!(x.grad().is_some());
2816 for p in graph.parameters() {
2817 assert!(p.variable.grad().is_some(), "{} missing gradient", p.name);
2818 }
2819 }
2820
2821 #[test]
2822 fn test_sigmoid_router_gate() {
2823 let graph = FlowBuilder::from(Identity)
2824 .gate(
2825 SigmoidRouter::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2826 vec![Box::new(Doubler), Box::new(Tripler)],
2827 )
2828 .build()
2829 .unwrap();
2830
2831 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2832 let y = graph.forward(&x).unwrap();
2833 assert_eq!(y.shape(), vec![1, 2]);
2834 }
2835
2836 #[test]
2837 fn test_fixed_selector_switch() {
2838 let graph = FlowBuilder::from(Identity)
2840 .switch(FixedSelector::new(1), vec![Box::new(Doubler), Box::new(Tripler)])
2841 .build()
2842 .unwrap();
2843
2844 let x = Variable::new(from_f32(&[2.0, 3.0], &[1, 2]), false);
2845 let y = graph.forward(&x).unwrap();
2846 let data = y.data().to_f32_vec().unwrap();
2847 assert!((data[0] - 6.0).abs() < 1e-5, "triple 2=6, got {}", data[0]);
2848 assert!((data[1] - 9.0).abs() < 1e-5, "triple 3=9, got {}", data[1]);
2849 }
2850
2851 #[test]
2852 fn test_argmax_selector_switch() {
2853 let graph = FlowBuilder::from(Identity)
2854 .switch(
2855 ArgmaxSelector::on_device(2, 2, crate::tensor::test_device()).unwrap(),
2856 vec![Box::new(Doubler), Box::new(Tripler)],
2857 )
2858 .build()
2859 .unwrap();
2860
2861 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2862 let y = graph.forward(&x).unwrap();
2863 assert_eq!(y.shape(), vec![1, 2]);
2865 assert_eq!(graph.parameters().len(), 2);
2867 }
2868
2869 #[test]
2872 fn test_threshold_halt_while() {
2873 let graph = FlowBuilder::from(Identity)
2876 .loop_body(Doubler)
2877 .while_cond(ThresholdHalt::new(10.0), 20)
2878 .build()
2879 .unwrap();
2880
2881 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2882 let y = graph.forward(&x).unwrap();
2883 let data = y.data().to_f32_vec().unwrap();
2884 assert!((data[0] - 8.0).abs() < 1e-5, "expected 8, got {}", data[0]);
2886 assert!((data[1] - 16.0).abs() < 1e-5, "expected 16, got {}", data[1]);
2887 }
2888
2889 #[test]
2890 fn test_threshold_halt_until() {
2891 let graph = FlowBuilder::from(Identity)
2896 .loop_body(Doubler)
2897 .until_cond(ThresholdHalt::new(10.0), 20)
2898 .build()
2899 .unwrap();
2900
2901 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2902 let y = graph.forward(&x).unwrap();
2903 let data = y.data().to_f32_vec().unwrap();
2904 assert!((data[0] - 8.0).abs() < 1e-5, "expected 8, got {}", data[0]);
2906 assert!((data[1] - 16.0).abs() < 1e-5, "expected 16, got {}", data[1]);
2907 }
2908
2909 #[test]
2910 fn test_threshold_halt_immediate() {
2911 let graph = FlowBuilder::from(Identity)
2913 .loop_body(Doubler)
2914 .while_cond(ThresholdHalt::new(0.5), 20)
2915 .build()
2916 .unwrap();
2917
2918 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
2919 let y = graph.forward(&x).unwrap();
2920 let data = y.data().to_f32_vec().unwrap();
2921 assert!((data[0] - 1.0).abs() < 1e-5, "expected 1, got {}", data[0]);
2923 assert!((data[1] - 2.0).abs() < 1e-5, "expected 2, got {}", data[1]);
2924 }
2925
2926 #[test]
2927 fn test_learned_halt_parameters() {
2928 let graph = FlowBuilder::from(Identity)
2929 .loop_body(Linear::on_device(2, 2, crate::tensor::test_device()).unwrap())
2930 .until_cond(LearnedHalt::on_device(2, crate::tensor::test_device()).unwrap(), 5)
2931 .build()
2932 .unwrap();
2933
2934 let params = graph.parameters();
2936 assert_eq!(params.len(), 4);
2937 }
2938
2939 #[test]
2940 fn test_named_parameters_unique() {
2941 let graph = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2942 .through(ReLU::new())
2943 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2944 .build()
2945 .unwrap();
2946
2947 let named = graph.named_parameters();
2948 assert_eq!(named.len(), 4);
2950
2951 let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
2953 let unique: std::collections::HashSet<&str> = names.iter().copied().collect();
2954 assert_eq!(names.len(), unique.len(), "duplicate names: {:?}", names);
2955 }
2956
2957 #[test]
2958 fn test_named_parameters_tagged_prefix() {
2959 let graph = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2960 .tag("encoder")
2961 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2962 .build()
2963 .unwrap();
2964
2965 let named = graph.named_parameters();
2966 let encoder_params: Vec<&str> = named.iter()
2968 .filter(|(n, _)| n.starts_with("encoder/"))
2969 .map(|(n, _)| n.as_str())
2970 .collect();
2971 assert_eq!(encoder_params.len(), 2, "tagged node should have 2 params with 'encoder/' prefix");
2972
2973 let untagged: Vec<&str> = named.iter()
2975 .filter(|(n, _)| !n.starts_with("encoder/"))
2976 .map(|(n, _)| n.as_str())
2977 .collect();
2978 assert_eq!(untagged.len(), 2, "untagged node should have 2 params");
2979 assert!(untagged[0].contains('/'), "should have prefix/name format: {}", untagged[0]);
2980 }
2981
2982 #[test]
2985 fn test_structural_hash_deterministic() {
2986 let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2987 .through(ReLU::new())
2988 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2989 .build()
2990 .unwrap();
2991
2992 let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
2993 .through(ReLU::new())
2994 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
2995 .build()
2996 .unwrap();
2997
2998 assert_eq!(g1.structural_hash(), g2.structural_hash());
2999 }
3000
3001 #[test]
3002 fn test_structural_hash_differs() {
3003 let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3004 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3005 .build()
3006 .unwrap();
3007
3008 let g2 = FlowBuilder::from(Linear::on_device(4, 16, crate::tensor::test_device()).unwrap())
3010 .through(Linear::on_device(16, 2, crate::tensor::test_device()).unwrap())
3011 .build()
3012 .unwrap();
3013
3014 assert_ne!(g1.structural_hash(), g2.structural_hash());
3015 }
3016
3017 #[test]
3018 fn test_short_hash_length() {
3019 let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3020 .build()
3021 .unwrap();
3022
3023 assert_eq!(g.structural_hash().len(), 64);
3024 assert_eq!(g.short_hash().len(), 8);
3025 assert!(g.structural_hash().starts_with(g.short_hash()));
3026 }
3027
3028 #[test]
3029 fn test_label_default_none() {
3030 let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3031 .build()
3032 .unwrap();
3033 assert!(g.label().is_none());
3034 }
3035
3036 #[test]
3037 fn test_label_set() {
3038 let g = FlowBuilder::from(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3039 .label("my-model")
3040 .build()
3041 .unwrap();
3042 assert_eq!(g.label(), Some("my-model"));
3043 }
3044
3045 #[test]
3046 fn test_label_does_not_affect_hash() {
3047 let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3048 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3049 .build()
3050 .unwrap();
3051
3052 let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3053 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3054 .label("different-label")
3055 .build()
3056 .unwrap();
3057
3058 assert_eq!(g1.structural_hash(), g2.structural_hash());
3059 }
3060
3061 #[test]
3062 fn test_graph_save_load_checkpoint() {
3063 let g = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3064 .tag("enc")
3065 .through(ReLU::new())
3066 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3067 .tag("dec")
3068 .build()
3069 .unwrap();
3070
3071 let dir = std::env::temp_dir();
3072 let path = dir.join("test_graph_ckpt.fdl");
3073 let path_str = path.to_str().unwrap();
3074
3075 g.save_checkpoint(path_str).unwrap();
3077
3078 let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3080 .tag("enc")
3081 .through(ReLU::new())
3082 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3083 .tag("dec")
3084 .build()
3085 .unwrap();
3086
3087 let report = g2.load_checkpoint(path_str).unwrap();
3088 assert_eq!(report.loaded.len(), 4); assert!(report.skipped.is_empty());
3090 assert!(report.missing.is_empty());
3091
3092 for ((n1, p1), (n2, p2)) in g.named_parameters().iter().zip(g2.named_parameters().iter()) {
3094 assert_eq!(n1, n2);
3095 assert_eq!(p1.variable.data().to_f32_vec().unwrap(),
3096 p2.variable.data().to_f32_vec().unwrap());
3097 }
3098
3099 std::fs::remove_file(path_str).ok();
3100 }
3101
3102 #[test]
3103 fn test_graph_checkpoint_hash_mismatch() {
3104 let g1 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3105 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3106 .build()
3107 .unwrap();
3108
3109 let dir = std::env::temp_dir();
3110 let path = dir.join("test_graph_ckpt_mismatch.fdl");
3111 let path_str = path.to_str().unwrap();
3112
3113 g1.save_checkpoint(path_str).unwrap();
3114
3115 let g2 = FlowBuilder::from(Linear::on_device(4, 16, crate::tensor::test_device()).unwrap())
3117 .through(Linear::on_device(16, 2, crate::tensor::test_device()).unwrap())
3118 .build()
3119 .unwrap();
3120
3121 let result = g2.load_checkpoint(path_str);
3122 assert!(result.is_err());
3123 assert!(format!("{}", result.unwrap_err()).contains("architecture mismatch"));
3124
3125 std::fs::remove_file(path_str).ok();
3126 }
3127
3128 #[test]
3129 fn test_graph_checkpoint_gz() {
3130 let g = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3131 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3132 .build()
3133 .unwrap();
3134
3135 let dir = std::env::temp_dir();
3136 let path = dir.join("test_graph_ckpt.fdl.gz");
3137 let path_str = path.to_str().unwrap();
3138
3139 g.save_checkpoint(path_str).unwrap();
3140
3141 let g2 = FlowBuilder::from(Linear::on_device(4, 8, crate::tensor::test_device()).unwrap())
3142 .through(Linear::on_device(8, 2, crate::tensor::test_device()).unwrap())
3143 .build()
3144 .unwrap();
3145
3146 let report = g2.load_checkpoint(path_str).unwrap();
3147 assert_eq!(report.loaded.len(), 4);
3148
3149 std::fs::remove_file(path_str).ok();
3150 }
3151
3152 #[test]
3155 fn test_collect_with_sum_reduction() {
3156 let graph = FlowBuilder::from(Identity)
3158 .tag("features")
3159 .build()
3160 .unwrap();
3161
3162 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), false);
3163 let _ = graph.forward(&x).unwrap();
3164 graph.collect_with(&["features"], Reduce::Sum).unwrap();
3165
3166 let collected = graph.collected("features");
3167 assert_eq!(collected.len(), 1);
3168 assert!((collected[0] - 6.0).abs() < 1e-5, "sum([1,2,3]) = 6, got {}", collected[0]);
3169 }
3170
3171 #[test]
3172 fn test_collect_with_mean_reduction() {
3173 let graph = FlowBuilder::from(Identity)
3174 .tag("out")
3175 .build()
3176 .unwrap();
3177
3178 let x = Variable::new(from_f32(&[2.0, 4.0, 6.0], &[1, 3]), false);
3179 let _ = graph.forward(&x).unwrap();
3180 graph.collect_with(&["out"], Reduce::Mean).unwrap();
3181
3182 let collected = graph.collected("out");
3183 assert!((collected[0] - 4.0).abs() < 1e-5, "mean([2,4,6]) = 4, got {}", collected[0]);
3184 }
3185
3186 #[test]
3187 fn test_collect_with_max_reduction() {
3188 let graph = FlowBuilder::from(Identity)
3189 .tag("out")
3190 .build()
3191 .unwrap();
3192
3193 let x = Variable::new(from_f32(&[1.0, 5.0, 3.0], &[1, 3]), false);
3194 let _ = graph.forward(&x).unwrap();
3195 graph.collect_with(&["out"], Reduce::Max).unwrap();
3196
3197 let collected = graph.collected("out");
3198 assert!((collected[0] - 5.0).abs() < 1e-5, "max([1,5,3]) = 5, got {}", collected[0]);
3199 }
3200
3201 #[test]
3202 fn test_collect_with_min_reduction() {
3203 let graph = FlowBuilder::from(Identity)
3204 .tag("out")
3205 .build()
3206 .unwrap();
3207
3208 let x = Variable::new(from_f32(&[-2.0, 0.0, 3.0], &[1, 3]), false);
3209 let _ = graph.forward(&x).unwrap();
3210 graph.collect_with(&["out"], Reduce::Min).unwrap();
3211
3212 let collected = graph.collected("out");
3213 assert!((collected[0] - (-2.0)).abs() < 1e-5, "min([-2,0,3]) = -2, got {}", collected[0]);
3214 }
3215
3216 #[test]
3217 fn test_collect_with_norm_reduction() {
3218 let graph = FlowBuilder::from(Identity)
3219 .tag("out")
3220 .build()
3221 .unwrap();
3222
3223 let x = Variable::new(from_f32(&[3.0, 4.0], &[1, 2]), false);
3224 let _ = graph.forward(&x).unwrap();
3225 graph.collect_with(&["out"], Reduce::Norm).unwrap();
3226
3227 let collected = graph.collected("out");
3228 assert!((collected[0] - 5.0).abs() < 1e-4, "norm([3,4]) = 5, got {}", collected[0]);
3230 }
3231
3232 #[test]
3233 fn test_collect_rejects_non_scalar() {
3234 let graph = FlowBuilder::from(Identity)
3236 .tag("out")
3237 .build()
3238 .unwrap();
3239
3240 let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
3241 let _ = graph.forward(&x).unwrap();
3242 assert!(graph.collect(&["out"]).is_err());
3243 }
3244
3245 #[test]
3246 fn test_collect_with_scalar_passthrough() {
3247 let graph = FlowBuilder::from(ScalarSum)
3249 .tag("loss")
3250 .build()
3251 .unwrap();
3252
3253 let x = Variable::new(from_f32(&[3.0, 7.0], &[1, 2]), false);
3254 let _ = graph.forward(&x).unwrap();
3255 graph.collect_with(&["loss"], Reduce::Max).unwrap();
3256
3257 let collected = graph.collected("loss");
3258 assert!((collected[0] - 10.0).abs() < 1e-5);
3260 }
3261
3262 #[test]
3263 fn test_collect_with_flush_trend_pipeline() {
3264 let graph = FlowBuilder::from(Identity)
3266 .tag("h")
3267 .build()
3268 .unwrap();
3269
3270 let x1 = Variable::new(from_f32(&[3.0, 4.0], &[1, 2]), false);
3272 let _ = graph.forward(&x1).unwrap();
3273 graph.collect_with(&["h"], Reduce::Norm).unwrap();
3274
3275 let x2 = Variable::new(from_f32(&[1.0, 0.0], &[1, 2]), false);
3276 let _ = graph.forward(&x2).unwrap();
3277 graph.collect_with(&["h"], Reduce::Norm).unwrap();
3278
3279 graph.flush(&["h"]);
3280
3281 let x3 = Variable::new(from_f32(&[0.5, 0.5], &[1, 2]), false);
3283 let _ = graph.forward(&x3).unwrap();
3284 graph.collect_with(&["h"], Reduce::Norm).unwrap();
3285 graph.flush(&["h"]);
3286
3287 let trend = graph.trend("h");
3288 assert_eq!(trend.len(), 2);
3289 assert!((trend.values()[0] - 3.0).abs() < 1e-4);
3291 assert!(trend.improving(0)); }
3293
3294 #[test]
3297 fn test_map_over_tag() {
3298 let graph = FlowBuilder::from(Identity)
3300 .tag("features")
3301 .through(Doubler) .map(Doubler)
3303 .over("features") .build()
3305 .unwrap();
3306
3307 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), false);
3308 let y = graph.forward(&x).unwrap();
3309 let data = y.data().to_f32_vec().unwrap();
3310 assert_eq!(y.shape(), vec![2, 2]);
3313 assert!((data[0] - 2.0).abs() < 1e-5); assert!((data[1] - 4.0).abs() < 1e-5); assert!((data[2] - 6.0).abs() < 1e-5); assert!((data[3] - 8.0).abs() < 1e-5); }
3318
3319 #[test]
3320 fn test_map_over_unknown_tag_error() {
3321 let result = FlowBuilder::from(Identity)
3322 .map(Doubler)
3323 .over("nonexistent")
3324 .build();
3325 assert!(result.is_err());
3326 }
3327
3328 #[test]
3329 fn test_map_slices() {
3330 let graph = FlowBuilder::from(Identity)
3332 .map(Doubler)
3333 .slices(2)
3334 .build()
3335 .unwrap();
3336
3337 let x = Variable::new(
3338 from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]),
3339 false,
3340 );
3341 let y = graph.forward(&x).unwrap();
3342 let data = y.data().to_f32_vec().unwrap();
3343
3344 assert_eq!(y.shape(), vec![2, 4]);
3346 assert!((data[0] - 2.0).abs() < 1e-5);
3347 assert!((data[7] - 16.0).abs() < 1e-5);
3348 }
3349
3350 #[test]
3351 fn test_map_slices_batched() {
3352 let graph = FlowBuilder::from(Identity)
3354 .map(Doubler)
3355 .batched()
3356 .slices(2)
3357 .build()
3358 .unwrap();
3359
3360 let x = Variable::new(
3361 from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]),
3362 false,
3363 );
3364 let y = graph.forward(&x).unwrap();
3365 let data = y.data().to_f32_vec().unwrap();
3366
3367 assert_eq!(y.shape(), vec![2, 4]);
3368 assert!((data[0] - 2.0).abs() < 1e-5);
3369 assert!((data[7] - 16.0).abs() < 1e-5);
3370 }
3371
3372 #[test]
3373 fn test_map_slices_gradient() {
3374 let graph = FlowBuilder::from(Identity)
3376 .map(Linear::on_device(2, 3, crate::tensor::test_device()).unwrap())
3377 .slices(2)
3378 .build()
3379 .unwrap();
3380
3381 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]), true);
3382 let y = graph.forward(&x).unwrap();
3383 assert_eq!(y.shape(), vec![2, 6]); let loss = y.sum().unwrap();
3385 loss.backward().unwrap();
3386
3387 assert!(x.grad().is_some());
3388 for p in graph.parameters() {
3389 assert!(p.variable.grad().is_some(), "{} should have gradient", p.name);
3390 }
3391 }
3392
3393 #[test]
3394 fn test_map_slices_not_divisible_error() {
3395 let graph = FlowBuilder::from(Identity)
3396 .map(Doubler)
3397 .slices(3)
3398 .build()
3399 .unwrap();
3400
3401 let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]), false);
3403 assert!(graph.forward(&x).is_err());
3404 }
3405}