opensrdk_symbolic_computation/expression/
variable.rs

1use crate::{Expression, Size, TensorExpression};
2use std::{collections::HashSet, iter::once};
3
4pub fn new_variable(id: String) -> Expression {
5    Expression::Variable(id, vec![])
6}
7
8impl Expression {
9    pub fn variable_ids(&self) -> HashSet<&str> {
10        match self {
11            Expression::Variable(id, _) => once(id.as_str()).collect::<HashSet<_>>(),
12            Expression::Constant(_) => HashSet::new(),
13            Expression::PartialVariable(v) => v
14                .elems()
15                .values()
16                .into_iter()
17                .flat_map(|v| v.variable_ids())
18                .collect(),
19            Expression::Add(l, r) => l
20                .variable_ids()
21                .into_iter()
22                .chain(r.variable_ids().into_iter())
23                .collect(),
24            Expression::Sub(l, r) => l
25                .variable_ids()
26                .into_iter()
27                .chain(r.variable_ids().into_iter())
28                .collect(),
29            Expression::Mul(l, r) => l
30                .variable_ids()
31                .into_iter()
32                .chain(r.variable_ids().into_iter())
33                .collect(),
34            Expression::Div(l, r) => l
35                .variable_ids()
36                .into_iter()
37                .chain(r.variable_ids().into_iter())
38                .collect(),
39            Expression::Neg(v) => v.variable_ids(),
40            Expression::Transcendental(v) => v.variable_ids(),
41            Expression::Tensor(v) => v.variable_ids(),
42            Expression::Matrix(v) => v.variable_ids(),
43        }
44    }
45
46    pub(crate) fn diff_variable(
47        symbol: &String,
48        sizes: &Vec<Size>,
49        variable_ids: &[&str],
50    ) -> Vec<Expression> {
51        let rank = sizes.len();
52        variable_ids
53            .iter()
54            .map(|&s| {
55                if s == symbol.as_str() {
56                    if rank == 0 {
57                        1.0.into()
58                    } else {
59                        TensorExpression::KroneckerDeltas(
60                            (0..rank).map(|r| [r, r + rank]).collect(),
61                        )
62                        .into()
63                    }
64                } else {
65                    0.0.into()
66                }
67            })
68            .collect()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::collections::HashSet;
75
76    use crate::new_variable;
77
78    #[test]
79    fn it_works() {
80        let id = "x";
81        let a = HashSet::from([id; 1]);
82        let ea = new_variable((id).to_string());
83        let ha = ea.variable_ids();
84
85        assert_eq!(a, ha);
86    }
87    #[test]
88    fn it_works2() {
89        let x = new_variable("x".to_string());
90        let mu = new_variable("mu".to_string());
91        let sigma = new_variable("sigma".to_string());
92        let expression = x * mu / sigma;
93        let diff_x = expression.differential(&["x"])[0].clone();
94        let diff_mu = expression.differential(&["mu"])[0].clone();
95        let diff_sigma = expression.differential(&["sigma"])[0].clone();
96        let diff_anpan = expression.differential(&["anpan"])[0].clone();
97
98        let tex_symbols = vec![("x", "x"), ("mu", r"\mu"), ("sigma", r"\Sigma")]
99            .into_iter()
100            .collect();
101
102        println!("{:#?}", diff_x.tex_code(&tex_symbols));
103        println!("{:#?}", diff_mu.tex_code(&tex_symbols));
104        println!("{:#?}", diff_sigma.tex_code(&tex_symbols));
105        println!("{:#?}", diff_anpan.tex_code(&tex_symbols));
106    }
107}