Skip to main content

ariadnetor_core/einsum/
mod.rs

1//! Einstein summation notation parser and contraction plan
2
3use std::collections::{HashMap, HashSet};
4
5/// Parsed einsum expression with N inputs (indices as ASCII codes)
6///
7/// Supports 1 to N input tensors. Output indices can be explicit (`->out`)
8/// or implicitly inferred (free indices sorted alphabetically).
9///
10/// # Examples
11///
12/// ```
13/// use ariadnetor_core::EinsumExpr;
14///
15/// // Matrix multiplication
16/// let expr = EinsumExpr::parse("ij,jk->ik").unwrap();
17/// assert_eq!(expr.num_inputs(), 2);
18/// assert!(expr.is_matrix_multiply());
19/// assert_eq!(expr.infer_output_shape(&[&[10, 20], &[20, 30]]).unwrap(), vec![10, 30]);
20///
21/// // Higher-dimensional contraction (not a plain matmul)
22/// let expr = EinsumExpr::parse("ijk,jkl->il").unwrap();
23/// assert_eq!(expr.out_indices(), &[b'i', b'l']);
24/// assert_eq!(expr.contracted_indices(), vec![b'j', b'k']);
25/// assert!(!expr.is_matrix_multiply());
26///
27/// // Element-wise: every index appears in the output, nothing is contracted
28/// let expr = EinsumExpr::parse("ij,ij->ij").unwrap();
29/// assert!(expr.contracted_indices().is_empty());
30///
31/// // Implicit output inference
32/// let expr = EinsumExpr::parse("ij,jk").unwrap();
33/// assert_eq!(expr.out_indices(), &[b'i', b'k']);
34///
35/// // Single tensor trace
36/// let expr = EinsumExpr::parse("ii->").unwrap();
37/// assert_eq!(expr.num_inputs(), 1);
38///
39/// // Errors: an output index absent from every input, or a non-alphabetic index
40/// assert!(EinsumExpr::parse("ij,jk->im").is_err());
41/// assert!(EinsumExpr::parse("i1,jk->ik").is_err());
42/// ```
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct EinsumExpr {
45    inputs: Vec<Vec<u8>>,
46    out_indices: Vec<u8>,
47}
48
49impl EinsumExpr {
50    /// Parse an einsum expression from string notation.
51    ///
52    /// When `->` is present, output indices are explicit.
53    /// When `->` is omitted, output is inferred as free indices (appearing
54    /// exactly once across all inputs) sorted alphabetically.
55    pub fn parse(notation: &str) -> Result<Self, String> {
56        let notation: String = notation.chars().filter(|c| !c.is_whitespace()).collect();
57
58        let (inputs_str, out_str) = if let Some((inp, out)) = notation.split_once("->") {
59            (inp, Some(out))
60        } else {
61            (notation.as_str(), None)
62        };
63
64        let input_parts: Vec<&str> = inputs_str.split(',').collect();
65        if input_parts.is_empty() {
66            return Err("No input tensors specified".to_string());
67        }
68
69        let inputs: Vec<Vec<u8>> = input_parts
70            .iter()
71            .map(|s| Self::parse_indices(s))
72            .collect::<Result<_, _>>()?;
73
74        let out_indices = if let Some(out) = out_str {
75            Self::parse_indices(out)?
76        } else {
77            Self::infer_output(&inputs)
78        };
79
80        let expr = Self {
81            inputs,
82            out_indices,
83        };
84        expr.validate()?;
85        Ok(expr)
86    }
87
88    /// Infer output indices when `->` is omitted.
89    ///
90    /// Free indices (appearing exactly once across all inputs) sorted alphabetically.
91    fn infer_output(inputs: &[Vec<u8>]) -> Vec<u8> {
92        let mut counts: HashMap<u8, usize> = HashMap::new();
93        for input in inputs {
94            for &idx in input {
95                *counts.entry(idx).or_insert(0) += 1;
96            }
97        }
98        let mut free: Vec<u8> = counts
99            .into_iter()
100            .filter(|&(_, count)| count == 1)
101            .map(|(idx, _)| idx)
102            .collect();
103        free.sort();
104        free
105    }
106
107    fn parse_indices(s: &str) -> Result<Vec<u8>, String> {
108        s.chars()
109            .map(|c| {
110                if c.is_ascii_alphabetic() {
111                    Ok(c as u8)
112                } else {
113                    Err(format!("Invalid index '{}': must be A-Z or a-z", c))
114                }
115            })
116            .collect()
117    }
118
119    /// Validate the einsum expression.
120    ///
121    /// Checks that all output indices appear in at least one input tensor.
122    pub fn validate(&self) -> Result<(), String> {
123        let mut input_indices = HashSet::new();
124        for input in &self.inputs {
125            for &idx in input {
126                input_indices.insert(idx);
127            }
128        }
129
130        for &idx in &self.out_indices {
131            if !input_indices.contains(&idx) {
132                return Err(format!(
133                    "Output index '{}' does not appear in any input tensor",
134                    idx as char
135                ));
136            }
137        }
138
139        Ok(())
140    }
141
142    /// Get all input index lists
143    pub fn inputs(&self) -> &[Vec<u8>] {
144        &self.inputs
145    }
146
147    /// Get output indices
148    pub fn out_indices(&self) -> &[u8] {
149        &self.out_indices
150    }
151
152    /// Number of input tensors
153    pub fn num_inputs(&self) -> usize {
154        self.inputs.len()
155    }
156
157    /// Convenience accessor for the first input's indices.
158    ///
159    /// # Panics
160    ///
161    /// Panics if the expression has no inputs.
162    pub fn lhs_indices(&self) -> &[u8] {
163        &self.inputs[0]
164    }
165
166    /// Convenience accessor for the second input's indices.
167    ///
168    /// # Panics
169    ///
170    /// Panics if the expression has fewer than 2 inputs.
171    pub fn rhs_indices(&self) -> &[u8] {
172        &self.inputs[1]
173    }
174
175    /// Get all unique indices across inputs and output
176    pub fn all_indices(&self) -> HashSet<u8> {
177        let mut indices = HashSet::new();
178        for input in &self.inputs {
179            indices.extend(input);
180        }
181        indices.extend(&self.out_indices);
182        indices
183    }
184
185    /// Get contracted indices (appear in inputs but not in output),
186    /// preserving the order of first appearance across inputs.
187    ///
188    /// # Examples
189    ///
190    /// ```
191    /// # use ariadnetor_core::EinsumExpr;
192    /// let expr = EinsumExpr::parse("ijk,jkl->il").unwrap();
193    /// assert_eq!(expr.contracted_indices(), vec![b'j', b'k']);
194    /// ```
195    pub fn contracted_indices(&self) -> Vec<u8> {
196        let output_set: HashSet<u8> = self.out_indices.iter().copied().collect();
197        let mut contracted = Vec::new();
198        let mut seen = HashSet::new();
199
200        for input in &self.inputs {
201            for &idx in input {
202                if !output_set.contains(&idx) && seen.insert(idx) {
203                    contracted.push(idx);
204                }
205            }
206        }
207
208        contracted
209    }
210
211    /// Check if this is a matrix multiplication pattern:
212    /// 2 inputs, 3 unique indices, each input has 2 indices, output has 2 indices,
213    /// exactly 1 contracted index.
214    ///
215    /// # Examples
216    ///
217    /// ```
218    /// # use ariadnetor_core::EinsumExpr;
219    /// assert!(EinsumExpr::parse("ij,jk->ik").unwrap().is_matrix_multiply());
220    /// assert!(!EinsumExpr::parse("ijk,jkl->il").unwrap().is_matrix_multiply());
221    /// ```
222    pub fn is_matrix_multiply(&self) -> bool {
223        if self.inputs.len() != 2 {
224            return false;
225        }
226
227        let mut all_indices: HashSet<u8> = HashSet::new();
228        for input in &self.inputs {
229            all_indices.extend(input);
230        }
231
232        if all_indices.len() != 3 {
233            return false;
234        }
235
236        if self.inputs[0].len() != 2 || self.inputs[1].len() != 2 {
237            return false;
238        }
239
240        if self.out_indices.len() != 2 {
241            return false;
242        }
243
244        self.contracted_indices().len() == 1
245    }
246
247    /// Infer the output tensor shape from input shapes.
248    ///
249    /// The number of shapes must match `num_inputs()`, and each shape's rank
250    /// must match its corresponding input index count. Shared indices must
251    /// have matching dimensions.
252    ///
253    /// # Examples
254    ///
255    /// ```
256    /// # use ariadnetor_core::EinsumExpr;
257    /// let expr = EinsumExpr::parse("ij,jk->ik").unwrap();
258    /// assert_eq!(expr.infer_output_shape(&[&[10, 20], &[20, 30]]).unwrap(), vec![10, 30]);
259    /// ```
260    pub fn infer_output_shape(&self, shapes: &[&[usize]]) -> Result<Vec<usize>, String> {
261        if shapes.len() != self.inputs.len() {
262            return Err(format!(
263                "Expected {} input shapes, got {}",
264                self.inputs.len(),
265                shapes.len()
266            ));
267        }
268
269        for (i, (input, shape)) in self.inputs.iter().zip(shapes.iter()).enumerate() {
270            if input.len() != shape.len() {
271                return Err(format!(
272                    "Input {} shape rank {} does not match index count {}",
273                    i,
274                    shape.len(),
275                    input.len()
276                ));
277            }
278        }
279
280        // Build index → dimension mapping
281        let mut index_dims: HashMap<u8, usize> = HashMap::new();
282
283        for (input, shape) in self.inputs.iter().zip(shapes.iter()) {
284            for (j, &idx) in input.iter().enumerate() {
285                let dim = shape[j];
286                if let Some(&existing_dim) = index_dims.get(&idx) {
287                    if existing_dim != dim {
288                        return Err(format!(
289                            "Dimension mismatch for index '{}': found {} and {}",
290                            idx as char, existing_dim, dim
291                        ));
292                    }
293                } else {
294                    index_dims.insert(idx, dim);
295                }
296            }
297        }
298
299        let mut output_shape = Vec::new();
300        for &idx in &self.out_indices {
301            let dim = index_dims.get(&idx).ok_or_else(|| {
302                format!("Output index '{}' not found in input tensors", idx as char)
303            })?;
304            output_shape.push(*dim);
305        }
306
307        Ok(output_shape)
308    }
309}
310
311/// Contraction plan identifying batch, contracted, and free indices for a 2-input einsum.
312///
313/// - **batch**: indices in both inputs AND in output (iterated over, not contracted)
314/// - **contracted**: indices in both inputs but NOT in output (summed over)
315/// - **free_lhs**: indices only in lhs and output (not in rhs)
316/// - **free_rhs**: indices only in rhs and output (not in lhs)
317#[derive(Debug, Clone, PartialEq, Eq)]
318pub struct ContractionPlan {
319    /// Indices in both inputs and the output (iterated over, not contracted).
320    pub batch: Vec<u8>,
321    /// Indices in both inputs but not the output (summed over).
322    pub contracted: Vec<u8>,
323    /// Indices only in the lhs and the output.
324    pub free_lhs: Vec<u8>,
325    /// Indices only in the rhs and the output.
326    pub free_rhs: Vec<u8>,
327}
328
329impl ContractionPlan {
330    /// Derive the batch / contracted / free index partition from a parsed einsum expression.
331    pub fn from_expr(expr: &EinsumExpr) -> Self {
332        let lhs = expr.lhs_indices();
333        let rhs = expr.rhs_indices();
334        let out = expr.out_indices();
335
336        let lhs_set: HashSet<u8> = lhs.iter().copied().collect();
337        let rhs_set: HashSet<u8> = rhs.iter().copied().collect();
338        let out_set: HashSet<u8> = out.iter().copied().collect();
339
340        // Contracted: in both lhs and rhs, not in output (preserve LHS order)
341        let contracted: Vec<u8> = lhs
342            .iter()
343            .filter(|idx| rhs_set.contains(idx) && !out_set.contains(idx))
344            .copied()
345            .collect();
346
347        // Batch: in both lhs and rhs AND in output (output order)
348        let batch: Vec<u8> = out
349            .iter()
350            .filter(|idx| lhs_set.contains(idx) && rhs_set.contains(idx))
351            .copied()
352            .collect();
353
354        let batch_set: HashSet<u8> = batch.iter().copied().collect();
355
356        // Free lhs: in output and lhs, not in rhs (excludes batch)
357        let free_lhs: Vec<u8> = out
358            .iter()
359            .filter(|idx| lhs_set.contains(idx) && !batch_set.contains(idx))
360            .copied()
361            .collect();
362
363        // Free rhs: in output and rhs, not in lhs (excludes batch)
364        let free_rhs: Vec<u8> = out
365            .iter()
366            .filter(|idx| rhs_set.contains(idx) && !batch_set.contains(idx))
367            .copied()
368            .collect();
369
370        Self {
371            batch,
372            contracted,
373            free_lhs,
374            free_rhs,
375        }
376    }
377
378    /// Compute LHS permutation to [batch, free_lhs, contracted] order.
379    pub fn lhs_permutation(&self, lhs_indices: &[u8], rhs_indices: &[u8]) -> Option<Vec<usize>> {
380        let contracted_set: HashSet<u8> = self.contracted.iter().copied().collect();
381        let mut target = self.batch.clone();
382        target.extend(&self.free_lhs);
383        for &idx in rhs_indices {
384            if contracted_set.contains(&idx) && !target.contains(&idx) {
385                target.push(idx);
386            }
387        }
388        compute_permutation(lhs_indices, &target)
389    }
390
391    /// Compute RHS permutation to [batch, contracted, free_rhs] order.
392    pub fn rhs_permutation(&self, rhs_indices: &[u8]) -> Option<Vec<usize>> {
393        let contracted_set: HashSet<u8> = self.contracted.iter().copied().collect();
394        let mut target = self.batch.clone();
395        for &idx in rhs_indices {
396            if contracted_set.contains(&idx) && !target.contains(&idx) {
397                target.push(idx);
398            }
399        }
400        target.extend(&self.free_rhs);
401        compute_permutation(rhs_indices, &target)
402    }
403}
404
405/// Compute permutation from current to target order
406pub fn compute_permutation(current: &[u8], target: &[u8]) -> Option<Vec<usize>> {
407    assert_eq!(current.len(), target.len());
408    let perm: Vec<usize> = target
409        .iter()
410        .map(|&idx| {
411            current
412                .iter()
413                .position(|&x| x == idx)
414                .expect("Index not found")
415        })
416        .collect();
417    if perm.iter().enumerate().all(|(i, &p)| i == p) {
418        None
419    } else {
420        Some(perm)
421    }
422}
423
424#[cfg(test)]
425mod tests;