pulldown_latex/parser/
macros.rs

1use std::collections::HashMap;
2
3use crate::parser::{ErrorKind, InnerResult, Token};
4
5use super::{lex, Argument};
6
7#[derive(Debug)]
8pub struct MacroContext<'input> {
9    definitions: HashMap<&'input str, Definition<'input>>,
10}
11
12impl<'input> MacroContext<'input> {
13    pub fn new() -> Self {
14        Self {
15            definitions: HashMap::new(),
16        }
17    }
18
19    // Problem 20.7 shows a convoluted example
20    //
21    // To check:
22    // - Strip braces of arguments
23
24    /// Define a new macro, from its name, parameter text, and replacement text.
25    ///
26    /// - The replacement text must be properly balanced.
27    /// - The parameter text must not contain '{' or '}'.
28    pub(crate) fn define(
29        &mut self,
30        name: &'input str,
31        mut parameter_text: &'input str,
32        replacement_text: &'input str,
33    ) -> InnerResult<()> {
34        // Check for the '#{' rule of the last parameter (TeXBook p. 204).
35        let last_param_brace_delimited = parameter_text.ends_with('#');
36        if last_param_brace_delimited {
37            // We know the parameter text is at least 1 character long, and the character in
38            // question is ASCII so we are fine slicing.
39            parameter_text = &parameter_text[..parameter_text.len() - 1];
40        };
41
42        let mut parameters = parameter_text.split('#').enumerate();
43
44        let prefix = parameters
45            .next()
46            .expect("split always yields at least one element")
47            .1;
48        let prefix = if prefix.is_empty() {
49            None
50        } else {
51            Some(prefix)
52        };
53
54        // Parse the arguments, making sure that they are in order and that the number of arguments
55        // is less than 10.
56        let parameters: Vec<_> = parameters
57            .map(|(i, arg)| -> InnerResult<Parameter> {
58                let mut chars = arg.chars();
59                let param_index = chars
60                    .next()
61                    .and_then(|c| c.is_ascii_digit().then_some(c as u8 - b'0'))
62                    .ok_or(ErrorKind::StandaloneHashSign)?;
63                if param_index != i as u8 {
64                    return Err(ErrorKind::IncorrectMacroParams(param_index, i as u8));
65                };
66                let suffix = chars.as_str();
67                Ok(if suffix.is_empty() {
68                    None
69                } else {
70                    Some(suffix)
71                })
72            })
73            .collect::<InnerResult<Vec<_>>>()?;
74
75        let replacement = parse_replacement_text(replacement_text, parameters.len() as u8)?;
76
77        self.definitions.insert(
78            name,
79            Definition::Macro(MacroDef {
80                prefix,
81                last_param_brace_delimited,
82                parameters,
83                replacement,
84            }),
85        );
86        Ok(())
87    }
88
89    pub(crate) fn contains(&self, name: &str) -> bool {
90        self.definitions.contains_key(name)
91    }
92
93    /// Assign a new control sequence to a token.
94    pub(crate) fn assign(&mut self, name: &'input str, alias_for: Token<'input>) {
95        self.definitions.insert(name, Definition::Alias(alias_for));
96    }
97
98    /// The argument count must be less than 9 if the optional argument is None, and less than 8 if
99    /// the optional argument is Some.
100    pub(crate) fn insert_command(
101        &mut self,
102        name: &'input str,
103        argument_count: u8,
104        first_arg_default: Option<&'input str>,
105        replacement: &'input str,
106    ) -> InnerResult<()> {
107        let replacement = parse_replacement_text(replacement, argument_count)?;
108        self.definitions.insert(
109            name,
110            Definition::Command(CommandDef {
111                argument_count,
112                first_arg_default,
113                replacement,
114            }),
115        );
116        Ok(())
117    }
118
119    /// If a macro is successfully expanded, the rest of the input must be discarded and the
120    /// returned string, which will contain the rest of the input appended, must be used instead.
121    ///
122    /// Along with the expanded string, the function returns the number of characters consumed
123    /// from the input.
124    pub(crate) fn try_expand_in(
125        &self,
126        name: &'input str,
127        input_rest: &'input str,
128        storage: &'input bumpalo::Bump,
129    ) -> Option<InnerResult<(&'input str, usize)>> {
130        Some(self.expand_definition_in(self.definitions.get(name)?, input_rest, storage))
131    }
132
133    /// Expand a definition in the storage and return the full expanded string along with the
134    /// length of the input that was consumed by the definition.
135    fn expand_definition_in(
136        &self,
137        definition: &Definition<'input>,
138        mut input_rest: &'input str,
139        storage: &'input bumpalo::Bump,
140    ) -> InnerResult<(&'input str, usize)> {
141        let original_length = input_rest.len();
142        Ok(match definition {
143            Definition::Macro(MacroDef {
144                prefix,
145                parameters,
146                last_param_brace_delimited,
147                replacement,
148            }) => {
149                if let Some(prefix) = prefix {
150                    input_rest = input_rest
151                        .strip_prefix(prefix)
152                        .ok_or(ErrorKind::IncorrectMacroPrefix)?;
153                };
154
155                let mut arguments: Vec<Result<Argument, &str>> =
156                    Vec::with_capacity(parameters.len());
157                for (index, param) in parameters.iter().enumerate() {
158                    if index == parameters.len() - 1 && *last_param_brace_delimited {
159                        if let Some(suffix) = param {
160                            let full_suffix = format!("{}{{", suffix);
161                            let (before, _) = input_rest
162                                .split_once(&full_suffix)
163                                .ok_or(ErrorKind::MacroSuffixNotFound)?;
164                            arguments.push(Err(before));
165                            input_rest = &input_rest[before.len()..];
166                        } else {
167                            let (before, _) = input_rest
168                                .split_once('{')
169                                .ok_or(ErrorKind::MacroSuffixNotFound)?;
170                            arguments.push(Err(before));
171                            input_rest = &input_rest[before.len()..];
172                        }
173                        break;
174                    }
175                    match param {
176                        None => arguments.push(Ok(lex::argument(&mut input_rest)?)),
177                        Some(suffix) => {
178                            arguments.push(Err(lex::content_with_suffix(&mut input_rest, suffix)?));
179                        }
180                    }
181                }
182
183                (
184                    expand_replacement(storage, replacement, &arguments, input_rest),
185                    original_length - input_rest.len(),
186                )
187            }
188            Definition::Alias(Token::Character(c)) => {
189                let ch = char::from(*c);
190                let mut string = bumpalo::collections::String::with_capacity_in(
191                    ch.len_utf8() + input_rest.len(),
192                    storage,
193                );
194                string.push(ch);
195                string.push_str(input_rest);
196                (string.into_bump_str(), 0)
197            }
198            Definition::Alias(Token::ControlSequence(cs)) => {
199                let mut string = bumpalo::collections::String::with_capacity_in(
200                    cs.len() + input_rest.len() + 1,
201                    storage,
202                );
203                string.push('\\');
204                string.push_str(cs);
205                string.push_str(input_rest);
206                (string.into_bump_str(), 0)
207            }
208            Definition::Command(CommandDef {
209                argument_count,
210                first_arg_default,
211                replacement,
212            }) => {
213                let mut arguments = Vec::with_capacity(*argument_count as usize);
214
215                if let Some(default_argument) = first_arg_default {
216                    arguments.push(Ok(Argument::Group(
217                        lex::optional_argument(&mut input_rest).unwrap_or(default_argument),
218                    )));
219                }
220
221                (0..(*argument_count - first_arg_default.is_some() as u8)).try_for_each(|_| {
222                    arguments.push(Ok(lex::argument(&mut input_rest)?));
223                    Ok(())
224                })?;
225
226                (
227                    expand_replacement(storage, replacement, &arguments, input_rest),
228                    original_length - input_rest.len(),
229                )
230            }
231        })
232    }
233}
234
235fn parse_replacement_text(
236    replacement_text: &str,
237    parameter_count: u8,
238) -> InnerResult<Vec<ReplacementToken>> {
239    let mut replacement_splits = replacement_text.split_inclusive('#').peekable();
240    let mut replacement_tokens: Vec<ReplacementToken> = Vec::new();
241
242    while let Some(split) = replacement_splits.next() {
243        replacement_tokens.push(ReplacementToken::String(split));
244
245        let next_split = match replacement_splits.peek_mut() {
246            Some(next_split) => next_split,
247            None if split.is_empty() => {
248                replacement_tokens.pop();
249                break;
250            }
251            None if *split
252                .as_bytes()
253                .last()
254                .expect("checked for not none in previous branch")
255                != b'#' =>
256            {
257                break;
258            }
259            None => {
260                return Err(ErrorKind::StandaloneHashSign);
261            }
262        };
263        let first_char = next_split
264            .chars()
265            .next()
266            .expect("split inclusive always yields at least one char per element");
267        if first_char == '#' {
268            // skip the next split since it will contain the second '#'
269            replacement_splits.next();
270        } else if first_char.is_ascii_digit() {
271            let param_index = first_char as u8 - b'0';
272            if param_index > parameter_count || param_index == 0 {
273                return Err(ErrorKind::IncorrectReplacementParams(
274                    param_index,
275                    parameter_count,
276                ));
277            };
278
279            match replacement_tokens
280                .last_mut()
281                .expect("was pushed previously in the loop")
282            {
283                ReplacementToken::String(s) => {
284                    if s.len() == 1 {
285                        replacement_tokens.pop();
286                    } else {
287                        *s = &s[..s.len() - 1];
288                    }
289                }
290                _ => unreachable!(),
291            }
292
293            replacement_tokens.push(ReplacementToken::Parameter(param_index));
294            // Make it so that the next split wont begin with the digit.
295            *next_split = &next_split[1..];
296        } else {
297            return Err(ErrorKind::StandaloneHashSign);
298        }
299    }
300
301    replacement_tokens.shrink_to_fit();
302    Ok(replacement_tokens)
303}
304
305fn expand_replacement<'store>(
306    storage: &'store bumpalo::Bump,
307    replacement: &[ReplacementToken],
308    // If Ok, its a regular argument, if Err, its a raw string to be inserted.
309    arguments: &[Result<Argument, &str>],
310    input_rest: &str,
311) -> &'store str {
312    let mut replacement_string = bumpalo::collections::String::new_in(storage);
313
314    for token in replacement {
315        match token {
316            ReplacementToken::Parameter(idx) => match &arguments[*idx as usize - 1] {
317                Ok(Argument::Token(Token::Character(ch))) => {
318                    replacement_string.push(char::from(*ch));
319                }
320                Ok(Argument::Token(Token::ControlSequence(cs))) => {
321                    replacement_string.push('\\');
322                    replacement_string.push_str(cs);
323                }
324                Ok(Argument::Group(group)) => {
325                    replacement_string.push('{');
326                    replacement_string.push_str(group);
327                    replacement_string.push('}');
328                }
329                Err(str) => {
330                    replacement_string.push_str(str);
331                }
332            },
333            ReplacementToken::String(str) => {
334                replacement_string.push_str(str);
335            }
336        }
337    }
338
339    replacement_string.push_str(input_rest);
340    replacement_string.shrink_to_fit();
341
342    replacement_string.into_bump_str()
343}
344
345impl<'input> Default for MacroContext<'input> {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351#[derive(Debug)]
352struct MacroDef<'a> {
353    prefix: Option<&'a str>,
354    parameters: Vec<Parameter<'a>>,
355    last_param_brace_delimited: bool,
356    replacement: Vec<ReplacementToken<'a>>,
357}
358
359#[derive(Debug)]
360struct CommandDef<'a> {
361    argument_count: u8,
362    first_arg_default: Option<&'a str>,
363    replacement: Vec<ReplacementToken<'a>>,
364}
365
366/// Some if the argument has a suffix, None otherwise.
367type Parameter<'a> = Option<&'a str>;
368
369#[derive(Debug, Clone, PartialEq, Eq)]
370enum ReplacementToken<'a> {
371    Parameter(u8),
372    String(&'a str),
373}
374
375#[derive(Debug)]
376enum Definition<'a> {
377    Macro(MacroDef<'a>),
378    Alias(Token<'a>),
379    Command(CommandDef<'a>),
380}
381
382#[cfg(test)]
383mod tests {
384    use super::{MacroContext, ReplacementToken};
385
386    #[test]
387    fn no_params() {
388        let mut ctx = MacroContext::new();
389        ctx.define("foo", "", "\\this {} is a ## test")
390            .map_err(|e| eprintln!("{e}"))
391            .unwrap();
392
393        let def = match ctx.definitions.get("foo").unwrap() {
394            super::Definition::Macro(def) => def,
395            _ => unreachable!(),
396        };
397        assert_eq!(def.prefix, None);
398        assert!(def.parameters.is_empty());
399        assert_eq!(
400            &def.replacement
401                .iter()
402                .filter_map(|t| match t {
403                    ReplacementToken::String(s) => Some(*s),
404                    _ => None,
405                })
406                .collect::<String>(),
407            "\\this {} is a # test"
408        );
409    }
410
411    #[test]
412    fn with_params() {
413        let mut ctx = MacroContext::new();
414        ctx.define("foo", "this#1test#2. should #", "\\this {} is a ## test#1")
415            .map_err(|e| eprintln!("{e}"))
416            .unwrap();
417
418        let def = match ctx.definitions.get("foo").unwrap() {
419            super::Definition::Macro(def) => def,
420            _ => unreachable!(),
421        };
422        assert_eq!(def.prefix, Some("this"));
423        assert_eq!(def.parameters, vec![Some("test"), Some(". should ")]);
424        assert!(def.last_param_brace_delimited);
425        assert_eq!(
426            def.replacement,
427            vec![
428                ReplacementToken::String("\\this {} is a #"),
429                ReplacementToken::String(" test"),
430                ReplacementToken::Parameter(1)
431            ]
432        );
433    }
434
435    // A complex exanple from p.20.7 in TeXBook:
436    // \def\cs AB#1#2C$#3\$ {#3{ab#1}#1 c##\x #2}
437    #[test]
438    fn texbook() {
439        let mut ctx = MacroContext::new();
440        ctx.define("cs", r"AB#1#2C$#3\$ ", r"#3{ab#1}#1 c##\x #2")
441            .map_err(|e| eprintln!("{e}"))
442            .unwrap();
443
444        let def = match ctx.definitions.get("cs").unwrap() {
445            super::Definition::Macro(def) => def,
446            _ => unreachable!(),
447        };
448        assert_eq!(def.prefix, Some("AB"));
449        assert_eq!(def.parameters, vec![None, Some("C$"), Some(r"\$ ")]);
450        assert_eq!(
451            def.replacement,
452            vec![
453                ReplacementToken::Parameter(3),
454                ReplacementToken::String(r"{ab"),
455                ReplacementToken::Parameter(1),
456                ReplacementToken::String(r"}"),
457                ReplacementToken::Parameter(1),
458                ReplacementToken::String(r" c#"),
459                ReplacementToken::String(r"\x "),
460                ReplacementToken::Parameter(2),
461            ]
462        );
463    }
464
465    #[test]
466    fn brace_delim_no_text() {
467        let mut ctx = MacroContext::new();
468        ctx.define("foo", "#", "2 + 2 = 4")
469            .map_err(|e| eprintln!("{e}"))
470            .unwrap();
471
472        let def = match ctx.definitions.get("foo").unwrap() {
473            super::Definition::Macro(def) => def,
474            _ => unreachable!(),
475        };
476        assert_eq!(def.prefix, None);
477        assert_eq!(def.parameters, vec![]);
478        assert!(def.last_param_brace_delimited);
479        assert_eq!(def.replacement, vec![ReplacementToken::String("2 + 2 = 4")]);
480    }
481}