opensrdk_symbolic_computation/expression/tensor_expression/operations/
direct.rs1use opensrdk_linear_algebra::sparse::SparseTensor;
2
3use crate::{BracketsLevel, Expression, Size, TensorExpression};
4use std::{collections::HashMap, iter::once};
5
6pub trait DirectProduct {
7 fn direct_product(self) -> Expression;
8}
9
10impl<I> DirectProduct for I
11where
12 I: Iterator<Item = Expression>,
13{
14 fn direct_product(self) -> Expression {
15 TensorExpression::DirectProduct(self.collect()).into()
16 }
17}
18
19impl Expression {
20 pub fn direct(self, rhs: Expression) -> Expression {
21 vec![self, rhs].into_iter().direct_product()
22 }
23}
24
25impl TensorExpression {
26 pub(crate) fn diff_direct_product(
27 terms: &Vec<Expression>,
28 symbols: &[&str],
29 ) -> Vec<Expression> {
30 let terms_len = terms.len();
31 let symbols_len = symbols.len();
32
33 let result = (0..terms_len)
34 .map(|i| {
35 let elems_left = (0..i).map(|j| terms[j].clone()).direct_product();
36 let elems_right = (i + 1..terms_len)
37 .map(|k| terms[k].clone())
38 .direct_product();
39 let elem_diff = terms[i].differential(symbols);
40
41 let elems = (0..symbols_len)
42 .map(|l| {
43 elems_left
44 .clone()
45 .direct(elem_diff[l].clone())
46 .direct(elems_right.clone())
47 })
48 .collect::<Vec<Expression>>();
49 elems
50 })
51 .fold(vec![Expression::from(0f64); symbols_len], |sum, x| {
52 let result_orig = (0..symbols_len)
53 .map(|m| sum[m].clone() + x[m].clone())
54 .collect::<Vec<Expression>>();
55 result_orig
56 });
57 result
58 }
59
60 pub(crate) fn tex_code_direct_product(
61 terms: &Vec<Expression>,
62 symbols: &HashMap<&str, &str>,
63 brackets_level: BracketsLevel,
64 ) -> String {
65 let inner = terms
66 .into_iter()
67 .map(|t| t.tex_code(symbols))
68 .collect::<Vec<_>>()
69 .join(r" \otimes ");
70
71 match brackets_level {
72 BracketsLevel::None => inner,
73 BracketsLevel::ForMul | BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
74 format!(r"\left({}\right)", inner)
75 }
76 }
77 }
78
79 pub(crate) fn size_direct_product(terms: &Vec<Expression>) -> Vec<Size> {
80 terms
81 .into_iter()
82 .map(|t| t.sizes())
83 .fold(vec![], |mut acc, next| {
84 if acc.len() < next.len() {
85 for i in 0..acc.len() {
86 if next[i] == Size::Many {
87 acc[i] = next[i];
88 }
89 }
90 acc.extend(next[acc.len()..].iter().copied());
91 } else {
92 for i in 0..next.len() {
93 if next[i] == Size::Many {
94 acc[i] = next[i];
95 }
96 }
97 }
98 acc
99 })
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use std::{
106 collections::{HashMap, HashSet},
107 ops::Add,
108 };
109
110 use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
111
112 use crate::{new_variable, Expression, MatrixExpression, TensorExpression};
113
114 #[test]
115 fn it_works() {
116 let mut hash1 = HashMap::new();
117 hash1.insert(vec![3usize; 8], 2.0);
118 hash1.insert(vec![1usize; 8], 3.0);
119 hash1.insert(vec![4usize; 8], 4.0);
120 hash1.insert(vec![5usize; 8], 2.0);
121 let a = SparseTensor::from(vec![6usize; 8], hash1).unwrap();
122
123 let ea = Expression::from(a);
124
125 let mut hash2 = HashMap::new();
126 hash2.insert(vec![3usize; 8], 2.0);
127 hash2.insert(vec![2usize; 8], 3.0);
128 hash2.insert(vec![4usize; 8], 1.0);
129 let b = SparseTensor::from(vec![6usize; 8], hash2).unwrap();
130
131 let eb = Expression::from(b);
132
133 let dp = ea.direct(eb);
134
135 println!("{:?}", dp);
136 }
137
138 #[test]
139 fn it_works1() {
140 let mut hash1 = HashMap::new();
141 hash1.insert(vec![3usize; 8], 2.0);
142 hash1.insert(vec![1usize; 8], 3.0);
143 hash1.insert(vec![4usize; 8], 4.0);
144 hash1.insert(vec![5usize; 8], 2.0);
145 let a = SparseTensor::from(vec![6usize; 8], hash1).unwrap();
146
147 let ea = Expression::from(a);
148
149 let mut hash2 = HashMap::new();
150 hash2.insert(vec![3usize; 8], 2.0);
151 hash2.insert(vec![2usize; 8], 3.0);
152 hash2.insert(vec![4usize; 8], 1.0);
153 let b = SparseTensor::from(vec![6usize; 8], hash2).unwrap();
154
155 let eb = Expression::from(b);
156
157 let id = "x";
158 let ec = new_variable((id).to_string());
159
160 let ids = &["x", "y"];
161
162 let diff_dp = TensorExpression::diff_direct_product(&vec![ec.clone(), ea, eb, ec], ids);
163 println!("{:?}", diff_dp);
164
165 let tex_symbols = vec![("x", "y")].into_iter().collect();
166 let tex_x = diff_dp[0].tex_code(&tex_symbols);
167 let tex_y = diff_dp[1].tex_code(&tex_symbols);
168
169 println!("{:?}", tex_x);
170 println!("{:?}", tex_y);
171 }
172}