opensrdk_symbolic_computation/expression_array/
mod.rs

1pub mod index;
2
3pub use index::*;
4use opensrdk_linear_algebra::indices_cartesian_product;
5
6use crate::Expression;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct ExpressionArray {
12    sizes: Vec<usize>,
13    elems: HashMap<Vec<usize>, Expression>,
14    default: Box<Expression>,
15}
16
17impl ExpressionArray {
18    pub fn new(sizes: Vec<usize>) -> Self {
19        Self {
20            sizes,
21            elems: HashMap::new(),
22            default: Box::new(0.0.into()),
23        }
24    }
25
26    pub fn from_factory(sizes: Vec<usize>, factory: impl Fn(&[usize]) -> Expression) -> Self {
27        let mut elems = HashMap::new();
28        let indices = indices_cartesian_product(&sizes);
29
30        indices.iter().for_each(|index| {
31            elems.insert(index.clone(), factory(&index));
32        });
33
34        Self {
35            sizes,
36            elems,
37            default: Box::new(0.0.into()),
38        }
39    }
40
41    pub fn sizes(&self) -> &[usize] {
42        &self.sizes
43    }
44
45    pub fn elems(&self) -> &HashMap<Vec<usize>, Expression> {
46        &self.elems
47    }
48
49    pub fn eject(self) -> (Vec<usize>, HashMap<Vec<usize>, Expression>) {
50        (self.sizes, self.elems)
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use std::collections::{HashMap, HashSet};
57
58    use opensrdk_linear_algebra::sparse::SparseTensor;
59
60    use crate::{new_variable, Expression, ExpressionArray};
61
62    #[test]
63    fn it_works1() {
64        let test_orig = vec![
65            Expression::from(1f64),
66            Expression::from(3f64),
67            Expression::from(1f64),
68            Expression::from(2f64),
69        ];
70        let factory = |i: &[usize]| test_orig[i[0].clone()].clone();
71        let sizes = vec![1usize, 2usize, 3usize, 4usize];
72        let test = ExpressionArray::from_factory(sizes, factory);
73        println!("{:?}", test);
74    }
75}