1use crate::{parser::*, *};
3use anyhow::Result;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, ToTokens, TokenStreamExt};
6use std::{
7 collections::{BTreeMap, BTreeSet},
8 fmt,
9 str::FromStr,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Subscript {
14 raw: RawSubscript,
15 position: Position,
16}
17
18impl Subscript {
19 pub fn raw(&self) -> &RawSubscript {
20 &self.raw
21 }
22
23 pub fn position(&self) -> &Position {
24 &self.position
25 }
26
27 pub fn indices(&self) -> Vec<char> {
28 match &self.raw {
29 RawSubscript::Indices(indices) => indices.clone(),
30 RawSubscript::Ellipsis { start, end } => {
31 start.iter().chain(end.iter()).cloned().collect()
32 }
33 }
34 }
35}
36
37impl ToTokens for Subscript {
38 fn to_tokens(&self, tokens: &mut TokenStream) {
39 ToTokens::to_tokens(&self.position, tokens)
40 }
41}
42
43#[cfg_attr(doc, katexit::katexit)]
44#[derive(Clone, PartialEq, Eq)]
46pub struct Subscripts {
47 pub inputs: Vec<Subscript>,
49 pub output: Subscript,
51}
52
53impl fmt::Debug for Subscripts {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 for (n, input) in self.inputs.iter().enumerate() {
57 write!(f, "{}", input.raw)?;
58 if n < self.inputs.len() - 1 {
59 write!(f, ",")?;
60 }
61 }
62 write!(f, "->{} | ", self.output.raw)?;
63
64 for (n, input) in self.inputs.iter().enumerate() {
65 write!(f, "{}", input.position)?;
66 if n < self.inputs.len() - 1 {
67 write!(f, ",")?;
68 }
69 }
70 write!(f, "->{}", self.output.position)?;
71 Ok(())
72 }
73}
74
75impl fmt::Display for Subscripts {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 fmt::Debug::fmt(self, f)
78 }
79}
80
81impl ToTokens for Subscripts {
82 fn to_tokens(&self, tokens: &mut TokenStream) {
83 let fn_name = format_ident!("{}", self.escaped_ident());
84 let args = &self.inputs;
85 let out = &self.output;
86 tokens.append_all(quote! {
87 let #out = #fn_name(#(#args),*);
88 });
89 }
90}
91
92impl Subscripts {
93 pub fn compute_order(&self) -> usize {
95 self.memory_order() + self.contraction_indices().len()
96 }
97
98 pub fn memory_order(&self) -> usize {
100 self.output.indices().len()
101 }
102
103 pub fn from_raw(names: &mut Namespace, raw: RawSubscripts) -> Self {
136 let inputs = raw
137 .inputs
138 .iter()
139 .enumerate()
140 .map(|(i, indices)| Subscript {
141 raw: indices.clone(),
142 position: Position::Arg(i),
143 })
144 .collect();
145 let position = names.new_ident();
146 if let Some(output) = raw.output {
147 return Subscripts {
148 inputs,
149 output: Subscript {
150 raw: output,
151 position,
152 },
153 };
154 }
155
156 let count = count_indices(&inputs);
157 let output = Subscript {
158 raw: RawSubscript::Indices(
159 count
160 .iter()
161 .filter_map(|(key, value)| if *value == 1 { Some(*key) } else { None })
162 .collect(),
163 ),
164 position,
165 };
166 Subscripts { inputs, output }
167 }
168
169 pub fn from_raw_indices(names: &mut Namespace, indices: &str) -> Result<Self> {
170 let raw = RawSubscripts::from_str(indices)?;
171 Ok(Self::from_raw(names, raw))
172 }
173
174 pub fn contraction_indices(&self) -> BTreeSet<char> {
196 let count = count_indices(&self.inputs);
197 let mut subscripts: BTreeSet<char> = count
198 .into_iter()
199 .filter_map(|(key, value)| if value > 1 { Some(key) } else { None })
200 .collect();
201 for c in &self.output.indices() {
202 subscripts.remove(c);
203 }
204 subscripts
205 }
206
207 pub fn factorize(
233 &self,
234 names: &mut Namespace,
235 inners: BTreeSet<Position>,
236 ) -> Result<(Self, Self)> {
237 let mut inner_inputs = Vec::new();
238 let mut outer_inputs = Vec::new();
239 let mut indices: BTreeMap<char, (usize , usize )> = BTreeMap::new();
240 for input in &self.inputs {
241 if inners.contains(&input.position) {
242 inner_inputs.push(input.clone());
243 for c in input.indices() {
244 indices
245 .entry(c)
246 .and_modify(|(i, _)| *i += 1)
247 .or_insert((1, 0));
248 }
249 } else {
250 outer_inputs.push(input.clone());
251 for c in input.indices() {
252 indices
253 .entry(c)
254 .and_modify(|(_, o)| *o += 1)
255 .or_insert((0, 1));
256 }
257 }
258 }
259 let out = Subscript {
260 raw: RawSubscript::Indices(
261 indices
262 .into_iter()
263 .filter_map(|(key, (i, o))| {
264 if i == 1 || (i >= 2 && o > 0) {
265 Some(key)
266 } else {
267 None
268 }
269 })
270 .collect(),
271 ),
272 position: names.new_ident(),
273 };
274 outer_inputs.insert(0, out.clone());
275 Ok((
276 Subscripts {
277 inputs: inner_inputs,
278 output: out,
279 },
280 Subscripts {
281 inputs: outer_inputs,
282 output: self.output.clone(),
283 },
284 ))
285 }
286
287 pub fn escaped_ident(&self) -> String {
293 use std::fmt::Write;
294 let mut out = String::new();
295 for input in &self.inputs {
296 write!(out, "{}", input.raw).unwrap();
297 write!(out, "_").unwrap();
298 }
299 write!(out, "_{}", self.output.raw).unwrap();
300 out
301 }
302}
303
304fn count_indices(inputs: &[Subscript]) -> BTreeMap<char, u32> {
305 let mut count = BTreeMap::new();
306 for input in inputs {
307 for c in input.indices() {
308 count.entry(c).and_modify(|n| *n += 1).or_insert(1);
309 }
310 }
311 count
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn escaped_ident() {
320 let mut names = Namespace::init();
321
322 let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk->ik").unwrap();
323 assert_eq!(subscripts.escaped_ident(), "ij_jk__ik");
324
325 let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk").unwrap();
327 assert_eq!(subscripts.escaped_ident(), "ij_jk__ik");
328
329 let subscripts = Subscripts::from_raw_indices(&mut names, "i,i").unwrap();
331 assert_eq!(subscripts.escaped_ident(), "i_i__");
332
333 let subscripts = Subscripts::from_raw_indices(&mut names, "ij...,jk...->ik...").unwrap();
335 assert_eq!(subscripts.escaped_ident(), "ij____jk_____ik___");
336 }
337}