petite_ad/multi/
multi_ad.rs

1use super::types::*;
2use crate::error::{AutodiffError, Result};
3
4/// Multi-variable automatic differentiation operations.
5///
6/// Represents operations in a computational graph for functions with multiple inputs.
7/// Each operation takes references to previous results via indices.
8///
9/// # Examples
10///
11/// ```
12/// use petite_ad::{MultiAD, multi_ops};
13///
14/// // Build graph: f(x, y) = sin(x) * (x + y)
15/// let exprs = multi_ops![
16///     (inp, 0),    // x at index 0
17///     (inp, 1),    // y at index 1
18///     (add, 0, 1), // x + y at index 2
19///     (sin, 0),    // sin(x) at index 3
20///     (mul, 2, 3), // sin(x) * (x + y) at index 4
21/// ];
22///
23/// let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
24/// let gradients = grad_fn(1.0);
25/// println!("f(0.6, 1.4) = {}", value);
26/// println!("∇f = {:?}", gradients);
27/// ```
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum MultiAD {
30    /// Input placeholder - references an input variable
31    Inp,
32    /// Addition: a + b
33    Add,
34    /// Subtraction: a - b
35    Sub,
36    /// Multiplication: a * b
37    Mul,
38    /// Division: a / b
39    ///
40    /// # Notes
41    /// - Delegates to `f64::div()`, which returns `inf` for division by zero
42    /// - Returns `NaN` for `0.0 / 0.0`
43    Div,
44    /// Power: a^b (a raised to the power of b)
45    ///
46    /// # Notes
47    /// - Delegates to `f64::powf()`
48    /// - For `x^n` where n is an integer, consider using repeated multiplication
49    Pow,
50    /// Sine function: sin(x)
51    ///
52    /// # Notes
53    /// - Delegates to `f64::sin()`, which operates in radians
54    /// - Returns values in the range `[-1.0, 1.0]`
55    Sin,
56    /// Cosine function: cos(x)
57    ///
58    /// # Notes
59    /// - Delegates to `f64::cos()`, which operates in radians
60    /// - Returns values in the range `[-1.0, 1.0]`
61    Cos,
62    /// Tangent function: tan(x)
63    ///
64    /// # Notes
65    /// - Delegates to `f64::tan()`, which operates in radians
66    /// - Returns very large values near `π/2 + kπ` (asymptotes)
67    Tan,
68    /// Exponential function: exp(x)
69    ///
70    /// # Notes
71    /// - Delegates to `f64::exp()`
72    /// - Returns `inf` for very large inputs (> ~709 for f64)
73    /// - Returns `0.0` for very large negative inputs (< ~-745 for f64)
74    Exp,
75    /// Natural logarithm: ln(x)
76    ///
77    /// # Notes
78    /// - Delegates to `f64::ln()`
79    /// - Returns `NaN` for negative inputs
80    /// - Returns `-inf` for `ln(0.0)`
81    Ln,
82    /// Square root: sqrt(x)
83    ///
84    /// # Notes
85    /// - Delegates to `f64::sqrt()`
86    /// - Returns `NaN` for negative inputs
87    Sqrt,
88    /// Absolute value: abs(x)
89    ///
90    /// # Notes
91    /// - Delegates to `f64::abs()`
92    /// - Subgradient at x=0 is 0 (consistent with common practice)
93    Abs,
94}
95
96impl MultiAD {
97    /// Get the name of this operation (for error messages and arity checking)
98    fn op_name(&self) -> &'static str {
99        match self {
100            MultiAD::Inp => "Inp",
101            MultiAD::Add => "Add",
102            MultiAD::Sub => "Sub",
103            MultiAD::Mul => "Mul",
104            MultiAD::Div => "Div",
105            MultiAD::Pow => "Pow",
106            MultiAD::Sin => "Sin",
107            MultiAD::Cos => "Cos",
108            MultiAD::Tan => "Tan",
109            MultiAD::Exp => "Exp",
110            MultiAD::Ln => "Ln",
111            MultiAD::Sqrt => "Sqrt",
112            MultiAD::Abs => "Abs",
113        }
114    }
115
116    /// Get the expected arity for this operation
117    fn expected_arity(&self) -> usize {
118        match self {
119            MultiAD::Inp
120            | MultiAD::Sin
121            | MultiAD::Cos
122            | MultiAD::Tan
123            | MultiAD::Exp
124            | MultiAD::Ln
125            | MultiAD::Sqrt
126            | MultiAD::Abs => 1,
127            MultiAD::Add | MultiAD::Sub | MultiAD::Mul | MultiAD::Div | MultiAD::Pow => 2,
128        }
129    }
130    /// Forward pass: compute the output of this operation given inputs
131    fn forward(&self, args: &[f64]) -> Result<f64> {
132        Ok(match self {
133            MultiAD::Inp => {
134                AutodiffError::check_arity("Inp", 1, args.len())?;
135                args[0]
136            }
137            MultiAD::Sin => {
138                AutodiffError::check_arity("Sin", 1, args.len())?;
139                args[0].sin()
140            }
141            MultiAD::Cos => {
142                AutodiffError::check_arity("Cos", 1, args.len())?;
143                args[0].cos()
144            }
145            MultiAD::Tan => {
146                AutodiffError::check_arity("Tan", 1, args.len())?;
147                args[0].tan()
148            }
149            MultiAD::Exp => {
150                AutodiffError::check_arity("Exp", 1, args.len())?;
151                args[0].exp()
152            }
153            MultiAD::Ln => {
154                AutodiffError::check_arity("Ln", 1, args.len())?;
155                args[0].ln()
156            }
157            MultiAD::Sqrt => {
158                AutodiffError::check_arity("Sqrt", 1, args.len())?;
159                args[0].sqrt()
160            }
161            MultiAD::Abs => {
162                AutodiffError::check_arity("Abs", 1, args.len())?;
163                args[0].abs()
164            }
165            MultiAD::Add => {
166                AutodiffError::check_arity("Add", 2, args.len())?;
167                args[0] + args[1]
168            }
169            MultiAD::Sub => {
170                AutodiffError::check_arity("Sub", 2, args.len())?;
171                args[0] - args[1]
172            }
173            MultiAD::Mul => {
174                AutodiffError::check_arity("Mul", 2, args.len())?;
175                args[0] * args[1]
176            }
177            MultiAD::Div => {
178                AutodiffError::check_arity("Div", 2, args.len())?;
179                args[0] / args[1]
180            }
181            MultiAD::Pow => {
182                AutodiffError::check_arity("Pow", 2, args.len())?;
183                args[0].powf(args[1])
184            }
185        })
186    }
187
188    /// Backward pass: compute local gradients ∂output/∂inputs
189    /// Returns a boxed closure that computes gradients given a cotangent value
190    fn backward_generic<W>(&self, args: &[f64]) -> Result<W>
191    where
192        W: From<Box<DynGradFn>>,
193    {
194        AutodiffError::check_arity(self.op_name(), self.expected_arity(), args.len())?;
195
196        let backward_fn: Box<dyn Fn(f64) -> Vec<f64>> = match self {
197            MultiAD::Inp => Box::new(|zcotangent: f64| vec![zcotangent]),
198            MultiAD::Sin => {
199                let arg_val = args[0];
200                Box::new(move |z_cotangent: f64| {
201                    let x_cotangent = z_cotangent * arg_val.cos();
202                    vec![x_cotangent]
203                })
204            }
205            MultiAD::Cos => {
206                let arg_val = args[0];
207                Box::new(move |z_cotangent: f64| {
208                    let x_cotangent = z_cotangent * -arg_val.sin();
209                    vec![x_cotangent]
210                })
211            }
212            MultiAD::Tan => {
213                let arg_val = args[0];
214                Box::new(move |z_cotangent: f64| {
215                    let x_cotangent = z_cotangent * (1.0 / arg_val.cos().powi(2));
216                    vec![x_cotangent]
217                })
218            }
219            MultiAD::Exp => {
220                let exp_val = args[0].exp();
221                Box::new(move |z_cotangent: f64| {
222                    let x_cotangent = z_cotangent * exp_val;
223                    vec![x_cotangent]
224                })
225            }
226            MultiAD::Ln => {
227                let arg_val = args[0];
228                Box::new(move |z_cotangent: f64| {
229                    let x_cotangent = z_cotangent * (1.0 / arg_val);
230                    vec![x_cotangent]
231                })
232            }
233            MultiAD::Add => Box::new(|z_cotangent: f64| vec![z_cotangent, z_cotangent]),
234            MultiAD::Sub => Box::new(|z_cotangent: f64| vec![z_cotangent, -z_cotangent]),
235            MultiAD::Mul => {
236                let arg0 = args[0];
237                let arg1 = args[1];
238                Box::new(move |z_cotangent: f64| vec![z_cotangent * arg1, z_cotangent * arg0])
239            }
240            MultiAD::Div => {
241                let arg0 = args[0];
242                let arg1 = args[1];
243                Box::new(move |z_cotangent: f64| {
244                    vec![z_cotangent / arg1, -z_cotangent * arg0 / arg1.powi(2)]
245                })
246            }
247            MultiAD::Pow => {
248                let base = args[0];
249                let exp = args[1];
250                Box::new(move |z_cotangent: f64| {
251                    // d(a^b)/da = b * a^(b-1)
252                    let d_base = z_cotangent * exp * base.powf(exp - 1.0);
253                    // d(a^b)/db = a^b * ln(a)
254                    let d_exp = z_cotangent * base.powf(exp) * base.ln();
255                    vec![d_base, d_exp]
256                })
257            }
258            MultiAD::Sqrt => {
259                let arg_val = args[0];
260                Box::new(move |z_cotangent: f64| {
261                    // d(sqrt(x))/dx = 1/(2*sqrt(x))
262                    let x_cotangent = z_cotangent / (2.0 * arg_val.sqrt());
263                    vec![x_cotangent]
264                })
265            }
266            MultiAD::Abs => {
267                let arg_val = args[0];
268                Box::new(move |z_cotangent: f64| {
269                    // d(|x|)/dx = sign(x) where sign(0) = 0
270                    let sign = if arg_val >= 0.0 { 1.0 } else { -1.0 };
271                    vec![z_cotangent * sign]
272                })
273            }
274        };
275        Ok(W::from(backward_fn))
276    }
277
278    /// Compute forward pass only (no gradient computation).
279    ///
280    /// Evaluates the computational graph to produce the final output value.
281    ///
282    /// # Arguments
283    ///
284    /// * `exprs` - Slice of (operation, indices) pairs defining the computation graph
285    /// * `inputs` - Input values for the function
286    ///
287    /// # Errors
288    ///
289    /// Returns `Err(AutodiffError)` if an operation receives incorrect arity.
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// use petite_ad::{MultiAD, multi_ops};
295    ///
296    /// let exprs = multi_ops![(inp, 0), (inp, 1), (add, 0, 1)];
297    /// let result = MultiAD::compute(&exprs, &[2.0, 3.0]).unwrap();
298    /// assert!((result - 5.0).abs() < 1e-10);
299    /// ```
300    #[must_use = "forward computation is expensive; discarding the result is likely a bug"]
301    pub fn compute(exprs: &[(MultiAD, Vec<usize>)], inputs: &[f64]) -> Result<f64> {
302        let mut values: Vec<f64> = inputs.to_vec();
303
304        for (op, arg_indices) in exprs {
305            if *op == MultiAD::Inp {
306                continue; // Input values are already in the values array
307            }
308
309            // Gather the argument values from the computation graph
310            let arg_values: Vec<f64> = arg_indices.iter().map(|&i| values[i]).collect();
311
312            // Compute this operation
313            let value = op.forward(&arg_values)?;
314            values.push(value);
315        }
316
317        // Return the final computed value
318        Ok(values.last().copied().unwrap_or(0.0))
319    }
320
321    /// Compute forward pass and return gradient function.
322    ///
323    /// Returns a tuple of (value, gradient_function). The gradient function
324    /// takes a cotangent (typically 1.0) and returns a vector of gradients
325    /// with respect to each input.
326    ///
327    /// The result is Box-wrapped by default. If you need Arc for sharing across threads,
328    /// convert using `Arc::from(box_fn)`.
329    ///
330    /// # Arguments
331    ///
332    /// * `exprs` - Computational graph as (operation, indices) pairs
333    /// * `inputs` - Input values to evaluate at
334    ///
335    /// # Returns
336    ///
337    /// Tuple of (output_value, gradient_function)
338    ///
339    /// # Errors
340    ///
341    /// Returns `Err(AutodiffError)` if an operation receives incorrect arity.
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use petite_ad::{MultiAD, multi_ops};
347    /// use std::sync::Arc;
348    ///
349    /// let exprs = multi_ops![
350    ///     (inp, 0), (inp, 1),
351    ///     (add, 0, 1), (sin, 0), (mul, 2, 3)
352    /// ];
353    /// let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
354    /// let gradients = grad_fn(1.0);
355    ///
356    /// // Convert to Arc if needed for sharing
357    /// let arc_grad_fn: Arc<dyn Fn(f64) -> Vec<f64>> = Arc::from(grad_fn);
358    /// ```
359    #[must_use = "gradient computation is expensive; discarding the result is likely a bug"]
360    pub fn compute_grad_generic<W>(
361        exprs: &[(MultiAD, Vec<usize>)],
362        inputs: &[f64],
363    ) -> Result<(f64, W)>
364    where
365        W: From<Box<DynGradFn>> + std::ops::Deref<Target = DynGradFn> + 'static,
366    {
367        // Pre-allocate with capacity for better performance
368        let estimated_size = inputs.len() + exprs.len();
369        let mut values: Vec<f64> = Vec::with_capacity(estimated_size);
370        values.extend_from_slice(inputs);
371
372        let mut backward_ops: Vec<Box<DynGradFn>> = Vec::with_capacity(exprs.len());
373        let mut arg_indices_list: Vec<Vec<usize>> = Vec::with_capacity(exprs.len());
374
375        // Forward pass: compute all values and track backward operations
376        for (op, args) in exprs {
377            if *op == MultiAD::Inp {
378                continue;
379            }
380            let arg_values: Vec<f64> = args.iter().map(|&i| values[i]).collect();
381            let value = op.forward(&arg_values)?;
382            values.push(value);
383
384            // Store the backward operation (which captures necessary values)
385            backward_ops.push(op.backward_generic(&arg_values)?);
386            arg_indices_list.push(args.clone());
387        }
388
389        let final_value = values.last().copied().unwrap_or(0.0);
390
391        // Clone the data we need for the backward pass
392        let num_inputs = inputs.len();
393        let values_clone = values;
394
395        let backward_fn = Box::new(move |cotangent: f64| -> Vec<f64> {
396            let mut cotangent_values = vec![0.0; values_clone.len()];
397            cotangent_values[values_clone.len() - 1] = cotangent;
398
399            // Backward pass: propagate cotangents from output to inputs
400            for (i, (backward_op, arg_indices)) in backward_ops
401                .iter()
402                .zip(arg_indices_list.iter())
403                .rev() // Process operations in reverse order
404                .enumerate()
405            {
406                let output_idx = values_clone.len() - 1 - i;
407                let current_cotangent_value = cotangent_values[output_idx];
408                let argv_cotangents = backward_op(current_cotangent_value);
409
410                // Accumulate gradients for each input argument
411                for (arg_idx, arg_cotangent) in arg_indices.iter().zip(argv_cotangents) {
412                    cotangent_values[*arg_idx] += arg_cotangent;
413                }
414            }
415
416            cotangent_values[..num_inputs].to_vec()
417        });
418
419        Ok((final_value, W::from(backward_fn)))
420    }
421
422    #[must_use = "gradient computation is expensive; discarding the result is likely a bug"]
423    pub fn compute_grad(
424        exprs: &[(MultiAD, Vec<usize>)],
425        inputs: &[f64],
426    ) -> Result<BackwardResultBox> {
427        Self::compute_grad_generic::<Box<DynGradFn>>(exprs, inputs)
428    }
429}