opensrdk_symbolic_computation/expression/
assign.rs1use crate::{AbstractSize, ConstantValue, Expression, ExpressionArray};
2use std::collections::HashMap;
3
4impl Expression {
5 pub fn assign(self, variables: &HashMap<&str, ConstantValue>) -> Expression {
6 match self {
7 Expression::Variable(id, sizes) => {
8 let v = variables.get(id.as_str());
9
10 match v {
11 Some(v) => {
12 if sizes != v.sizes().into_abstract_size() {
13 panic!("Variable {} has sizes {:?} but is assigned a value with sizes {:?}", id, sizes, v.sizes());
14 }
15 v.clone().into()
16 }
17 None => Expression::Variable(id.clone(), sizes.clone()),
18 }
19 }
20 Expression::Constant(_) => self,
21 Expression::PartialVariable(v) => Expression::PartialVariable(
22 ExpressionArray::from_factory(v.sizes().to_vec(), |indices| {
23 v[indices].clone().assign(variables)
24 }),
25 ),
26
27 Expression::Add(l, r) => l.assign(variables) + r.assign(variables),
28 Expression::Sub(l, r) => l.assign(variables) - r.assign(variables),
29 Expression::Mul(l, r) => l.assign(variables) * r.assign(variables),
30 Expression::Div(l, r) => l.assign(variables) / r.assign(variables),
31 Expression::Neg(v) => -v.assign(variables),
32 Expression::Transcendental(v) => v.assign(variables),
33 Expression::Tensor(v) => v.assign(variables),
34 Expression::Matrix(v) => v.assign(variables),
35 }
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use std::collections::HashMap;
42
43 use opensrdk_linear_algebra::sparse::SparseTensor;
44
45 use crate::{
46 new_partial_variable, new_variable, new_variable_tensor, AbstractSize, ConstantValue,
47 Expression, ExpressionArray,
48 };
49
50 #[test]
51 fn it_works1() {
52 let id = "x";
53 let ea = new_variable((id).to_string());
54 let mut hash1 = HashMap::new();
55 hash1.insert(id, ConstantValue::Scalar(2.0));
56 let result = ea.assign(&hash1);
57 assert_eq!(result, Expression::from(ConstantValue::Scalar(2.0)))
58 }
59
60 #[test]
61 fn it_works2() {
62 let id = "x";
63 let mut hash1 = HashMap::new();
64
65 let mut hash = HashMap::new();
66 hash.insert(vec![3usize; 8], 2.0);
67 hash.insert(vec![1usize; 8], 3.0);
68 hash.insert(vec![4usize; 8], 4.0);
69 hash.insert(vec![5usize; 8], 2.0);
70 let c = SparseTensor::from(vec![6usize; 8], hash.clone()).unwrap();
71
72 hash1.insert(id, ConstantValue::Tensor(c));
73
74 let ec = new_variable_tensor(id.to_string(), [6usize; 8].into_abstract_size());
75
76 let result = ec.assign(&hash1);
77 assert_eq!(
78 result,
79 Expression::from(ConstantValue::Tensor(
80 SparseTensor::from(vec![6usize; 8], hash).unwrap()
81 ))
82 )
83 }
84
85 #[test]
86 fn it_works3() {
87 let x = new_variable("x".to_string());
88 let y = new_variable("y".to_string());
89
90 let expression = x.clone().sin() + y.clone().cos().exp();
91
92 let theta_map = &mut HashMap::new();
93 theta_map.insert("x", ConstantValue::Scalar(3f64));
94 theta_map.insert("y", ConstantValue::Scalar(7f64));
95
96 println!("{:#?}", expression);
97 println!("{:#?}", expression.assign(&*theta_map));
98 }
99
100 #[test]
101 fn it_works4() {
102 let a = new_variable("a".to_string());
103 let b = new_variable("b".to_string());
104 let c = new_variable("c".to_string());
105 let d = new_variable("d".to_string());
106 let e = new_variable("e".to_string());
107 let f = new_variable("f".to_string());
108
109 let add_1 = a.clone().sin() + b.clone().cos().exp();
110 let add_2 = add_1.clone() * c.clone();
111 let add_3 = d.clone().sin() - e.clone().cos().exp();
112 let add_4 = add_3.clone() / f.clone();
113 let add_5 = add_2.clone() + add_4.clone();
114
115 let theta_map = &mut HashMap::new();
116 theta_map.insert("a", ConstantValue::Scalar(3f64));
117 theta_map.insert("b", ConstantValue::Scalar(7f64));
118 theta_map.insert("c", ConstantValue::Scalar(-3f64));
119 theta_map.insert("d", ConstantValue::Scalar(-7f64));
120 theta_map.insert("e", ConstantValue::Scalar(4f64));
121 theta_map.insert("f", ConstantValue::Scalar(6f64));
122
123 println!("{:#?}", add_1);
124 println!("{:#?}", add_1.assign(&*theta_map));
125 println!("{:#?}", add_5);
126 println!("{:#?}", add_5.assign(&*theta_map));
127 }
128}