opensrdk_symbolic_computation/expression/tensor_expression/operations/
dot.rs1use crate::{BracketsLevel, Expression, ExpressionArray, Size, TensorExpression};
2use opensrdk_linear_algebra::{generate_rank_combinations, RankIndex};
3use std::{collections::HashMap, iter::once};
4
5type TermIndex = usize;
6
7fn next_char(c: char, count: usize) -> char {
8 std::char::from_u32(c as u32 + count as u32).unwrap_or(c)
9}
10
11pub trait DotProduct {
12 fn dot_product(self, rank_combinations: &[HashMap<RankIndex, String>]) -> Expression;
13}
14
15impl<I> DotProduct for I
16where
17 I: Iterator<Item = Expression>,
18{
19 fn dot_product(self, rank_combinations: &[HashMap<RankIndex, String>]) -> Expression {
20 let terms = self
22 .zip(rank_combinations.iter())
23 .flat_map(|(t, rank_combination)| {
24 if let Expression::Tensor(t) = &t {
25 if let TensorExpression::DotProduct {
26 terms: t,
27 rank_combinations,
28 } = t.as_ref()
29 {
30 let t = t.clone();
31 let mut rank_combinations = rank_combinations.clone();
32 let not_1dimension_ranks =
33 TensorExpression::not_1dimension_ranks_in_dot_product(
34 &t,
35 &rank_combinations,
36 );
37
38 for (&rank, id) in rank_combination.iter() {
39 if let Some(&term_index) = not_1dimension_ranks.get(&rank) {
40 rank_combinations[term_index].insert(rank, id.to_owned());
41 }
42 }
43
44 return t
45 .into_iter()
46 .zip(rank_combinations.into_iter())
47 .collect::<Vec<_>>();
48 }
49 }
50
51 vec![(t, rank_combination.clone())]
52 })
53 .collect::<Vec<_>>();
54
55 if terms.iter().find(|&t| &t.0 == &0.0.into()).is_some() {
56 return 0.0.into();
57 }
58
59 let deltas = terms
61 .iter()
62 .filter_map(|(t, r)| {
63 if let Expression::Tensor(t) = t {
64 if let TensorExpression::KroneckerDeltas(rank_pairs) = t.as_ref() {
65 return Some((rank_pairs.clone(), r));
66 }
67 }
68
69 None
70 })
71 .collect::<Vec<_>>();
72 let not_deltas = terms
73 .iter()
74 .filter(|(t, _)| {
75 if let Expression::Tensor(t) = t {
76 if let &TensorExpression::KroneckerDeltas(_) = t.as_ref() {
77 return false;
78 }
79 }
80
81 true
82 })
83 .collect::<Vec<_>>();
84
85 let flatten_deltas = deltas
86 .iter()
87 .map(|(t, _)| t)
88 .flatten()
89 .cloned()
90 .collect::<Vec<_>>();
91 let flatten_deltas_combination = deltas
92 .iter()
93 .flat_map(|(_, r)| r.iter())
94 .map(|(&rank, id)| (rank, id.to_owned()))
95 .collect::<HashMap<_, _>>();
96
97 let mut new_terms = not_deltas
98 .iter()
99 .map(|(t, _)| t.clone())
100 .collect::<Vec<_>>();
101 let mut new_rank_combinations = not_deltas
102 .iter()
103 .map(|&(_, r)| r.clone())
104 .collect::<Vec<_>>();
105
106 if flatten_deltas.len() > 0 {
109 let merged_deltas = TensorExpression::KroneckerDeltas(flatten_deltas);
110
111 new_terms.insert(0, merged_deltas.into());
112 new_rank_combinations.insert(0, flatten_deltas_combination);
113 }
114
115 TensorExpression::DotProduct {
116 terms: new_terms,
117 rank_combinations: new_rank_combinations,
118 }
119 .into()
120 }
121}
122
123impl Expression {
124 pub fn dot(self, rhs: Expression, rank_pairs: &[[RankIndex; 2]]) -> Expression {
125 if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
126 return Expression::PartialVariable(ExpressionArray::from_factory(
131 vr.sizes().to_vec(),
132 |indices| {
133 vec![vl[indices].clone(), vr[indices].clone()]
134 .into_iter()
135 .dot_product(&generate_rank_combinations(rank_pairs))
136 },
137 ));
138 }
139
140 vec![self, rhs]
141 .into_iter()
142 .dot_product(&generate_rank_combinations(rank_pairs))
143 }
144}
145
146impl TensorExpression {
147 pub(crate) fn diff_dot_product(
148 terms: &Vec<Expression>,
149 rank_combinations: &Vec<HashMap<RankIndex, String>>,
150 symbols: &[&str],
151 ) -> Vec<Expression> {
152 let mut result = terms[0]
153 .differential(symbols)
154 .into_iter()
155 .map(|d| {
156 once(d)
157 .chain(terms[1..].iter().cloned())
158 .dot_product(rank_combinations)
159 })
160 .collect::<Vec<_>>();
161
162 for i in 1..terms.len() {
163 result
164 .iter_mut()
165 .zip(terms[i].differential(symbols).into_iter())
166 .for_each(|(r, d)| {
167 *r = r.clone()
168 + terms[0..i]
169 .iter()
170 .cloned()
171 .chain(once(d))
172 .chain(terms[i + 1..].iter().cloned())
173 .dot_product(rank_combinations);
174 });
175 }
176
177 result
178 }
179
180 pub(crate) fn tex_code_dot_product(
181 terms: &Vec<Expression>,
182 rank_combinations: &Vec<HashMap<RankIndex, String>>,
183 symbols: &HashMap<&str, &str>,
184 ) -> String {
185 let mut ids = Vec::<String>::new();
186 let mut id_index = HashMap::<String, usize>::new();
187
188 for i in 0..terms.len() {
189 for (_, id) in rank_combinations[i].iter() {
190 if !id_index.contains_key(id) {
191 ids.push(id.clone());
192 id_index.insert(id.clone(), ids.len() - 1);
193 }
194 }
195 }
196
197 let mut result = String::new();
198 result.push_str(&format!(
199 r"\sum_{{{}}}",
200 ids.iter()
201 .enumerate()
202 .map(|(k, _)| format!("{}", next_char('i', k)))
203 .collect::<Vec<_>>()
204 .join(", ")
205 ));
206
207 for i in 0..terms.len() {
208 let mut sorted = rank_combinations[i].iter().collect::<Vec<_>>();
209 sorted.sort_by(|a, b| a.0.cmp(b.0));
210 result.push_str(&format!(
211 "{}_{{{}}}",
212 terms[i]._tex_code(symbols, BracketsLevel::ForMul),
213 sorted
214 .into_iter()
215 .map(|(j, id)| format!("[{}] = {}", j, next_char('i', id_index[id])))
216 .collect::<Vec<_>>()
217 .join(", ")
218 ));
219 }
220
221 format!("{{{}}}", result)
222 }
223
224 pub(crate) fn size_dot_product(
225 terms: &Vec<Expression>,
226 rank_combinations: &Vec<HashMap<RankIndex, String>>,
227 ) -> Vec<Size> {
228 let max_rank = terms.iter().map(|vi| vi.sizes().len()).max().unwrap();
229 let mut sizes = vec![Size::One; max_rank];
230
231 for i in 0..terms.len() {
232 let term_sizes = terms[i].sizes();
233
234 for (rank, size) in term_sizes.iter().enumerate() {
235 if sizes[rank] == Size::Many {
236 continue;
237 }
238 if let Some(_) = rank_combinations[i].get(&rank) {
239 continue;
240 }
241 sizes.insert(rank, size.clone());
242 }
243 }
244
245 sizes
246 }
247
248 pub fn not_1dimension_ranks_in_dot_product(
249 terms: &Vec<Expression>,
250 rank_combinations: &Vec<HashMap<RankIndex, String>>,
251 ) -> HashMap<RankIndex, TermIndex> {
252 let mut not_1dimension_ranks = HashMap::new();
253
254 for i in 0..terms.len() {
255 let term_sizes = terms[i].sizes();
256 for (rank, size) in term_sizes.iter().enumerate() {
257 if let Some(_) = rank_combinations[i].get(&rank) {
258 continue;
259 }
260
261 if *size != Size::One {
262 if not_1dimension_ranks.contains_key(&rank) {
263 panic!(
264 "Rank {} is not 1-dimension in terms[{}] and terms[{}]",
265 rank,
266 not_1dimension_ranks.get(&rank).unwrap(),
267 i
268 );
269 }
270 not_1dimension_ranks.insert(rank, i);
271 }
272 }
273 }
274
275 not_1dimension_ranks
276 }
277}