exmex/expression/
partial.rs

1use std::{
2    fmt::{Debug, Display},
3    iter,
4    str::FromStr,
5};
6
7use smallvec::SmallVec;
8
9use crate::{
10    data_type::DataType,
11    definitions::N_BINOPS_OF_DEEPEX_ON_STACK,
12    exerr,
13    expression::{
14        deep::{prioritized_indices, DeepEx, DeepNode},
15        flat::ExprIdxVec,
16    },
17    DiffDataType, ExError, ExResult, Express, MakeOperators, MatchLiteral,
18};
19
20pub fn check_partial_index(var_idx: usize, n_vars: usize, unparsed: &str) -> ExResult<()> {
21    if var_idx >= n_vars {
22        Err(exerr!(
23            "index {} is invalid since we have only {} vars in {}",
24            var_idx,
25            n_vars,
26            unparsed
27        ))
28    } else {
29        Ok(())
30    }
31}
32
33/// *`feature = "partial"`* - Trait for partial differentiation. This is implemented for expressions
34/// with datatypes that implement `DiffDataType`.  
35pub trait Differentiate<'a, T>
36where
37    T: DiffDataType,
38    <T as FromStr>::Err: Debug,
39    Self: Sized + Express<'a, T> + Display + Debug,
40{
41    /// *`feature = "partial"`* - This method computes a new expression
42    /// that is the partial derivative of `self` with default operators.
43    ///
44    /// # Example
45    ///
46    /// ```rust
47    /// # use std::error::Error;
48    /// # fn main() -> Result<(), Box<dyn Error>> {
49    /// #
50    /// use exmex::prelude::*;
51    ///
52    /// let expr = FlatEx::<f64>::parse("sin(1+y^2)*x")?;
53    /// let dexpr_dx = expr.partial(0)?;
54    ///
55    /// assert!((dexpr_dx.eval(&[9e5, 2.0])? - (5.0 as f64).sin()).abs() < 1e-12);
56    /// //                        |    
57    /// //           The partial derivative dexpr_dx does depend on x. Still, it
58    /// //           expects the same number of parameters as the corresponding
59    /// //           antiderivative. Hence, you can pass any number for x.  
60    ///
61    /// #
62    /// #     Ok(())
63    /// # }
64    /// ```
65    /// # Arguments
66    ///
67    /// * `var_idx` - variable with respect to which the partial derivative is computed
68    ///
69    /// # Errors
70    ///
71    /// * If you use custom operators this might not work as expected. It could return an [`ExError`](crate::ExError) if
72    ///   an operator is not found or compute a wrong result if an operator is defined in an un-expected way.
73    ///
74    fn partial(self, var_idx: usize) -> ExResult<Self> {
75        self.partial_nth(var_idx, 1)
76    }
77
78    /// Like [`Differentiate::partial`]. The only difference is that in case there is no differentation defined for
79    /// a binary operator this will not necessarily throw an error depending on `missing_op_mode`, see [`MissingOpMode`].
80    fn partial_relaxed(self, var_idx: usize, missing_op_mode: MissingOpMode) -> ExResult<Self> {
81        self.partial_nth_relaxed(var_idx, 1, missing_op_mode)
82    }
83
84    /// *`feature = "partial"`* - Computes the nth partial derivative with respect to one variable
85    /// # Example
86    /// ```rust
87    /// # use std::error::Error;
88    /// # fn main() -> Result<(), Box<dyn Error>> {
89    /// #
90    /// use exmex::prelude::*;
91    ///
92    /// let mut expr = FlatEx::<f64>::parse("x^4+y^4")?;
93    ///
94    /// let dexpr_dxx_nth = expr.clone().partial_nth(0, 2)?;
95    ///
96    /// let dexpr_dx = expr.partial(0)?;
97    /// let dexpr_dxx_2step = dexpr_dx.partial(0)?;
98    ///
99    /// assert!((dexpr_dxx_2step.eval(&[4.3, 2.1])? - dexpr_dxx_nth.eval(&[4.3, 2.1])?).abs() < 1e-12);
100    /// #
101    /// #     Ok(())
102    /// # }
103    /// ```
104    /// # Arguments
105    ///
106    /// * `var_idx` - variable with respect to which the partial derivative is computed
107    /// * `n` - order of derivation
108    ///
109    /// # Errors
110    ///
111    /// * If you use custom operators this might not work as expected. It could return an [`ExError`](crate::ExError) if
112    ///   an operator is not found or compute a wrong result if an operator is defined in an un-expected way.
113    ///
114    fn partial_nth(self, var_idx: usize, n: usize) -> ExResult<Self> {
115        self.partial_iter(iter::repeat_n(var_idx, n))
116    }
117
118    /// Like [`Differentiate::partial_nth`]. The only difference is that in case there is no differentation defined for
119    /// a binary operator this will not necessarily throw an error depending on `missing_op_mode`, see [`MissingOpMode`].
120    fn partial_nth_relaxed(
121        self,
122        var_idx: usize,
123        n: usize,
124        missing_op_mode: MissingOpMode,
125    ) -> ExResult<Self> {
126        self.partial_iter_relaxed(iter::repeat_n(var_idx, n), missing_op_mode)
127    }
128
129    /// *`feature = "partial"`* - Computes a chain of partial derivatives with respect to the variables passed as iterator
130    ///
131    /// # Example
132    /// ```rust
133    /// # use std::error::Error;
134    /// # fn main() -> Result<(), Box<dyn Error>> {
135    /// #
136    /// use exmex::prelude::*;
137    ///
138    /// let mut expr = FlatEx::<f64>::parse("x^4+y^4")?;
139    ///
140    /// let dexpr_dxy_iter = expr.clone().partial_iter([0, 1].iter().copied())?;
141    ///
142    /// let dexpr_dx = expr.partial(0)?;
143    /// let dexpr_dxy_2step = dexpr_dx.partial(1)?;
144    ///
145    /// assert!((dexpr_dxy_2step.eval(&[4.3, 2.1])? - dexpr_dxy_iter.eval(&[4.3, 2.1])?).abs() < 1e-12);
146    /// #
147    /// #     Ok(())
148    /// # }
149    /// ```
150    /// # Arguments
151    ///
152    /// * `var_idxs` - variables with respect to which the partial derivative is computed
153    /// * `n` - order of derivation
154    ///
155    /// # Errors
156    ///
157    /// * If you use custom operators this might not work as expected. It could return an [`ExError`](crate::ExError) if
158    ///   an operator is not found or compute a wrong result if an operator is defined in an un-expected way.
159    ///
160    fn partial_iter<I>(self, var_idxs: I) -> ExResult<Self>
161    where
162        I: Iterator<Item = usize> + Clone,
163    {
164        self.partial_iter_relaxed(var_idxs, MissingOpMode::Error)
165    }
166
167    /// Like [`Differentiate::partial_iter`]. The only difference is that in case there is no differentation defined for
168    /// a binary this will not necessarily throw an error depending on `missing_op_mode`, see [`MissingOpMode`].
169    fn partial_iter_relaxed<I>(self, var_idxs: I, missing_op_mode: MissingOpMode) -> ExResult<Self>
170    where
171        I: Iterator<Item = usize> + Clone,
172    {
173        let mut deepex = self.to_deepex()?;
174
175        let unparsed = deepex.unparse();
176        for var_idx in var_idxs.clone() {
177            check_partial_index(var_idx, deepex.var_names().len(), unparsed)?;
178        }
179        for var_idx in var_idxs {
180            deepex = partial_deepex(var_idx, deepex, missing_op_mode)?;
181        }
182        deepex.compile();
183        Self::from_deepex(deepex)
184    }
185}
186#[derive(Clone, Debug)]
187struct ValueDerivative<'a, T, OF, LM>
188where
189    T: DataType,
190    OF: MakeOperators<T>,
191    LM: MatchLiteral,
192    <T as FromStr>::Err: Debug,
193{
194    val: DeepEx<'a, T, OF, LM>,
195    der: DeepEx<'a, T, OF, LM>,
196}
197
198type BinOpPartial<'a, T, OF, LM> = fn(
199    ValueDerivative<'a, T, OF, LM>,
200    ValueDerivative<'a, T, OF, LM>,
201) -> ExResult<ValueDerivative<'a, T, OF, LM>>;
202
203type UnaryOpOuter<'a, T, OF, LM> = fn(DeepEx<'a, T, OF, LM>) -> ExResult<DeepEx<'a, T, OF, LM>>;
204
205#[derive(Debug)]
206pub struct PartialDerivative<'a, T: DataType, OF, LM>
207where
208    OF: MakeOperators<T>,
209    LM: MatchLiteral,
210    <T as FromStr>::Err: Debug,
211{
212    repr: &'a str,
213    bin_op: Option<BinOpPartial<'a, T, OF, LM>>,
214    unary_outer_op: Option<UnaryOpOuter<'a, T, OF, LM>>,
215}
216
217fn make_op_missing_err(repr: &str) -> ExError {
218    exerr!("operator {} needed for outer partial derivative", repr)
219}
220
221fn partial_derivative_outer<'a, T: DiffDataType, OF, LM>(
222    deepex: DeepEx<'a, T, OF, LM>,
223    partial_derivative_ops: &[PartialDerivative<'a, T, OF, LM>],
224) -> ExResult<DeepEx<'a, T, OF, LM>>
225where
226    OF: MakeOperators<T>,
227    LM: MatchLiteral,
228    <T as FromStr>::Err: Debug,
229{
230    let mut factorexes = deepex
231        .unary_op()
232        .reprs
233        .iter()
234        .enumerate()
235        .map(|(idx, repr)| {
236            let op = partial_derivative_ops
237                .iter()
238                .find(|pdo| pdo.repr == *repr)
239                .ok_or_else(|| make_op_missing_err(repr))?;
240            let unary_deri_op = op.unary_outer_op.ok_or_else(|| make_op_missing_err(repr))?;
241            let mut new_deepex = deepex.clone();
242            for _ in 0..idx {
243                new_deepex = new_deepex.without_latest_unary();
244            }
245            unary_deri_op(new_deepex)
246        });
247    factorexes.try_fold(DeepEx::one(), |dp1, dp2| -> ExResult<DeepEx<T, OF, LM>> {
248        dp2.and_then(|dp2| dp2 * dp1)
249    })
250}
251
252/// Feature `partial` - What should happen in case for an operator the derivative is missing
253#[derive(Clone, Copy, Debug)]
254pub enum MissingOpMode {
255    /// Compute partial derviatives per operand like for `+`
256    PerOperand,
257    /// Do not compute partial derivatives and keep the operands as they were
258    None,
259    /// Return an error
260    Error,
261}
262
263fn partial_derivative_inner<'a, T: DiffDataType, OF, LM>(
264    var_idx: usize,
265    deepex: DeepEx<'a, T, OF, LM>,
266    partial_derivative_ops: &[PartialDerivative<'a, T, OF, LM>],
267    missing_op_mode: MissingOpMode,
268) -> ExResult<DeepEx<'a, T, OF, LM>>
269where
270    OF: MakeOperators<T>,
271    LM: MatchLiteral,
272    <T as FromStr>::Err: Debug,
273{
274    // special case, partial derivative of only 1 node
275    if deepex.nodes().len() == 1 {
276        let res = match deepex.nodes()[0].clone() {
277            DeepNode::Num(_) => DeepEx::zero(),
278            DeepNode::Var((var_i, _)) => {
279                if var_i == var_idx {
280                    DeepEx::one()
281                } else {
282                    DeepEx::zero()
283                }
284            }
285            DeepNode::Expr(e) => partial_deepex(var_idx, *e, missing_op_mode)?,
286        };
287        let (res, _) = res.var_names_union(deepex);
288        return Ok(res);
289    }
290
291    let prio_indices = prioritized_indices(&deepex.bin_ops().ops, deepex.nodes());
292
293    let make_deepex = |node: DeepNode<'a, T, OF, LM>| match node {
294        DeepNode::Expr(e) => e,
295        _ => Box::new(DeepEx::from_node(node)),
296    };
297
298    let mut nodes = deepex
299        .nodes()
300        .iter()
301        .map(|node| -> ExResult<_> {
302            let deepex_val = make_deepex(node.clone());
303            let deepex_der = partial_deepex(var_idx, (*deepex_val).clone(), missing_op_mode)?;
304            Ok(Some(ValueDerivative {
305                val: *deepex_val,
306                der: deepex_der,
307            }))
308        })
309        .collect::<ExResult<Vec<_>>>()?;
310
311    let partial_bin_ops_of_deepex =
312        deepex
313            .bin_ops()
314            .reprs
315            .iter()
316            .map(|repr| {
317                (
318                    *repr,
319                    partial_derivative_ops.iter().find(|pdo| &pdo.repr == repr),
320                )
321            })
322            .collect::<SmallVec<
323                [(&str, Option<&PartialDerivative<'a, T, OF, LM>>); N_BINOPS_OF_DEEPEX_ON_STACK],
324            >>();
325
326    let mut num_inds = prio_indices.clone();
327    let mut used_prio_indices = ExprIdxVec::new();
328
329    for (i, &bin_op_idx) in prio_indices.iter().enumerate() {
330        let num_idx = num_inds[i];
331        let node_1 = nodes[num_idx].take();
332        let node_2 = nodes[num_idx + 1].take();
333
334        let pd_deepex = if let (Some(n1), Some(n2)) = (node_1, node_2) {
335            let pdo = &partial_bin_ops_of_deepex[bin_op_idx];
336            match pdo {
337                (_, Some(pdo)) => pdo
338                    .bin_op
339                    .ok_or_else(|| exerr!("cannot find binary op for {}", pdo.repr))?(
340                    n1, n2
341                ),
342                (repr, None) => match missing_op_mode {
343                    MissingOpMode::PerOperand => partial_deri_per_operand(repr, n1, n2),
344                    MissingOpMode::None => partial_derisval(repr, n1, n2),
345                    MissingOpMode::Error => Err(exerr!("cannot find binary op for {repr}",))?,
346                },
347            }
348        } else {
349            Err(ExError::new(
350                "nodes do not contain values in partial derivative",
351            ))
352        }?;
353        nodes[num_idx] = Some(pd_deepex);
354        nodes.remove(num_idx + 1);
355        // reduce indices after removed position
356        for num_idx_after in num_inds.iter_mut() {
357            if *num_idx_after > num_idx {
358                *num_idx_after -= 1;
359            }
360        }
361        used_prio_indices.push(bin_op_idx);
362    }
363    let res = nodes[0]
364        .take()
365        .ok_or_else(|| {
366            ExError::new("node 0 needs to contain valder at the end of partial derviative")
367        })?
368        .der;
369    let (res, _) = res.var_names_union(deepex);
370    Ok(res)
371}
372
373pub fn partial_deepex<T: DiffDataType, OF, LM>(
374    var_idx: usize,
375    deepex: DeepEx<'_, T, OF, LM>,
376    missing_op_mode: MissingOpMode,
377) -> ExResult<DeepEx<'_, T, OF, LM>>
378where
379    OF: MakeOperators<T>,
380    LM: MatchLiteral,
381    <T as FromStr>::Err: Debug,
382{
383    let partial_derivative_ops = make_partial_derivative_ops::<T, OF, LM>();
384    let inner = partial_derivative_inner(
385        var_idx,
386        deepex.clone(),
387        &partial_derivative_ops,
388        missing_op_mode,
389    )?;
390    let outer = partial_derivative_outer(deepex, &partial_derivative_ops)?;
391    inner * outer
392}
393
394enum Base {
395    Two,
396    Ten,
397    Euler,
398}
399fn log_deri<T: DiffDataType, OF, LM>(
400    f: DeepEx<'_, T, OF, LM>,
401    base: Base,
402) -> ExResult<DeepEx<'_, T, OF, LM>>
403where
404    OF: MakeOperators<T>,
405    LM: MatchLiteral,
406    <T as FromStr>::Err: Debug,
407{
408    let ln_base = |base_float: f32| DeepEx::from_num(T::from(base_float)).ln();
409    let x = f.without_latest_unary();
410    let denominator = match base {
411        Base::Ten => (x * ln_base(10.0)?)?,
412        Base::Two => (x * ln_base(2.0)?)?,
413        Base::Euler => x,
414    };
415    DeepEx::one() / denominator
416}
417
418fn partial_deri_per_operand<'a, T, OF, LM>(
419    repr: &'a str,
420    f: ValueDerivative<'a, T, OF, LM>,
421    g: ValueDerivative<'a, T, OF, LM>,
422) -> ExResult<ValueDerivative<'a, T, OF, LM>>
423where
424    T: DiffDataType,
425    OF: MakeOperators<T>,
426    LM: MatchLiteral,
427    <T as FromStr>::Err: Debug,
428{
429    Ok(ValueDerivative {
430        val: f.val.clone().operate_bin(g.val.clone(), repr)?,
431        der: f.der.operate_bin(g.der, repr)?,
432    })
433}
434
435macro_rules! make_partial_per_operand {
436    ($repr:expr) => {
437        PartialDerivative {
438            repr: $repr,
439            bin_op: Some(
440                |f: ValueDerivative<T, OF, LM>,
441                 g: ValueDerivative<T, OF, LM>|
442                 -> ExResult<ValueDerivative<T, OF, LM>> {
443                    Ok(ValueDerivative {
444                        val: f.val.operate_bin(g.val, $repr)?,
445                        der: f.der.operate_bin(g.der, $repr)?,
446                    })
447                },
448            ),
449            unary_outer_op: None,
450        }
451    };
452}
453
454fn partial_derisval<'a, T, OF, LM>(
455    repr: &'a str,
456    f: ValueDerivative<'a, T, OF, LM>,
457    g: ValueDerivative<'a, T, OF, LM>,
458) -> ExResult<ValueDerivative<'a, T, OF, LM>>
459where
460    T: DiffDataType,
461    OF: MakeOperators<T>,
462    LM: MatchLiteral,
463    <T as FromStr>::Err: Debug,
464{
465    Ok(ValueDerivative {
466        val: f.val.clone().operate_bin(g.val.clone(), repr)?,
467        der: f.val.operate_bin(g.val, repr)?,
468    })
469}
470
471macro_rules! make_partial_derisval {
472    ($repr:expr) => {
473        PartialDerivative {
474            repr: $repr,
475            bin_op: Some(
476                |f: ValueDerivative<T, OF, LM>,
477                 g: ValueDerivative<T, OF, LM>|
478                 -> ExResult<ValueDerivative<T, OF, LM>> {
479                    partial_derisval($repr, f, g)
480                },
481            ),
482            unary_outer_op: None,
483        }
484    };
485}
486
487pub fn make_partial_derivative_ops<'a, T, OF, LM>() -> Vec<PartialDerivative<'a, T, OF, LM>>
488where
489    T: DiffDataType,
490    OF: MakeOperators<T>,
491    LM: MatchLiteral,
492    <T as FromStr>::Err: Debug,
493{
494    vec![
495        PartialDerivative {
496            repr: "^",
497            bin_op: Some(
498                |f: ValueDerivative<T, OF, LM>,
499                 g: ValueDerivative<T, OF, LM>|
500                 -> ExResult<ValueDerivative<T, OF, LM>> {
501                    let one = DeepEx::one();
502                    let val = f.val.clone().pow(g.val.clone())?;
503                    let g_minus_1 = (g.val.clone() - one)?;
504                    let der_1 = ((f.val.clone().pow(g_minus_1)? * g.val)? * f.der)?;
505                    let der_2 = ((val.clone() * f.val.ln()?)? * g.der)?;
506                    let der = (der_1 + der_2)?;
507                    Ok(ValueDerivative { val, der })
508                },
509            ),
510            unary_outer_op: None,
511        },
512        PartialDerivative {
513            repr: "+",
514            bin_op: Some(
515                |f: ValueDerivative<T, OF, LM>,
516                 g: ValueDerivative<T, OF, LM>|
517                 -> ExResult<ValueDerivative<T, OF, LM>> {
518                    Ok(ValueDerivative {
519                        val: (f.val + g.val)?,
520                        der: (f.der + g.der)?,
521                    })
522                },
523            ),
524            unary_outer_op: Some(|_: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
525                Ok(DeepEx::one())
526            }),
527        },
528        PartialDerivative {
529            repr: "-",
530            bin_op: Some(
531                |f: ValueDerivative<T, OF, LM>,
532                 g: ValueDerivative<T, OF, LM>|
533                 -> ExResult<ValueDerivative<T, OF, LM>> {
534                    Ok(ValueDerivative {
535                        val: (f.val - g.val)?,
536                        der: (f.der - g.der)?,
537                    })
538                },
539            ),
540            unary_outer_op: Some(
541                |_: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> { -DeepEx::one() },
542            ),
543        },
544        PartialDerivative {
545            repr: "*",
546            bin_op: Some(
547                |f: ValueDerivative<T, OF, LM>,
548                 g: ValueDerivative<T, OF, LM>|
549                 -> ExResult<ValueDerivative<T, OF, LM>> {
550                    let val = (f.val.clone() * g.val.clone())?;
551                    let der_1 = (g.val * f.der)?;
552                    let der_2 = (g.der * f.val)?;
553                    let der = (der_1 + der_2)?;
554                    Ok(ValueDerivative { val, der })
555                },
556            ),
557            unary_outer_op: None,
558        },
559        make_partial_derisval!(">"),
560        make_partial_derisval!("<"),
561        make_partial_derisval!("!="),
562        make_partial_derisval!("=="),
563        make_partial_derisval!("<="),
564        make_partial_derisval!(">="),
565        make_partial_per_operand!("if"),
566        make_partial_per_operand!("else"),
567        PartialDerivative {
568            repr: "/",
569            bin_op: Some(
570                |f: ValueDerivative<T, OF, LM>,
571                 g: ValueDerivative<T, OF, LM>|
572                 -> ExResult<ValueDerivative<T, OF, LM>> {
573                    let val = (f.val.clone() / g.val.clone())?;
574                    let numerator = ((f.der * g.val.clone())? - (g.der * f.val)?)?;
575                    let denominator = (g.val.clone() * g.val)?;
576                    Ok(ValueDerivative {
577                        val,
578                        der: (numerator / denominator)?,
579                    })
580                },
581            ),
582            unary_outer_op: None,
583        },
584        PartialDerivative {
585            repr: "sqrt",
586            bin_op: None,
587            unary_outer_op: Some(
588                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
589                    let one = DeepEx::one();
590                    let two = DeepEx::from_num(T::from(2.0));
591                    one / (two * f)?
592                },
593            ),
594        },
595        PartialDerivative {
596            repr: "ln",
597            bin_op: None,
598            unary_outer_op: Some(
599                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
600                    log_deri(f, Base::Euler)
601                },
602            ),
603        },
604        PartialDerivative {
605            repr: "log",
606            bin_op: None,
607            unary_outer_op: Some(
608                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
609                    log_deri(f, Base::Euler)
610                },
611            ),
612        },
613        PartialDerivative {
614            repr: "log10",
615            bin_op: None,
616            unary_outer_op: Some(
617                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
618                    log_deri(f, Base::Ten)
619                },
620            ),
621        },
622        PartialDerivative {
623            repr: "log2",
624            bin_op: None,
625            unary_outer_op: Some(
626                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
627                    log_deri(f, Base::Two)
628                },
629            ),
630        },
631        PartialDerivative {
632            repr: "exp",
633            bin_op: None,
634            unary_outer_op: Some(
635                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> { Ok(f) },
636            ),
637        },
638        PartialDerivative {
639            repr: "sin",
640            bin_op: None,
641            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
642                f.without_latest_unary().cos()
643            }),
644        },
645        PartialDerivative {
646            repr: "cos",
647            bin_op: None,
648            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
649                let sin = f.without_latest_unary().sin()?;
650                -sin
651            }),
652        },
653        PartialDerivative {
654            repr: "tan",
655            bin_op: None,
656            unary_outer_op: Some(
657                |f: DeepEx<'a, T, OF, LM>| -> ExResult<DeepEx<'a, T, OF, LM>> {
658                    let two = DeepEx::from_num(T::from(2.0));
659                    let cos_squared_ex = f.clone().without_latest_unary().cos()?.pow(two)?;
660                    DeepEx::one() / cos_squared_ex
661                },
662            ),
663        },
664        PartialDerivative {
665            repr: "asin",
666            bin_op: None,
667            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
668                let one = DeepEx::one();
669                let two = DeepEx::from_num(T::from(2.0));
670                let inner_squared = f.without_latest_unary().pow(two)?;
671                let insq_min1_sqrt = (one.clone() - inner_squared)?.sqrt()?;
672                one / insq_min1_sqrt
673            }),
674        },
675        PartialDerivative {
676            repr: "acos",
677            bin_op: None,
678            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
679                let one = DeepEx::one();
680                let two = DeepEx::from_num(T::from(2.0));
681                let inner_squared = f.without_latest_unary().pow(two)?;
682                let denominator = (one.clone() - inner_squared)?.sqrt()?;
683                let div = (one / denominator)?;
684                -div
685            }),
686        },
687        PartialDerivative {
688            repr: "atan",
689            bin_op: None,
690            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
691                let one = DeepEx::one();
692                let two = DeepEx::from_num(T::from(2.0));
693                let inner_squared = f.without_latest_unary().pow(two)?;
694                one.clone() / (one + inner_squared)?
695            }),
696        },
697        PartialDerivative {
698            repr: "sinh",
699            bin_op: None,
700            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
701                f.without_latest_unary().cosh()
702            }),
703        },
704        PartialDerivative {
705            repr: "cosh",
706            bin_op: None,
707            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
708                f.without_latest_unary().sinh()
709            }),
710        },
711        PartialDerivative {
712            repr: "tanh",
713            bin_op: None,
714            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
715                let one = DeepEx::one();
716                let two = DeepEx::from_num(T::from(2.0));
717                one - f.without_latest_unary().tanh()?.pow(two)?
718            }),
719        },
720        PartialDerivative {
721            repr: "asinh",
722            bin_op: None,
723            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
724                let one = DeepEx::one();
725                let two = DeepEx::from_num(T::from(2.0));
726                one.clone() / (one + f.without_latest_unary().pow(two)?)?.sqrt()?
727            }),
728        },
729        PartialDerivative {
730            repr: "acosh",
731            bin_op: None,
732            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
733                let one = DeepEx::one();
734                one.clone()
735                    / ((f.clone().without_latest_unary() - one.clone())?.sqrt()?
736                        * (f.without_latest_unary() + one)?.sqrt()?)?
737            }),
738        },
739        PartialDerivative {
740            repr: "atanh",
741            bin_op: None,
742            unary_outer_op: Some(|f: DeepEx<T, OF, LM>| -> ExResult<DeepEx<T, OF, LM>> {
743                let one = DeepEx::one();
744                let two = DeepEx::from_num(T::from(2.0));
745                one.clone() / (one - f.without_latest_unary().pow(two)?)?
746            }),
747        },
748    ]
749}
750
751#[cfg(test)]
752use crate::{util::assert_float_eq_f64, FlatEx, FloatOpsFactory, NumberMatcher};
753
754#[test]
755fn test_pmp() -> ExResult<()> {
756    let x = 1.5f64;
757    let fex = FlatEx::<f64>::parse("+-+x")?;
758    let deri = fex.partial(0)?;
759    println!("{}", deri);
760    let reference = -1.0;
761    assert_float_eq_f64(deri.eval(&[x])?, reference);
762    Ok(())
763}
764#[test]
765fn test_compile() -> ExResult<()> {
766    let deepex = DeepEx::<f64>::parse("1+(((a+x^2*x^2)))")?;
767    println!("{}", deepex);
768    assert_eq!(format!("{}", deepex), "1.0+({a}+{x}^2.0*{x}^2.0)");
769    let mut ddeepex = partial_deepex(1, deepex, MissingOpMode::Error)?;
770    ddeepex.compile();
771    println!("{}", ddeepex);
772    assert_eq!(
773        format!("{}", ddeepex),
774        "(({x}^2.0)*({x}*2.0))+(({x}*2.0)*({x}^2.0))"
775    );
776    Ok(())
777}
778#[test]
779fn test_sincosin() -> ExResult<()> {
780    let x = 1.5f64;
781    let fex = FlatEx::<f64>::parse("sin(cos(sin(x)))")?;
782    let deri = fex.partial(0)?;
783    println!("{}", deri);
784    let reference = x.cos() * (-x.sin().sin()) * x.sin().cos().cos();
785    assert_float_eq_f64(deri.eval(&[x])?, reference);
786    Ok(())
787}
788
789#[test]
790fn test_partial() {
791    let dut = DeepEx::<f64>::parse("z*sin(x)+cos(y)^(sin(z))").unwrap();
792    let d_z = partial_deepex(2, dut.clone(), MissingOpMode::Error).unwrap();
793    assert_float_eq_f64(
794        d_z.eval(&[-0.18961918881278095, -6.383306547710852, 3.1742139703464503])
795            .unwrap(),
796        -0.18346624475117082,
797    );
798    let dut = DeepEx::<f64>::parse("sin(x)/x^2").unwrap();
799    let d_x = partial_deepex(0, dut, MissingOpMode::Error).unwrap();
800    assert_float_eq_f64(
801        d_x.eval(&[-0.18961918881278095]).unwrap(),
802        -27.977974668662565,
803    );
804
805    let dut = DeepEx::<f64>::parse("x^y").unwrap();
806    let d_x = partial_deepex(0, dut, MissingOpMode::Error).unwrap();
807    assert_float_eq_f64(d_x.eval(&[7.5, 3.5]).unwrap(), 539.164392544148);
808}
809
810#[test]
811fn test_partial_3_vars() {
812    fn eval_(deepex: &DeepEx<f64, FloatOpsFactory<f64>, NumberMatcher>, vars: &[f64]) -> f64 {
813        deepex.eval(vars).unwrap()
814    }
815    fn assert(s: &str, vars: &[f64], ref_vals: &[f64]) {
816        let dut = DeepEx::<f64>::parse(s).unwrap();
817        let d_x = partial_deepex(0, dut.clone(), MissingOpMode::Error).unwrap();
818        assert_float_eq_f64(eval_(&d_x, vars), ref_vals[0]);
819        let d_y = partial_deepex(1, dut.clone(), MissingOpMode::Error).unwrap();
820        assert_float_eq_f64(eval_(&d_y, vars), ref_vals[1]);
821        let d_z = partial_deepex(2, dut.clone(), MissingOpMode::Error).unwrap();
822        assert_float_eq_f64(eval_(&d_z, vars), ref_vals[2]);
823    }
824    assert("x+y+z", &[2345.3, 4523.5, 1.2], &[1.0, 1.0, 1.0]);
825    assert(
826        "x^2+y^2+z^2",
827        &[2345.3, 4523.5, 1.2],
828        &[2345.3 * 2.0, 4523.5 * 2.0, 2.4],
829    );
830}
831
832#[test]
833fn test_partial_x2x() {
834    let deepex = DeepEx::<f64>::parse("x * 2 * x").unwrap();
835    let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
836    let result = derivative.eval(&[0.0]).unwrap();
837    assert_float_eq_f64(result, 0.0);
838    let result = derivative.eval(&[1.0]).unwrap();
839    assert_float_eq_f64(result, 4.0);
840}
841
842#[test]
843fn test_partial_cos_squared() {
844    let deepex = DeepEx::<f64>::parse("cos(y) ^ 2").unwrap();
845    let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
846    let result = derivative.eval(&[0.0]).unwrap();
847    assert_float_eq_f64(result, 0.0);
848    let result = derivative.eval(&[1.0]).unwrap();
849    assert_float_eq_f64(result, -0.9092974268256818);
850}
851
852#[test]
853fn test_num_ops() {
854    fn eval_<'a>(
855        deepex: &DeepEx<'a, f64, FloatOpsFactory<f64>, NumberMatcher>,
856        vars: &[f64],
857        val: f64,
858    ) {
859        assert_float_eq_f64(deepex.eval(vars).unwrap(), val);
860    }
861    fn check_shape(deepex: &DeepEx<f64, FloatOpsFactory<f64>, NumberMatcher>, n_nodes: usize) {
862        assert_eq!(deepex.nodes().len(), n_nodes);
863        assert_eq!(deepex.bin_ops().ops.len(), n_nodes - 1);
864        assert_eq!(deepex.bin_ops().reprs.len(), n_nodes - 1);
865    }
866
867    let minus_one = DeepEx::<f64>::parse("-1").unwrap();
868    let one = (minus_one.clone() * minus_one.clone()).unwrap();
869    check_shape(&one, 1);
870    eval_(&one, &[], 1.0);
871}
872
873#[test]
874fn test_partial_combined() {
875    let deepex = DeepEx::<f64>::parse("sin(x) + cos(y) ^ 2").unwrap();
876    let d_y = partial_deepex(1, deepex.clone(), MissingOpMode::Error).unwrap();
877    let result = d_y.eval(&[231.431, 0.0]).unwrap();
878    assert_float_eq_f64(result, 0.0);
879    let result = d_y.eval(&[-12.0, 1.0]).unwrap();
880    assert_float_eq_f64(result, -0.9092974268256818);
881    let d_x = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
882    let result = d_x.eval(&[231.431, 0.0]).unwrap();
883    assert_float_eq_f64(result, 0.5002954462477305);
884    let result = d_x.eval(&[-12.0, 1.0]).unwrap();
885    assert_float_eq_f64(result, 0.8438539587324921);
886}
887
888#[test]
889fn test_partial_derivative_second_var() {
890    let deepex = DeepEx::<f64>::parse("sin(x) + cos(y)").unwrap();
891    let derivative = partial_deepex(1, deepex.clone(), MissingOpMode::Error).unwrap();
892    let result = derivative.eval(&[231.431, 0.0]).unwrap();
893    assert_float_eq_f64(result, 0.0);
894    let result = derivative.eval(&[-12.0, 1.0]).unwrap();
895    assert_float_eq_f64(result, -0.8414709848078965);
896}
897
898#[test]
899fn test_partial_derivative_first_var() {
900    let deepex = DeepEx::<f64>::parse("sin(x) + cos(y)").unwrap();
901    let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error).unwrap();
902    let result = derivative.eval(&[0.0, 2345.03]).unwrap();
903    assert_float_eq_f64(result, 1.0);
904    let result = derivative.eval(&[1.0, 43212.43]).unwrap();
905    assert_float_eq_f64(result, 0.5403023058681398);
906}
907
908#[test]
909fn test_partial_inner() {
910    fn test(text: &str, vals: &[f64], ref_vals: &[f64], var_idx: usize) {
911        let partial_derivative_ops =
912            make_partial_derivative_ops::<f64, FloatOpsFactory<f64>, NumberMatcher>();
913        let deepex_1 = DeepEx::<f64>::parse(text).unwrap();
914        let deri = partial_derivative_inner(
915            var_idx,
916            deepex_1,
917            &partial_derivative_ops,
918            MissingOpMode::Error,
919        )
920        .unwrap();
921        for i in 0..vals.len() {
922            assert_float_eq_f64(deri.eval(&[vals[i]]).unwrap(), ref_vals[i]);
923        }
924    }
925    test("sin(x)", &[1.0, 0.0, 2.0], &[1.0, 1.0, 1.0], 0);
926    test("sin(x^2)", &[1.0, 0.0, 2.0], &[2.0, 0.0, 4.0], 0);
927}
928
929#[test]
930fn test_partial_outer() {
931    fn test(text: &str, vals: &[f64], ref_vals: &[f64]) {
932        let partial_derivative_ops =
933            make_partial_derivative_ops::<f64, FloatOpsFactory<f64>, NumberMatcher>();
934        let deepex_1 = DeepEx::<f64>::parse(text).unwrap();
935        let deepex = deepex_1.nodes()[0].clone();
936
937        if let DeepNode::Expr(e) = deepex {
938            let deri = partial_derivative_outer(*e, &partial_derivative_ops).unwrap();
939            for i in 0..vals.len() {
940                assert_float_eq_f64(deri.eval(&[vals[i]]).unwrap(), ref_vals[i]);
941            }
942        }
943    }
944    test("x", &[1.0, 0.0, 2.0], &[1.0, 0.0, 2.0]);
945    test(
946        "sin(x)",
947        &[1.0, 0.0, 2.0],
948        &[0.5403023058681398, 1.0, -0.4161468365471424],
949    );
950}
951
952#[test]
953fn test_partial_derivative_simple() -> ExResult<()> {
954    let deepex = DeepEx::<f64>::parse("1")?;
955    let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
956
957    assert_eq!(derivative.nodes().len(), 1);
958    assert_eq!(derivative.bin_ops().ops.len(), 0);
959    match derivative.nodes()[0] {
960        DeepNode::Num(n) => assert_float_eq_f64(n, 0.0),
961        _ => unreachable!(),
962    }
963    let deepex = DeepEx::<f64>::parse("x")?;
964    let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
965    assert_eq!(derivative.nodes().len(), 1);
966    assert_eq!(derivative.bin_ops().ops.len(), 0);
967    match derivative.nodes()[0] {
968        DeepNode::Num(n) => assert_float_eq_f64(n, 1.0),
969        _ => unreachable!(),
970    }
971    let deepex = DeepEx::<f64>::parse("x^2")?;
972    let derivative = partial_deepex(0, deepex, MissingOpMode::Error)?;
973    let result = derivative.eval(&[4.5])?;
974    assert_float_eq_f64(result, 9.0);
975
976    let deepex = DeepEx::<f64>::parse("sin(x)")?;
977    let derivative = partial_deepex(0, deepex.clone(), MissingOpMode::Error)?;
978    let result = derivative.eval(&[0.0])?;
979    assert_float_eq_f64(result, 1.0);
980    let result = derivative.eval(&[1.0])?;
981    assert_float_eq_f64(result, 0.5403023058681398);
982    Ok(())
983}