opensrdk_symbolic_computation/expression/tensor_expression/operations/
direct.rs

1use opensrdk_linear_algebra::sparse::SparseTensor;
2
3use crate::{BracketsLevel, Expression, Size, TensorExpression};
4use std::{collections::HashMap, iter::once};
5
6pub trait DirectProduct {
7    fn direct_product(self) -> Expression;
8}
9
10impl<I> DirectProduct for I
11where
12    I: Iterator<Item = Expression>,
13{
14    fn direct_product(self) -> Expression {
15        TensorExpression::DirectProduct(self.collect()).into()
16    }
17}
18
19impl Expression {
20    pub fn direct(self, rhs: Expression) -> Expression {
21        vec![self, rhs].into_iter().direct_product()
22    }
23}
24
25impl TensorExpression {
26    pub(crate) fn diff_direct_product(
27        terms: &Vec<Expression>,
28        symbols: &[&str],
29    ) -> Vec<Expression> {
30        let terms_len = terms.len();
31        let symbols_len = symbols.len();
32
33        let result = (0..terms_len)
34            .map(|i| {
35                let elems_left = (0..i).map(|j| terms[j].clone()).direct_product();
36                let elems_right = (i + 1..terms_len)
37                    .map(|k| terms[k].clone())
38                    .direct_product();
39                let elem_diff = terms[i].differential(symbols);
40
41                let elems = (0..symbols_len)
42                    .map(|l| {
43                        elems_left
44                            .clone()
45                            .direct(elem_diff[l].clone())
46                            .direct(elems_right.clone())
47                    })
48                    .collect::<Vec<Expression>>();
49                elems
50            })
51            .fold(vec![Expression::from(0f64); symbols_len], |sum, x| {
52                let result_orig = (0..symbols_len)
53                    .map(|m| sum[m].clone() + x[m].clone())
54                    .collect::<Vec<Expression>>();
55                result_orig
56            });
57        result
58    }
59
60    pub(crate) fn tex_code_direct_product(
61        terms: &Vec<Expression>,
62        symbols: &HashMap<&str, &str>,
63        brackets_level: BracketsLevel,
64    ) -> String {
65        let inner = terms
66            .into_iter()
67            .map(|t| t.tex_code(symbols))
68            .collect::<Vec<_>>()
69            .join(r" \otimes ");
70
71        match brackets_level {
72            BracketsLevel::None => inner,
73            BracketsLevel::ForMul | BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
74                format!(r"\left({}\right)", inner)
75            }
76        }
77    }
78
79    pub(crate) fn size_direct_product(terms: &Vec<Expression>) -> Vec<Size> {
80        terms
81            .into_iter()
82            .map(|t| t.sizes())
83            .fold(vec![], |mut acc, next| {
84                if acc.len() < next.len() {
85                    for i in 0..acc.len() {
86                        if next[i] == Size::Many {
87                            acc[i] = next[i];
88                        }
89                    }
90                    acc.extend(next[acc.len()..].iter().copied());
91                } else {
92                    for i in 0..next.len() {
93                        if next[i] == Size::Many {
94                            acc[i] = next[i];
95                        }
96                    }
97                }
98                acc
99            })
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use std::{
106        collections::{HashMap, HashSet},
107        ops::Add,
108    };
109
110    use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
111
112    use crate::{new_variable, Expression, MatrixExpression, TensorExpression};
113
114    #[test]
115    fn it_works() {
116        let mut hash1 = HashMap::new();
117        hash1.insert(vec![3usize; 8], 2.0);
118        hash1.insert(vec![1usize; 8], 3.0);
119        hash1.insert(vec![4usize; 8], 4.0);
120        hash1.insert(vec![5usize; 8], 2.0);
121        let a = SparseTensor::from(vec![6usize; 8], hash1).unwrap();
122
123        let ea = Expression::from(a);
124
125        let mut hash2 = HashMap::new();
126        hash2.insert(vec![3usize; 8], 2.0);
127        hash2.insert(vec![2usize; 8], 3.0);
128        hash2.insert(vec![4usize; 8], 1.0);
129        let b = SparseTensor::from(vec![6usize; 8], hash2).unwrap();
130
131        let eb = Expression::from(b);
132
133        let dp = ea.direct(eb);
134
135        println!("{:?}", dp);
136    }
137
138    #[test]
139    fn it_works1() {
140        let mut hash1 = HashMap::new();
141        hash1.insert(vec![3usize; 8], 2.0);
142        hash1.insert(vec![1usize; 8], 3.0);
143        hash1.insert(vec![4usize; 8], 4.0);
144        hash1.insert(vec![5usize; 8], 2.0);
145        let a = SparseTensor::from(vec![6usize; 8], hash1).unwrap();
146
147        let ea = Expression::from(a);
148
149        let mut hash2 = HashMap::new();
150        hash2.insert(vec![3usize; 8], 2.0);
151        hash2.insert(vec![2usize; 8], 3.0);
152        hash2.insert(vec![4usize; 8], 1.0);
153        let b = SparseTensor::from(vec![6usize; 8], hash2).unwrap();
154
155        let eb = Expression::from(b);
156
157        let id = "x";
158        let ec = new_variable((id).to_string());
159
160        let ids = &["x", "y"];
161
162        let diff_dp = TensorExpression::diff_direct_product(&vec![ec.clone(), ea, eb, ec], ids);
163        println!("{:?}", diff_dp);
164
165        let tex_symbols = vec![("x", "y")].into_iter().collect();
166        let tex_x = diff_dp[0].tex_code(&tex_symbols);
167        let tex_y = diff_dp[1].tex_code(&tex_symbols);
168
169        println!("{:?}", tex_x);
170        println!("{:?}", tex_y);
171    }
172}