mathhook_core/calculus/derivatives/partial/
gradient.rs1use crate::calculus::derivatives::Derivative;
4use crate::core::{Expression, Symbol};
5use crate::simplify::Simplify;
6
7pub struct GradientOperations;
9
10impl GradientOperations {
11 pub fn compute(expr: &Expression, variables: Vec<Symbol>) -> Vec<Expression> {
31 let n = variables.len();
32 let mut gradient = Vec::with_capacity(n);
33
34 for var in variables {
35 let partial = expr.derivative(var).simplify();
36 gradient.push(partial);
37 }
38
39 gradient
40 }
41
42 pub fn compute_cached(
62 expr: &Expression,
63 variables: &[Symbol],
64 cache: &mut std::collections::HashMap<Symbol, Expression>,
65 ) -> Vec<Expression> {
66 let mut gradient = Vec::with_capacity(variables.len());
67
68 for var in variables {
69 let partial = cache
70 .entry(var.clone())
71 .or_insert_with(|| expr.derivative(var.clone()).simplify())
72 .clone();
73 gradient.push(partial);
74 }
75
76 gradient
77 }
78}
79
80pub struct DirectionalDerivatives;
82
83impl DirectionalDerivatives {
84 pub fn compute(
104 expr: &Expression,
105 variables: Vec<Symbol>,
106 direction: Vec<Expression>,
107 ) -> Expression {
108 if variables.len() != direction.len() {
109 panic!(
110 "Dimension mismatch: {} variables vs {} direction components",
111 variables.len(),
112 direction.len()
113 );
114 }
115
116 let gradient = GradientOperations::compute(expr, variables);
117 Self::dot_product(gradient, direction)
118 }
119
120 fn dot_product(gradient: Vec<Expression>, direction: Vec<Expression>) -> Expression {
122 let n = gradient.len();
123 let mut dot_terms = Vec::with_capacity(n);
124
125 for (grad_component, dir_component) in gradient.into_iter().zip(direction) {
126 dot_terms.push(Expression::mul(vec![grad_component, dir_component]));
127 }
128
129 Expression::add(dot_terms).simplify()
130 }
131
132 pub fn unit_directional(
151 expr: &Expression,
152 variables: Vec<Symbol>,
153 direction: Vec<Expression>,
154 ) -> Expression {
155 let magnitude_squared: Vec<Expression> = direction
156 .iter()
157 .map(|component| Expression::pow(component.clone(), Expression::integer(2)))
158 .collect();
159
160 let magnitude =
161 Expression::function("sqrt", vec![Expression::add(magnitude_squared).simplify()]);
162
163 let unit_direction: Vec<Expression> = direction
164 .into_iter()
165 .map(|component| {
166 Expression::mul(vec![
167 component,
168 Expression::pow(magnitude.clone(), Expression::integer(-1)),
169 ])
170 })
171 .collect();
172
173 Self::compute(expr, variables, unit_direction)
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::symbol;
181 use std::collections::HashMap;
182
183 #[test]
184 fn test_basic_gradient_computation() {
185 let x = symbol!(x);
186 let y = symbol!(y);
187
188 let quadratic = Expression::add(vec![
189 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
190 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
191 ]);
192
193 let gradient = GradientOperations::compute(&quadratic, vec![x.clone(), y.clone()]);
194 assert_eq!(gradient.len(), 2);
195 assert!(!gradient[0].is_zero());
196 assert!(!gradient[1].is_zero());
197 }
198
199 #[test]
200 fn test_linear_function_gradient() {
201 let x = symbol!(x);
202 let y = symbol!(y);
203
204 let linear = Expression::add(vec![
205 Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
206 Expression::mul(vec![Expression::integer(4), Expression::symbol(y.clone())]),
207 Expression::integer(5),
208 ]);
209
210 let gradient = GradientOperations::compute(&linear, vec![x.clone(), y.clone()]);
211 assert_eq!(gradient.len(), 2);
212 assert_eq!(gradient[0].simplify(), Expression::integer(3));
213 assert_eq!(gradient[1].simplify(), Expression::integer(4));
214 }
215
216 #[test]
217 fn test_multivariate_polynomial_gradient() {
218 let x = symbol!(x);
219 let y = symbol!(y);
220
221 let poly = Expression::add(vec![
222 Expression::mul(vec![
223 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
224 Expression::symbol(y.clone()),
225 ]),
226 Expression::mul(vec![
227 Expression::symbol(x.clone()),
228 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
229 ]),
230 ]);
231
232 let gradient = GradientOperations::compute(&poly, vec![x.clone(), y.clone()]);
233 assert_eq!(gradient.len(), 2);
234 assert!(!gradient[0].is_zero());
235 assert!(!gradient[1].is_zero());
236 }
237
238 #[test]
239 fn test_gradient_caching() {
240 let x = symbol!(x);
241 let y = symbol!(y);
242
243 let expr = Expression::function(
244 "sin",
245 vec![Expression::add(vec![
246 Expression::symbol(x.clone()),
247 Expression::symbol(y.clone()),
248 ])],
249 );
250
251 let mut cache = HashMap::new();
252 let gradient1 =
253 GradientOperations::compute_cached(&expr, &[x.clone(), y.clone()], &mut cache);
254 let gradient2 =
255 GradientOperations::compute_cached(&expr, &[x.clone(), y.clone()], &mut cache);
256
257 assert_eq!(gradient1.len(), 2);
258 assert_eq!(gradient2.len(), 2);
259 assert_eq!(gradient1[0], gradient2[0]);
260 assert_eq!(gradient1[1], gradient2[1]);
261 assert_eq!(cache.len(), 2);
262 }
263
264 #[test]
265 fn test_three_variable_gradient() {
266 let x = symbol!(x);
267 let y = symbol!(y);
268 let z = symbol!(z);
269
270 let expr = Expression::add(vec![
271 Expression::pow(Expression::symbol(x.clone()), Expression::integer(3)),
272 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
273 Expression::symbol(z.clone()),
274 ]);
275
276 let gradient = GradientOperations::compute(&expr, vec![x.clone(), y.clone(), z.clone()]);
277 assert_eq!(gradient.len(), 3);
278 assert!(!gradient[0].is_zero());
279 assert!(!gradient[1].is_zero());
280 assert_eq!(gradient[2].simplify(), Expression::integer(1));
281 }
282
283 #[test]
284 fn test_directional_derivative_basic() {
285 let x = symbol!(x);
286 let y = symbol!(y);
287
288 let expr = Expression::add(vec![
289 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
290 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
291 ]);
292
293 let direction = vec![Expression::integer(1), Expression::integer(0)];
294 let dir_deriv =
295 DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], direction);
296 assert!(!dir_deriv.is_zero());
297 }
298
299 #[test]
300 fn test_directional_derivative_diagonal() {
301 let x = symbol!(x);
302 let y = symbol!(y);
303
304 let expr = Expression::mul(vec![
305 Expression::symbol(x.clone()),
306 Expression::symbol(y.clone()),
307 ]);
308
309 let direction = vec![Expression::integer(1), Expression::integer(1)];
310 let dir_deriv =
311 DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], direction);
312 assert!(!dir_deriv.is_zero());
313 }
314
315 #[test]
316 fn test_unit_directional_derivative() {
317 let x = symbol!(x);
318 let y = symbol!(y);
319
320 let expr = Expression::add(vec![
321 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
322 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
323 ]);
324
325 let direction = vec![Expression::integer(3), Expression::integer(4)];
326 let unit_dir_deriv =
327 DirectionalDerivatives::unit_directional(&expr, vec![x.clone(), y.clone()], direction);
328 assert!(!unit_dir_deriv.is_zero());
329 }
330
331 #[test]
332 fn test_constant_function_gradient() {
333 let x = symbol!(x);
334 let y = symbol!(y);
335
336 let constant = Expression::integer(42);
337 let gradient = GradientOperations::compute(&constant, vec![x.clone(), y.clone()]);
338
339 assert_eq!(gradient.len(), 2);
340 assert_eq!(gradient[0].simplify(), Expression::integer(0));
341 assert_eq!(gradient[1].simplify(), Expression::integer(0));
342 }
343
344 #[test]
345 fn test_single_variable_gradient() {
346 let x = symbol!(x);
347
348 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(3));
349 let gradient = GradientOperations::compute(&expr, vec![x.clone()]);
350
351 assert_eq!(gradient.len(), 1);
352 assert!(!gradient[0].is_zero());
353 }
354
355 #[test]
356 #[should_panic(expected = "Dimension mismatch")]
357 fn test_directional_derivative_dimension_mismatch() {
358 let x = symbol!(x);
359 let y = symbol!(y);
360
361 let expr = Expression::add(vec![
362 Expression::symbol(x.clone()),
363 Expression::symbol(y.clone()),
364 ]);
365
366 let wrong_direction = vec![Expression::integer(1)];
367 DirectionalDerivatives::compute(&expr, vec![x, y], wrong_direction);
368 }
369
370 #[test]
371 fn test_trigonometric_function_gradient() {
372 let x = symbol!(x);
373 let y = symbol!(y);
374
375 let trig_expr = Expression::add(vec![
376 Expression::function("sin", vec![Expression::symbol(x.clone())]),
377 Expression::function("cos", vec![Expression::symbol(y.clone())]),
378 ]);
379
380 let gradient = GradientOperations::compute(&trig_expr, vec![x.clone(), y.clone()]);
381 assert_eq!(gradient.len(), 2);
382 assert!(!gradient[0].is_zero());
383 assert!(!gradient[1].is_zero());
384 }
385
386 #[test]
387 fn test_zero_direction_vector() {
388 let x = symbol!(x);
389 let y = symbol!(y);
390
391 let expr = Expression::add(vec![
392 Expression::symbol(x.clone()),
393 Expression::symbol(y.clone()),
394 ]);
395
396 let zero_direction = vec![Expression::integer(0), Expression::integer(0)];
397 let dir_deriv =
398 DirectionalDerivatives::compute(&expr, vec![x.clone(), y.clone()], zero_direction);
399 assert_eq!(dir_deriv.simplify(), Expression::integer(0));
400 }
401}