1use std::collections::HashMap;
2
3use crate::parser::{ErrorKind, InnerResult, Token};
4
5use super::{lex, Argument};
6
7#[derive(Debug)]
8pub struct MacroContext<'input> {
9 definitions: HashMap<&'input str, Definition<'input>>,
10}
11
12impl<'input> MacroContext<'input> {
13 pub fn new() -> Self {
14 Self {
15 definitions: HashMap::new(),
16 }
17 }
18
19 pub(crate) fn define(
29 &mut self,
30 name: &'input str,
31 mut parameter_text: &'input str,
32 replacement_text: &'input str,
33 ) -> InnerResult<()> {
34 let last_param_brace_delimited = parameter_text.ends_with('#');
36 if last_param_brace_delimited {
37 parameter_text = ¶meter_text[..parameter_text.len() - 1];
40 };
41
42 let mut parameters = parameter_text.split('#').enumerate();
43
44 let prefix = parameters
45 .next()
46 .expect("split always yields at least one element")
47 .1;
48 let prefix = if prefix.is_empty() {
49 None
50 } else {
51 Some(prefix)
52 };
53
54 let parameters: Vec<_> = parameters
57 .map(|(i, arg)| -> InnerResult<Parameter> {
58 let mut chars = arg.chars();
59 let param_index = chars
60 .next()
61 .and_then(|c| c.is_ascii_digit().then_some(c as u8 - b'0'))
62 .ok_or(ErrorKind::StandaloneHashSign)?;
63 if param_index != i as u8 {
64 return Err(ErrorKind::IncorrectMacroParams(param_index, i as u8));
65 };
66 let suffix = chars.as_str();
67 Ok(if suffix.is_empty() {
68 None
69 } else {
70 Some(suffix)
71 })
72 })
73 .collect::<InnerResult<Vec<_>>>()?;
74
75 let replacement = parse_replacement_text(replacement_text, parameters.len() as u8)?;
76
77 self.definitions.insert(
78 name,
79 Definition::Macro(MacroDef {
80 prefix,
81 last_param_brace_delimited,
82 parameters,
83 replacement,
84 }),
85 );
86 Ok(())
87 }
88
89 pub(crate) fn contains(&self, name: &str) -> bool {
90 self.definitions.contains_key(name)
91 }
92
93 pub(crate) fn assign(&mut self, name: &'input str, alias_for: Token<'input>) {
95 self.definitions.insert(name, Definition::Alias(alias_for));
96 }
97
98 pub(crate) fn insert_command(
101 &mut self,
102 name: &'input str,
103 argument_count: u8,
104 first_arg_default: Option<&'input str>,
105 replacement: &'input str,
106 ) -> InnerResult<()> {
107 let replacement = parse_replacement_text(replacement, argument_count)?;
108 self.definitions.insert(
109 name,
110 Definition::Command(CommandDef {
111 argument_count,
112 first_arg_default,
113 replacement,
114 }),
115 );
116 Ok(())
117 }
118
119 pub(crate) fn try_expand_in(
125 &self,
126 name: &'input str,
127 input_rest: &'input str,
128 storage: &'input bumpalo::Bump,
129 ) -> Option<InnerResult<(&'input str, usize)>> {
130 Some(self.expand_definition_in(self.definitions.get(name)?, input_rest, storage))
131 }
132
133 fn expand_definition_in(
136 &self,
137 definition: &Definition<'input>,
138 mut input_rest: &'input str,
139 storage: &'input bumpalo::Bump,
140 ) -> InnerResult<(&'input str, usize)> {
141 let original_length = input_rest.len();
142 Ok(match definition {
143 Definition::Macro(MacroDef {
144 prefix,
145 parameters,
146 last_param_brace_delimited,
147 replacement,
148 }) => {
149 if let Some(prefix) = prefix {
150 input_rest = input_rest
151 .strip_prefix(prefix)
152 .ok_or(ErrorKind::IncorrectMacroPrefix)?;
153 };
154
155 let mut arguments: Vec<Result<Argument, &str>> =
156 Vec::with_capacity(parameters.len());
157 for (index, param) in parameters.iter().enumerate() {
158 if index == parameters.len() - 1 && *last_param_brace_delimited {
159 if let Some(suffix) = param {
160 let full_suffix = format!("{}{{", suffix);
161 let (before, _) = input_rest
162 .split_once(&full_suffix)
163 .ok_or(ErrorKind::MacroSuffixNotFound)?;
164 arguments.push(Err(before));
165 input_rest = &input_rest[before.len()..];
166 } else {
167 let (before, _) = input_rest
168 .split_once('{')
169 .ok_or(ErrorKind::MacroSuffixNotFound)?;
170 arguments.push(Err(before));
171 input_rest = &input_rest[before.len()..];
172 }
173 break;
174 }
175 match param {
176 None => arguments.push(Ok(lex::argument(&mut input_rest)?)),
177 Some(suffix) => {
178 arguments.push(Err(lex::content_with_suffix(&mut input_rest, suffix)?));
179 }
180 }
181 }
182
183 (
184 expand_replacement(storage, replacement, &arguments, input_rest),
185 original_length - input_rest.len(),
186 )
187 }
188 Definition::Alias(Token::Character(c)) => {
189 let ch = char::from(*c);
190 let mut string = bumpalo::collections::String::with_capacity_in(
191 ch.len_utf8() + input_rest.len(),
192 storage,
193 );
194 string.push(ch);
195 string.push_str(input_rest);
196 (string.into_bump_str(), 0)
197 }
198 Definition::Alias(Token::ControlSequence(cs)) => {
199 let mut string = bumpalo::collections::String::with_capacity_in(
200 cs.len() + input_rest.len() + 1,
201 storage,
202 );
203 string.push('\\');
204 string.push_str(cs);
205 string.push_str(input_rest);
206 (string.into_bump_str(), 0)
207 }
208 Definition::Command(CommandDef {
209 argument_count,
210 first_arg_default,
211 replacement,
212 }) => {
213 let mut arguments = Vec::with_capacity(*argument_count as usize);
214
215 if let Some(default_argument) = first_arg_default {
216 arguments.push(Ok(Argument::Group(
217 lex::optional_argument(&mut input_rest).unwrap_or(default_argument),
218 )));
219 }
220
221 (0..(*argument_count - first_arg_default.is_some() as u8)).try_for_each(|_| {
222 arguments.push(Ok(lex::argument(&mut input_rest)?));
223 Ok(())
224 })?;
225
226 (
227 expand_replacement(storage, replacement, &arguments, input_rest),
228 original_length - input_rest.len(),
229 )
230 }
231 })
232 }
233}
234
235fn parse_replacement_text(
236 replacement_text: &str,
237 parameter_count: u8,
238) -> InnerResult<Vec<ReplacementToken>> {
239 let mut replacement_splits = replacement_text.split_inclusive('#').peekable();
240 let mut replacement_tokens: Vec<ReplacementToken> = Vec::new();
241
242 while let Some(split) = replacement_splits.next() {
243 replacement_tokens.push(ReplacementToken::String(split));
244
245 let next_split = match replacement_splits.peek_mut() {
246 Some(next_split) => next_split,
247 None if split.is_empty() => {
248 replacement_tokens.pop();
249 break;
250 }
251 None if *split
252 .as_bytes()
253 .last()
254 .expect("checked for not none in previous branch")
255 != b'#' =>
256 {
257 break;
258 }
259 None => {
260 return Err(ErrorKind::StandaloneHashSign);
261 }
262 };
263 let first_char = next_split
264 .chars()
265 .next()
266 .expect("split inclusive always yields at least one char per element");
267 if first_char == '#' {
268 replacement_splits.next();
270 } else if first_char.is_ascii_digit() {
271 let param_index = first_char as u8 - b'0';
272 if param_index > parameter_count || param_index == 0 {
273 return Err(ErrorKind::IncorrectReplacementParams(
274 param_index,
275 parameter_count,
276 ));
277 };
278
279 match replacement_tokens
280 .last_mut()
281 .expect("was pushed previously in the loop")
282 {
283 ReplacementToken::String(s) => {
284 if s.len() == 1 {
285 replacement_tokens.pop();
286 } else {
287 *s = &s[..s.len() - 1];
288 }
289 }
290 _ => unreachable!(),
291 }
292
293 replacement_tokens.push(ReplacementToken::Parameter(param_index));
294 *next_split = &next_split[1..];
296 } else {
297 return Err(ErrorKind::StandaloneHashSign);
298 }
299 }
300
301 replacement_tokens.shrink_to_fit();
302 Ok(replacement_tokens)
303}
304
305fn expand_replacement<'store>(
306 storage: &'store bumpalo::Bump,
307 replacement: &[ReplacementToken],
308 arguments: &[Result<Argument, &str>],
310 input_rest: &str,
311) -> &'store str {
312 let mut replacement_string = bumpalo::collections::String::new_in(storage);
313
314 for token in replacement {
315 match token {
316 ReplacementToken::Parameter(idx) => match &arguments[*idx as usize - 1] {
317 Ok(Argument::Token(Token::Character(ch))) => {
318 replacement_string.push(char::from(*ch));
319 }
320 Ok(Argument::Token(Token::ControlSequence(cs))) => {
321 replacement_string.push('\\');
322 replacement_string.push_str(cs);
323 }
324 Ok(Argument::Group(group)) => {
325 replacement_string.push('{');
326 replacement_string.push_str(group);
327 replacement_string.push('}');
328 }
329 Err(str) => {
330 replacement_string.push_str(str);
331 }
332 },
333 ReplacementToken::String(str) => {
334 replacement_string.push_str(str);
335 }
336 }
337 }
338
339 replacement_string.push_str(input_rest);
340 replacement_string.shrink_to_fit();
341
342 replacement_string.into_bump_str()
343}
344
345impl<'input> Default for MacroContext<'input> {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351#[derive(Debug)]
352struct MacroDef<'a> {
353 prefix: Option<&'a str>,
354 parameters: Vec<Parameter<'a>>,
355 last_param_brace_delimited: bool,
356 replacement: Vec<ReplacementToken<'a>>,
357}
358
359#[derive(Debug)]
360struct CommandDef<'a> {
361 argument_count: u8,
362 first_arg_default: Option<&'a str>,
363 replacement: Vec<ReplacementToken<'a>>,
364}
365
366type Parameter<'a> = Option<&'a str>;
368
369#[derive(Debug, Clone, PartialEq, Eq)]
370enum ReplacementToken<'a> {
371 Parameter(u8),
372 String(&'a str),
373}
374
375#[derive(Debug)]
376enum Definition<'a> {
377 Macro(MacroDef<'a>),
378 Alias(Token<'a>),
379 Command(CommandDef<'a>),
380}
381
382#[cfg(test)]
383mod tests {
384 use super::{MacroContext, ReplacementToken};
385
386 #[test]
387 fn no_params() {
388 let mut ctx = MacroContext::new();
389 ctx.define("foo", "", "\\this {} is a ## test")
390 .map_err(|e| eprintln!("{e}"))
391 .unwrap();
392
393 let def = match ctx.definitions.get("foo").unwrap() {
394 super::Definition::Macro(def) => def,
395 _ => unreachable!(),
396 };
397 assert_eq!(def.prefix, None);
398 assert!(def.parameters.is_empty());
399 assert_eq!(
400 &def.replacement
401 .iter()
402 .filter_map(|t| match t {
403 ReplacementToken::String(s) => Some(*s),
404 _ => None,
405 })
406 .collect::<String>(),
407 "\\this {} is a # test"
408 );
409 }
410
411 #[test]
412 fn with_params() {
413 let mut ctx = MacroContext::new();
414 ctx.define("foo", "this#1test#2. should #", "\\this {} is a ## test#1")
415 .map_err(|e| eprintln!("{e}"))
416 .unwrap();
417
418 let def = match ctx.definitions.get("foo").unwrap() {
419 super::Definition::Macro(def) => def,
420 _ => unreachable!(),
421 };
422 assert_eq!(def.prefix, Some("this"));
423 assert_eq!(def.parameters, vec![Some("test"), Some(". should ")]);
424 assert!(def.last_param_brace_delimited);
425 assert_eq!(
426 def.replacement,
427 vec![
428 ReplacementToken::String("\\this {} is a #"),
429 ReplacementToken::String(" test"),
430 ReplacementToken::Parameter(1)
431 ]
432 );
433 }
434
435 #[test]
438 fn texbook() {
439 let mut ctx = MacroContext::new();
440 ctx.define("cs", r"AB#1#2C$#3\$ ", r"#3{ab#1}#1 c##\x #2")
441 .map_err(|e| eprintln!("{e}"))
442 .unwrap();
443
444 let def = match ctx.definitions.get("cs").unwrap() {
445 super::Definition::Macro(def) => def,
446 _ => unreachable!(),
447 };
448 assert_eq!(def.prefix, Some("AB"));
449 assert_eq!(def.parameters, vec![None, Some("C$"), Some(r"\$ ")]);
450 assert_eq!(
451 def.replacement,
452 vec![
453 ReplacementToken::Parameter(3),
454 ReplacementToken::String(r"{ab"),
455 ReplacementToken::Parameter(1),
456 ReplacementToken::String(r"}"),
457 ReplacementToken::Parameter(1),
458 ReplacementToken::String(r" c#"),
459 ReplacementToken::String(r"\x "),
460 ReplacementToken::Parameter(2),
461 ]
462 );
463 }
464
465 #[test]
466 fn brace_delim_no_text() {
467 let mut ctx = MacroContext::new();
468 ctx.define("foo", "#", "2 + 2 = 4")
469 .map_err(|e| eprintln!("{e}"))
470 .unwrap();
471
472 let def = match ctx.definitions.get("foo").unwrap() {
473 super::Definition::Macro(def) => def,
474 _ => unreachable!(),
475 };
476 assert_eq!(def.prefix, None);
477 assert_eq!(def.parameters, vec![]);
478 assert!(def.last_param_brace_delimited);
479 assert_eq!(def.replacement, vec![ReplacementToken::String("2 + 2 = 4")]);
480 }
481}