1use std::{
2 fmt::{Debug, Display},
3 iter,
4 str::FromStr,
5};
6
7use smallvec::SmallVec;
8
9use crate::{
10 data_type::DataType,
11 definitions::N_BINOPS_OF_DEEPEX_ON_STACK,
12 exerr,
13 expression::{
14 deep::{prioritized_indices, DeepEx, DeepNode},
15 flat::ExprIdxVec,
16 },
17 DiffDataType, ExError, ExResult, Express, MakeOperators, MatchLiteral,
18};
19
20pub fn check_partial_index(var_idx: usize, n_vars: usize, unparsed: &str) -> ExResult<()> {
21 if var_idx >= n_vars {
22 Err(exerr!(
23 "index {} is invalid since we have only {} vars in {}",
24 var_idx,
25 n_vars,
26 unparsed
27 ))
28 } else {
29 Ok(())
30 }
31}
32
33pub trait Differentiate<'a, T>
36where
37 T: DiffDataType,
38 <T as FromStr>::Err: Debug,
39 Self: Sized + Express<'a, T> + Display + Debug,
40{
41 fn partial(self, var_idx: usize) -> ExResult<Self> {
75 self.partial_nth(var_idx, 1)
76 }
77
78 fn partial_relaxed(self, var_idx: usize, missing_op_mode: MissingOpMode) -> ExResult<Self> {
81 self.partial_nth_relaxed(var_idx, 1, missing_op_mode)
82 }
83
84 fn partial_nth(self, var_idx: usize, n: usize) -> ExResult<Self> {
115 self.partial_iter(iter::repeat_n(var_idx, n))
116 }
117
118 fn partial_nth_relaxed(
121 self,
122 var_idx: usize,
123 n: usize,
124 missing_op_mode: MissingOpMode,
125 ) -> ExResult<Self> {
126 self.partial_iter_relaxed(iter::repeat_n(var_idx, n), missing_op_mode)
127 }
128
129 fn partial_iter<I>(self, var_idxs: I) -> ExResult<Self>
161 where
162 I: Iterator<Item = usize> + Clone,
163 {
164 self.partial_iter_relaxed(var_idxs, MissingOpMode::Error)
165 }
166
167 fn partial_iter_relaxed<I>(self, var_idxs: I, missing_op_mode: MissingOpMode) -> ExResult<Self>
170 where
171 I: Iterator<Item = usize> + Clone,
172 {
173 let mut deepex = self.to_deepex()?;
174
175 let unparsed = deepex.unparse();
176 for var_idx in var_idxs.clone() {
177 check_partial_index(var_idx, deepex.var_names().len(), unparsed)?;
178 }
179 for var_idx in var_idxs {
180 deepex = partial_deepex(var_idx, deepex, missing_op_mode)?;
181 }
182 deepex.compile();
183 Self::from_deepex(deepex)
184 }
185}
186#[derive(Clone, Debug)]
187struct ValueDerivative<'a, T, OF, LM>
188where
189 T: DataType,
190 OF: MakeOperators<T>,
191 LM: MatchLiteral,
192 <T as FromStr>::Err: Debug,
193{
194 val: DeepEx<'a, T, OF, LM>,
195 der: DeepEx<'a, T, OF, LM>,
196}
197
198type BinOpPartial<'a, T, OF, LM> = fn(
199 ValueDerivative<'a, T, OF, LM>,
200 ValueDerivative<'a, T, OF, LM>,
201) -> ExResult<ValueDerivative<'a, T, OF, LM>>;
202
203type UnaryOpOuter<'a, T, OF, LM> = fn(DeepEx<'a, T, OF, LM>) -> ExResult<DeepEx<'a, T, OF, LM>>;
204
205#[derive(Debug)]
206pub struct PartialDerivative<'a, T: DataType, OF, LM>
207where
208 OF: MakeOperators<T>,
209 LM: MatchLiteral,
210 <T as FromStr>::Err: Debug,
211{
212 repr: &'a str,
213 bin_op: Option<BinOpPartial<'a, T, OF, LM>>,
214 unary_outer_op: Option<UnaryOpOuter<'a, T, OF, LM>>,
215}
216
217fn make_op_missing_err(repr: &str) -> ExError {
218 exerr!("operator {} needed for outer partial derivative", repr)
219}
220
221fn partial_derivative_outer<'a, T: DiffDataType, OF, LM>(
222 deepex: DeepEx<'a, T, OF, LM>,
223 partial_derivative_ops: &[PartialDerivative<'a, T, OF, LM>],
224) -> ExResult<DeepEx<'a, T, OF, LM>>
225where
226 OF: MakeOperators<T>,
227 LM: MatchLiteral,
228 <T as FromStr>::Err: Debug,
229{
230 let mut factorexes = deepex
231 .unary_op()
232 .reprs
233 .iter()
234 .enumerate()
235 .map(|(idx, repr)| {
236 let op = partial_derivative_ops
237 .iter()
238 .find(|pdo| pdo.repr == *repr)
239 .ok_or_else(|| make_op_missing_err(repr))?;
240 let unary_deri_op = op.unary_outer_op.ok_or_else(|| make_op_missing_err(repr))?;
241 let mut new_deepex = deepex.clone();
242 for _ in 0..idx {
243 new_deepex = new_deepex.without_latest_unary();
244 }
245 unary_deri_op(new_deepex)
246 });
247 factorexes.try_fold(DeepEx::one(), |dp1, dp2| -> ExResult<DeepEx<T, OF, LM>> {
248 dp2.and_then(|dp2| dp2 * dp1)
249 })
250}
251
252#[derive(Clone, Copy, Debug)]
254pub enum MissingOpMode {
255 PerOperand,
257 None,
259 Error,
261}
262
263fn partial_derivative_inner<'a, T: DiffDataType, OF, LM>(
264 var_idx: usize,
265 deepex: DeepEx<'a, T, OF, LM>,
266 partial_derivative_ops: &[PartialDerivative<'a, T, OF, LM>],
267 missing_op_mode: MissingOpMode,
268) -> ExResult<DeepEx<'a, T, OF, LM>>
269where
270 OF: MakeOperators<T>,
271 LM: MatchLiteral,
272 <T as FromStr>::Err: Debug,
273{
274 if deepex.nodes().len() == 1 {
276 let res = match deepex.nodes()[0].clone() {
277 DeepNode::Num(_) => DeepEx::zero(),
278 DeepNode::Var((var_i, _)) => {
279 if var_i == var_idx {
280 DeepEx::one()
281 } else {
282 DeepEx::zero()
283 }
284 }
285 DeepNode::Expr(e) => partial_deepex(var_idx, *e, missing_op_mode)?,
286 };
287 let (res, _) = res.var_names_union(deepex);
288 return Ok(res);
289 }
290
291 let prio_indices = prioritized_indices(&deepex.bin_ops().ops, deepex.nodes());
292
293 let make_deepex = |node: DeepNode<'a, T, OF, LM>| match node {
294 DeepNode::Expr(e) => e,
295 _ => Box::new(DeepEx::from_node(node)),
296 };
297
298 let mut nodes = deepex
299 .nodes()
300 .iter()
301 .map(|node| -> ExResult<_> {
302 let deepex_val = make_deepex(node.clone());
303 let deepex_der = partial_deepex(var_idx, (*deepex_val).clone(), missing_op_mode)?;
304 Ok(Some(ValueDerivative {
305 val: *deepex_val,
306 der: deepex_der,
307 }))
308 })
309 .collect::<ExResult<Vec<_>>>()?;
310
311 let partial_bin_ops_of_deepex =
312 deepex
313 .bin_ops()
314 .reprs
315 .iter()
316 .map(|repr| {
317 (
318 *repr,
319 partial_derivative_ops.iter().find(|pdo| &pdo.repr == repr),
320 )
321 })
322 .collect::<SmallVec<
323 [(&str, Option<&PartialDerivative<'a, T, OF, LM>>); N_BINOPS_OF_DEEPEX_ON_STACK],
324 >>();
325
326 let mut num_inds = prio_indices.clone();
327 let mut used_prio_indices = ExprIdxVec::new();
328
329 for (i, &bin_op_idx) in prio_indices.iter().enumerate() {
330 let num_idx = num_inds[i];
331 let node_1 = nodes[num_idx].take();
332 let node_2 = nodes[num_idx + 1].take();
333
334 let pd_deepex = if let (Some(n1), Some(n2)) = (node_1, node_2) {
335 let pdo = &partial_bin_ops_of_deepex[bin_op_idx];
336 match pdo {
337 (_, Some(pdo)) => pdo
338 .bin_op
339 .ok_or_else(|| exerr!("cannot find binary op for {}", pdo.repr))?(
340 n1, n2
341 ),
342 (repr, None) => match missing_op_mode {
343 MissingOpMode::PerOperand => partial_deri_per_operand(repr, n1, n2),
344 MissingOpMode::None => partial_derisval(repr, n1, n2),
345 MissingOpMode::Error => Err(exerr!("cannot find binary op for {repr}",))?,
346 },
347 }
348 } else {
349 Err(ExError::new(
350 "nodes do not contain values in partial derivative",
351 ))
352 }?;
353 nodes[num_idx] = Some(pd_deepex);
354 nodes.remove(num_idx + 1);
355 for num_idx_after in num_inds.iter_mut() {
357 if *num_idx_after > num_idx {
358 *num_idx_after -= 1;
359 }
360 }
361 used_prio_indices.push(bin_op_idx);
362 }
363 let res = nodes[0]
364 .take()
365 .ok_or_else(|| {
366 ExError::new("node 0 needs to contain valder at the end of partial derviative")
367 })?
368 .der;
369 let (res, _) = res.var_names_union(deepex);
370 Ok(res)
371}
372
373pub fn partial_deepex<T: DiffDataType, OF, LM>(
374 var_idx: usize,
375 deepex: DeepEx<'_, T, OF, LM>,
376 missing_op_mode: MissingOpMode,
377) -> ExResult<DeepEx<'_, T, OF, LM>>
378where
379 OF: MakeOperators<T>,
380 LM: MatchLiteral,
381 <T as FromStr>::Err: Debug,
382{
383 let partial_derivative_ops = make_partial_derivative_ops::<T, OF, LM>();
384 let inner = partial_derivative_inner(
385 var_idx,
386 deepex.clone(),
387 &partial_derivative_ops,
388 missing_op_mode,
389 )?;
390 let outer = partial_derivative_outer(deepex, &partial_derivative_ops)?;
391 inner * outer
392}
393
394enum Base {
395 Two,
396 Ten,
397 Euler,
398}
399fn log_deri<T: DiffDataType, OF, LM>(
400 f: DeepEx<'_, T, OF, LM>,
401 base: Base,
402) -> ExResult<DeepEx<'_, T, OF, LM>>
403where
404 OF: MakeOperators<T>,
405 LM: MatchLiteral,
406 <T as FromStr>::Err: Debug,
407{
408 let ln_base = |base_float: f32| DeepEx::from_num(T::from(base_float)).ln();
409 let x = f.without_latest_unary();
410 let denominator = match base {
411 Base::Ten => (x * ln_base(10.0)?)?,
412 Base::Two => (x * ln_base(2.0)?)?,
413 Base::Euler => x,
414 };
415 DeepEx::one() / denominator
416}
417
418fn partial_deri_per_operand<'a, T, OF, LM>(
419 repr: &'a str,
420 f: ValueDerivative<'a, T, OF, LM>,
421 g: ValueDerivative<'a, T, OF, LM>,
422) -> ExResult<ValueDerivative<'a, T, OF, LM>>
423where
424 T: DiffDataType,
425 OF: MakeOperators<T>,
426 LM: MatchLiteral,
427 <T as FromStr>::Err: Debug,
428{
429 Ok(ValueDerivative {
430 val: f.val.clone().operate_bin(g.val.clone(), repr)?,
431 der: f.der.operate_bin(g.der, repr)?,
432 })
433}
434
435macro_rules! make_partial_per_operand {
436 ($repr:expr) => {
437 PartialDerivative {
438 repr: $repr,
439 bin_op: Some(
440 |f: ValueDerivative<T, OF, LM>,
441 g: ValueDerivative<T, OF, LM>|
442 -> ExResult<ValueDerivative<T, OF, LM>> {
443 Ok(ValueDerivative {
444 val: f.val.operate_bin(g.val, $repr)?,
445 der: f.der.operate_bin(g.der, $repr)?,
446 })
447 },
448 ),
449 unary_outer_op: None,
450 }
451 };
452}
453
454fn partial_derisval<'a, T, OF, LM>(
455 repr: &'a str,
456 f: ValueDerivative<'a, T, OF, LM>,
457 g: ValueDerivative<'a, T, OF, LM>,
458) -> ExResult<ValueDerivative<'a, T, OF, LM>>
459where
460 T: DiffDataType,
461 OF: MakeOperators<T>,
462 LM: MatchLiteral,
463 <T as FromStr>::Err: Debug,
464{
465 Ok(ValueDerivative {
466 val: f.val.clone().operate_bin(g.val.clone(), repr)?,
467 der: f.val.operate_bin(g.val, repr)?,
468 })
469}
470
471macro_rules! make_partial_derisval {
472 ($repr:expr) => {
473 PartialDerivative {
474 repr: $repr,
475 bin_op: Some(
476 |f: ValueDerivative<T, OF, LM>,
477 g: ValueDerivative<T, OF, LM>|
478 -> ExResult<ValueDerivative<T, OF, LM>> {
479 partial_derisval($repr, f, g)
480 },
481 ),
482 unary_outer_op: None,
483 }
484 };
485}
486
487pub fn make_partial_derivative_ops<'a, T, OF, LM>() -> Vec<PartialDerivative<'a, T, OF, LM>>
488where
489 T: DiffDataType,
490 OF: MakeOperators<T>,
491 LM: MatchLiteral,
492 <T as FromStr>::Err: Debug,
493{
494 vec![
495 PartialDerivative {
496 repr: "^",
497 bin_op: Some(
498 |f: ValueDerivative<T, OF, LM>,
499 g: ValueDerivative<T, OF, LM>|
500 -> ExResult<ValueDerivative<T, OF, LM>> {
501 let one = DeepEx::one();
502 let val = f.val.clone().pow(g.val.clone())?;
503 let g_minus_1 = (g.val.clone() - one)?;
504 let der_1 = ((f.val.clone().pow(g_minus_1)? * g.val)? * f.der)?;
505 let der_2 = ((val.clone() * f.val.ln()?)? * g.der)?;
506 let der = (der_1 + der_2)?;
507 Ok(ValueDerivative { val, der })
508 },
509 ),
510 unary_outer_op: None,
511 },
512 PartialDerivative {
513 repr: "+",
514 bin_op: Some(
515 |f: ValueDerivative<T, OF, LM>,
516 g: ValueDerivative<T, OF, LM>|
517 -> ExResult<ValueDerivative<T, OF, LM>> {
518 Ok(ValueDerivative {
519 val: (f.val + g.val)?,
520 der: (f.der + g.der)?,
521 })
522 },
523 ),
524 unary_outer_op: Some(|_: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
525 Ok(DeepEx::one())
526 }),
527 },
528 PartialDerivative {
529 repr: "-",
530 bin_op: Some(
531 |f: ValueDerivative<T, OF, LM>,
532 g: ValueDerivative<T, OF, LM>|
533 -> ExResult<ValueDerivative<T, OF, LM>> {
534 Ok(ValueDerivative {
535 val: (f.val - g.val)?,
536 der: (f.der - g.der)?,
537 })
538 },
539 ),
540 unary_outer_op: Some(
541 |_: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> { -DeepEx::one() },
542 ),
543 },
544 PartialDerivative {
545 repr: "*",
546 bin_op: Some(
547 |f: ValueDerivative<T, OF, LM>,
548 g: ValueDerivative<T, OF, LM>|
549 -> ExResult<ValueDerivative<T, OF, LM>> {
550 let val = (f.val.clone() * g.val.clone())?;
551 let der_1 = (g.val * f.der)?;
552 let der_2 = (g.der * f.val)?;
553 let der = (der_1 + der_2)?;
554 Ok(ValueDerivative { val, der })
555 },
556 ),
557 unary_outer_op: None,
558 },
559 make_partial_derisval!(">"),
560 make_partial_derisval!("<"),
561 make_partial_derisval!("!="),
562 make_partial_derisval!("=="),
563 make_partial_derisval!("<="),
564 make_partial_derisval!(">="),
565 make_partial_per_operand!("if"),
566 make_partial_per_operand!("else"),
567 PartialDerivative {
568 repr: "/",
569 bin_op: Some(
570 |f: ValueDerivative<T, OF, LM>,
571 g: ValueDerivative<T, OF, LM>|
572 -> ExResult<ValueDerivative<T, OF, LM>> {
573 let val = (f.val.clone() / g.val.clone())?;
574 let numerator = ((f.der * g.val.clone())? - (g.der * f.val)?)?;
575 let denominator = (g.val.clone() * g.val)?;
576 Ok(ValueDerivative {
577 val,
578 der: (numerator / denominator)?,
579 })
580 },
581 ),
582 unary_outer_op: None,
583 },
584 PartialDerivative {
585 repr: "sqrt",
586 bin_op: None,
587 unary_outer_op: Some(
588 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
589 let one = DeepEx::one();
590 let two = DeepEx::from_num(T::from(2.0));
591 one / (two * f)?
592 },
593 ),
594 },
595 PartialDerivative {
596 repr: "ln",
597 bin_op: None,
598 unary_outer_op: Some(
599 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
600 log_deri(f, Base::Euler)
601 },
602 ),
603 },
604 PartialDerivative {
605 repr: "log",
606 bin_op: None,
607 unary_outer_op: Some(
608 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
609 log_deri(f, Base::Euler)
610 },
611 ),
612 },
613 PartialDerivative {
614 repr: "log10",
615 bin_op: None,
616 unary_outer_op: Some(
617 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
618 log_deri(f, Base::Ten)
619 },
620 ),
621 },
622 PartialDerivative {
623 repr: "log2",
624 bin_op: None,
625 unary_outer_op: Some(
626 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
627 log_deri(f, Base::Two)
628 },
629 ),
630 },
631 PartialDerivative {
632 repr: "exp",
633 bin_op: None,
634 unary_outer_op: Some(
635 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> { Ok(f) },
636 ),
637 },
638 PartialDerivative {
639 repr: "sin",
640 bin_op: None,
641 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
642 f.without_latest_unary().cos()
643 }),
644 },
645 PartialDerivative {
646 repr: "cos",
647 bin_op: None,
648 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
649 let sin = f.without_latest_unary().sin()?;
650 -sin
651 }),
652 },
653 PartialDerivative {
654 repr: "tan",
655 bin_op: None,
656 unary_outer_op: Some(
657 |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
658 let two = DeepEx::from_num(T::from(2.0));
659 let cos_squared_ex = f.clone().without_latest_unary().cos()?.pow(two)?;
660 DeepEx::one() / cos_squared_ex
661 },
662 ),
663 },
664 PartialDerivative {
665 repr: "asin",
666 bin_op: None,
667 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
668 let one = DeepEx::one();
669 let two = DeepEx::from_num(T::from(2.0));
670 let inner_squared = f.without_latest_unary().pow(two)?;
671 let insq_min1_sqrt = (one.clone() - inner_squared)?.sqrt()?;
672 one / insq_min1_sqrt
673 }),
674 },
675 PartialDerivative {
676 repr: "acos",
677 bin_op: None,
678 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
679 let one = DeepEx::one();
680 let two = DeepEx::from_num(T::from(2.0));
681 let inner_squared = f.without_latest_unary().pow(two)?;
682 let denominator = (one.clone() - inner_squared)?.sqrt()?;
683 let div = (one / denominator)?;
684 -div
685 }),
686 },
687 PartialDerivative {
688 repr: "atan",
689 bin_op: None,
690 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
691 let one = DeepEx::one();
692 let two = DeepEx::from_num(T::from(2.0));
693 let inner_squared = f.without_latest_unary().pow(two)?;
694 one.clone() / (one + inner_squared)?
695 }),
696 },
697 PartialDerivative {
698 repr: "sinh",
699 bin_op: None,
700 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
701 f.without_latest_unary().cosh()
702 }),
703 },
704 PartialDerivative {
705 repr: "cosh",
706 bin_op: None,
707 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
708 f.without_latest_unary().sinh()
709 }),
710 },
711 PartialDerivative {
712 repr: "tanh",
713 bin_op: None,
714 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
715 let one = DeepEx::one();
716 let two = DeepEx::from_num(T::from(2.0));
717 one - f.without_latest_unary().tanh()?.pow(two)?
718 }),
719 },
720 PartialDerivative {
721 repr: "asinh",
722 bin_op: None,
723 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
724 let one = DeepEx::one();
725 let two = DeepEx::from_num(T::from(2.0));
726 one.clone() / (one + f.without_latest_unary().pow(two)?)?.sqrt()?
727 }),
728 },
729 PartialDerivative {
730 repr: "acosh",
731 bin_op: None,
732 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
733 let one = DeepEx::one();
734 one.clone()
735 / ((f.clone().without_latest_unary() - one.clone())?.sqrt()?
736 * (f.without_latest_unary() + one)?.sqrt()?)?
737 }),
738 },
739 PartialDerivative {
740 repr: "atanh",
741 bin_op: None,
742 unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
743 let one = DeepEx::one();
744 let two = DeepEx::from_num(T::from(2.0));
745 one.clone() / (one - f.without_latest_unary().pow(two)?)?
746 }),
747 },
748 ]
749}
750
751#[cfg(test)]
752use crate::{util::assert_float_eq_f64, FlatEx, FloatOpsFactory, NumberMatcher};
753
754#[test]
755fn test_pmp() -> ExResult<()> {
756 let x = 1.5f64;
757 let fex = FlatEx::<f64>::parse("+-+x")?;
758 let deri = fex.partial(0)?;
759 println!("{}", deri);
760 let reference = -1.0;
761 assert_float_eq_f64(deri.eval(&[x])?, reference);
762 Ok(())
763}
764#[test]
765fn test_compile() -> ExResult<()> {
766 let deepex = DeepEx::<f64>::parse("1+(((a+x^2*x^2)))")?;
767 println!("{}", deepex);
768 assert_eq!(format!("{}", deepex), "1.0+({a}+{x}^2.0*{x}^2.0)");
769 let mut ddeepex = partial_deepex(1, deepex, MissingOpMode::Error)?;
770 ddeepex.compile();
771 println!("{}", ddeepex);
772 assert_eq!(
773 format!("{}", ddeepex),
774 "(({x}^2.0)*({x}*2.0))+(({x}*2.0)*({x}^2.0))"
775 );
776 Ok(())
777}
778#[test]
779fn test_sincosin() -> ExResult<()> {
780 let x = 1.5f64;
781 let fex = FlatEx::<f64>::parse("sin(cos(sin(x)))")?;
782 let deri = fex.partial(0)?;
783 println!("{}", deri);
784 let reference = x.cos() * (-x.sin().sin()) * x.sin().cos().cos();
785 assert_float_eq_f64(deri.eval(&[x])?, reference);
786 Ok(())
787}
788
789#[test]
790fn test_partial() {
791 let dut = DeepEx::<f64>::parse("z*sin(x)+cos(y)^(sin(z))").unwrap();
792 let d_z = partial_deepex(2, dut.clone(), MissingOpMode::Error).unwrap();
793 assert_float_eq_f64(
794 d_z.eval(&[-0.18961918881278095, -6.383306547710852, 3.1742139703464503])
795 .unwrap(),
796 -0.18346624475117082,
797 );
798 let dut = DeepEx::<f64>::parse("sin(x)/x^2").unwrap();
799 let d_x = partial_deepex(0, dut, MissingOpMode::Error).unwrap();
800 assert_float_eq_f64(
801 d_x.eval(&[-0.18961918881278095]).unwrap(),
802 -27.977974668662565,
803 );
804
805 let dut = DeepEx::<f64>::parse("x^y").unwrap();
806 let d_x = partial_deepex(0, dut, MissingOpMode::Error).unwrap();
807 assert_float_eq_f64(d_x.eval(&[7.5, 3.5]).unwrap(), 539.164392544148);
808}
809
810#[test]
811fn test_partial_3_vars() {
812 fn eval_(deepex: &DeepEx<f64, FloatOpsFactory<f64>, NumberMatcher>, vars: &[f64]) -> f64 {
813 deepex.eval(vars).unwrap()
814 }
815 fn assert(s: &str, vars: &[f64], ref_vals: &[f64]) {
816 let dut = DeepEx::<f64>::parse(s).unwrap();
817 let d_x = partial_deepex(0, dut.clone(), MissingOpMode::Error).unwrap();
818 assert_float_eq_f64(eval_(&d_x, vars), ref_vals[0]);
819 let d_y = partial_deepex(1, dut.clone(), MissingOpMode::Error).unwrap();
820 assert_float_eq_f64(eval_(&d_y, vars), ref_vals[1]);
821 let d_z = partial_deepex(2, dut.clone(), MissingOpMode::Error).unwrap();
822 assert_float_eq_f64(eval_(&d_z, vars), ref_vals[2]);
823 }
824 assert("x+y+z", &[2345.3, 4523.5, 1.2], &[1.0, 1.0, 1.0]);
825 assert(
826 "x^2+y^2+z^2",
827 &[2345.3, 4523.5, 1.2],
828 &[2345.3 * 2.0, 4523.5 * 2.0, 2.4],
829 );
830}
831
832#[test]
833fn test_partial_x2x() {
834 let deepex = DeepEx::<f64>::parse("x * 2 * x").unwrap();
835 let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
836 let result = derivative.eval(&[0.0]).unwrap();
837 assert_float_eq_f64(result, 0.0);
838 let result = derivative.eval(&[1.0]).unwrap();
839 assert_float_eq_f64(result, 4.0);
840}
841
842#[test]
843fn test_partial_cos_squared() {
844 let deepex = DeepEx::<f64>::parse("cos(y) ^ 2").unwrap();
845 let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
846 let result = derivative.eval(&[0.0]).unwrap();
847 assert_float_eq_f64(result, 0.0);
848 let result = derivative.eval(&[1.0]).unwrap();
849 assert_float_eq_f64(result, -0.9092974268256818);
850}
851
852#[test]
853fn test_num_ops() {
854 fn eval_<'a>(
855 deepex: &DeepEx<'a, f64, FloatOpsFactory<f64>, NumberMatcher>,
856 vars: &[f64],
857 val: f64,
858 ) {
859 assert_float_eq_f64(deepex.eval(vars).unwrap(), val);
860 }
861 fn check_shape(deepex: &DeepEx<f64, FloatOpsFactory<f64>, NumberMatcher>, n_nodes: usize) {
862 assert_eq!(deepex.nodes().len(), n_nodes);
863 assert_eq!(deepex.bin_ops().ops.len(), n_nodes - 1);
864 assert_eq!(deepex.bin_ops().reprs.len(), n_nodes - 1);
865 }
866
867 let minus_one = DeepEx::<f64>::parse("-1").unwrap();
868 let one = (minus_one.clone() * minus_one.clone()).unwrap();
869 check_shape(&one, 1);
870 eval_(&one, &[], 1.0);
871}
872
873#[test]
874fn test_partial_combined() {
875 let deepex = DeepEx::<f64>::parse("sin(x) + cos(y) ^ 2").unwrap();
876 let d_y = partial_deepex(1, deepex.clone(), MissingOpMode::Error).unwrap();
877 let result = d_y.eval(&[231.431, 0.0]).unwrap();
878 assert_float_eq_f64(result, 0.0);
879 let result = d_y.eval(&[-12.0, 1.0]).unwrap();
880 assert_float_eq_f64(result, -0.9092974268256818);
881 let d_x = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
882 let result = d_x.eval(&[231.431, 0.0]).unwrap();
883 assert_float_eq_f64(result, 0.5002954462477305);
884 let result = d_x.eval(&[-12.0, 1.0]).unwrap();
885 assert_float_eq_f64(result, 0.8438539587324921);
886}
887
888#[test]
889fn test_partial_derivative_second_var() {
890 let deepex = DeepEx::<f64>::parse("sin(x) + cos(y)").unwrap();
891 let derivative = partial_deepex(1, deepex.clone(), MissingOpMode::Error).unwrap();
892 let result = derivative.eval(&[231.431, 0.0]).unwrap();
893 assert_float_eq_f64(result, 0.0);
894 let result = derivative.eval(&[-12.0, 1.0]).unwrap();
895 assert_float_eq_f64(result, -0.8414709848078965);
896}
897
898#[test]
899fn test_partial_derivative_first_var() {
900 let deepex = DeepEx::<f64>::parse("sin(x) + cos(y)").unwrap();
901 let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
902 let result = derivative.eval(&[0.0, 2345.03]).unwrap();
903 assert_float_eq_f64(result, 1.0);
904 let result = derivative.eval(&[1.0, 43212.43]).unwrap();
905 assert_float_eq_f64(result, 0.5403023058681398);
906}
907
908#[test]
909fn test_partial_inner() {
910 fn test(text: &str, vals: &[f64], ref_vals: &[f64], var_idx: usize) {
911 let partial_derivative_ops =
912 make_partial_derivative_ops::<f64, FloatOpsFactory<f64>, NumberMatcher>();
913 let deepex_1 = DeepEx::<f64>::parse(text).unwrap();
914 let deri = partial_derivative_inner(
915 var_idx,
916 deepex_1,
917 &partial_derivative_ops,
918 MissingOpMode::Error,
919 )
920 .unwrap();
921 for i in 0..vals.len() {
922 assert_float_eq_f64(deri.eval(&[vals[i]]).unwrap(), ref_vals[i]);
923 }
924 }
925 test("sin(x)", &[1.0, 0.0, 2.0], &[1.0, 1.0, 1.0], 0);
926 test("sin(x^2)", &[1.0, 0.0, 2.0], &[2.0, 0.0, 4.0], 0);
927}
928
929#[test]
930fn test_partial_outer() {
931 fn test(text: &str, vals: &[f64], ref_vals: &[f64]) {
932 let partial_derivative_ops =
933 make_partial_derivative_ops::<f64, FloatOpsFactory<f64>, NumberMatcher>();
934 let deepex_1 = DeepEx::<f64>::parse(text).unwrap();
935 let deepex = deepex_1.nodes()[0].clone();
936
937 if let DeepNode::Expr(e) = deepex {
938 let deri = partial_derivative_outer(*e, &partial_derivative_ops).unwrap();
939 for i in 0..vals.len() {
940 assert_float_eq_f64(deri.eval(&[vals[i]]).unwrap(), ref_vals[i]);
941 }
942 }
943 }
944 test("x", &[1.0, 0.0, 2.0], &[1.0, 0.0, 2.0]);
945 test(
946 "sin(x)",
947 &[1.0, 0.0, 2.0],
948 &[0.5403023058681398, 1.0, -0.4161468365471424],
949 );
950}
951
952#[test]
953fn test_partial_derivative_simple() -> ExResult<()> {
954 let deepex = DeepEx::<f64>::parse("1")?;
955 let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
956
957 assert_eq!(derivative.nodes().len(), 1);
958 assert_eq!(derivative.bin_ops().ops.len(), 0);
959 match derivative.nodes()[0] {
960 DeepNode::Num(n) => assert_float_eq_f64(n, 0.0),
961 _ => unreachable!(),
962 }
963 let deepex = DeepEx::<f64>::parse("x")?;
964 let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
965 assert_eq!(derivative.nodes().len(), 1);
966 assert_eq!(derivative.bin_ops().ops.len(), 0);
967 match derivative.nodes()[0] {
968 DeepNode::Num(n) => assert_float_eq_f64(n, 1.0),
969 _ => unreachable!(),
970 }
971 let deepex = DeepEx::<f64>::parse("x^2")?;
972 let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
973 let result = derivative.eval(&[4.5])?;
974 assert_float_eq_f64(result, 9.0);
975
976 let deepex = DeepEx::<f64>::parse("sin(x)")?;
977 let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error)?;
978 let result = derivative.eval(&[0.0])?;
979 assert_float_eq_f64(result, 1.0);
980 let result = derivative.eval(&[1.0])?;
981 assert_float_eq_f64(result, 0.5403023058681398);
982 Ok(())
983}