srcsrv/
ast.rs

1use crate::errors::{EvalError, ParseError};
2use std::result::Result;
3
4use memchr::{memchr, memchr2};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum AstNode<'a> {
8    /// String concatenation of the evaluated child nodes.
9    Sequence(Vec<AstNode<'a>>),
10    /// A literal string.
11    LiteralString(&'a str),
12    /// Substitute with the value of the variable with this name.
13    Variable(&'a str),
14    /// Substitute with the value of the variable whose name is given by the
15    /// value of the variable with this name.
16    FnVar(Box<AstNode<'a>>),
17    /// Substitute with the string but with all slashes replaced by backslashes.
18    FnBackslash(Box<AstNode<'a>>),
19    /// Substitute with the file name extracted from the path.
20    FnFile(Box<AstNode<'a>>),
21}
22
23impl<'a> AstNode<'a> {
24    pub fn parse(s: &'a str) -> Result<AstNode<'a>, ParseError> {
25        if s.is_empty() {
26            return Ok(AstNode::LiteralString(""));
27        }
28        let s = s.as_bytes();
29        let (node, _rest) = Self::parse_all(s, false)?;
30        Ok(node)
31    }
32
33    fn parse_all(s: &'a [u8], nested: bool) -> Result<(AstNode<'a>, &'a [u8]), ParseError> {
34        let (node, rest) = Self::parse_one(s, nested)?;
35        if rest.is_empty() || (nested && rest[0] == b')') {
36            return Ok((node, rest));
37        }
38
39        let mut nodes = vec![node];
40        let mut rest = rest;
41        loop {
42            let (node, r) = Self::parse_one(rest, nested)?;
43            nodes.push(node);
44            rest = r;
45            if rest.is_empty() || (nested && rest[0] == b')') {
46                return Ok((AstNode::Sequence(nodes), rest));
47            }
48        }
49    }
50
51    // s must not be empty
52    fn parse_one(s: &'a [u8], nested: bool) -> Result<(AstNode<'a>, &'a [u8]), ParseError> {
53        if s[0] != b'%' {
54            // We have a literal at the beginning.
55            let literal_end = if nested {
56                memchr2(b'%', b')', s)
57            } else {
58                memchr(b'%', s)
59            };
60            let literal_end = literal_end.unwrap_or(s.len());
61            let (literal, rest) = s.split_at(literal_end);
62            let string = std::str::from_utf8(literal).map_err(|_| ParseError::InvalidUtf8)?;
63            return Ok((AstNode::LiteralString(string), rest));
64        }
65
66        // We start with a %.
67        let s = &s[1..];
68        let second_percent_pos = memchr(b'%', s).ok_or(ParseError::MissingPercent)?;
69        let rest = &s[second_percent_pos + 1..];
70        let var_name =
71            std::str::from_utf8(&s[..second_percent_pos]).map_err(|_| ParseError::InvalidUtf8)?;
72        match var_name.to_ascii_lowercase().as_str() {
73            "fnvar" => {
74                let (node, rest) = Self::try_parse_args(rest, "fnvar")?;
75                Ok((AstNode::FnVar(Box::new(node)), rest))
76            }
77            "fnbksl" => {
78                let (node, rest) = Self::try_parse_args(rest, "fnbksl")?;
79                Ok((AstNode::FnBackslash(Box::new(node)), rest))
80            }
81            "fnfile" => {
82                let (node, rest) = Self::try_parse_args(rest, "fnfile")?;
83                Ok((AstNode::FnFile(Box::new(node)), rest))
84            }
85            _ => Ok((AstNode::Variable(var_name), rest)),
86        }
87    }
88
89    fn try_parse_args(s: &'a [u8], function: &str) -> Result<(AstNode<'a>, &'a [u8]), ParseError> {
90        if s.is_empty() || s[0] != b'(' {
91            return Err(ParseError::MissingOpeningParen(function.to_string()));
92        }
93        let (node, rest) = Self::parse_all(&s[1..], true)?;
94        if rest.is_empty() || rest[0] != b')' {
95            return Err(ParseError::MissingClosingParen(function.to_string()));
96        }
97        Ok((node, &rest[1..]))
98    }
99
100    pub fn eval<F>(&self, f: &mut F) -> Result<String, EvalError>
101    where
102        F: FnMut(&str) -> Result<String, EvalError>,
103    {
104        match self {
105            AstNode::Sequence(nodes) => {
106                let values: Result<Vec<String>, EvalError> =
107                    nodes.iter().map(|node| node.eval(f)).collect();
108                Ok(values?.join(""))
109            }
110            AstNode::LiteralString(s) => Ok(s.to_string()),
111            AstNode::Variable(var_name) => f(var_name),
112            AstNode::FnVar(node) => {
113                let var_name = node.eval(f)?;
114                f(&var_name)
115            }
116            AstNode::FnBackslash(node) => {
117                let val = node.eval(f)?;
118                Ok(val.replace('/', "\\"))
119            }
120            AstNode::FnFile(node) => {
121                let val = node.eval(f)?;
122                match val.rsplit_once('\\') {
123                    Some((_base, file)) => Ok(file.to_string()),
124                    None => Ok(val),
125                }
126            }
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use crate::{AstNode, ParseError};
134
135    #[test]
136    fn basic_parsing() -> Result<(), ParseError> {
137        assert_eq!(AstNode::parse("hello")?, AstNode::LiteralString("hello"));
138        assert_eq!(
139            AstNode::parse("hello%world%")?,
140            AstNode::Sequence(vec![
141                AstNode::LiteralString("hello"),
142                AstNode::Variable("world")
143            ])
144        );
145        assert_eq!(
146            AstNode::parse("%hello%world")?,
147            AstNode::Sequence(vec![
148                AstNode::Variable("hello"),
149                AstNode::LiteralString("world")
150            ])
151        );
152        assert_eq!(
153            AstNode::parse("%fnfile%(world)")?,
154            AstNode::FnFile(Box::new(AstNode::LiteralString("world")))
155        );
156        Ok(())
157    }
158}