einsum_codegen/
parser.rs

1//! Parse einsum subscripts
2//!
3//! These parsers are implemented using [nom](https://github.com/Geal/nom),
4//! and corresponding EBNF-like schema are written in each document page.
5//!
6
7use anyhow::{bail, Error, Result};
8use nom::{
9    bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*, IResult,
10    Parser,
11};
12use std::fmt;
13
14/// index = `a` | `b` | `c` | `d` | `e` | `f` | `g` | `h` | `i` | `j` | `k` | `l` |`m` | `n` | `o` | `p` | `q` | `r` | `s` | `t` | `u` | `v` | `w` | `x` |`y` | `z`;
15pub fn index(input: &str) -> IResult<&str, char> {
16    satisfy(|c| c.is_ascii_lowercase()).parse(input)
17}
18
19/// ellipsis = `...`
20pub fn ellipsis(input: &str) -> IResult<&str, &str> {
21    tag("...").parse(input)
22}
23
24/// subscript = { [index] } [ [ellipsis] { [index] } ];
25pub fn subscript(input: &str) -> IResult<&str, RawSubscript> {
26    let mut indices = many0(tuple((multispace0, index)).map(|(_space, c)| c));
27    let (input, start) = indices(input)?;
28    let (input, end) = opt(tuple((multispace0, ellipsis, multispace0, indices))
29        .map(|(_space_pre, _ellipsis, _space_post, output)| output))(input)?;
30    if let Some(end) = end {
31        Ok((input, RawSubscript::Ellipsis { start, end }))
32    } else {
33        Ok((input, RawSubscript::Indices(start)))
34    }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum RawSubscript {
39    /// Indices without ellipsis, e.g. `ijk`
40    Indices(Vec<char>),
41    /// Indices with ellipsis, e.g. `i...j`
42    Ellipsis { start: Vec<char>, end: Vec<char> },
43}
44
45impl<const N: usize> PartialEq<[char; N]> for RawSubscript {
46    fn eq(&self, other: &[char; N]) -> bool {
47        match self {
48            RawSubscript::Indices(indices) => indices.eq(other),
49            _ => false,
50        }
51    }
52}
53
54impl fmt::Display for RawSubscript {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            RawSubscript::Indices(indices) => {
58                for i in indices {
59                    write!(f, "{}", i)?;
60                }
61            }
62            RawSubscript::Ellipsis { start, end } => {
63                for i in start {
64                    write!(f, "{}", i)?;
65                }
66                write!(f, "___")?;
67                for i in end {
68                    write!(f, "{}", i)?;
69                }
70            }
71        }
72        Ok(())
73    }
74}
75
76/// Einsum subscripts, e.g. `ij,jk->ik`
77#[derive(Debug, PartialEq, Eq)]
78pub struct RawSubscripts {
79    /// Input subscript, `ij` and `jk`
80    pub inputs: Vec<RawSubscript>,
81    /// Output subscript. This may be empty for "implicit mode".
82    pub output: Option<RawSubscript>,
83}
84
85impl std::str::FromStr for RawSubscripts {
86    type Err = Error;
87    fn from_str(input: &str) -> Result<Self> {
88        use nom::Finish;
89        if let Ok((_, ss)) = subscripts(input).finish() {
90            Ok(ss)
91        } else {
92            bail!("Invalid subscripts: {}", input);
93        }
94    }
95}
96
97/// subscripts = [subscript] {`,` [subscript]} \[ `->` [subscript] \]
98pub fn subscripts(input: &str) -> IResult<&str, RawSubscripts> {
99    let (input, _head) = multispace0(input)?;
100    let (input, inputs) = separated_list1(tuple((multispace0, char(','))), subscript)(input)?;
101    let (input, output) = opt(tuple((multispace0, tag("->"), multispace0, subscript))
102        .map(|(_space_pre, _arrow, _space_post, output)| output))(input)?;
103    Ok((input, RawSubscripts { inputs, output }))
104}
105
106#[cfg(test)]
107mod tests {
108
109    use super::*;
110    use nom::Finish;
111
112    #[test]
113    fn test_subscript() {
114        let (res, out) = subscript("ijk").finish().unwrap();
115        assert_eq!(out, RawSubscript::Indices(vec!['i', 'j', 'k']));
116        assert_eq!(res, "");
117
118        let (res, out) = subscript("...").finish().unwrap();
119        assert_eq!(
120            out,
121            RawSubscript::Ellipsis {
122                start: Vec::new(),
123                end: Vec::new()
124            }
125        );
126        assert_eq!(res, "");
127
128        let (res, out) = subscript("i...").finish().unwrap();
129        assert_eq!(
130            out,
131            RawSubscript::Ellipsis {
132                start: vec!['i'],
133                end: Vec::new()
134            }
135        );
136        assert_eq!(res, "");
137
138        let (res, out) = subscript("...j").finish().unwrap();
139        assert_eq!(
140            out,
141            RawSubscript::Ellipsis {
142                start: Vec::new(),
143                end: vec!['j'],
144            }
145        );
146        assert_eq!(res, "");
147
148        let (res, out) = subscript("i...j").finish().unwrap();
149        assert_eq!(
150            out,
151            RawSubscript::Ellipsis {
152                start: vec!['i'],
153                end: vec!['j'],
154            }
155        );
156        assert_eq!(res, "");
157    }
158
159    #[test]
160    fn test_operator() {
161        fn test(input: &str) {
162            dbg!(input);
163            let (_, op) = subscripts(input).finish().unwrap();
164            assert_eq!(
165                op,
166                RawSubscripts {
167                    inputs: vec![
168                        RawSubscript::Indices(vec!['i', 'j']),
169                        RawSubscript::Indices(vec!['j', 'k'])
170                    ],
171                    output: Some(RawSubscript::Indices(vec!['i', 'k'])),
172                }
173            );
174        }
175        test("ij,jk->ik");
176
177        // with space
178        test(" ij,jk->ik");
179        test("i j,jk->ik");
180        test("ij ,jk->ik");
181        test("ij, jk->ik");
182        test("ij,j k->ik");
183        test("ij,jk ->ik");
184        test("ij,jk-> ik");
185        test("ij,jk->i k");
186
187        // implicit mode
188        let (_, op) = subscripts("ij,jk").finish().unwrap();
189        assert_eq!(
190            op,
191            RawSubscripts {
192                inputs: vec![
193                    RawSubscript::Indices(vec!['i', 'j']),
194                    RawSubscript::Indices(vec!['j', 'k'])
195                ],
196                output: None,
197            }
198        );
199
200        // with ...
201        let (_, op) = subscripts("i...,i...->...").finish().unwrap();
202        assert_eq!(
203            op,
204            RawSubscripts {
205                inputs: vec![
206                    RawSubscript::Ellipsis {
207                        start: vec!['i'],
208                        end: Vec::new()
209                    },
210                    RawSubscript::Ellipsis {
211                        start: vec!['i'],
212                        end: Vec::new()
213                    }
214                ],
215                output: Some(RawSubscript::Ellipsis {
216                    start: Vec::new(),
217                    end: Vec::new()
218                })
219            }
220        );
221    }
222}