quantrs2_symengine_pure/diff/
mod.rs1use egg::{Id, Language, RecExpr};
7
8use crate::error::{SymEngineError, SymEngineResult};
9use crate::expr::{ExprLang, Expression};
10
11pub fn differentiate(expr: &Expression, var: &Expression) -> Expression {
16 let var_name = match var.as_symbol() {
17 Some(name) => name.to_string(),
18 None => {
19 return Expression::zero();
21 }
22 };
23
24 differentiate_rec(
25 expr.as_rec_expr(),
26 expr.as_rec_expr().as_ref().len() - 1,
27 &var_name,
28 )
29}
30
31fn differentiate_rec(expr: &RecExpr<ExprLang>, idx: usize, var: &str) -> Expression {
33 let node = &expr[Id::from(idx)];
34
35 match node {
36 ExprLang::Num(s) => {
38 let name = s.as_str();
39 if name.parse::<f64>().is_ok() {
41 Expression::zero()
43 } else if name == var {
44 Expression::one()
46 } else {
47 Expression::zero()
49 }
50 }
51
52 ExprLang::Add([a, b]) => {
54 let da = differentiate_rec(expr, usize::from(*a), var);
55 let db = differentiate_rec(expr, usize::from(*b), var);
56 da + db
57 }
58
59 ExprLang::Mul([a, b]) => {
61 let a_expr = extract_subexpr(expr, usize::from(*a));
62 let b_expr = extract_subexpr(expr, usize::from(*b));
63 let da = differentiate_rec(expr, usize::from(*a), var);
64 let db = differentiate_rec(expr, usize::from(*b), var);
65 a_expr * db + da * b_expr
66 }
67
68 ExprLang::Div([a, b]) => {
70 let a_expr = extract_subexpr(expr, usize::from(*a));
71 let b_expr = extract_subexpr(expr, usize::from(*b));
72 let da = differentiate_rec(expr, usize::from(*a), var);
73 let db = differentiate_rec(expr, usize::from(*b), var);
74 (da * b_expr.clone() - a_expr * db) / (b_expr.clone() * b_expr)
75 }
76
77 ExprLang::Pow([a, b]) => {
79 let a_expr = extract_subexpr(expr, usize::from(*a));
80 let b_expr = extract_subexpr(expr, usize::from(*b));
81 let da = differentiate_rec(expr, usize::from(*a), var);
82
83 if let Some(n) = b_expr.to_f64() {
85 Expression::float_unchecked(n)
87 * a_expr.pow(&Expression::float_unchecked(n - 1.0))
88 * da
89 } else {
90 let db = differentiate_rec(expr, usize::from(*b), var);
92 let ln_a = crate::ops::trig::log(&a_expr);
93 let term1 = db * ln_a;
94 let term2 = b_expr.clone() * da / a_expr.clone();
95 a_expr.pow(&b_expr) * (term1 + term2)
96 }
97 }
98
99 ExprLang::Neg([a]) => {
101 let da = differentiate_rec(expr, usize::from(*a), var);
102 da.neg()
103 }
104
105 ExprLang::Inv([a]) => {
107 let a_expr = extract_subexpr(expr, usize::from(*a));
108 let da = differentiate_rec(expr, usize::from(*a), var);
109 da.neg() / (a_expr.clone() * a_expr)
110 }
111
112 ExprLang::Sin([a]) => {
114 let a_expr = extract_subexpr(expr, usize::from(*a));
115 let da = differentiate_rec(expr, usize::from(*a), var);
116 crate::ops::trig::cos(&a_expr) * da
117 }
118
119 ExprLang::Cos([a]) => {
120 let a_expr = extract_subexpr(expr, usize::from(*a));
121 let da = differentiate_rec(expr, usize::from(*a), var);
122 crate::ops::trig::sin(&a_expr).neg() * da
123 }
124
125 ExprLang::Tan([a]) => {
126 let a_expr = extract_subexpr(expr, usize::from(*a));
127 let da = differentiate_rec(expr, usize::from(*a), var);
128 let sec_sq =
129 Expression::one() + crate::ops::trig::tan(&a_expr).pow(&Expression::int(2));
130 sec_sq * da
131 }
132
133 ExprLang::Exp([a]) => {
135 let a_expr = extract_subexpr(expr, usize::from(*a));
136 let da = differentiate_rec(expr, usize::from(*a), var);
137 crate::ops::trig::exp(&a_expr) * da
138 }
139
140 ExprLang::Log([a]) => {
142 let a_expr = extract_subexpr(expr, usize::from(*a));
143 let da = differentiate_rec(expr, usize::from(*a), var);
144 da / a_expr
145 }
146
147 ExprLang::Sqrt([a]) => {
149 let a_expr = extract_subexpr(expr, usize::from(*a));
150 let da = differentiate_rec(expr, usize::from(*a), var);
151 da / (Expression::int(2) * crate::ops::trig::sqrt(&a_expr))
152 }
153
154 ExprLang::Abs([a]) => {
156 let a_expr = extract_subexpr(expr, usize::from(*a));
157 let da = differentiate_rec(expr, usize::from(*a), var);
158 a_expr.clone() / crate::ops::trig::abs(&a_expr) * da
159 }
160
161 ExprLang::Asin([a]) => {
163 let a_expr = extract_subexpr(expr, usize::from(*a));
164 let da = differentiate_rec(expr, usize::from(*a), var);
165 da / crate::ops::trig::sqrt(&(Expression::one() - a_expr.pow(&Expression::int(2))))
166 }
167
168 ExprLang::Acos([a]) => {
169 let a_expr = extract_subexpr(expr, usize::from(*a));
170 let da = differentiate_rec(expr, usize::from(*a), var);
171 da.neg()
172 / crate::ops::trig::sqrt(&(Expression::one() - a_expr.pow(&Expression::int(2))))
173 }
174
175 ExprLang::Atan([a]) => {
176 let a_expr = extract_subexpr(expr, usize::from(*a));
177 let da = differentiate_rec(expr, usize::from(*a), var);
178 da / (Expression::one() + a_expr.pow(&Expression::int(2)))
179 }
180
181 ExprLang::Sinh([a]) => {
183 let a_expr = extract_subexpr(expr, usize::from(*a));
184 let da = differentiate_rec(expr, usize::from(*a), var);
185 crate::ops::trig::cosh(&a_expr) * da
186 }
187
188 ExprLang::Cosh([a]) => {
189 let a_expr = extract_subexpr(expr, usize::from(*a));
190 let da = differentiate_rec(expr, usize::from(*a), var);
191 crate::ops::trig::sinh(&a_expr) * da
192 }
193
194 ExprLang::Tanh([a]) => {
195 let a_expr = extract_subexpr(expr, usize::from(*a));
196 let da = differentiate_rec(expr, usize::from(*a), var);
197 let sech_sq =
198 Expression::one() - crate::ops::trig::tanh(&a_expr).pow(&Expression::int(2));
199 sech_sq * da
200 }
201
202 ExprLang::Re([a]) => {
204 let da = differentiate_rec(expr, usize::from(*a), var);
205 crate::ops::complex::re(&da)
206 }
207
208 ExprLang::Im([a]) => {
209 let da = differentiate_rec(expr, usize::from(*a), var);
210 crate::ops::complex::im(&da)
211 }
212
213 ExprLang::Conj([a]) => {
214 let da = differentiate_rec(expr, usize::from(*a), var);
215 da.conjugate()
216 }
217
218 _ => Expression::zero(),
220 }
221}
222
223fn extract_subexpr(expr: &RecExpr<ExprLang>, idx: usize) -> Expression {
225 let mut new_expr = RecExpr::default();
226 extract_subexpr_rec(
227 expr,
228 idx,
229 &mut new_expr,
230 &mut std::collections::HashMap::new(),
231 );
232 Expression::from_rec_expr(new_expr)
233}
234
235fn extract_subexpr_rec(
236 expr: &RecExpr<ExprLang>,
237 idx: usize,
238 new_expr: &mut RecExpr<ExprLang>,
239 id_map: &mut std::collections::HashMap<usize, Id>,
240) -> Id {
241 if let Some(&new_id) = id_map.get(&idx) {
242 return new_id;
243 }
244
245 let node = &expr[Id::from(idx)];
246 let new_node = node.clone().map_children(|child_id| {
247 extract_subexpr_rec(expr, usize::from(child_id), new_expr, id_map)
248 });
249 let new_id = new_expr.add(new_node);
250 id_map.insert(idx, new_id);
251 new_id
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_diff_constant() {
260 let c = Expression::int(5);
261 let x = Expression::symbol("x");
262 let dc = differentiate(&c, &x);
263 assert!(dc.is_zero());
264 }
265
266 #[test]
267 fn test_diff_variable() {
268 let x = Expression::symbol("x");
269 let dx = differentiate(&x, &x);
270 assert!(dx.is_one());
271 }
272
273 #[test]
274 fn test_diff_other_variable() {
275 let x = Expression::symbol("x");
276 let y = Expression::symbol("y");
277 let dy = differentiate(&y, &x);
278 assert!(dy.is_zero());
279 }
280
281 #[test]
282 fn test_diff_sum() {
283 let x = Expression::symbol("x");
284 let c = Expression::int(5);
285 let expr = x.clone() + c; let dx = differentiate(&expr, &x);
287 assert!(!dx.to_string().is_empty());
290 }
291
292 #[test]
293 fn test_diff_product() {
294 let x = Expression::symbol("x");
295 let expr = x.clone() * x.clone(); let dx = differentiate(&expr, &x);
297 assert!(!dx.is_zero());
300 }
301}