quantrs2_symengine_pure/diff/
mod.rs

1//! Symbolic differentiation module.
2//!
3//! This module implements automatic symbolic differentiation using
4//! the standard rules of calculus.
5
6use egg::{Id, Language, RecExpr};
7
8use crate::error::{SymEngineError, SymEngineResult};
9use crate::expr::{ExprLang, Expression};
10
11/// Compute the derivative of an expression with respect to a variable.
12///
13/// This implements symbolic differentiation using the chain rule,
14/// product rule, quotient rule, and standard derivative formulas.
15pub fn differentiate(expr: &Expression, var: &Expression) -> Expression {
16    let var_name = match var.as_symbol() {
17        Some(name) => name.to_string(),
18        None => {
19            // If var is not a symbol, return zero (constant derivative)
20            return Expression::zero();
21        }
22    };
23
24    differentiate_rec(
25        expr.as_rec_expr(),
26        expr.as_rec_expr().as_ref().len() - 1,
27        &var_name,
28    )
29}
30
31/// Recursive differentiation helper
32fn differentiate_rec(expr: &RecExpr<ExprLang>, idx: usize, var: &str) -> Expression {
33    let node = &expr[Id::from(idx)];
34
35    match node {
36        // d/dx(c) = 0 for constants, d/dx(x) = 1, d/dx(y) = 0
37        ExprLang::Num(s) => {
38            let name = s.as_str();
39            // Check if it's a number (constant) or a variable
40            if name.parse::<f64>().is_ok() {
41                // It's a numeric constant
42                Expression::zero()
43            } else if name == var {
44                // It's the variable we're differentiating with respect to
45                Expression::one()
46            } else {
47                // It's a different variable (treat as constant)
48                Expression::zero()
49            }
50        }
51
52        // d/dx(a + b) = da/dx + db/dx
53        ExprLang::Add([a, b]) => {
54            let da = differentiate_rec(expr, usize::from(*a), var);
55            let db = differentiate_rec(expr, usize::from(*b), var);
56            da + db
57        }
58
59        // d/dx(a * b) = a * db/dx + da/dx * b (product rule)
60        ExprLang::Mul([a, b]) => {
61            let a_expr = extract_subexpr(expr, usize::from(*a));
62            let b_expr = extract_subexpr(expr, usize::from(*b));
63            let da = differentiate_rec(expr, usize::from(*a), var);
64            let db = differentiate_rec(expr, usize::from(*b), var);
65            a_expr * db + da * b_expr
66        }
67
68        // d/dx(a / b) = (da/dx * b - a * db/dx) / b^2 (quotient rule)
69        ExprLang::Div([a, b]) => {
70            let a_expr = extract_subexpr(expr, usize::from(*a));
71            let b_expr = extract_subexpr(expr, usize::from(*b));
72            let da = differentiate_rec(expr, usize::from(*a), var);
73            let db = differentiate_rec(expr, usize::from(*b), var);
74            (da * b_expr.clone() - a_expr * db) / (b_expr.clone() * b_expr)
75        }
76
77        // d/dx(a^b) - power rule with chain rule
78        ExprLang::Pow([a, b]) => {
79            let a_expr = extract_subexpr(expr, usize::from(*a));
80            let b_expr = extract_subexpr(expr, usize::from(*b));
81            let da = differentiate_rec(expr, usize::from(*a), var);
82
83            // Simple case: constant exponent
84            if let Some(n) = b_expr.to_f64() {
85                // d/dx(a^n) = n * a^(n-1) * da/dx
86                Expression::float_unchecked(n)
87                    * a_expr.pow(&Expression::float_unchecked(n - 1.0))
88                    * da
89            } else {
90                // General power rule: a^b * (b' * ln(a) + b * a'/a)
91                let db = differentiate_rec(expr, usize::from(*b), var);
92                let ln_a = crate::ops::trig::log(&a_expr);
93                let term1 = db * ln_a;
94                let term2 = b_expr.clone() * da / a_expr.clone();
95                a_expr.pow(&b_expr) * (term1 + term2)
96            }
97        }
98
99        // d/dx(-a) = -da/dx
100        ExprLang::Neg([a]) => {
101            let da = differentiate_rec(expr, usize::from(*a), var);
102            da.neg()
103        }
104
105        // d/dx(1/a) = -da/dx / a^2
106        ExprLang::Inv([a]) => {
107            let a_expr = extract_subexpr(expr, usize::from(*a));
108            let da = differentiate_rec(expr, usize::from(*a), var);
109            da.neg() / (a_expr.clone() * a_expr)
110        }
111
112        // Trigonometric derivatives
113        ExprLang::Sin([a]) => {
114            let a_expr = extract_subexpr(expr, usize::from(*a));
115            let da = differentiate_rec(expr, usize::from(*a), var);
116            crate::ops::trig::cos(&a_expr) * da
117        }
118
119        ExprLang::Cos([a]) => {
120            let a_expr = extract_subexpr(expr, usize::from(*a));
121            let da = differentiate_rec(expr, usize::from(*a), var);
122            crate::ops::trig::sin(&a_expr).neg() * da
123        }
124
125        ExprLang::Tan([a]) => {
126            let a_expr = extract_subexpr(expr, usize::from(*a));
127            let da = differentiate_rec(expr, usize::from(*a), var);
128            let sec_sq =
129                Expression::one() + crate::ops::trig::tan(&a_expr).pow(&Expression::int(2));
130            sec_sq * da
131        }
132
133        // d/dx(exp(a)) = exp(a) * da/dx
134        ExprLang::Exp([a]) => {
135            let a_expr = extract_subexpr(expr, usize::from(*a));
136            let da = differentiate_rec(expr, usize::from(*a), var);
137            crate::ops::trig::exp(&a_expr) * da
138        }
139
140        // d/dx(log(a)) = da/dx / a
141        ExprLang::Log([a]) => {
142            let a_expr = extract_subexpr(expr, usize::from(*a));
143            let da = differentiate_rec(expr, usize::from(*a), var);
144            da / a_expr
145        }
146
147        // d/dx(sqrt(a)) = da/dx / (2 * sqrt(a))
148        ExprLang::Sqrt([a]) => {
149            let a_expr = extract_subexpr(expr, usize::from(*a));
150            let da = differentiate_rec(expr, usize::from(*a), var);
151            da / (Expression::int(2) * crate::ops::trig::sqrt(&a_expr))
152        }
153
154        // d/dx(|a|) = a/|a| * da/dx (for a != 0)
155        ExprLang::Abs([a]) => {
156            let a_expr = extract_subexpr(expr, usize::from(*a));
157            let da = differentiate_rec(expr, usize::from(*a), var);
158            a_expr.clone() / crate::ops::trig::abs(&a_expr) * da
159        }
160
161        // Inverse trig
162        ExprLang::Asin([a]) => {
163            let a_expr = extract_subexpr(expr, usize::from(*a));
164            let da = differentiate_rec(expr, usize::from(*a), var);
165            da / crate::ops::trig::sqrt(&(Expression::one() - a_expr.pow(&Expression::int(2))))
166        }
167
168        ExprLang::Acos([a]) => {
169            let a_expr = extract_subexpr(expr, usize::from(*a));
170            let da = differentiate_rec(expr, usize::from(*a), var);
171            da.neg()
172                / crate::ops::trig::sqrt(&(Expression::one() - a_expr.pow(&Expression::int(2))))
173        }
174
175        ExprLang::Atan([a]) => {
176            let a_expr = extract_subexpr(expr, usize::from(*a));
177            let da = differentiate_rec(expr, usize::from(*a), var);
178            da / (Expression::one() + a_expr.pow(&Expression::int(2)))
179        }
180
181        // Hyperbolic functions
182        ExprLang::Sinh([a]) => {
183            let a_expr = extract_subexpr(expr, usize::from(*a));
184            let da = differentiate_rec(expr, usize::from(*a), var);
185            crate::ops::trig::cosh(&a_expr) * da
186        }
187
188        ExprLang::Cosh([a]) => {
189            let a_expr = extract_subexpr(expr, usize::from(*a));
190            let da = differentiate_rec(expr, usize::from(*a), var);
191            crate::ops::trig::sinh(&a_expr) * da
192        }
193
194        ExprLang::Tanh([a]) => {
195            let a_expr = extract_subexpr(expr, usize::from(*a));
196            let da = differentiate_rec(expr, usize::from(*a), var);
197            let sech_sq =
198                Expression::one() - crate::ops::trig::tanh(&a_expr).pow(&Expression::int(2));
199            sech_sq * da
200        }
201
202        // Complex operations
203        ExprLang::Re([a]) => {
204            let da = differentiate_rec(expr, usize::from(*a), var);
205            crate::ops::complex::re(&da)
206        }
207
208        ExprLang::Im([a]) => {
209            let da = differentiate_rec(expr, usize::from(*a), var);
210            crate::ops::complex::im(&da)
211        }
212
213        ExprLang::Conj([a]) => {
214            let da = differentiate_rec(expr, usize::from(*a), var);
215            da.conjugate()
216        }
217
218        // For unhandled cases, return zero
219        _ => Expression::zero(),
220    }
221}
222
223/// Extract a subexpression from a RecExpr
224fn extract_subexpr(expr: &RecExpr<ExprLang>, idx: usize) -> Expression {
225    let mut new_expr = RecExpr::default();
226    extract_subexpr_rec(
227        expr,
228        idx,
229        &mut new_expr,
230        &mut std::collections::HashMap::new(),
231    );
232    Expression::from_rec_expr(new_expr)
233}
234
235fn extract_subexpr_rec(
236    expr: &RecExpr<ExprLang>,
237    idx: usize,
238    new_expr: &mut RecExpr<ExprLang>,
239    id_map: &mut std::collections::HashMap<usize, Id>,
240) -> Id {
241    if let Some(&new_id) = id_map.get(&idx) {
242        return new_id;
243    }
244
245    let node = &expr[Id::from(idx)];
246    let new_node = node.clone().map_children(|child_id| {
247        extract_subexpr_rec(expr, usize::from(child_id), new_expr, id_map)
248    });
249    let new_id = new_expr.add(new_node);
250    id_map.insert(idx, new_id);
251    new_id
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_diff_constant() {
260        let c = Expression::int(5);
261        let x = Expression::symbol("x");
262        let dc = differentiate(&c, &x);
263        assert!(dc.is_zero());
264    }
265
266    #[test]
267    fn test_diff_variable() {
268        let x = Expression::symbol("x");
269        let dx = differentiate(&x, &x);
270        assert!(dx.is_one());
271    }
272
273    #[test]
274    fn test_diff_other_variable() {
275        let x = Expression::symbol("x");
276        let y = Expression::symbol("y");
277        let dy = differentiate(&y, &x);
278        assert!(dy.is_zero());
279    }
280
281    #[test]
282    fn test_diff_sum() {
283        let x = Expression::symbol("x");
284        let c = Expression::int(5);
285        let expr = x.clone() + c; // x + 5
286        let dx = differentiate(&expr, &x);
287        // d/dx(x + 5) = 1 + 0 (unsimplified)
288        // After simplification it would be 1
289        assert!(!dx.to_string().is_empty());
290    }
291
292    #[test]
293    fn test_diff_product() {
294        let x = Expression::symbol("x");
295        let expr = x.clone() * x.clone(); // x^2 as x*x
296        let dx = differentiate(&expr, &x);
297        // d/dx(x*x) = x*1 + 1*x = 2x
298        // The result won't be simplified, but the structure should be correct
299        assert!(!dx.is_zero());
300    }
301}