aad/
variable.rs

1use std::ops::{Add, Mul};
2
3use crate::gradients::{GradientError, Gradients};
4use crate::operation_record::OperationRecord;
5use crate::tape::Tape;
6use num_traits::{One, Zero};
7
8#[derive(Clone, Copy, Debug)]
9/// A variable type that tracks operations for automatic differentiation.
10///
11/// This struct represents a variable in the computation graph, storing its value
12/// and maintaining references to the tape that records operations performed on it.
13///
14/// # Type Parameters
15///
16/// * `'a` - The lifetime of the reference to the tape
17/// * `F` - The underlying numeric type (typically `f32` or `f64`)
18///
19/// # Fields
20///
21/// * `index` - The unique index of this variable in the computation tape
22/// * `tape` - Reference to the tape that records operations on this variable
23/// * `value` - The current value of the variable
24pub struct Variable<'a, F> {
25    pub(crate) index: Option<(usize, &'a Tape<F>)>,
26    pub(crate) value: F,
27}
28
29type BinaryFn<T, S = T> = fn(T, S) -> T;
30type UnaryFn<T> = fn(T) -> T;
31type BinaryPairFn<T> = fn(T, T) -> (T, T);
32
33impl<F: Copy> Variable<'_, F> {
34    #[inline]
35    #[must_use]
36    pub const fn value(&self) -> F {
37        self.value
38    }
39    #[inline]
40    #[must_use]
41    pub fn apply_binary_function(&self, rhs: &Self, f: BinaryFn<F>, dfdx: BinaryPairFn<F>) -> Self {
42        #[inline]
43        fn create_index<'a, F>(
44            value: F,
45            rhs: Variable<'a, F>,
46            dfdx: fn(F, F) -> (F, F),
47            idx: [usize; 2],
48            tape: &'a Tape<F>,
49        ) -> usize {
50            let operations = &mut tape.operations.borrow_mut();
51            let count = (*operations).len();
52            let df = dfdx(value, rhs.value);
53            (*operations).push(OperationRecord([(idx[0], df.0), (idx[1], df.1)]));
54            count
55        }
56        let value = f(self.value, rhs.value);
57        match (self.index, rhs.index) {
58            (Some((i, tape)), Some((j, _))) => Variable {
59                index: Some((create_index(self.value, *rhs, dfdx, [i, j], tape), tape)),
60                value,
61            },
62            (None, None) => Variable { index: None, value },
63            (None, Some((j, tape))) => Variable {
64                index: Some((
65                    create_index(self.value, *rhs, dfdx, [usize::MAX, j], tape),
66                    tape,
67                )),
68                value,
69            },
70            (Some((i, tape)), None) => Variable {
71                index: Some((
72                    create_index(self.value, *rhs, dfdx, [i, usize::MAX], tape),
73                    tape,
74                )),
75                value,
76            },
77        }
78    }
79}
80
81impl<F: Copy + Zero> Variable<'_, F> {
82    #[inline]
83    #[must_use]
84    pub fn apply_unary_function(&self, f: UnaryFn<F>, df: UnaryFn<F>) -> Self {
85        let value = f(self.value);
86        match self.index {
87            Some((i, tape)) => Variable {
88                index: {
89                    let operations = &mut tape.operations.borrow_mut();
90                    let count = (*operations).len();
91                    (*operations).push(OperationRecord([
92                        (i, df(self.value)),
93                        (usize::MAX, F::zero()),
94                    ]));
95                    Some((count, tape))
96                },
97                value,
98            },
99            None => Variable { index: None, value },
100        }
101    }
102
103    #[inline]
104    #[must_use]
105    pub fn apply_scalar_function<T: Copy>(
106        &self,
107        f: BinaryFn<F, T>,
108        df: BinaryFn<F, T>,
109        scalar: T,
110    ) -> Self {
111        let value = f(self.value, scalar);
112        match self.index {
113            Some((i, tape)) => Variable {
114                index: {
115                    let operations = &mut tape.operations.borrow_mut();
116                    let count = (*operations).len();
117                    (*operations).push(OperationRecord([
118                        (i, df(self.value, scalar)),
119                        (usize::MAX, F::zero()),
120                    ]));
121                    Some((count, tape))
122                },
123                value,
124            },
125            None => Variable { index: None, value },
126        }
127    }
128}
129
130impl<F: Copy + One + Zero> Variable<'_, F> {
131    #[inline]
132    /// Computes gradients for this variable with respect to all variables in the computation graph.
133    ///
134    /// This performs reverse-mode automatic differentiation by traversing the computation graph
135    /// backwards from this variable to compute partial derivatives with respect to all variables.
136    ///
137    /// # Returns
138    ///
139    /// * `Ok(Gradients<F>)` - The computed gradients if successful
140    /// * `Err(GradientError)` - If this variable has no index in the computation graph
141    ///
142    /// # Errors
143    ///
144    /// * Returns `GradientError::MissingIndex` if this variable has no index in the computation graph
145    pub fn compute_gradients(&self) -> Result<Gradients<F>, GradientError> {
146        let (var_index, tape) = self.index.ok_or(GradientError::MissingIndex)?;
147        let operations = &tape.operations.borrow();
148        let mut grads = vec![F::zero(); operations.len()];
149        grads[var_index] = F::one();
150
151        for (i, operation) in (*operations).iter().enumerate().rev() {
152            let grad = grads[i];
153            if grad.is_zero() {
154                continue;
155            }
156            for j in 0..2 {
157                let (idx, val) = operation.0[j];
158                if idx == usize::MAX {
159                    continue;
160                }
161                grads[idx] = grads[idx] + val * grad;
162            }
163        }
164
165        Ok(Gradients(grads))
166    }
167}
168
169macro_rules! impl_partial_ord {
170    ($scalar:ty) => {
171        impl<'a> PartialOrd<Variable<'a, $scalar>> for $scalar {
172            #[inline]
173            fn partial_cmp(&self, other: &Variable<'a, $scalar>) -> Option<std::cmp::Ordering> {
174                self.partial_cmp(&other.value)
175            }
176        }
177
178        impl<'a> PartialEq<Variable<'a, $scalar>> for $scalar {
179            #[inline]
180            fn eq(&self, other: &Variable<'a, $scalar>) -> bool {
181                self == &other.value
182            }
183        }
184    };
185}
186
187impl_partial_ord!(f32);
188impl_partial_ord!(f64);
189
190macro_rules! impl_partial_ord_for_variable {
191    ($scalar:ty) => {
192        impl<'a, 'b> PartialOrd<Variable<'a, Variable<'b, $scalar>>> for $scalar {
193            #[inline]
194            fn partial_cmp(
195                &self,
196                other: &Variable<'a, Variable<'b, $scalar>>,
197            ) -> Option<std::cmp::Ordering> {
198                self.partial_cmp(&other.value)
199            }
200        }
201    };
202}
203
204impl_partial_ord_for_variable!(f64);
205
206impl<'a, 'b> PartialEq<Variable<'a, Variable<'b, f64>>> for f64 {
207    #[inline]
208    fn eq(&self, other: &Variable<'a, Variable<'b, f64>>) -> bool {
209        self == &other.value
210    }
211}
212
213impl<F: Zero> Zero for Variable<'_, F>
214where
215    Self: Add<Self, Output = Self>,
216{
217    #[inline]
218    #[must_use]
219    fn zero() -> Self {
220        Self::constant(F::zero())
221    }
222
223    #[inline]
224    fn is_zero(&self) -> bool {
225        self.value.is_zero()
226    }
227
228    #[inline]
229    fn set_zero(&mut self) {
230        *self = Self::zero();
231    }
232}
233
234impl<F: One> One for Variable<'_, F>
235where
236    Self: Mul<Self, Output = Self>,
237{
238    #[inline]
239    #[must_use]
240    fn one() -> Self {
241        Self::constant(F::one())
242    }
243
244    #[inline]
245    fn set_one(&mut self) {
246        *self = Self::one();
247    }
248
249    #[inline]
250    fn is_one(&self) -> bool
251    where
252        Self: PartialEq,
253    {
254        *self == Self::one()
255    }
256}
257
258impl<F> Variable<'_, F> {
259    #[inline]
260    #[must_use]
261    pub fn constant(value: F) -> Self {
262        Self { index: None, value }
263    }
264}
265
266impl<F: From<f64>> From<f64> for Variable<'_, F> {
267    #[inline]
268    fn from(value: f64) -> Self {
269        Self::constant(F::from(value))
270    }
271}
272
273impl<F: From<f32>> From<f32> for Variable<'_, F> {
274    #[inline]
275    fn from(value: f32) -> Self {
276        Self::constant(F::from(value))
277    }
278}
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_compute_second_gradients() {
285        let tape = Tape::new();
286        let tape2 = Tape::new();
287        let [x, y] = tape.create_variables(&[1.0, 2.0]);
288        let [x, y] = tape2.create_variables(&[x, y]);
289        let z = x * x + y;
290        let grads = z.compute_gradients().expect("Failed to compute gradients");
291        let grad = grads.get_gradient(&x).expect("Failed to get gradient");
292        let z = grad
293            .compute_gradients()
294            .expect("Failed to compute second gradients");
295        let grad2 = z
296            .get_gradient(&x.value)
297            .expect("Failed to get second gradient");
298        assert_eq!(grad2, 2.0);
299    }
300}