burn_autodiff/graph/
node.rs

1use alloc::{sync::Arc, vec::Vec};
2
3#[cfg(target_has_atomic = "64")]
4use core::sync::atomic::{AtomicU64, Ordering};
5#[cfg(not(target_has_atomic = "64"))]
6use portable_atomic::{AtomicU64, Ordering};
7
8use crate::checkpoint::retro_forward::RetroForward;
9use crate::runtime::AutodiffClientImpl;
10
11use super::Requirement;
12
13#[derive(Debug, Clone)]
14pub enum ComputingProperty {
15    ComputeBound,
16    MemoryBound {
17        retro_forward: Arc<dyn RetroForward>,
18    },
19    Ambiguous, // Maybe autotune someday
20}
21
22/// This is safe only because we only call RetroForward on the autodiff server.
23/// Therefore, the trait will never be used by multiple threads at the same time.
24///
25/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
26/// Arc, which will make (dyn RetroForward) safely implement Send.
27unsafe impl Send for ComputingProperty {}
28/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
29unsafe impl Sync for ComputingProperty {}
30
31/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
32#[derive(new, Debug)]
33pub struct Node {
34    pub parents: Vec<Parent>,
35    pub order: usize,
36    pub id: NodeId,
37    pub requirement: Requirement,
38    pub properties: ComputingProperty,
39    pub client: AutodiffClientImpl,
40}
41pub type NodeRef = Arc<Node>;
42
43#[derive(new, Debug, Clone, PartialEq, Eq)]
44pub struct Parent {
45    pub id: NodeId,
46}
47
48impl Node {
49    /// Returns the [node](Node) only if gradients are required.
50    pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
51        match self.requirement.is_none() {
52            true => None,
53            false => Some(self.clone()),
54        }
55    }
56}
57
58/// Unique identifier generated for each node.
59#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
60pub struct NodeId {
61    /// The integer representation of the id
62    pub value: u64,
63}
64
65impl core::fmt::Display for NodeId {
66    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67        f.write_fmt(format_args!("NodeId({})", self.value))
68    }
69}
70
71impl NodeId {
72    /// Create a unique [node id](NodeId).
73    pub fn new() -> Self {
74        static COUNTER: AtomicU64 = AtomicU64::new(0);
75        let value = COUNTER.fetch_add(1, Ordering::Relaxed);
76        if value == u64::MAX {
77            panic!("NodeId overflowed");
78        }
79        Self { value }
80    }
81}
82
83impl Default for NodeId {
84    fn default() -> Self {
85        Self::new()
86    }
87}