1use crate::array::Array;
18use anyhow::Result;
19use std::cell::RefCell;
20use std::rc::Rc;
21
22pub mod backward;
23pub mod nn;
24pub mod ops;
25pub mod optim;
26pub mod tensor;
27pub mod train;
28
29pub use nn::{
30 BatchNorm1d, Conv1d, Dropout, Flatten, Linear, Module, ReLU, Sequential, Sigmoid, Softmax,
31};
32pub use optim::{
33 AdaBound, AdaDelta, AdaGrad, Adam, AdamW, CosineAnnealingLR, ExponentialLR, LinearWarmup,
34 Lookahead, NAdam, Optimizer, RAdam, RMSprop, ReduceLROnPlateau, Rprop, Scheduler, StepLR, LAMB,
35 LBFGS, SGD,
36};
37pub use tensor::Tensor;
38pub use train::{
39 CrossEntropyLoss, Dataset, LossFunction, MSELoss, Metrics, Trainer, TrainerBuilder,
40};
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct NodeId(usize);
45
46impl NodeId {
47 pub fn new() -> Self {
48 use std::sync::atomic::{AtomicUsize, Ordering};
49 static COUNTER: AtomicUsize = AtomicUsize::new(0);
50 Self(COUNTER.fetch_add(1, Ordering::SeqCst))
51 }
52}
53
54#[derive(Debug, Clone)]
56pub enum OpKind {
57 Add,
59 Mul,
60 Sub,
61 Div,
62 Neg, Abs, MatMul,
67 Transpose,
68
69 Exp,
71 Log,
72 Pow(f32),
73 Sqrt,
74
75 ReLU,
77 Sigmoid,
78 Tanh,
79 Softmax,
80
81 Sin,
83 Cos,
84 Tan,
85
86 MSE,
88 CrossEntropy,
89
90 Sum {
92 axis: Option<usize>,
93 },
94 Mean {
95 axis: Option<usize>,
96 },
97 Max {
98 axis: Option<usize>,
99 },
100
101 Conv1D {
103 stride: usize,
104 padding: usize,
105 },
106 Flatten {
107 start_dim: usize,
108 end_dim: usize,
109 },
110 Reshape {
111 shape: Vec<usize>,
112 },
113 BatchNorm {
114 training: bool,
115 momentum: f32,
116 eps: f32,
117 },
118 Dropout {
119 p: f32,
120 training: bool,
121 },
122
123 Leaf,
125}
126
127pub type BackwardFn = Box<dyn Fn(&Array, &[Tensor], &Tensor) -> Result<Vec<Array>>>;
137
138#[derive(Clone)]
140pub struct ComputeNode {
141 pub id: NodeId,
142 pub op: OpKind,
143 pub inputs: Vec<Tensor>,
144 pub backward_fn: Option<Rc<BackwardFn>>,
145}
146
147impl ComputeNode {
148 pub fn new(op: OpKind, inputs: Vec<Tensor>, backward_fn: Option<BackwardFn>) -> Self {
149 Self {
150 id: NodeId::new(),
151 op,
152 inputs,
153 backward_fn: backward_fn.map(Rc::new),
154 }
155 }
156
157 pub fn leaf() -> Self {
158 Self {
159 id: NodeId::new(),
160 op: OpKind::Leaf,
161 inputs: vec![],
162 backward_fn: None,
163 }
164 }
165}
166
167impl std::fmt::Debug for ComputeNode {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("ComputeNode")
170 .field("id", &self.id)
171 .field("op", &self.op)
172 .field("inputs", &self.inputs.len())
173 .field("has_backward", &self.backward_fn.is_some())
174 .finish()
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct AutogradContext {
181 pub enabled: bool,
182}
183
184impl AutogradContext {
185 pub fn new(enabled: bool) -> Self {
186 Self { enabled }
187 }
188
189 pub fn no_grad() -> Self {
190 Self { enabled: false }
191 }
192}
193
194thread_local! {
195 static AUTOGRAD_ENABLED: RefCell<bool> = RefCell::new(true);
196}
197
198pub fn is_grad_enabled() -> bool {
200 AUTOGRAD_ENABLED.with(|enabled| *enabled.borrow())
201}
202
203pub fn set_grad_enabled(enabled: bool) {
205 AUTOGRAD_ENABLED.with(|e| *e.borrow_mut() = enabled);
206}
207
208pub struct NoGrad {
210 prev_state: bool,
211}
212
213impl NoGrad {
214 pub fn new() -> Self {
215 let prev_state = is_grad_enabled();
216 set_grad_enabled(false);
217 Self { prev_state }
218 }
219}
220
221impl Drop for NoGrad {
222 fn drop(&mut self) {
223 set_grad_enabled(self.prev_state);
224 }
225}
226
227#[macro_export]
241macro_rules! no_grad {
242 ($($body:tt)*) => {{
243 let _guard = $crate::autograd::NoGrad::new();
244 $($body)*
245 }};
246}