1use std::ops::{Add, Mul};
2
3use crate::gradients::{GradientError, Gradients};
4use crate::operation_record::OperationRecord;
5use crate::tape::Tape;
6use num_traits::{One, Zero};
7
8#[derive(Clone, Copy, Debug)]
9pub struct Variable<'a, F> {
25 pub(crate) index: Option<(usize, &'a Tape<F>)>,
26 pub(crate) value: F,
27}
28
29type BinaryFn<T, S = T> = fn(T, S) -> T;
30type UnaryFn<T> = fn(T) -> T;
31type BinaryPairFn<T> = fn(T, T) -> (T, T);
32
33impl<F: Copy> Variable<'_, F> {
34 #[inline]
35 #[must_use]
36 pub const fn value(&self) -> F {
37 self.value
38 }
39 #[inline]
40 #[must_use]
41 pub fn apply_binary_function(&self, rhs: &Self, f: BinaryFn<F>, dfdx: BinaryPairFn<F>) -> Self {
42 #[inline]
43 fn create_index<'a, F>(
44 value: F,
45 rhs: Variable<'a, F>,
46 dfdx: fn(F, F) -> (F, F),
47 idx: [usize; 2],
48 tape: &'a Tape<F>,
49 ) -> usize {
50 let operations = &mut tape.operations.borrow_mut();
51 let count = (*operations).len();
52 let df = dfdx(value, rhs.value);
53 (*operations).push(OperationRecord([(idx[0], df.0), (idx[1], df.1)]));
54 count
55 }
56 let value = f(self.value, rhs.value);
57 match (self.index, rhs.index) {
58 (Some((i, tape)), Some((j, _))) => Variable {
59 index: Some((create_index(self.value, *rhs, dfdx, [i, j], tape), tape)),
60 value,
61 },
62 (None, None) => Variable { index: None, value },
63 (None, Some((j, tape))) => Variable {
64 index: Some((
65 create_index(self.value, *rhs, dfdx, [usize::MAX, j], tape),
66 tape,
67 )),
68 value,
69 },
70 (Some((i, tape)), None) => Variable {
71 index: Some((
72 create_index(self.value, *rhs, dfdx, [i, usize::MAX], tape),
73 tape,
74 )),
75 value,
76 },
77 }
78 }
79}
80
81impl<F: Copy + Zero> Variable<'_, F> {
82 #[inline]
83 #[must_use]
84 pub fn apply_unary_function(&self, f: UnaryFn<F>, df: UnaryFn<F>) -> Self {
85 let value = f(self.value);
86 match self.index {
87 Some((i, tape)) => Variable {
88 index: {
89 let operations = &mut tape.operations.borrow_mut();
90 let count = (*operations).len();
91 (*operations).push(OperationRecord([
92 (i, df(self.value)),
93 (usize::MAX, F::zero()),
94 ]));
95 Some((count, tape))
96 },
97 value,
98 },
99 None => Variable { index: None, value },
100 }
101 }
102
103 #[inline]
104 #[must_use]
105 pub fn apply_scalar_function<T: Copy>(
106 &self,
107 f: BinaryFn<F, T>,
108 df: BinaryFn<F, T>,
109 scalar: T,
110 ) -> Self {
111 let value = f(self.value, scalar);
112 match self.index {
113 Some((i, tape)) => Variable {
114 index: {
115 let operations = &mut tape.operations.borrow_mut();
116 let count = (*operations).len();
117 (*operations).push(OperationRecord([
118 (i, df(self.value, scalar)),
119 (usize::MAX, F::zero()),
120 ]));
121 Some((count, tape))
122 },
123 value,
124 },
125 None => Variable { index: None, value },
126 }
127 }
128}
129
130impl<F: Copy + One + Zero> Variable<'_, F> {
131 #[inline]
132 pub fn compute_gradients(&self) -> Result<Gradients<F>, GradientError> {
146 let (var_index, tape) = self.index.ok_or(GradientError::MissingIndex)?;
147 let operations = &tape.operations.borrow();
148 let mut grads = vec![F::zero(); operations.len()];
149 grads[var_index] = F::one();
150
151 for (i, operation) in (*operations).iter().enumerate().rev() {
152 let grad = grads[i];
153 if grad.is_zero() {
154 continue;
155 }
156 for j in 0..2 {
157 let (idx, val) = operation.0[j];
158 if idx == usize::MAX {
159 continue;
160 }
161 grads[idx] = grads[idx] + val * grad;
162 }
163 }
164
165 Ok(Gradients(grads))
166 }
167}
168
169macro_rules! impl_partial_ord {
170 ($scalar:ty) => {
171 impl<'a> PartialOrd<Variable<'a, $scalar>> for $scalar {
172 #[inline]
173 fn partial_cmp(&self, other: &Variable<'a, $scalar>) -> Option<std::cmp::Ordering> {
174 self.partial_cmp(&other.value)
175 }
176 }
177
178 impl<'a> PartialEq<Variable<'a, $scalar>> for $scalar {
179 #[inline]
180 fn eq(&self, other: &Variable<'a, $scalar>) -> bool {
181 self == &other.value
182 }
183 }
184 };
185}
186
187impl_partial_ord!(f32);
188impl_partial_ord!(f64);
189
190macro_rules! impl_partial_ord_for_variable {
191 ($scalar:ty) => {
192 impl<'a, 'b> PartialOrd<Variable<'a, Variable<'b, $scalar>>> for $scalar {
193 #[inline]
194 fn partial_cmp(
195 &self,
196 other: &Variable<'a, Variable<'b, $scalar>>,
197 ) -> Option<std::cmp::Ordering> {
198 self.partial_cmp(&other.value)
199 }
200 }
201 };
202}
203
204impl_partial_ord_for_variable!(f64);
205
206impl<'a, 'b> PartialEq<Variable<'a, Variable<'b, f64>>> for f64 {
207 #[inline]
208 fn eq(&self, other: &Variable<'a, Variable<'b, f64>>) -> bool {
209 self == &other.value
210 }
211}
212
213impl<F: Zero> Zero for Variable<'_, F>
214where
215 Self: Add<Self, Output = Self>,
216{
217 #[inline]
218 #[must_use]
219 fn zero() -> Self {
220 Self::constant(F::zero())
221 }
222
223 #[inline]
224 fn is_zero(&self) -> bool {
225 self.value.is_zero()
226 }
227
228 #[inline]
229 fn set_zero(&mut self) {
230 *self = Self::zero();
231 }
232}
233
234impl<F: One> One for Variable<'_, F>
235where
236 Self: Mul<Self, Output = Self>,
237{
238 #[inline]
239 #[must_use]
240 fn one() -> Self {
241 Self::constant(F::one())
242 }
243
244 #[inline]
245 fn set_one(&mut self) {
246 *self = Self::one();
247 }
248
249 #[inline]
250 fn is_one(&self) -> bool
251 where
252 Self: PartialEq,
253 {
254 *self == Self::one()
255 }
256}
257
258impl<F> Variable<'_, F> {
259 #[inline]
260 #[must_use]
261 pub fn constant(value: F) -> Self {
262 Self { index: None, value }
263 }
264}
265
266impl<F: From<f64>> From<f64> for Variable<'_, F> {
267 #[inline]
268 fn from(value: f64) -> Self {
269 Self::constant(F::from(value))
270 }
271}
272
273impl<F: From<f32>> From<f32> for Variable<'_, F> {
274 #[inline]
275 fn from(value: f32) -> Self {
276 Self::constant(F::from(value))
277 }
278}
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_compute_second_gradients() {
285 let tape = Tape::new();
286 let tape2 = Tape::new();
287 let [x, y] = tape.create_variables(&[1.0, 2.0]);
288 let [x, y] = tape2.create_variables(&[x, y]);
289 let z = x * x + y;
290 let grads = z.compute_gradients().expect("Failed to compute gradients");
291 let grad = grads.get_gradient(&x).expect("Failed to get gradient");
292 let z = grad
293 .compute_gradients()
294 .expect("Failed to compute second gradients");
295 let grad2 = z
296 .get_gradient(&x.value)
297 .expect("Failed to get second gradient");
298 assert_eq!(grad2, 2.0);
299 }
300}