opensrdk_symbolic_computation/expression/
variable.rs1use crate::{Expression, Size, TensorExpression};
2use std::{collections::HashSet, iter::once};
3
4pub fn new_variable(id: String) -> Expression {
5 Expression::Variable(id, vec![])
6}
7
8impl Expression {
9 pub fn variable_ids(&self) -> HashSet<&str> {
10 match self {
11 Expression::Variable(id, _) => once(id.as_str()).collect::<HashSet<_>>(),
12 Expression::Constant(_) => HashSet::new(),
13 Expression::PartialVariable(v) => v
14 .elems()
15 .values()
16 .into_iter()
17 .flat_map(|v| v.variable_ids())
18 .collect(),
19 Expression::Add(l, r) => l
20 .variable_ids()
21 .into_iter()
22 .chain(r.variable_ids().into_iter())
23 .collect(),
24 Expression::Sub(l, r) => l
25 .variable_ids()
26 .into_iter()
27 .chain(r.variable_ids().into_iter())
28 .collect(),
29 Expression::Mul(l, r) => l
30 .variable_ids()
31 .into_iter()
32 .chain(r.variable_ids().into_iter())
33 .collect(),
34 Expression::Div(l, r) => l
35 .variable_ids()
36 .into_iter()
37 .chain(r.variable_ids().into_iter())
38 .collect(),
39 Expression::Neg(v) => v.variable_ids(),
40 Expression::Transcendental(v) => v.variable_ids(),
41 Expression::Tensor(v) => v.variable_ids(),
42 Expression::Matrix(v) => v.variable_ids(),
43 }
44 }
45
46 pub(crate) fn diff_variable(
47 symbol: &String,
48 sizes: &Vec<Size>,
49 variable_ids: &[&str],
50 ) -> Vec<Expression> {
51 let rank = sizes.len();
52 variable_ids
53 .iter()
54 .map(|&s| {
55 if s == symbol.as_str() {
56 if rank == 0 {
57 1.0.into()
58 } else {
59 TensorExpression::KroneckerDeltas(
60 (0..rank).map(|r| [r, r + rank]).collect(),
61 )
62 .into()
63 }
64 } else {
65 0.0.into()
66 }
67 })
68 .collect()
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use std::collections::HashSet;
75
76 use crate::new_variable;
77
78 #[test]
79 fn it_works() {
80 let id = "x";
81 let a = HashSet::from([id; 1]);
82 let ea = new_variable((id).to_string());
83 let ha = ea.variable_ids();
84
85 assert_eq!(a, ha);
86 }
87 #[test]
88 fn it_works2() {
89 let x = new_variable("x".to_string());
90 let mu = new_variable("mu".to_string());
91 let sigma = new_variable("sigma".to_string());
92 let expression = x * mu / sigma;
93 let diff_x = expression.differential(&["x"])[0].clone();
94 let diff_mu = expression.differential(&["mu"])[0].clone();
95 let diff_sigma = expression.differential(&["sigma"])[0].clone();
96 let diff_anpan = expression.differential(&["anpan"])[0].clone();
97
98 let tex_symbols = vec![("x", "x"), ("mu", r"\mu"), ("sigma", r"\Sigma")]
99 .into_iter()
100 .collect();
101
102 println!("{:#?}", diff_x.tex_code(&tex_symbols));
103 println!("{:#?}", diff_mu.tex_code(&tex_symbols));
104 println!("{:#?}", diff_sigma.tex_code(&tex_symbols));
105 println!("{:#?}", diff_anpan.tex_code(&tex_symbols));
106 }
107}