opensrdk_symbolic_computation/expression/
assign.rs

1use crate::{AbstractSize, ConstantValue, Expression, ExpressionArray};
2use std::collections::HashMap;
3
4impl Expression {
5    pub fn assign(self, variables: &HashMap<&str, ConstantValue>) -> Expression {
6        match self {
7            Expression::Variable(id, sizes) => {
8                let v = variables.get(id.as_str());
9
10                match v {
11                    Some(v) => {
12                        if sizes != v.sizes().into_abstract_size() {
13                            panic!("Variable {} has sizes {:?} but is assigned a value with sizes {:?}", id, sizes, v.sizes());
14                        }
15                        v.clone().into()
16                    }
17                    None => Expression::Variable(id.clone(), sizes.clone()),
18                }
19            }
20            Expression::Constant(_) => self,
21            Expression::PartialVariable(v) => Expression::PartialVariable(
22                ExpressionArray::from_factory(v.sizes().to_vec(), |indices| {
23                    v[indices].clone().assign(variables)
24                }),
25            ),
26            
27            Expression::Add(l, r) => l.assign(variables) + r.assign(variables),
28            Expression::Sub(l, r) => l.assign(variables) - r.assign(variables),
29            Expression::Mul(l, r) => l.assign(variables) * r.assign(variables),
30            Expression::Div(l, r) => l.assign(variables) / r.assign(variables),
31            Expression::Neg(v) => -v.assign(variables),
32            Expression::Transcendental(v) => v.assign(variables),
33            Expression::Tensor(v) => v.assign(variables),
34            Expression::Matrix(v) => v.assign(variables),
35        }
36    }
37}
38
39#[cfg(test)]
40mod tests {
41    use std::collections::HashMap;
42
43    use opensrdk_linear_algebra::sparse::SparseTensor;
44
45    use crate::{
46        new_partial_variable, new_variable, new_variable_tensor, AbstractSize, ConstantValue,
47        Expression, ExpressionArray,
48    };
49
50    #[test]
51    fn it_works1() {
52        let id = "x";
53        let ea = new_variable((id).to_string());
54        let mut hash1 = HashMap::new();
55        hash1.insert(id, ConstantValue::Scalar(2.0));
56        let result = ea.assign(&hash1);
57        assert_eq!(result, Expression::from(ConstantValue::Scalar(2.0)))
58    }
59
60    #[test]
61    fn it_works2() {
62        let id = "x";
63        let mut hash1 = HashMap::new();
64
65        let mut hash = HashMap::new();
66        hash.insert(vec![3usize; 8], 2.0);
67        hash.insert(vec![1usize; 8], 3.0);
68        hash.insert(vec![4usize; 8], 4.0);
69        hash.insert(vec![5usize; 8], 2.0);
70        let c = SparseTensor::from(vec![6usize; 8], hash.clone()).unwrap();
71
72        hash1.insert(id, ConstantValue::Tensor(c));
73
74        let ec = new_variable_tensor(id.to_string(), [6usize; 8].into_abstract_size());
75
76        let result = ec.assign(&hash1);
77        assert_eq!(
78            result,
79            Expression::from(ConstantValue::Tensor(
80                SparseTensor::from(vec![6usize; 8], hash).unwrap()
81            ))
82        )
83    }
84
85    #[test]
86    fn it_works3() {
87        let x = new_variable("x".to_string());
88        let y = new_variable("y".to_string());
89
90        let expression = x.clone().sin() + y.clone().cos().exp();
91
92        let theta_map = &mut HashMap::new();
93        theta_map.insert("x", ConstantValue::Scalar(3f64));
94        theta_map.insert("y", ConstantValue::Scalar(7f64));
95
96        println!("{:#?}", expression);
97        println!("{:#?}", expression.assign(&*theta_map));
98    }
99
100    #[test]
101    fn it_works4() {
102        let a = new_variable("a".to_string());
103        let b = new_variable("b".to_string());
104        let c = new_variable("c".to_string());
105        let d = new_variable("d".to_string());
106        let e = new_variable("e".to_string());
107        let f = new_variable("f".to_string());
108
109        let add_1 = a.clone().sin() + b.clone().cos().exp();
110        let add_2 = add_1.clone() * c.clone();
111        let add_3 = d.clone().sin() - e.clone().cos().exp();
112        let add_4 = add_3.clone() / f.clone();
113        let add_5 = add_2.clone() + add_4.clone();
114
115        let theta_map = &mut HashMap::new();
116        theta_map.insert("a", ConstantValue::Scalar(3f64));
117        theta_map.insert("b", ConstantValue::Scalar(7f64));
118        theta_map.insert("c", ConstantValue::Scalar(-3f64));
119        theta_map.insert("d", ConstantValue::Scalar(-7f64));
120        theta_map.insert("e", ConstantValue::Scalar(4f64));
121        theta_map.insert("f", ConstantValue::Scalar(6f64));
122
123        println!("{:#?}", add_1);
124        println!("{:#?}", add_1.assign(&*theta_map));
125        println!("{:#?}", add_5);
126        println!("{:#?}", add_5.assign(&*theta_map));
127    }
128}