petite_ad/multi/multi_ad.rs
1use super::types::*;
2use crate::error::{AutodiffError, Result};
3
4/// Multi-variable automatic differentiation operations.
5///
6/// Represents operations in a computational graph for functions with multiple inputs.
7/// Each operation takes references to previous results via indices.
8///
9/// # Examples
10///
11/// ```
12/// use petite_ad::{MultiAD, multi_ops};
13///
14/// // Build graph: f(x, y) = sin(x) * (x + y)
15/// let exprs = multi_ops![
16/// (inp, 0), // x at index 0
17/// (inp, 1), // y at index 1
18/// (add, 0, 1), // x + y at index 2
19/// (sin, 0), // sin(x) at index 3
20/// (mul, 2, 3), // sin(x) * (x + y) at index 4
21/// ];
22///
23/// let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
24/// let gradients = grad_fn(1.0);
25/// println!("f(0.6, 1.4) = {}", value);
26/// println!("∇f = {:?}", gradients);
27/// ```
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum MultiAD {
30 /// Input placeholder - references an input variable
31 Inp,
32 /// Addition: a + b
33 Add,
34 /// Subtraction: a - b
35 Sub,
36 /// Multiplication: a * b
37 Mul,
38 /// Division: a / b
39 ///
40 /// # Notes
41 /// - Delegates to `f64::div()`, which returns `inf` for division by zero
42 /// - Returns `NaN` for `0.0 / 0.0`
43 Div,
44 /// Power: a^b (a raised to the power of b)
45 ///
46 /// # Notes
47 /// - Delegates to `f64::powf()`
48 /// - For `x^n` where n is an integer, consider using repeated multiplication
49 Pow,
50 /// Sine function: sin(x)
51 ///
52 /// # Notes
53 /// - Delegates to `f64::sin()`, which operates in radians
54 /// - Returns values in the range `[-1.0, 1.0]`
55 Sin,
56 /// Cosine function: cos(x)
57 ///
58 /// # Notes
59 /// - Delegates to `f64::cos()`, which operates in radians
60 /// - Returns values in the range `[-1.0, 1.0]`
61 Cos,
62 /// Tangent function: tan(x)
63 ///
64 /// # Notes
65 /// - Delegates to `f64::tan()`, which operates in radians
66 /// - Returns very large values near `π/2 + kπ` (asymptotes)
67 Tan,
68 /// Exponential function: exp(x)
69 ///
70 /// # Notes
71 /// - Delegates to `f64::exp()`
72 /// - Returns `inf` for very large inputs (> ~709 for f64)
73 /// - Returns `0.0` for very large negative inputs (< ~-745 for f64)
74 Exp,
75 /// Natural logarithm: ln(x)
76 ///
77 /// # Notes
78 /// - Delegates to `f64::ln()`
79 /// - Returns `NaN` for negative inputs
80 /// - Returns `-inf` for `ln(0.0)`
81 Ln,
82 /// Square root: sqrt(x)
83 ///
84 /// # Notes
85 /// - Delegates to `f64::sqrt()`
86 /// - Returns `NaN` for negative inputs
87 Sqrt,
88 /// Absolute value: abs(x)
89 ///
90 /// # Notes
91 /// - Delegates to `f64::abs()`
92 /// - Subgradient at x=0 is 0 (consistent with common practice)
93 Abs,
94}
95
96impl MultiAD {
97 /// Get the name of this operation (for error messages and arity checking)
98 fn op_name(&self) -> &'static str {
99 match self {
100 MultiAD::Inp => "Inp",
101 MultiAD::Add => "Add",
102 MultiAD::Sub => "Sub",
103 MultiAD::Mul => "Mul",
104 MultiAD::Div => "Div",
105 MultiAD::Pow => "Pow",
106 MultiAD::Sin => "Sin",
107 MultiAD::Cos => "Cos",
108 MultiAD::Tan => "Tan",
109 MultiAD::Exp => "Exp",
110 MultiAD::Ln => "Ln",
111 MultiAD::Sqrt => "Sqrt",
112 MultiAD::Abs => "Abs",
113 }
114 }
115
116 /// Get the expected arity for this operation
117 fn expected_arity(&self) -> usize {
118 match self {
119 MultiAD::Inp
120 | MultiAD::Sin
121 | MultiAD::Cos
122 | MultiAD::Tan
123 | MultiAD::Exp
124 | MultiAD::Ln
125 | MultiAD::Sqrt
126 | MultiAD::Abs => 1,
127 MultiAD::Add | MultiAD::Sub | MultiAD::Mul | MultiAD::Div | MultiAD::Pow => 2,
128 }
129 }
130 /// Forward pass: compute the output of this operation given inputs
131 fn forward(&self, args: &[f64]) -> Result<f64> {
132 Ok(match self {
133 MultiAD::Inp => {
134 AutodiffError::check_arity("Inp", 1, args.len())?;
135 args[0]
136 }
137 MultiAD::Sin => {
138 AutodiffError::check_arity("Sin", 1, args.len())?;
139 args[0].sin()
140 }
141 MultiAD::Cos => {
142 AutodiffError::check_arity("Cos", 1, args.len())?;
143 args[0].cos()
144 }
145 MultiAD::Tan => {
146 AutodiffError::check_arity("Tan", 1, args.len())?;
147 args[0].tan()
148 }
149 MultiAD::Exp => {
150 AutodiffError::check_arity("Exp", 1, args.len())?;
151 args[0].exp()
152 }
153 MultiAD::Ln => {
154 AutodiffError::check_arity("Ln", 1, args.len())?;
155 args[0].ln()
156 }
157 MultiAD::Sqrt => {
158 AutodiffError::check_arity("Sqrt", 1, args.len())?;
159 args[0].sqrt()
160 }
161 MultiAD::Abs => {
162 AutodiffError::check_arity("Abs", 1, args.len())?;
163 args[0].abs()
164 }
165 MultiAD::Add => {
166 AutodiffError::check_arity("Add", 2, args.len())?;
167 args[0] + args[1]
168 }
169 MultiAD::Sub => {
170 AutodiffError::check_arity("Sub", 2, args.len())?;
171 args[0] - args[1]
172 }
173 MultiAD::Mul => {
174 AutodiffError::check_arity("Mul", 2, args.len())?;
175 args[0] * args[1]
176 }
177 MultiAD::Div => {
178 AutodiffError::check_arity("Div", 2, args.len())?;
179 args[0] / args[1]
180 }
181 MultiAD::Pow => {
182 AutodiffError::check_arity("Pow", 2, args.len())?;
183 args[0].powf(args[1])
184 }
185 })
186 }
187
188 /// Backward pass: compute local gradients ∂output/∂inputs
189 /// Returns a boxed closure that computes gradients given a cotangent value
190 fn backward_generic<W>(&self, args: &[f64]) -> Result<W>
191 where
192 W: From<Box<DynGradFn>>,
193 {
194 AutodiffError::check_arity(self.op_name(), self.expected_arity(), args.len())?;
195
196 let backward_fn: Box<dyn Fn(f64) -> Vec<f64>> = match self {
197 MultiAD::Inp => Box::new(|zcotangent: f64| vec![zcotangent]),
198 MultiAD::Sin => {
199 let arg_val = args[0];
200 Box::new(move |z_cotangent: f64| {
201 let x_cotangent = z_cotangent * arg_val.cos();
202 vec![x_cotangent]
203 })
204 }
205 MultiAD::Cos => {
206 let arg_val = args[0];
207 Box::new(move |z_cotangent: f64| {
208 let x_cotangent = z_cotangent * -arg_val.sin();
209 vec![x_cotangent]
210 })
211 }
212 MultiAD::Tan => {
213 let arg_val = args[0];
214 Box::new(move |z_cotangent: f64| {
215 let x_cotangent = z_cotangent * (1.0 / arg_val.cos().powi(2));
216 vec![x_cotangent]
217 })
218 }
219 MultiAD::Exp => {
220 let exp_val = args[0].exp();
221 Box::new(move |z_cotangent: f64| {
222 let x_cotangent = z_cotangent * exp_val;
223 vec![x_cotangent]
224 })
225 }
226 MultiAD::Ln => {
227 let arg_val = args[0];
228 Box::new(move |z_cotangent: f64| {
229 let x_cotangent = z_cotangent * (1.0 / arg_val);
230 vec![x_cotangent]
231 })
232 }
233 MultiAD::Add => Box::new(|z_cotangent: f64| vec![z_cotangent, z_cotangent]),
234 MultiAD::Sub => Box::new(|z_cotangent: f64| vec![z_cotangent, -z_cotangent]),
235 MultiAD::Mul => {
236 let arg0 = args[0];
237 let arg1 = args[1];
238 Box::new(move |z_cotangent: f64| vec![z_cotangent * arg1, z_cotangent * arg0])
239 }
240 MultiAD::Div => {
241 let arg0 = args[0];
242 let arg1 = args[1];
243 Box::new(move |z_cotangent: f64| {
244 vec![z_cotangent / arg1, -z_cotangent * arg0 / arg1.powi(2)]
245 })
246 }
247 MultiAD::Pow => {
248 let base = args[0];
249 let exp = args[1];
250 Box::new(move |z_cotangent: f64| {
251 // d(a^b)/da = b * a^(b-1)
252 let d_base = z_cotangent * exp * base.powf(exp - 1.0);
253 // d(a^b)/db = a^b * ln(a)
254 let d_exp = z_cotangent * base.powf(exp) * base.ln();
255 vec![d_base, d_exp]
256 })
257 }
258 MultiAD::Sqrt => {
259 let arg_val = args[0];
260 Box::new(move |z_cotangent: f64| {
261 // d(sqrt(x))/dx = 1/(2*sqrt(x))
262 let x_cotangent = z_cotangent / (2.0 * arg_val.sqrt());
263 vec![x_cotangent]
264 })
265 }
266 MultiAD::Abs => {
267 let arg_val = args[0];
268 Box::new(move |z_cotangent: f64| {
269 // d(|x|)/dx = sign(x) where sign(0) = 0
270 let sign = if arg_val >= 0.0 { 1.0 } else { -1.0 };
271 vec![z_cotangent * sign]
272 })
273 }
274 };
275 Ok(W::from(backward_fn))
276 }
277
278 /// Compute forward pass only (no gradient computation).
279 ///
280 /// Evaluates the computational graph to produce the final output value.
281 ///
282 /// # Arguments
283 ///
284 /// * `exprs` - Slice of (operation, indices) pairs defining the computation graph
285 /// * `inputs` - Input values for the function
286 ///
287 /// # Errors
288 ///
289 /// Returns `Err(AutodiffError)` if an operation receives incorrect arity.
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// use petite_ad::{MultiAD, multi_ops};
295 ///
296 /// let exprs = multi_ops![(inp, 0), (inp, 1), (add, 0, 1)];
297 /// let result = MultiAD::compute(&exprs, &[2.0, 3.0]).unwrap();
298 /// assert!((result - 5.0).abs() < 1e-10);
299 /// ```
300 #[must_use = "forward computation is expensive; discarding the result is likely a bug"]
301 pub fn compute(exprs: &[(MultiAD, Vec<usize>)], inputs: &[f64]) -> Result<f64> {
302 let mut values: Vec<f64> = inputs.to_vec();
303
304 for (op, arg_indices) in exprs {
305 if *op == MultiAD::Inp {
306 continue; // Input values are already in the values array
307 }
308
309 // Gather the argument values from the computation graph
310 let arg_values: Vec<f64> = arg_indices.iter().map(|&i| values[i]).collect();
311
312 // Compute this operation
313 let value = op.forward(&arg_values)?;
314 values.push(value);
315 }
316
317 // Return the final computed value
318 Ok(values.last().copied().unwrap_or(0.0))
319 }
320
321 /// Compute forward pass and return gradient function.
322 ///
323 /// Returns a tuple of (value, gradient_function). The gradient function
324 /// takes a cotangent (typically 1.0) and returns a vector of gradients
325 /// with respect to each input.
326 ///
327 /// The result is Box-wrapped by default. If you need Arc for sharing across threads,
328 /// convert using `Arc::from(box_fn)`.
329 ///
330 /// # Arguments
331 ///
332 /// * `exprs` - Computational graph as (operation, indices) pairs
333 /// * `inputs` - Input values to evaluate at
334 ///
335 /// # Returns
336 ///
337 /// Tuple of (output_value, gradient_function)
338 ///
339 /// # Errors
340 ///
341 /// Returns `Err(AutodiffError)` if an operation receives incorrect arity.
342 ///
343 /// # Examples
344 ///
345 /// ```
346 /// use petite_ad::{MultiAD, multi_ops};
347 /// use std::sync::Arc;
348 ///
349 /// let exprs = multi_ops![
350 /// (inp, 0), (inp, 1),
351 /// (add, 0, 1), (sin, 0), (mul, 2, 3)
352 /// ];
353 /// let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
354 /// let gradients = grad_fn(1.0);
355 ///
356 /// // Convert to Arc if needed for sharing
357 /// let arc_grad_fn: Arc<dyn Fn(f64) -> Vec<f64>> = Arc::from(grad_fn);
358 /// ```
359 #[must_use = "gradient computation is expensive; discarding the result is likely a bug"]
360 pub fn compute_grad_generic<W>(
361 exprs: &[(MultiAD, Vec<usize>)],
362 inputs: &[f64],
363 ) -> Result<(f64, W)>
364 where
365 W: From<Box<DynGradFn>> + std::ops::Deref<Target = DynGradFn> + 'static,
366 {
367 // Pre-allocate with capacity for better performance
368 let estimated_size = inputs.len() + exprs.len();
369 let mut values: Vec<f64> = Vec::with_capacity(estimated_size);
370 values.extend_from_slice(inputs);
371
372 let mut backward_ops: Vec<Box<DynGradFn>> = Vec::with_capacity(exprs.len());
373 let mut arg_indices_list: Vec<Vec<usize>> = Vec::with_capacity(exprs.len());
374
375 // Forward pass: compute all values and track backward operations
376 for (op, args) in exprs {
377 if *op == MultiAD::Inp {
378 continue;
379 }
380 let arg_values: Vec<f64> = args.iter().map(|&i| values[i]).collect();
381 let value = op.forward(&arg_values)?;
382 values.push(value);
383
384 // Store the backward operation (which captures necessary values)
385 backward_ops.push(op.backward_generic(&arg_values)?);
386 arg_indices_list.push(args.clone());
387 }
388
389 let final_value = values.last().copied().unwrap_or(0.0);
390
391 // Clone the data we need for the backward pass
392 let num_inputs = inputs.len();
393 let values_clone = values;
394
395 let backward_fn = Box::new(move |cotangent: f64| -> Vec<f64> {
396 let mut cotangent_values = vec![0.0; values_clone.len()];
397 cotangent_values[values_clone.len() - 1] = cotangent;
398
399 // Backward pass: propagate cotangents from output to inputs
400 for (i, (backward_op, arg_indices)) in backward_ops
401 .iter()
402 .zip(arg_indices_list.iter())
403 .rev() // Process operations in reverse order
404 .enumerate()
405 {
406 let output_idx = values_clone.len() - 1 - i;
407 let current_cotangent_value = cotangent_values[output_idx];
408 let argv_cotangents = backward_op(current_cotangent_value);
409
410 // Accumulate gradients for each input argument
411 for (arg_idx, arg_cotangent) in arg_indices.iter().zip(argv_cotangents) {
412 cotangent_values[*arg_idx] += arg_cotangent;
413 }
414 }
415
416 cotangent_values[..num_inputs].to_vec()
417 });
418
419 Ok((final_value, W::from(backward_fn)))
420 }
421
422 #[must_use = "gradient computation is expensive; discarding the result is likely a bug"]
423 pub fn compute_grad(
424 exprs: &[(MultiAD, Vec<usize>)],
425 inputs: &[f64],
426 ) -> Result<BackwardResultBox> {
427 Self::compute_grad_generic::<Box<DynGradFn>>(exprs, inputs)
428 }
429}