opensrdk_symbolic_computation/expression/tensor_expression/operations/
dot.rs

1use crate::{BracketsLevel, Expression, ExpressionArray, Size, TensorExpression};
2use opensrdk_linear_algebra::{generate_rank_combinations, RankIndex};
3use std::{collections::HashMap, iter::once};
4
5type TermIndex = usize;
6
7fn next_char(c: char, count: usize) -> char {
8    std::char::from_u32(c as u32 + count as u32).unwrap_or(c)
9}
10
11pub trait DotProduct {
12    fn dot_product(self, rank_combinations: &[HashMap<RankIndex, String>]) -> Expression;
13}
14
15impl<I> DotProduct for I
16where
17    I: Iterator<Item = Expression>,
18{
19    fn dot_product(self, rank_combinations: &[HashMap<RankIndex, String>]) -> Expression {
20        // Flatten InnerProd
21        let terms = self
22            .zip(rank_combinations.iter())
23            .flat_map(|(t, rank_combination)| {
24                if let Expression::Tensor(t) = &t {
25                    if let TensorExpression::DotProduct {
26                        terms: t,
27                        rank_combinations,
28                    } = t.as_ref()
29                    {
30                        let t = t.clone();
31                        let mut rank_combinations = rank_combinations.clone();
32                        let not_1dimension_ranks =
33                            TensorExpression::not_1dimension_ranks_in_dot_product(
34                                &t,
35                                &rank_combinations,
36                            );
37
38                        for (&rank, id) in rank_combination.iter() {
39                            if let Some(&term_index) = not_1dimension_ranks.get(&rank) {
40                                rank_combinations[term_index].insert(rank, id.to_owned());
41                            }
42                        }
43
44                        return t
45                            .into_iter()
46                            .zip(rank_combinations.into_iter())
47                            .collect::<Vec<_>>();
48                    }
49                }
50
51                vec![(t, rank_combination.clone())]
52            })
53            .collect::<Vec<_>>();
54
55        if terms.iter().find(|&t| &t.0 == &0.0.into()).is_some() {
56            return 0.0.into();
57        }
58
59        // Merge KroneckerDeltas
60        let deltas = terms
61            .iter()
62            .filter_map(|(t, r)| {
63                if let Expression::Tensor(t) = t {
64                    if let TensorExpression::KroneckerDeltas(rank_pairs) = t.as_ref() {
65                        return Some((rank_pairs.clone(), r));
66                    }
67                }
68
69                None
70            })
71            .collect::<Vec<_>>();
72        let not_deltas = terms
73            .iter()
74            .filter(|(t, _)| {
75                if let Expression::Tensor(t) = t {
76                    if let &TensorExpression::KroneckerDeltas(_) = t.as_ref() {
77                        return false;
78                    }
79                }
80
81                true
82            })
83            .collect::<Vec<_>>();
84
85        let flatten_deltas = deltas
86            .iter()
87            .map(|(t, _)| t)
88            .flatten()
89            .cloned()
90            .collect::<Vec<_>>();
91        let flatten_deltas_combination = deltas
92            .iter()
93            .flat_map(|(_, r)| r.iter())
94            .map(|(&rank, id)| (rank, id.to_owned()))
95            .collect::<HashMap<_, _>>();
96
97        let mut new_terms = not_deltas
98            .iter()
99            .map(|(t, _)| t.clone())
100            .collect::<Vec<_>>();
101        let mut new_rank_combinations = not_deltas
102            .iter()
103            .map(|&(_, r)| r.clone())
104            .collect::<Vec<_>>();
105
106        // TODO: Merge constants
107
108        if flatten_deltas.len() > 0 {
109            let merged_deltas = TensorExpression::KroneckerDeltas(flatten_deltas);
110
111            new_terms.insert(0, merged_deltas.into());
112            new_rank_combinations.insert(0, flatten_deltas_combination);
113        }
114
115        TensorExpression::DotProduct {
116            terms: new_terms,
117            rank_combinations: new_rank_combinations,
118        }
119        .into()
120    }
121}
122
123impl Expression {
124    pub fn dot(self, rhs: Expression, rank_pairs: &[[RankIndex; 2]]) -> Expression {
125        if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
126            // if vl.sizes() == vr.sizes() {
127            //     panic!("Mistach Sizes of Variables");
128            // }
129
130            return Expression::PartialVariable(ExpressionArray::from_factory(
131                vr.sizes().to_vec(),
132                |indices| {
133                    vec![vl[indices].clone(), vr[indices].clone()]
134                        .into_iter()
135                        .dot_product(&generate_rank_combinations(rank_pairs))
136                },
137            ));
138        }
139
140        vec![self, rhs]
141            .into_iter()
142            .dot_product(&generate_rank_combinations(rank_pairs))
143    }
144}
145
146impl TensorExpression {
147    pub(crate) fn diff_dot_product(
148        terms: &Vec<Expression>,
149        rank_combinations: &Vec<HashMap<RankIndex, String>>,
150        symbols: &[&str],
151    ) -> Vec<Expression> {
152        let mut result = terms[0]
153            .differential(symbols)
154            .into_iter()
155            .map(|d| {
156                once(d)
157                    .chain(terms[1..].iter().cloned())
158                    .dot_product(rank_combinations)
159            })
160            .collect::<Vec<_>>();
161
162        for i in 1..terms.len() {
163            result
164                .iter_mut()
165                .zip(terms[i].differential(symbols).into_iter())
166                .for_each(|(r, d)| {
167                    *r = r.clone()
168                        + terms[0..i]
169                            .iter()
170                            .cloned()
171                            .chain(once(d))
172                            .chain(terms[i + 1..].iter().cloned())
173                            .dot_product(rank_combinations);
174                });
175        }
176
177        result
178    }
179
180    pub(crate) fn tex_code_dot_product(
181        terms: &Vec<Expression>,
182        rank_combinations: &Vec<HashMap<RankIndex, String>>,
183        symbols: &HashMap<&str, &str>,
184    ) -> String {
185        let mut ids = Vec::<String>::new();
186        let mut id_index = HashMap::<String, usize>::new();
187
188        for i in 0..terms.len() {
189            for (_, id) in rank_combinations[i].iter() {
190                if !id_index.contains_key(id) {
191                    ids.push(id.clone());
192                    id_index.insert(id.clone(), ids.len() - 1);
193                }
194            }
195        }
196
197        let mut result = String::new();
198        result.push_str(&format!(
199            r"\sum_{{{}}}",
200            ids.iter()
201                .enumerate()
202                .map(|(k, _)| format!("{}", next_char('i', k)))
203                .collect::<Vec<_>>()
204                .join(", ")
205        ));
206
207        for i in 0..terms.len() {
208            let mut sorted = rank_combinations[i].iter().collect::<Vec<_>>();
209            sorted.sort_by(|a, b| a.0.cmp(b.0));
210            result.push_str(&format!(
211                "{}_{{{}}}",
212                terms[i]._tex_code(symbols, BracketsLevel::ForMul),
213                sorted
214                    .into_iter()
215                    .map(|(j, id)| format!("[{}] = {}", j, next_char('i', id_index[id])))
216                    .collect::<Vec<_>>()
217                    .join(", ")
218            ));
219        }
220
221        format!("{{{}}}", result)
222    }
223
224    pub(crate) fn size_dot_product(
225        terms: &Vec<Expression>,
226        rank_combinations: &Vec<HashMap<RankIndex, String>>,
227    ) -> Vec<Size> {
228        let max_rank = terms.iter().map(|vi| vi.sizes().len()).max().unwrap();
229        let mut sizes = vec![Size::One; max_rank];
230
231        for i in 0..terms.len() {
232            let term_sizes = terms[i].sizes();
233
234            for (rank, size) in term_sizes.iter().enumerate() {
235                if sizes[rank] == Size::Many {
236                    continue;
237                }
238                if let Some(_) = rank_combinations[i].get(&rank) {
239                    continue;
240                }
241                sizes.insert(rank, size.clone());
242            }
243        }
244
245        sizes
246    }
247
248    pub fn not_1dimension_ranks_in_dot_product(
249        terms: &Vec<Expression>,
250        rank_combinations: &Vec<HashMap<RankIndex, String>>,
251    ) -> HashMap<RankIndex, TermIndex> {
252        let mut not_1dimension_ranks = HashMap::new();
253
254        for i in 0..terms.len() {
255            let term_sizes = terms[i].sizes();
256            for (rank, size) in term_sizes.iter().enumerate() {
257                if let Some(_) = rank_combinations[i].get(&rank) {
258                    continue;
259                }
260
261                if *size != Size::One {
262                    if not_1dimension_ranks.contains_key(&rank) {
263                        panic!(
264                            "Rank {} is not 1-dimension in terms[{}] and terms[{}]",
265                            rank,
266                            not_1dimension_ranks.get(&rank).unwrap(),
267                            i
268                        );
269                    }
270                    not_1dimension_ranks.insert(rank, i);
271                }
272            }
273        }
274
275        not_1dimension_ranks
276    }
277}