mathhook_core/calculus/derivatives/partial/
gradient.rs

1//! Gradient and directional derivative operations
2
3use crate::calculus::derivatives::Derivative;
4use crate::core::{Expression, Symbol};
5use crate::simplify::Simplify;
6
7/// Gradient vector operations
8pub struct GradientOperations;
9
10impl GradientOperations {
11    /// Compute gradient vector
12    ///
13    /// # Examples
14    ///
15    /// ```rust
16    /// use mathhook_core::simplify::Simplify;
17    /// use mathhook_core::calculus::derivatives::Derivative;
18    /// use mathhook_core::{Expression};
19    /// use mathhook_core::symbol;
20    /// use mathhook_core::calculus::derivatives::GradientOperations;
21    ///
22    /// let x = symbol!(x);
23    /// let y = symbol!(y);
24    /// let expr = Expression::add(vec![
25    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
26    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
27    /// ]);
28    /// let gradient = GradientOperations::compute(&expr, vec![x, y]);
29    /// ```
30    pub fn compute(expr: &Expression, variables: Vec<Symbol>) -> Vec<Expression> {
31        let n = variables.len();
32        let mut gradient = Vec::with_capacity(n);
33
34        for var in variables {
35            let partial = expr.derivative(var).simplify();
36            gradient.push(partial);
37        }
38
39        gradient
40    }
41
42    /// Compute gradient with caching for repeated computations
43    ///
44    /// # Examples
45    ///
46    /// ```rust
47    /// use mathhook_core::{Expression};
48    /// use mathhook_core::symbol;
49    /// use mathhook_core::calculus::derivatives::GradientOperations;
50    /// use std::collections::HashMap;
51    ///
52    /// let x = symbol!(x);
53    /// let y = symbol!(y);
54    /// let expr = Expression::mul(vec![
55    ///     Expression::symbol(x.clone()),
56    ///     Expression::symbol(y.clone())
57    /// ]);
58    /// let mut cache = HashMap::new();
59    /// let gradient = GradientOperations::compute_cached(&expr, &[x, y], &mut cache);
60    /// ```
61    pub fn compute_cached(
62        expr: &Expression,
63        variables: &[Symbol],
64        cache: &mut std::collections::HashMap<Symbol, Expression>,
65    ) -> Vec<Expression> {
66        let mut gradient = Vec::with_capacity(variables.len());
67
68        for var in variables {
69            let partial = cache
70                .entry(var.clone())
71                .or_insert_with(|| expr.derivative(var.clone()).simplify())
72                .clone();
73            gradient.push(partial);
74        }
75
76        gradient
77    }
78}
79
80/// Directional derivative operations
81pub struct DirectionalDerivatives;
82
83impl DirectionalDerivatives {
84    /// Compute directional derivative
85    ///
86    /// # Examples
87    ///
88    /// ```rust
89    /// use mathhook_core::simplify::Simplify;
90    /// use mathhook_core::{Expression};
91    /// use mathhook_core::symbol;
92    /// use mathhook_core::calculus::derivatives::DirectionalDerivatives;
93    ///
94    /// let x = symbol!(x);
95    /// let y = symbol!(y);
96    /// let expr = Expression::add(vec![
97    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
98    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
99    /// ]);
100    /// let direction = vec![Expression::integer(1), Expression::integer(1)];
101    /// let dir_deriv = DirectionalDerivatives::compute(&expr, vec![x, y], direction);
102    /// ```
103    pub fn compute(
104        expr: &Expression,
105        variables: Vec<Symbol>,
106        direction: Vec<Expression>,
107    ) -> Expression {
108        if variables.len() != direction.len() {
109            panic!(
110                "Dimension mismatch: {} variables vs {} direction components",
111                variables.len(),
112                direction.len()
113            );
114        }
115
116        let gradient = GradientOperations::compute(expr, variables);
117        Self::dot_product(gradient, direction)
118    }
119
120    /// Compute dot product of gradient and direction
121    fn dot_product(gradient: Vec<Expression>, direction: Vec<Expression>) -> Expression {
122        let n = gradient.len();
123        let mut dot_terms = Vec::with_capacity(n);
124
125        for (grad_component, dir_component) in gradient.into_iter().zip(direction) {
126            dot_terms.push(Expression::mul(vec![grad_component, dir_component]));
127        }
128
129        Expression::add(dot_terms).simplify()
130    }
131
132    /// Compute unit directional derivative (normalizes direction vector)
133    ///
134    /// # Examples
135    ///
136    /// ```rust
137    /// use mathhook_core::{Expression, Symbol};
138    /// use mathhook_core::symbol;
139    /// use mathhook_core::calculus::derivatives::DirectionalDerivatives;
140    ///
141    /// let x = symbol!(x);
142    /// let y = symbol!(y);
143    /// let expr = Expression::add(vec![
144    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
145    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
146    /// ]);
147    /// let direction = vec![Expression::integer(3), Expression::integer(4)];
148    /// let unit_dir_deriv = DirectionalDerivatives::unit_directional(&expr, vec![x, y], direction);
149    /// ```
150    pub fn unit_directional(
151        expr: &Expression,
152        variables: Vec<Symbol>,
153        direction: Vec<Expression>,
154    ) -> Expression {
155        let magnitude_squared: Vec<Expression> = direction
156            .iter()
157            .map(|component| Expression::pow(component.clone(), Expression::integer(2)))
158            .collect();
159
160        let magnitude =
161            Expression::function("sqrt", vec![Expression::add(magnitude_squared).simplify()]);
162
163        let unit_direction: Vec<Expression> = direction
164            .into_iter()
165            .map(|component| {
166                Expression::mul(vec![
167                    component,
168                    Expression::pow(magnitude.clone(), Expression::integer(-1)),
169                ])
170            })
171            .collect();
172
173        Self::compute(expr, variables, unit_direction)
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::symbol;
181    use std::collections::HashMap;
182
183    #[test]
184    fn test_basic_gradient_computation() {
185        let x = symbol!(x);
186        let y = symbol!(y);
187
188        let quadratic = Expression::add(vec![
189            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
190            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
191        ]);
192
193        let gradient = GradientOperations::compute(&quadratic, vec![x.clone(), y.clone()]);
194        assert_eq!(gradient.len(), 2);
195        assert!(!gradient[0].is_zero());
196        assert!(!gradient[1].is_zero());
197    }
198
199    #[test]
200    fn test_linear_function_gradient() {
201        let x = symbol!(x);
202        let y = symbol!(y);
203
204        let linear = Expression::add(vec![
205            Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
206            Expression::mul(vec![Expression::integer(4), Expression::symbol(y.clone())]),
207            Expression::integer(5),
208        ]);
209
210        let gradient = GradientOperations::compute(&linear, vec![x.clone(), y.clone()]);
211        assert_eq!(gradient.len(), 2);
212        assert_eq!(gradient[0].simplify(), Expression::integer(3));
213        assert_eq!(gradient[1].simplify(), Expression::integer(4));
214    }
215
216    #[test]
217    fn test_multivariate_polynomial_gradient() {
218        let x = symbol!(x);
219        let y = symbol!(y);
220
221        let poly = Expression::add(vec![
222            Expression::mul(vec![
223                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
224                Expression::symbol(y.clone()),
225            ]),
226            Expression::mul(vec![
227                Expression::symbol(x.clone()),
228                Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
229            ]),
230        ]);
231
232        let gradient = GradientOperations::compute(&poly, vec![x.clone(), y.clone()]);
233        assert_eq!(gradient.len(), 2);
234        assert!(!gradient[0].is_zero());
235        assert!(!gradient[1].is_zero());
236    }
237
238    #[test]
239    fn test_gradient_caching() {
240        let x = symbol!(x);
241        let y = symbol!(y);
242
243        let expr = Expression::function(
244            "sin",
245            vec![Expression::add(vec![
246                Expression::symbol(x.clone()),
247                Expression::symbol(y.clone()),
248            ])],
249        );
250
251        let mut cache = HashMap::new();
252        let gradient1 =
253            GradientOperations::compute_cached(&expr, &[x.clone(), y.clone()], &mut cache);
254        let gradient2 =
255            GradientOperations::compute_cached(&expr, &[x.clone(), y.clone()], &mut cache);
256
257        assert_eq!(gradient1.len(), 2);
258        assert_eq!(gradient2.len(), 2);
259        assert_eq!(gradient1[0], gradient2[0]);
260        assert_eq!(gradient1[1], gradient2[1]);
261        assert_eq!(cache.len(), 2);
262    }
263
264    #[test]
265    fn test_three_variable_gradient() {
266        let x = symbol!(x);
267        let y = symbol!(y);
268        let z = symbol!(z);
269
270        let expr = Expression::add(vec![
271            Expression::pow(Expression::symbol(x.clone()), Expression::integer(3)),
272            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
273            Expression::symbol(z.clone()),
274        ]);
275
276        let gradient = GradientOperations::compute(&expr, vec![x.clone(), y.clone(), z.clone()]);
277        assert_eq!(gradient.len(), 3);
278        assert!(!gradient[0].is_zero());
279        assert!(!gradient[1].is_zero());
280        assert_eq!(gradient[2].simplify(), Expression::integer(1));
281    }
282
283    #[test]
284    fn test_directional_derivative_basic() {
285        let x = symbol!(x);
286        let y = symbol!(y);
287
288        let expr = Expression::add(vec![
289            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
290            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
291        ]);
292
293        let direction = vec![Expression::integer(1), Expression::integer(0)];
294        let dir_deriv =
295            DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], direction);
296        assert!(!dir_deriv.is_zero());
297    }
298
299    #[test]
300    fn test_directional_derivative_diagonal() {
301        let x = symbol!(x);
302        let y = symbol!(y);
303
304        let expr = Expression::mul(vec![
305            Expression::symbol(x.clone()),
306            Expression::symbol(y.clone()),
307        ]);
308
309        let direction = vec![Expression::integer(1), Expression::integer(1)];
310        let dir_deriv =
311            DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], direction);
312        assert!(!dir_deriv.is_zero());
313    }
314
315    #[test]
316    fn test_unit_directional_derivative() {
317        let x = symbol!(x);
318        let y = symbol!(y);
319
320        let expr = Expression::add(vec![
321            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
322            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
323        ]);
324
325        let direction = vec![Expression::integer(3), Expression::integer(4)];
326        let unit_dir_deriv =
327            DirectionalDerivatives::unit_directional(&expr, vec![x.clone(), y.clone()], direction);
328        assert!(!unit_dir_deriv.is_zero());
329    }
330
331    #[test]
332    fn test_constant_function_gradient() {
333        let x = symbol!(x);
334        let y = symbol!(y);
335
336        let constant = Expression::integer(42);
337        let gradient = GradientOperations::compute(&constant, vec![x.clone(), y.clone()]);
338
339        assert_eq!(gradient.len(), 2);
340        assert_eq!(gradient[0].simplify(), Expression::integer(0));
341        assert_eq!(gradient[1].simplify(), Expression::integer(0));
342    }
343
344    #[test]
345    fn test_single_variable_gradient() {
346        let x = symbol!(x);
347
348        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(3));
349        let gradient = GradientOperations::compute(&expr, vec![x.clone()]);
350
351        assert_eq!(gradient.len(), 1);
352        assert!(!gradient[0].is_zero());
353    }
354
355    #[test]
356    #[should_panic(expected = "Dimension mismatch")]
357    fn test_directional_derivative_dimension_mismatch() {
358        let x = symbol!(x);
359        let y = symbol!(y);
360
361        let expr = Expression::add(vec![
362            Expression::symbol(x.clone()),
363            Expression::symbol(y.clone()),
364        ]);
365
366        let wrong_direction = vec![Expression::integer(1)];
367        DirectionalDerivatives::compute(&expr, vec![x, y], wrong_direction);
368    }
369
370    #[test]
371    fn test_trigonometric_function_gradient() {
372        let x = symbol!(x);
373        let y = symbol!(y);
374
375        let trig_expr = Expression::add(vec![
376            Expression::function("sin", vec![Expression::symbol(x.clone())]),
377            Expression::function("cos", vec![Expression::symbol(y.clone())]),
378        ]);
379
380        let gradient = GradientOperations::compute(&trig_expr, vec![x.clone(), y.clone()]);
381        assert_eq!(gradient.len(), 2);
382        assert!(!gradient[0].is_zero());
383        assert!(!gradient[1].is_zero());
384    }
385
386    #[test]
387    fn test_zero_direction_vector() {
388        let x = symbol!(x);
389        let y = symbol!(y);
390
391        let expr = Expression::add(vec![
392            Expression::symbol(x.clone()),
393            Expression::symbol(y.clone()),
394        ]);
395
396        let zero_direction = vec![Expression::integer(0), Expression::integer(0)];
397        let dir_deriv =
398            DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], zero_direction);
399        assert_eq!(dir_deriv.simplify(), Expression::integer(0));
400    }
401}