1use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
12use crate::diff::diff_impl::DiffError;
13use crate::kernel::{ExprData, ExprId, ExprPool};
14use crate::simplify::engine::simplify;
15
16#[deprecated(
24 since = "2.0.0",
25 note = "use DiffError::ForwardUnknownFunction / ForwardNonIntegerExponent instead"
26)]
27pub type ForwardDiffError = DiffError;
28
29#[derive(Clone, Debug)]
39pub struct DualValue {
40 pub value: ExprId,
41 pub tangent: ExprId,
42}
43
44impl DualValue {
45 fn new(value: ExprId, tangent: ExprId) -> Self {
46 DualValue { value, tangent }
47 }
48
49 fn constant(value: ExprId, pool: &ExprPool) -> Self {
50 let zero = pool.integer(0_i32);
51 DualValue::new(value, zero)
52 }
53
54 fn seed(value: ExprId, pool: &ExprPool) -> Self {
55 let one = pool.integer(1_i32);
56 DualValue::new(value, one)
57 }
58
59 fn add(self, rhs: Self, pool: &ExprPool) -> Self {
60 let value = pool.add(vec![self.value, rhs.value]);
61 let tangent = pool.add(vec![self.tangent, rhs.tangent]);
62 DualValue::new(value, tangent)
63 }
64
65 fn mul(self, rhs: Self, pool: &ExprPool) -> Self {
66 let value = pool.mul(vec![self.value, rhs.value]);
68 let term1 = pool.mul(vec![self.value, rhs.tangent]);
69 let term2 = pool.mul(vec![rhs.value, self.tangent]);
70 let tangent = pool.add(vec![term1, term2]);
71 DualValue::new(value, tangent)
72 }
73
74 #[allow(dead_code)]
75 fn neg(self, pool: &ExprPool) -> Self {
76 let neg_one = pool.integer(-1_i32);
77 let value = pool.mul(vec![neg_one, self.value]);
78 let tangent = pool.mul(vec![neg_one, self.tangent]);
79 DualValue::new(value, tangent)
80 }
81
82 #[allow(dead_code)]
83 fn sub(self, rhs: Self, pool: &ExprPool) -> Self {
84 self.add(rhs.neg(pool), pool)
85 }
86
87 #[allow(dead_code)]
89 fn div(self, rhs: Self, pool: &ExprPool) -> Self {
90 let value = pool.mul(vec![self.value, pool.pow(rhs.value, pool.integer(-1_i32))]);
91 let bda = pool.mul(vec![rhs.value, self.tangent]);
92 let adb = pool.mul(vec![self.value, rhs.tangent]);
93 let neg_one = pool.integer(-1_i32);
94 let numerator = pool.add(vec![bda, pool.mul(vec![neg_one, adb])]);
95 let b_sq = pool.pow(rhs.value, pool.integer(2_i32));
96 let tangent = pool.mul(vec![numerator, pool.pow(b_sq, pool.integer(-1_i32))]);
97 DualValue::new(value, tangent)
98 }
99
100 fn pow_int(self, n: rug::Integer, pool: &ExprPool) -> Self {
102 if n == 0 {
103 let one = pool.integer(1_i32);
104 return DualValue::new(one, pool.integer(0_i32));
105 }
106 if n == 1 {
107 return self;
108 }
109 let n_id = pool.integer(n.clone());
110 let n_minus_1 = pool.integer(n - 1);
111 let value = pool.pow(self.value, n_id);
112 let base_pow = pool.pow(self.value, n_minus_1);
113 let tangent = pool.mul(vec![n_id, base_pow, self.tangent]);
114 DualValue::new(value, tangent)
115 }
116
117 fn sin(self, pool: &ExprPool) -> Self {
118 let value = pool.func("sin", vec![self.value]);
120 let cos_f = pool.func("cos", vec![self.value]);
121 let tangent = pool.mul(vec![cos_f, self.tangent]);
122 DualValue::new(value, tangent)
123 }
124
125 fn cos(self, pool: &ExprPool) -> Self {
126 let value = pool.func("cos", vec![self.value]);
128 let sin_f = pool.func("sin", vec![self.value]);
129 let neg_one = pool.integer(-1_i32);
130 let tangent = pool.mul(vec![neg_one, sin_f, self.tangent]);
131 DualValue::new(value, tangent)
132 }
133
134 fn exp(self, pool: &ExprPool) -> Self {
135 let value = pool.func("exp", vec![self.value]);
137 let tangent = pool.mul(vec![value, self.tangent]);
138 DualValue::new(value, tangent)
139 }
140
141 fn log(self, pool: &ExprPool) -> Self {
142 let value = pool.func("log", vec![self.value]);
144 let f_inv = pool.pow(self.value, pool.integer(-1_i32));
145 let tangent = pool.mul(vec![self.tangent, f_inv]);
146 DualValue::new(value, tangent)
147 }
148
149 fn sqrt(self, pool: &ExprPool) -> Self {
150 let value = pool.func("sqrt", vec![self.value]);
152 let two_sqrt = pool.mul(vec![pool.integer(2_i32), value]);
153 let tangent = pool.mul(vec![self.tangent, pool.pow(two_sqrt, pool.integer(-1_i32))]);
154 DualValue::new(value, tangent)
155 }
156}
157
158fn eval_dual(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DualValue, DiffError> {
163 enum Node {
164 IsVar,
165 IsConst,
166 Add(Vec<ExprId>),
167 Mul(Vec<ExprId>),
168 Pow { base: ExprId, exp: ExprId },
169 Func { name: String, arg: ExprId },
170 }
171
172 let node = pool.with(expr, |data| match data {
173 ExprData::Symbol { .. } if expr == var => Node::IsVar,
174 ExprData::Symbol { .. }
175 | ExprData::Integer(_)
176 | ExprData::Rational(_)
177 | ExprData::Float(_) => Node::IsConst,
178 ExprData::Add(args) => Node::Add(args.clone()),
179 ExprData::Mul(args) => Node::Mul(args.clone()),
180 ExprData::Pow { base, exp } => Node::Pow {
181 base: *base,
182 exp: *exp,
183 },
184 ExprData::Func { name, args } if args.len() == 1 => Node::Func {
185 name: name.clone(),
186 arg: args[0],
187 },
188 ExprData::Func { name, .. } => Node::Func {
189 name: name.clone(),
190 arg: expr,
191 },
192 ExprData::Piecewise { .. } | ExprData::Predicate { .. } => Node::IsConst,
195 ExprData::Forall { .. } | ExprData::Exists { .. } => Node::IsConst,
196 ExprData::BigO(_) => Node::IsConst,
197 });
198
199 match node {
200 Node::IsVar => Ok(DualValue::seed(expr, pool)),
201 Node::IsConst => Ok(DualValue::constant(expr, pool)),
202 Node::Add(args) => {
203 let mut acc = DualValue::constant(pool.integer(0_i32), pool);
204 for a in args {
205 acc = acc.add(eval_dual(a, var, pool)?, pool);
206 }
207 Ok(acc)
208 }
209 Node::Mul(args) => {
210 let mut acc = DualValue::constant(pool.integer(1_i32), pool);
211 for a in args {
212 acc = acc.mul(eval_dual(a, var, pool)?, pool);
213 }
214 Ok(acc)
215 }
216 Node::Pow { base, exp } => {
217 let n = pool
218 .with(exp, |data| match data {
219 ExprData::Integer(n) => Some(n.0.clone()),
220 _ => None,
221 })
222 .ok_or(DiffError::ForwardNonIntegerExponent)?;
223 let b = eval_dual(base, var, pool)?;
224 Ok(b.pow_int(n, pool))
225 }
226 Node::Func { name, arg } => {
227 if arg == expr {
229 return Err(DiffError::ForwardUnknownFunction(name));
230 }
231 let inner = eval_dual(arg, var, pool)?;
232 match name.as_str() {
233 "sin" => Ok(inner.sin(pool)),
234 "cos" => Ok(inner.cos(pool)),
235 "exp" => Ok(inner.exp(pool)),
236 "log" => Ok(inner.log(pool)),
237 "sqrt" => Ok(inner.sqrt(pool)),
238 other => Err(DiffError::ForwardUnknownFunction(other.to_string())),
239 }
240 }
241 }
242}
243
244pub fn diff_forward(
260 expr: ExprId,
261 var: ExprId,
262 pool: &ExprPool,
263) -> Result<DerivedExpr<ExprId>, DiffError> {
264 let dual = eval_dual(expr, var, pool)?;
265 let tangent_raw = dual.tangent;
266
267 let simplified = simplify(tangent_raw, pool);
269
270 let mut log = DerivationLog::new();
272 log.push(RewriteStep::simple("diff_forward", expr, simplified.value));
273 let full_log = log.merge(simplified.log);
274 Ok(DerivedExpr::with_log(simplified.value, full_log))
275}
276
277#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::diff::diff as sym_diff;
285 use crate::kernel::{Domain, ExprPool};
286 use crate::poly::UniPoly;
287
288 fn p() -> ExprPool {
289 ExprPool::new()
290 }
291
292 #[test]
293 fn forward_diff_constant() {
294 let pool = p();
295 let x = pool.symbol("x", Domain::Real);
296 let r = diff_forward(pool.integer(5_i32), x, &pool).unwrap();
297 assert_eq!(r.value, pool.integer(0_i32));
298 }
299
300 #[test]
301 fn forward_diff_identity() {
302 let pool = p();
303 let x = pool.symbol("x", Domain::Real);
304 let r = diff_forward(x, x, &pool).unwrap();
305 assert_eq!(r.value, pool.integer(1_i32));
306 }
307
308 #[test]
309 fn forward_diff_other_var() {
310 let pool = p();
311 let x = pool.symbol("x", Domain::Real);
312 let y = pool.symbol("y", Domain::Real);
313 let r = diff_forward(y, x, &pool).unwrap();
314 assert_eq!(r.value, pool.integer(0_i32));
315 }
316
317 #[test]
318 fn forward_diff_linear() {
319 let pool = p();
321 let x = pool.symbol("x", Domain::Real);
322 let expr = pool.mul(vec![pool.integer(3_i32), x]);
323 let r = diff_forward(expr, x, &pool).unwrap();
324 assert_eq!(r.value, pool.integer(3_i32));
325 }
326
327 #[test]
328 fn forward_diff_quadratic_agrees_with_symbolic() {
329 let pool = p();
331 let x = pool.symbol("x", Domain::Real);
332 let expr = pool.pow(x, pool.integer(2_i32));
333 let fwd = diff_forward(expr, x, &pool).unwrap();
334 let sym = sym_diff(expr, x, &pool).unwrap();
335 let fwd_poly = UniPoly::from_symbolic(fwd.value, x, &pool).unwrap();
337 let sym_poly = UniPoly::from_symbolic(sym.value, x, &pool).unwrap();
338 assert_eq!(fwd_poly.coefficients_i64(), sym_poly.coefficients_i64());
339 }
340
341 #[test]
342 fn forward_diff_cubic_agrees_with_symbolic() {
343 let pool = p();
344 let x = pool.symbol("x", Domain::Real);
345 let expr = pool.pow(x, pool.integer(3_i32));
346 let fwd = diff_forward(expr, x, &pool).unwrap().value;
347 let sym = sym_diff(expr, x, &pool).unwrap().value;
348 let fwd_poly = UniPoly::from_symbolic(fwd, x, &pool).unwrap();
349 let sym_poly = UniPoly::from_symbolic(sym, x, &pool).unwrap();
350 assert_eq!(fwd_poly.coefficients_i64(), sym_poly.coefficients_i64());
351 }
352
353 #[test]
354 fn forward_diff_sin() {
355 let pool = p();
356 let x = pool.symbol("x", Domain::Real);
357 let r = diff_forward(pool.func("sin", vec![x]), x, &pool).unwrap();
358 assert_eq!(r.value, pool.func("cos", vec![x]));
359 }
360
361 #[test]
362 fn forward_diff_exp() {
363 let pool = p();
364 let x = pool.symbol("x", Domain::Real);
365 let exp_x = pool.func("exp", vec![x]);
366 let r = diff_forward(exp_x, x, &pool).unwrap();
367 assert_eq!(r.value, exp_x);
368 }
369
370 #[test]
371 fn forward_diff_log() {
372 let pool = p();
374 let x = pool.symbol("x", Domain::Real);
375 let r = diff_forward(pool.func("log", vec![x]), x, &pool).unwrap();
376 assert_eq!(r.value, pool.pow(x, pool.integer(-1_i32)));
377 }
378
379 #[test]
380 fn forward_diff_step_logged() {
381 let pool = p();
382 let x = pool.symbol("x", Domain::Real);
383 let r = diff_forward(x, x, &pool).unwrap();
384 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_forward"));
385 }
386}