einsum_codegen/
namespace.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, ToTokens, TokenStreamExt};
3use std::fmt;
4
5/// Names of tensors
6///
7/// As the crate level document explains,
8/// einsum factorization requires to track names of tensors
9/// in addition to subscripts, and this struct manages it.
10/// This works as a simple counter, which counts how many intermediate
11/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier.
12///
13#[derive(Debug, PartialEq, Eq, Clone)]
14pub struct Namespace {
15    last: usize,
16}
17
18impl Namespace {
19    /// Create new namespace
20    pub fn init() -> Self {
21        Namespace { last: 0 }
22    }
23
24    /// Issue new identifier
25    pub fn new_ident(&mut self) -> Position {
26        let pos = Position::Out(self.last);
27        self.last += 1;
28        pos
29    }
30}
31
32/// Which tensor the subscript specifies
33#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
34pub enum Position {
35    /// The tensor which user inputs as N-th argument of einsum
36    Arg(usize),
37    /// The tensor created by einsum in its N-th step
38    Out(usize),
39}
40
41impl fmt::Debug for Position {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Position::Arg(n) => write!(f, "arg{}", n),
45            Position::Out(n) => write!(f, "out{}", n),
46        }
47    }
48}
49
50impl fmt::Display for Position {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        fmt::Debug::fmt(self, f)
53    }
54}
55
56impl ToTokens for Position {
57    fn to_tokens(&self, tokens: &mut TokenStream) {
58        match self {
59            Position::Arg(n) => tokens.append(format_ident!("arg{}", n)),
60            Position::Out(n) => tokens.append(format_ident!("out{}", n)),
61        }
62    }
63}