Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use crate::error::{ParseError, ParseResult};
4use crate::parse_node::{AlignSpec, AlignType, Measurement, Mode, ParseNode, StyleStr};
5use crate::parser::Parser;
6
7// ── Environment registry ─────────────────────────────────────────────────
8
9pub struct EnvContext<'a, 'b> {
10    pub mode: Mode,
11    pub env_name: String,
12    pub parser: &'a mut Parser<'b>,
13}
14
15pub type EnvHandler = fn(
16    ctx: &mut EnvContext,
17    args: Vec<ParseNode>,
18    opt_args: Vec<Option<ParseNode>>,
19) -> ParseResult<ParseNode>;
20
21pub struct EnvSpec {
22    pub num_args: usize,
23    pub num_optional_args: usize,
24    pub handler: EnvHandler,
25}
26
27pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
28    std::sync::LazyLock::new(|| {
29        let mut map = HashMap::new();
30        register_array(&mut map);
31        register_matrix(&mut map);
32        register_cases(&mut map);
33        register_align(&mut map);
34        register_gathered(&mut map);
35        register_equation(&mut map);
36        register_smallmatrix(&mut map);
37        register_alignat(&mut map);
38        register_subarray(&mut map);
39        map
40    });
41
42// ── ArrayConfig ──────────────────────────────────────────────────────────
43
44#[derive(Default)]
45pub struct ArrayConfig {
46    pub hskip_before_and_after: Option<bool>,
47    pub add_jot: Option<bool>,
48    pub cols: Option<Vec<AlignSpec>>,
49    pub arraystretch: Option<f64>,
50    pub col_separation_type: Option<String>,
51    pub single_row: bool,
52    pub empty_single_row: bool,
53    pub max_num_cols: Option<usize>,
54    pub leqno: Option<bool>,
55}
56
57
58// ── parseArray ───────────────────────────────────────────────────────────
59
60fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
61    let mut hline_info = Vec::new();
62    parser.consume_spaces()?;
63
64    let mut nxt = parser.fetch()?.text.clone();
65    if nxt == "\\relax" {
66        parser.consume();
67        parser.consume_spaces()?;
68        nxt = parser.fetch()?.text.clone();
69    }
70    while nxt == "\\hline" || nxt == "\\hdashline" {
71        parser.consume();
72        hline_info.push(nxt == "\\hdashline");
73        parser.consume_spaces()?;
74        nxt = parser.fetch()?.text.clone();
75    }
76    Ok(hline_info)
77}
78
79fn d_cell_style(env_name: &str) -> Option<StyleStr> {
80    if env_name.starts_with('d') {
81        Some(StyleStr::Display)
82    } else {
83        Some(StyleStr::Text)
84    }
85}
86
87pub fn parse_array(
88    parser: &mut Parser,
89    config: ArrayConfig,
90    style: Option<StyleStr>,
91) -> ParseResult<ParseNode> {
92    parser.gullet.begin_group();
93
94    if !config.single_row {
95        parser
96            .gullet
97            .set_text_macro("\\cr", "\\\\\\relax");
98    }
99
100    let arraystretch = config.arraystretch.unwrap_or(1.0);
101
102    parser.gullet.begin_group();
103
104    let mut row: Vec<ParseNode> = Vec::new();
105    let mut body: Vec<Vec<ParseNode>> = Vec::new();
106    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
107    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
108
109    hlines_before_row.push(get_hlines(parser)?);
110
111    loop {
112        let break_token = if config.single_row { "\\end" } else { "\\\\" };
113        let cell_body = parser.parse_expression(false, Some(break_token))?;
114        parser.gullet.end_group();
115        parser.gullet.begin_group();
116
117        let mut cell = ParseNode::OrdGroup {
118            mode: parser.mode,
119            body: cell_body,
120            semisimple: None,
121            loc: None,
122        };
123
124        if let Some(s) = style {
125            cell = ParseNode::Styling {
126                mode: parser.mode,
127                style: s,
128                body: vec![cell],
129                loc: None,
130            };
131        }
132
133        row.push(cell.clone());
134        let next = parser.fetch()?.text.clone();
135
136        if next == "&" {
137            if let Some(max) = config.max_num_cols {
138                if row.len() >= max {
139                    return Err(ParseError::msg("Too many tab characters: &"));
140                }
141            }
142            parser.consume();
143        } else if next == "\\end" {
144            // Check for trailing empty row and remove it
145            let is_empty_trailing = if let Some(s) = style {
146                if s == StyleStr::Text || s == StyleStr::Display {
147                    if let ParseNode::Styling { body: ref sb, .. } = cell {
148                        if let Some(ParseNode::OrdGroup {
149                            body: ref ob, ..
150                        }) = sb.first()
151                        {
152                            ob.is_empty()
153                        } else {
154                            false
155                        }
156                    } else {
157                        false
158                    }
159                } else {
160                    false
161                }
162            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
163                ob.is_empty()
164            } else {
165                false
166            };
167
168            body.push(row);
169
170            if is_empty_trailing
171                && (body.len() > 1 || !config.empty_single_row)
172            {
173                body.pop();
174            }
175
176            if hlines_before_row.len() < body.len() + 1 {
177                hlines_before_row.push(vec![]);
178            }
179            break;
180        } else if next == "\\\\" {
181            parser.consume();
182            let size = if parser.gullet.future().text != " " {
183                parser.parse_size_group(true)?
184            } else {
185                None
186            };
187            let gap = size.and_then(|s| {
188                if let ParseNode::Size { value, .. } = s {
189                    Some(value)
190                } else {
191                    None
192                }
193            });
194            row_gaps.push(gap);
195
196            body.push(row);
197            hlines_before_row.push(get_hlines(parser)?);
198            row = Vec::new();
199        } else {
200            return Err(ParseError::msg(format!(
201                "Expected & or \\\\ or \\cr or \\end, got '{}'",
202                next
203            )));
204        }
205    }
206
207    parser.gullet.end_group();
208    parser.gullet.end_group();
209
210    Ok(ParseNode::Array {
211        mode: parser.mode,
212        body,
213        row_gaps,
214        hlines_before_row,
215        cols: config.cols,
216        col_separation_type: config.col_separation_type,
217        hskip_before_and_after: config.hskip_before_and_after,
218        add_jot: config.add_jot,
219        arraystretch,
220        tags: None,
221        leqno: config.leqno,
222        is_cd: None,
223        loc: None,
224    })
225}
226
227// ── array / darray ───────────────────────────────────────────────────────
228
229fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
230    fn handle_array(
231        ctx: &mut EnvContext,
232        args: Vec<ParseNode>,
233        _opt_args: Vec<Option<ParseNode>>,
234    ) -> ParseResult<ParseNode> {
235        let colalign = match &args[0] {
236            ParseNode::OrdGroup { body, .. } => body.clone(),
237            other if other.is_symbol_node() => vec![other.clone()],
238            _ => return Err(ParseError::msg("Invalid column alignment for array")),
239        };
240
241        let mut cols = Vec::new();
242        for nde in &colalign {
243            let ca = nde
244                .symbol_text()
245                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
246            match ca {
247                "l" | "c" | "r" => cols.push(AlignSpec {
248                    align_type: AlignType::Align,
249                    align: Some(ca.to_string()),
250                    pregap: None,
251                    postgap: None,
252                }),
253                "|" => cols.push(AlignSpec {
254                    align_type: AlignType::Separator,
255                    align: Some("|".to_string()),
256                    pregap: None,
257                    postgap: None,
258                }),
259                ":" => cols.push(AlignSpec {
260                    align_type: AlignType::Separator,
261                    align: Some(":".to_string()),
262                    pregap: None,
263                    postgap: None,
264                }),
265                _ => {
266                    return Err(ParseError::msg(format!(
267                        "Unknown column alignment: {}",
268                        ca
269                    )))
270                }
271            }
272        }
273
274        let max_num_cols = cols.len();
275        let config = ArrayConfig {
276            cols: Some(cols),
277            hskip_before_and_after: Some(true),
278            max_num_cols: Some(max_num_cols),
279            ..Default::default()
280        };
281        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
282    }
283
284    for name in &["array", "darray"] {
285        map.insert(
286            name,
287            EnvSpec {
288                num_args: 1,
289                num_optional_args: 0,
290                handler: handle_array,
291            },
292        );
293    }
294}
295
296// ── matrix variants ──────────────────────────────────────────────────────
297
298fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
299    fn handle_matrix(
300        ctx: &mut EnvContext,
301        _args: Vec<ParseNode>,
302        _opt_args: Vec<Option<ParseNode>>,
303    ) -> ParseResult<ParseNode> {
304        let base_name = ctx.env_name.replace('*', "");
305        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
306            "matrix" => None,
307            "pmatrix" => Some(("(", ")")),
308            "bmatrix" => Some(("[", "]")),
309            "Bmatrix" => Some(("\\{", "\\}")),
310            "vmatrix" => Some(("|", "|")),
311            "Vmatrix" => Some(("\\Vert", "\\Vert")),
312            _ => None,
313        };
314
315        let mut col_align = "c".to_string();
316
317        // mathtools starred matrix: parse optional [l|c|r] alignment
318        if ctx.env_name.ends_with('*') {
319            ctx.parser.gullet.consume_spaces();
320            if ctx.parser.gullet.future().text == "[" {
321                ctx.parser.gullet.pop_token();
322                ctx.parser.gullet.consume_spaces();
323                let align_tok = ctx.parser.gullet.pop_token();
324                if !"lcr".contains(align_tok.text.as_str()) {
325                    return Err(ParseError::new(
326                        "Expected l or c or r".to_string(),
327                        Some(&align_tok),
328                    ));
329                }
330                col_align = align_tok.text.clone();
331                ctx.parser.gullet.consume_spaces();
332                let close = ctx.parser.gullet.pop_token();
333                if close.text != "]" {
334                    return Err(ParseError::new(
335                        "Expected ]".to_string(),
336                        Some(&close),
337                    ));
338                }
339            }
340        }
341
342        let config = ArrayConfig {
343            hskip_before_and_after: Some(false),
344            cols: Some(vec![AlignSpec {
345                align_type: AlignType::Align,
346                align: Some(col_align.clone()),
347                pregap: None,
348                postgap: None,
349            }]),
350            ..Default::default()
351        };
352
353        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
354
355        // Fix cols to match actual number of columns
356        if let ParseNode::Array {
357            ref body,
358            ref mut cols,
359            ..
360        } = res
361        {
362            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
363            *cols = Some(
364                (0..num_cols)
365                    .map(|_| AlignSpec {
366                        align_type: AlignType::Align,
367                        align: Some(col_align.to_string()),
368                        pregap: None,
369                        postgap: None,
370                    })
371                    .collect(),
372            );
373        }
374
375        match delimiters {
376            Some((left, right)) => Ok(ParseNode::LeftRight {
377                mode: ctx.mode,
378                body: vec![res],
379                left: left.to_string(),
380                right: right.to_string(),
381                right_color: None,
382                loc: None,
383            }),
384            None => Ok(res),
385        }
386    }
387
388    for name in &[
389        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
390        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
391    ] {
392        map.insert(
393            name,
394            EnvSpec {
395                num_args: 0,
396                num_optional_args: 0,
397                handler: handle_matrix,
398            },
399        );
400    }
401}
402
403// ── cases / dcases / rcases / drcases ────────────────────────────────────
404
405fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
406    fn handle_cases(
407        ctx: &mut EnvContext,
408        _args: Vec<ParseNode>,
409        _opt_args: Vec<Option<ParseNode>>,
410    ) -> ParseResult<ParseNode> {
411        let config = ArrayConfig {
412            arraystretch: Some(1.2),
413            cols: Some(vec![
414                AlignSpec {
415                    align_type: AlignType::Align,
416                    align: Some("l".to_string()),
417                    pregap: Some(0.0),
418                    postgap: Some(1.0),
419                },
420                AlignSpec {
421                    align_type: AlignType::Align,
422                    align: Some("l".to_string()),
423                    pregap: Some(0.0),
424                    postgap: Some(0.0),
425                },
426            ]),
427            ..Default::default()
428        };
429
430        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
431
432        let (left, right) = if ctx.env_name.contains('r') {
433            (".", "\\}")
434        } else {
435            ("\\{", ".")
436        };
437
438        Ok(ParseNode::LeftRight {
439            mode: ctx.mode,
440            body: vec![res],
441            left: left.to_string(),
442            right: right.to_string(),
443            right_color: None,
444            loc: None,
445        })
446    }
447
448    for name in &["cases", "dcases", "rcases", "drcases"] {
449        map.insert(
450            name,
451            EnvSpec {
452                num_args: 0,
453                num_optional_args: 0,
454                handler: handle_cases,
455            },
456        );
457    }
458}
459
460// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
461
462fn handle_aligned(
463    ctx: &mut EnvContext,
464    args: Vec<ParseNode>,
465    _opt_args: Vec<Option<ParseNode>>,
466) -> ParseResult<ParseNode> {
467        let is_split = ctx.env_name == "split";
468        let is_alignat = ctx.env_name.contains("at");
469        let sep_type = if is_alignat { "alignat" } else { "align" };
470
471        let config = ArrayConfig {
472            add_jot: Some(true),
473            empty_single_row: true,
474            col_separation_type: Some(sep_type.to_string()),
475            max_num_cols: if is_split { Some(2) } else { None },
476            ..Default::default()
477        };
478
479        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
480
481        // Extract explicit column count from first arg (alignat only)
482        let mut num_maths = 0usize;
483        let mut explicit_cols = 0usize;
484        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
485            let mut arg_str = String::new();
486            for node in body {
487                if let Some(t) = node.symbol_text() {
488                    arg_str.push_str(t);
489                }
490            }
491            if let Ok(n) = arg_str.parse::<usize>() {
492                num_maths = n;
493                explicit_cols = n * 2;
494            }
495        }
496        let is_aligned = explicit_cols == 0;
497
498        // Determine actual number of columns
499        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
500            body.iter().map(|r| r.len()).max().unwrap_or(0)
501        } else {
502            0
503        };
504
505        if let ParseNode::Array {
506            body: ref mut array_body,
507            ..
508        } = res
509        {
510            for row in array_body.iter_mut() {
511                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
512                let mut i = 1;
513                while i < row.len() {
514                    if let ParseNode::Styling {
515                        body: ref mut styling_body,
516                        ..
517                    } = row[i]
518                    {
519                        if let Some(ParseNode::OrdGroup {
520                            body: ref mut og_body,
521                            ..
522                        }) = styling_body.first_mut()
523                        {
524                            og_body.insert(
525                                0,
526                                ParseNode::OrdGroup {
527                                    mode: ctx.mode,
528                                    body: vec![],
529                                    semisimple: None,
530                                    loc: None,
531                                },
532                            );
533                        }
534                    }
535                    i += 2;
536                }
537
538                if !is_aligned {
539                    let cur_maths = row.len() / 2;
540                    if num_maths < cur_maths {
541                        return Err(ParseError::msg(format!(
542                            "Too many math in a row: expected {}, but got {}",
543                            num_maths, cur_maths
544                        )));
545                    }
546                } else if num_cols < row.len() {
547                    num_cols = row.len();
548                }
549            }
550        }
551
552        if !is_aligned {
553            num_cols = explicit_cols;
554        }
555
556        let mut cols = Vec::new();
557        for i in 0..num_cols {
558            let (align, pregap) = if i % 2 == 1 {
559                ("l", 0.0)
560            } else if i > 0 && is_aligned {
561                ("r", 1.0)
562            } else {
563                ("r", 0.0)
564            };
565            cols.push(AlignSpec {
566                align_type: AlignType::Align,
567                align: Some(align.to_string()),
568                pregap: Some(pregap),
569                postgap: Some(0.0),
570            });
571        }
572
573        if let ParseNode::Array {
574            cols: ref mut array_cols,
575            col_separation_type: ref mut array_sep_type,
576            ..
577        } = res
578        {
579            *array_cols = Some(cols);
580            *array_sep_type = Some(
581                if is_aligned { "align" } else { "alignat" }.to_string(),
582            );
583        }
584
585    Ok(res)
586}
587
588fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
589    for name in &["align", "align*", "aligned", "split"] {
590        map.insert(
591            name,
592            EnvSpec {
593                num_args: 0,
594                num_optional_args: 0,
595                handler: handle_aligned,
596            },
597        );
598    }
599}
600
601// ── gathered / gather / gather* ──────────────────────────────────────────
602
603fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
604    fn handle_gathered(
605        ctx: &mut EnvContext,
606        _args: Vec<ParseNode>,
607        _opt_args: Vec<Option<ParseNode>>,
608    ) -> ParseResult<ParseNode> {
609        let config = ArrayConfig {
610            cols: Some(vec![AlignSpec {
611                align_type: AlignType::Align,
612                align: Some("c".to_string()),
613                pregap: None,
614                postgap: None,
615            }]),
616            add_jot: Some(true),
617            col_separation_type: Some("gather".to_string()),
618            empty_single_row: true,
619            ..Default::default()
620        };
621        parse_array(ctx.parser, config, Some(StyleStr::Display))
622    }
623
624    for name in &["gathered", "gather", "gather*"] {
625        map.insert(
626            name,
627            EnvSpec {
628                num_args: 0,
629                num_optional_args: 0,
630                handler: handle_gathered,
631            },
632        );
633    }
634}
635
636// ── equation / equation* ─────────────────────────────────────────────────
637
638fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
639    fn handle_equation(
640        ctx: &mut EnvContext,
641        _args: Vec<ParseNode>,
642        _opt_args: Vec<Option<ParseNode>>,
643    ) -> ParseResult<ParseNode> {
644        let config = ArrayConfig {
645            empty_single_row: true,
646            single_row: true,
647            max_num_cols: Some(1),
648            ..Default::default()
649        };
650        parse_array(ctx.parser, config, Some(StyleStr::Display))
651    }
652
653    for name in &["equation", "equation*"] {
654        map.insert(
655            name,
656            EnvSpec {
657                num_args: 0,
658                num_optional_args: 0,
659                handler: handle_equation,
660            },
661        );
662    }
663}
664
665// ── smallmatrix ──────────────────────────────────────────────────────────
666
667fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
668    fn handle_smallmatrix(
669        ctx: &mut EnvContext,
670        _args: Vec<ParseNode>,
671        _opt_args: Vec<Option<ParseNode>>,
672    ) -> ParseResult<ParseNode> {
673        let config = ArrayConfig {
674            arraystretch: Some(0.5),
675            ..Default::default()
676        };
677        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
678        if let ParseNode::Array {
679            ref mut col_separation_type,
680            ..
681        } = res
682        {
683            *col_separation_type = Some("small".to_string());
684        }
685        Ok(res)
686    }
687
688    map.insert(
689        "smallmatrix",
690        EnvSpec {
691            num_args: 0,
692            num_optional_args: 0,
693            handler: handle_smallmatrix,
694        },
695    );
696}
697
698// ── alignat / alignat* / alignedat ──────────────────────────────────────
699
700fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
701    for name in &["alignat", "alignat*", "alignedat"] {
702        map.insert(
703            name,
704            EnvSpec {
705                num_args: 1,
706                num_optional_args: 0,
707                handler: handle_aligned,
708            },
709        );
710    }
711}
712
713// ── subarray ────────────────────────────────────────────────────────────
714
715fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
716    fn handle_subarray(
717        ctx: &mut EnvContext,
718        args: Vec<ParseNode>,
719        _opt_args: Vec<Option<ParseNode>>,
720    ) -> ParseResult<ParseNode> {
721        let colalign = match &args[0] {
722            ParseNode::OrdGroup { body, .. } => body.clone(),
723            other if other.is_symbol_node() => vec![other.clone()],
724            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
725        };
726
727        let mut cols = Vec::new();
728        for nde in &colalign {
729            let ca = nde
730                .symbol_text()
731                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
732            match ca {
733                "l" | "c" => cols.push(AlignSpec {
734                    align_type: AlignType::Align,
735                    align: Some(ca.to_string()),
736                    pregap: None,
737                    postgap: None,
738                }),
739                _ => {
740                    return Err(ParseError::msg(format!(
741                        "Unknown column alignment: {}",
742                        ca
743                    )))
744                }
745            }
746        }
747
748        if cols.len() > 1 {
749            return Err(ParseError::msg("{subarray} can contain only one column"));
750        }
751
752        let config = ArrayConfig {
753            cols: Some(cols),
754            hskip_before_and_after: Some(false),
755            arraystretch: Some(0.5),
756            ..Default::default()
757        };
758
759        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
760
761        if let ParseNode::Array { ref body, .. } = res {
762            if !body.is_empty() && body[0].len() > 1 {
763                return Err(ParseError::msg("{subarray} can contain only one column"));
764            }
765        }
766
767        Ok(res)
768    }
769
770    map.insert(
771        "subarray",
772        EnvSpec {
773            num_args: 1,
774            num_optional_args: 0,
775            handler: handle_subarray,
776        },
777    );
778}