burn_autodiff/graph/
node.rs1use alloc::{sync::Arc, vec::Vec};
2
3#[cfg(target_has_atomic = "64")]
4use core::sync::atomic::{AtomicU64, Ordering};
5#[cfg(not(target_has_atomic = "64"))]
6use portable_atomic::{AtomicU64, Ordering};
7
8use crate::checkpoint::retro_forward::RetroForward;
9use crate::runtime::AutodiffClientImpl;
10
11use super::Requirement;
12
13#[derive(Debug, Clone)]
14pub enum ComputingProperty {
15 ComputeBound,
16 MemoryBound {
17 retro_forward: Arc<dyn RetroForward>,
18 },
19 Ambiguous, }
21
22unsafe impl Send for ComputingProperty {}
28unsafe impl Sync for ComputingProperty {}
30
31#[derive(new, Debug)]
33pub struct Node {
34 pub parents: Vec<Parent>,
35 pub order: usize,
36 pub id: NodeId,
37 pub requirement: Requirement,
38 pub properties: ComputingProperty,
39 pub client: AutodiffClientImpl,
40}
41pub type NodeRef = Arc<Node>;
42
43#[derive(new, Debug, Clone, PartialEq, Eq)]
44pub struct Parent {
45 pub id: NodeId,
46}
47
48impl Node {
49 pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
51 match self.requirement.is_none() {
52 true => None,
53 false => Some(self.clone()),
54 }
55 }
56}
57
58#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
60pub struct NodeId {
61 pub value: u64,
63}
64
65impl core::fmt::Display for NodeId {
66 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67 f.write_fmt(format_args!("NodeId({})", self.value))
68 }
69}
70
71impl NodeId {
72 pub fn new() -> Self {
74 static COUNTER: AtomicU64 = AtomicU64::new(0);
75 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
76 if value == u64::MAX {
77 panic!("NodeId overflowed");
78 }
79 Self { value }
80 }
81}
82
83impl Default for NodeId {
84 fn default() -> Self {
85 Self::new()
86 }
87}