ariadnetor_core/einsum/mod.rs
1//! Einstein summation notation parser and contraction plan
2
3use std::collections::{HashMap, HashSet};
4
5/// Parsed einsum expression with N inputs (indices as ASCII codes)
6///
7/// Supports 1 to N input tensors. Output indices can be explicit (`->out`)
8/// or implicitly inferred (free indices sorted alphabetically).
9///
10/// # Examples
11///
12/// ```
13/// use ariadnetor_core::EinsumExpr;
14///
15/// // Matrix multiplication
16/// let expr = EinsumExpr::parse("ij,jk->ik").unwrap();
17/// assert_eq!(expr.num_inputs(), 2);
18/// assert!(expr.is_matrix_multiply());
19/// assert_eq!(expr.infer_output_shape(&[&[10, 20], &[20, 30]]).unwrap(), vec![10, 30]);
20///
21/// // Higher-dimensional contraction (not a plain matmul)
22/// let expr = EinsumExpr::parse("ijk,jkl->il").unwrap();
23/// assert_eq!(expr.out_indices(), &[b'i', b'l']);
24/// assert_eq!(expr.contracted_indices(), vec![b'j', b'k']);
25/// assert!(!expr.is_matrix_multiply());
26///
27/// // Element-wise: every index appears in the output, nothing is contracted
28/// let expr = EinsumExpr::parse("ij,ij->ij").unwrap();
29/// assert!(expr.contracted_indices().is_empty());
30///
31/// // Implicit output inference
32/// let expr = EinsumExpr::parse("ij,jk").unwrap();
33/// assert_eq!(expr.out_indices(), &[b'i', b'k']);
34///
35/// // Single tensor trace
36/// let expr = EinsumExpr::parse("ii->").unwrap();
37/// assert_eq!(expr.num_inputs(), 1);
38///
39/// // Errors: an output index absent from every input, or a non-alphabetic index
40/// assert!(EinsumExpr::parse("ij,jk->im").is_err());
41/// assert!(EinsumExpr::parse("i1,jk->ik").is_err());
42/// ```
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct EinsumExpr {
45 inputs: Vec<Vec<u8>>,
46 out_indices: Vec<u8>,
47}
48
49impl EinsumExpr {
50 /// Parse an einsum expression from string notation.
51 ///
52 /// When `->` is present, output indices are explicit.
53 /// When `->` is omitted, output is inferred as free indices (appearing
54 /// exactly once across all inputs) sorted alphabetically.
55 pub fn parse(notation: &str) -> Result<Self, String> {
56 let notation: String = notation.chars().filter(|c| !c.is_whitespace()).collect();
57
58 let (inputs_str, out_str) = if let Some((inp, out)) = notation.split_once("->") {
59 (inp, Some(out))
60 } else {
61 (notation.as_str(), None)
62 };
63
64 let input_parts: Vec<&str> = inputs_str.split(',').collect();
65 if input_parts.is_empty() {
66 return Err("No input tensors specified".to_string());
67 }
68
69 let inputs: Vec<Vec<u8>> = input_parts
70 .iter()
71 .map(|s| Self::parse_indices(s))
72 .collect::<Result<_, _>>()?;
73
74 let out_indices = if let Some(out) = out_str {
75 Self::parse_indices(out)?
76 } else {
77 Self::infer_output(&inputs)
78 };
79
80 let expr = Self {
81 inputs,
82 out_indices,
83 };
84 expr.validate()?;
85 Ok(expr)
86 }
87
88 /// Infer output indices when `->` is omitted.
89 ///
90 /// Free indices (appearing exactly once across all inputs) sorted alphabetically.
91 fn infer_output(inputs: &[Vec<u8>]) -> Vec<u8> {
92 let mut counts: HashMap<u8, usize> = HashMap::new();
93 for input in inputs {
94 for &idx in input {
95 *counts.entry(idx).or_insert(0) += 1;
96 }
97 }
98 let mut free: Vec<u8> = counts
99 .into_iter()
100 .filter(|&(_, count)| count == 1)
101 .map(|(idx, _)| idx)
102 .collect();
103 free.sort();
104 free
105 }
106
107 fn parse_indices(s: &str) -> Result<Vec<u8>, String> {
108 s.chars()
109 .map(|c| {
110 if c.is_ascii_alphabetic() {
111 Ok(c as u8)
112 } else {
113 Err(format!("Invalid index '{}': must be A-Z or a-z", c))
114 }
115 })
116 .collect()
117 }
118
119 /// Validate the einsum expression.
120 ///
121 /// Checks that all output indices appear in at least one input tensor.
122 pub fn validate(&self) -> Result<(), String> {
123 let mut input_indices = HashSet::new();
124 for input in &self.inputs {
125 for &idx in input {
126 input_indices.insert(idx);
127 }
128 }
129
130 for &idx in &self.out_indices {
131 if !input_indices.contains(&idx) {
132 return Err(format!(
133 "Output index '{}' does not appear in any input tensor",
134 idx as char
135 ));
136 }
137 }
138
139 Ok(())
140 }
141
142 /// Get all input index lists
143 pub fn inputs(&self) -> &[Vec<u8>] {
144 &self.inputs
145 }
146
147 /// Get output indices
148 pub fn out_indices(&self) -> &[u8] {
149 &self.out_indices
150 }
151
152 /// Number of input tensors
153 pub fn num_inputs(&self) -> usize {
154 self.inputs.len()
155 }
156
157 /// Convenience accessor for the first input's indices.
158 ///
159 /// # Panics
160 ///
161 /// Panics if the expression has no inputs.
162 pub fn lhs_indices(&self) -> &[u8] {
163 &self.inputs[0]
164 }
165
166 /// Convenience accessor for the second input's indices.
167 ///
168 /// # Panics
169 ///
170 /// Panics if the expression has fewer than 2 inputs.
171 pub fn rhs_indices(&self) -> &[u8] {
172 &self.inputs[1]
173 }
174
175 /// Get all unique indices across inputs and output
176 pub fn all_indices(&self) -> HashSet<u8> {
177 let mut indices = HashSet::new();
178 for input in &self.inputs {
179 indices.extend(input);
180 }
181 indices.extend(&self.out_indices);
182 indices
183 }
184
185 /// Get contracted indices (appear in inputs but not in output),
186 /// preserving the order of first appearance across inputs.
187 ///
188 /// # Examples
189 ///
190 /// ```
191 /// # use ariadnetor_core::EinsumExpr;
192 /// let expr = EinsumExpr::parse("ijk,jkl->il").unwrap();
193 /// assert_eq!(expr.contracted_indices(), vec![b'j', b'k']);
194 /// ```
195 pub fn contracted_indices(&self) -> Vec<u8> {
196 let output_set: HashSet<u8> = self.out_indices.iter().copied().collect();
197 let mut contracted = Vec::new();
198 let mut seen = HashSet::new();
199
200 for input in &self.inputs {
201 for &idx in input {
202 if !output_set.contains(&idx) && seen.insert(idx) {
203 contracted.push(idx);
204 }
205 }
206 }
207
208 contracted
209 }
210
211 /// Check if this is a matrix multiplication pattern:
212 /// 2 inputs, 3 unique indices, each input has 2 indices, output has 2 indices,
213 /// exactly 1 contracted index.
214 ///
215 /// # Examples
216 ///
217 /// ```
218 /// # use ariadnetor_core::EinsumExpr;
219 /// assert!(EinsumExpr::parse("ij,jk->ik").unwrap().is_matrix_multiply());
220 /// assert!(!EinsumExpr::parse("ijk,jkl->il").unwrap().is_matrix_multiply());
221 /// ```
222 pub fn is_matrix_multiply(&self) -> bool {
223 if self.inputs.len() != 2 {
224 return false;
225 }
226
227 let mut all_indices: HashSet<u8> = HashSet::new();
228 for input in &self.inputs {
229 all_indices.extend(input);
230 }
231
232 if all_indices.len() != 3 {
233 return false;
234 }
235
236 if self.inputs[0].len() != 2 || self.inputs[1].len() != 2 {
237 return false;
238 }
239
240 if self.out_indices.len() != 2 {
241 return false;
242 }
243
244 self.contracted_indices().len() == 1
245 }
246
247 /// Infer the output tensor shape from input shapes.
248 ///
249 /// The number of shapes must match `num_inputs()`, and each shape's rank
250 /// must match its corresponding input index count. Shared indices must
251 /// have matching dimensions.
252 ///
253 /// # Examples
254 ///
255 /// ```
256 /// # use ariadnetor_core::EinsumExpr;
257 /// let expr = EinsumExpr::parse("ij,jk->ik").unwrap();
258 /// assert_eq!(expr.infer_output_shape(&[&[10, 20], &[20, 30]]).unwrap(), vec![10, 30]);
259 /// ```
260 pub fn infer_output_shape(&self, shapes: &[&[usize]]) -> Result<Vec<usize>, String> {
261 if shapes.len() != self.inputs.len() {
262 return Err(format!(
263 "Expected {} input shapes, got {}",
264 self.inputs.len(),
265 shapes.len()
266 ));
267 }
268
269 for (i, (input, shape)) in self.inputs.iter().zip(shapes.iter()).enumerate() {
270 if input.len() != shape.len() {
271 return Err(format!(
272 "Input {} shape rank {} does not match index count {}",
273 i,
274 shape.len(),
275 input.len()
276 ));
277 }
278 }
279
280 // Build index → dimension mapping
281 let mut index_dims: HashMap<u8, usize> = HashMap::new();
282
283 for (input, shape) in self.inputs.iter().zip(shapes.iter()) {
284 for (j, &idx) in input.iter().enumerate() {
285 let dim = shape[j];
286 if let Some(&existing_dim) = index_dims.get(&idx) {
287 if existing_dim != dim {
288 return Err(format!(
289 "Dimension mismatch for index '{}': found {} and {}",
290 idx as char, existing_dim, dim
291 ));
292 }
293 } else {
294 index_dims.insert(idx, dim);
295 }
296 }
297 }
298
299 let mut output_shape = Vec::new();
300 for &idx in &self.out_indices {
301 let dim = index_dims.get(&idx).ok_or_else(|| {
302 format!("Output index '{}' not found in input tensors", idx as char)
303 })?;
304 output_shape.push(*dim);
305 }
306
307 Ok(output_shape)
308 }
309}
310
311/// Contraction plan identifying batch, contracted, and free indices for a 2-input einsum.
312///
313/// - **batch**: indices in both inputs AND in output (iterated over, not contracted)
314/// - **contracted**: indices in both inputs but NOT in output (summed over)
315/// - **free_lhs**: indices only in lhs and output (not in rhs)
316/// - **free_rhs**: indices only in rhs and output (not in lhs)
317#[derive(Debug, Clone, PartialEq, Eq)]
318pub struct ContractionPlan {
319 /// Indices in both inputs and the output (iterated over, not contracted).
320 pub batch: Vec<u8>,
321 /// Indices in both inputs but not the output (summed over).
322 pub contracted: Vec<u8>,
323 /// Indices only in the lhs and the output.
324 pub free_lhs: Vec<u8>,
325 /// Indices only in the rhs and the output.
326 pub free_rhs: Vec<u8>,
327}
328
329impl ContractionPlan {
330 /// Derive the batch / contracted / free index partition from a parsed einsum expression.
331 pub fn from_expr(expr: &EinsumExpr) -> Self {
332 let lhs = expr.lhs_indices();
333 let rhs = expr.rhs_indices();
334 let out = expr.out_indices();
335
336 let lhs_set: HashSet<u8> = lhs.iter().copied().collect();
337 let rhs_set: HashSet<u8> = rhs.iter().copied().collect();
338 let out_set: HashSet<u8> = out.iter().copied().collect();
339
340 // Contracted: in both lhs and rhs, not in output (preserve LHS order)
341 let contracted: Vec<u8> = lhs
342 .iter()
343 .filter(|idx| rhs_set.contains(idx) && !out_set.contains(idx))
344 .copied()
345 .collect();
346
347 // Batch: in both lhs and rhs AND in output (output order)
348 let batch: Vec<u8> = out
349 .iter()
350 .filter(|idx| lhs_set.contains(idx) && rhs_set.contains(idx))
351 .copied()
352 .collect();
353
354 let batch_set: HashSet<u8> = batch.iter().copied().collect();
355
356 // Free lhs: in output and lhs, not in rhs (excludes batch)
357 let free_lhs: Vec<u8> = out
358 .iter()
359 .filter(|idx| lhs_set.contains(idx) && !batch_set.contains(idx))
360 .copied()
361 .collect();
362
363 // Free rhs: in output and rhs, not in lhs (excludes batch)
364 let free_rhs: Vec<u8> = out
365 .iter()
366 .filter(|idx| rhs_set.contains(idx) && !batch_set.contains(idx))
367 .copied()
368 .collect();
369
370 Self {
371 batch,
372 contracted,
373 free_lhs,
374 free_rhs,
375 }
376 }
377
378 /// Compute LHS permutation to [batch, free_lhs, contracted] order.
379 pub fn lhs_permutation(&self, lhs_indices: &[u8], rhs_indices: &[u8]) -> Option<Vec<usize>> {
380 let contracted_set: HashSet<u8> = self.contracted.iter().copied().collect();
381 let mut target = self.batch.clone();
382 target.extend(&self.free_lhs);
383 for &idx in rhs_indices {
384 if contracted_set.contains(&idx) && !target.contains(&idx) {
385 target.push(idx);
386 }
387 }
388 compute_permutation(lhs_indices, &target)
389 }
390
391 /// Compute RHS permutation to [batch, contracted, free_rhs] order.
392 pub fn rhs_permutation(&self, rhs_indices: &[u8]) -> Option<Vec<usize>> {
393 let contracted_set: HashSet<u8> = self.contracted.iter().copied().collect();
394 let mut target = self.batch.clone();
395 for &idx in rhs_indices {
396 if contracted_set.contains(&idx) && !target.contains(&idx) {
397 target.push(idx);
398 }
399 }
400 target.extend(&self.free_rhs);
401 compute_permutation(rhs_indices, &target)
402 }
403}
404
405/// Compute permutation from current to target order
406pub fn compute_permutation(current: &[u8], target: &[u8]) -> Option<Vec<usize>> {
407 assert_eq!(current.len(), target.len());
408 let perm: Vec<usize> = target
409 .iter()
410 .map(|&idx| {
411 current
412 .iter()
413 .position(|&x| x == idx)
414 .expect("Index not found")
415 })
416 .collect();
417 if perm.iter().enumerate().all(|(i, &p)| i == p) {
418 None
419 } else {
420 Some(perm)
421 }
422}
423
424#[cfg(test)]
425mod tests;