Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use crate::error::{ParseError, ParseResult};
4use crate::macro_expander::MacroDefinition;
5use crate::parse_node::{AlignSpec, AlignType, Measurement, Mode, ParseNode, StyleStr};
6use crate::parser::Parser;
7
8// ── Environment registry ─────────────────────────────────────────────────
9
10pub struct EnvContext<'a, 'b> {
11    pub mode: Mode,
12    pub env_name: String,
13    pub parser: &'a mut Parser<'b>,
14}
15
16pub type EnvHandler = fn(
17    ctx: &mut EnvContext,
18    args: Vec<ParseNode>,
19    opt_args: Vec<Option<ParseNode>>,
20) -> ParseResult<ParseNode>;
21
22pub struct EnvSpec {
23    pub num_args: usize,
24    pub num_optional_args: usize,
25    pub handler: EnvHandler,
26}
27
28pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
29    std::sync::LazyLock::new(|| {
30        let mut map = HashMap::new();
31        register_array(&mut map);
32        register_matrix(&mut map);
33        register_cases(&mut map);
34        register_align(&mut map);
35        register_gathered(&mut map);
36        register_equation(&mut map);
37        register_smallmatrix(&mut map);
38        register_alignat(&mut map);
39        register_subarray(&mut map);
40        map
41    });
42
43// ── ArrayConfig ──────────────────────────────────────────────────────────
44
45#[derive(Default)]
46pub struct ArrayConfig {
47    pub hskip_before_and_after: Option<bool>,
48    pub add_jot: Option<bool>,
49    pub cols: Option<Vec<AlignSpec>>,
50    pub arraystretch: Option<f64>,
51    pub col_separation_type: Option<String>,
52    pub single_row: bool,
53    pub empty_single_row: bool,
54    pub max_num_cols: Option<usize>,
55    pub leqno: Option<bool>,
56}
57
58
59// ── parseArray ───────────────────────────────────────────────────────────
60
61fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
62    let mut hline_info = Vec::new();
63    parser.consume_spaces()?;
64
65    let mut nxt = parser.fetch()?.text.clone();
66    if nxt == "\\relax" {
67        parser.consume();
68        parser.consume_spaces()?;
69        nxt = parser.fetch()?.text.clone();
70    }
71    while nxt == "\\hline" || nxt == "\\hdashline" {
72        parser.consume();
73        hline_info.push(nxt == "\\hdashline");
74        parser.consume_spaces()?;
75        nxt = parser.fetch()?.text.clone();
76    }
77    Ok(hline_info)
78}
79
80fn d_cell_style(env_name: &str) -> Option<StyleStr> {
81    if env_name.starts_with('d') {
82        Some(StyleStr::Display)
83    } else {
84        Some(StyleStr::Text)
85    }
86}
87
88pub fn parse_array(
89    parser: &mut Parser,
90    config: ArrayConfig,
91    style: Option<StyleStr>,
92) -> ParseResult<ParseNode> {
93    parser.gullet.begin_group();
94
95    if !config.single_row {
96        parser
97            .gullet
98            .set_text_macro("\\cr", "\\\\\\relax");
99    }
100
101    let arraystretch = config.arraystretch.unwrap_or_else(|| {
102        // Check if \arraystretch is defined as a macro (e.g., via \def\arraystretch{1.5})
103        if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
104            let s = match def {
105                MacroDefinition::Text(s) => s.clone(),
106                MacroDefinition::Tokens { tokens, .. } => {
107                    // Tokens are stored in reverse order (stack convention for expansion)
108                    tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
109                }
110                MacroDefinition::Function(_) => String::new(),
111            };
112            s.parse::<f64>().unwrap_or(1.0)
113        } else {
114            1.0
115        }
116    });
117
118    parser.gullet.begin_group();
119
120    let mut row: Vec<ParseNode> = Vec::new();
121    let mut body: Vec<Vec<ParseNode>> = Vec::new();
122    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
123    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
124
125    hlines_before_row.push(get_hlines(parser)?);
126
127    loop {
128        let break_token = if config.single_row { "\\end" } else { "\\\\" };
129        let cell_body = parser.parse_expression(false, Some(break_token))?;
130        parser.gullet.end_group();
131        parser.gullet.begin_group();
132
133        let mut cell = ParseNode::OrdGroup {
134            mode: parser.mode,
135            body: cell_body,
136            semisimple: None,
137            loc: None,
138        };
139
140        if let Some(s) = style {
141            cell = ParseNode::Styling {
142                mode: parser.mode,
143                style: s,
144                body: vec![cell],
145                loc: None,
146            };
147        }
148
149        row.push(cell.clone());
150        let next = parser.fetch()?.text.clone();
151
152        if next == "&" {
153            if let Some(max) = config.max_num_cols {
154                if row.len() >= max {
155                    return Err(ParseError::msg("Too many tab characters: &"));
156                }
157            }
158            parser.consume();
159        } else if next == "\\end" {
160            // Check for trailing empty row and remove it
161            let is_empty_trailing = if let Some(s) = style {
162                if s == StyleStr::Text || s == StyleStr::Display {
163                    if let ParseNode::Styling { body: ref sb, .. } = cell {
164                        if let Some(ParseNode::OrdGroup {
165                            body: ref ob, ..
166                        }) = sb.first()
167                        {
168                            ob.is_empty()
169                        } else {
170                            false
171                        }
172                    } else {
173                        false
174                    }
175                } else {
176                    false
177                }
178            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
179                ob.is_empty()
180            } else {
181                false
182            };
183
184            body.push(row);
185
186            if is_empty_trailing
187                && (body.len() > 1 || !config.empty_single_row)
188            {
189                body.pop();
190            }
191
192            if hlines_before_row.len() < body.len() + 1 {
193                hlines_before_row.push(vec![]);
194            }
195            break;
196        } else if next == "\\\\" {
197            parser.consume();
198            let size = if parser.gullet.future().text != " " {
199                parser.parse_size_group(true)?
200            } else {
201                None
202            };
203            let gap = size.and_then(|s| {
204                if let ParseNode::Size { value, .. } = s {
205                    Some(value)
206                } else {
207                    None
208                }
209            });
210            row_gaps.push(gap);
211
212            body.push(row);
213            hlines_before_row.push(get_hlines(parser)?);
214            row = Vec::new();
215        } else {
216            return Err(ParseError::msg(format!(
217                "Expected & or \\\\ or \\cr or \\end, got '{}'",
218                next
219            )));
220        }
221    }
222
223    parser.gullet.end_group();
224    parser.gullet.end_group();
225
226    Ok(ParseNode::Array {
227        mode: parser.mode,
228        body,
229        row_gaps,
230        hlines_before_row,
231        cols: config.cols,
232        col_separation_type: config.col_separation_type,
233        hskip_before_and_after: config.hskip_before_and_after,
234        add_jot: config.add_jot,
235        arraystretch,
236        tags: None,
237        leqno: config.leqno,
238        is_cd: None,
239        loc: None,
240    })
241}
242
243// ── array / darray ───────────────────────────────────────────────────────
244
245fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
246    fn handle_array(
247        ctx: &mut EnvContext,
248        args: Vec<ParseNode>,
249        _opt_args: Vec<Option<ParseNode>>,
250    ) -> ParseResult<ParseNode> {
251        let colalign = match &args[0] {
252            ParseNode::OrdGroup { body, .. } => body.clone(),
253            other if other.is_symbol_node() => vec![other.clone()],
254            _ => return Err(ParseError::msg("Invalid column alignment for array")),
255        };
256
257        let mut cols = Vec::new();
258        for nde in &colalign {
259            let ca = nde
260                .symbol_text()
261                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
262            match ca {
263                "l" | "c" | "r" => cols.push(AlignSpec {
264                    align_type: AlignType::Align,
265                    align: Some(ca.to_string()),
266                    pregap: None,
267                    postgap: None,
268                }),
269                "|" => cols.push(AlignSpec {
270                    align_type: AlignType::Separator,
271                    align: Some("|".to_string()),
272                    pregap: None,
273                    postgap: None,
274                }),
275                ":" => cols.push(AlignSpec {
276                    align_type: AlignType::Separator,
277                    align: Some(":".to_string()),
278                    pregap: None,
279                    postgap: None,
280                }),
281                _ => {
282                    return Err(ParseError::msg(format!(
283                        "Unknown column alignment: {}",
284                        ca
285                    )))
286                }
287            }
288        }
289
290        let max_num_cols = cols.len();
291        let config = ArrayConfig {
292            cols: Some(cols),
293            hskip_before_and_after: Some(true),
294            max_num_cols: Some(max_num_cols),
295            ..Default::default()
296        };
297        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
298    }
299
300    for name in &["array", "darray"] {
301        map.insert(
302            name,
303            EnvSpec {
304                num_args: 1,
305                num_optional_args: 0,
306                handler: handle_array,
307            },
308        );
309    }
310}
311
312// ── matrix variants ──────────────────────────────────────────────────────
313
314fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
315    fn handle_matrix(
316        ctx: &mut EnvContext,
317        _args: Vec<ParseNode>,
318        _opt_args: Vec<Option<ParseNode>>,
319    ) -> ParseResult<ParseNode> {
320        let base_name = ctx.env_name.replace('*', "");
321        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
322            "matrix" => None,
323            "pmatrix" => Some(("(", ")")),
324            "bmatrix" => Some(("[", "]")),
325            "Bmatrix" => Some(("\\{", "\\}")),
326            "vmatrix" => Some(("|", "|")),
327            "Vmatrix" => Some(("\\Vert", "\\Vert")),
328            _ => None,
329        };
330
331        let mut col_align = "c".to_string();
332
333        // mathtools starred matrix: parse optional [l|c|r] alignment
334        if ctx.env_name.ends_with('*') {
335            ctx.parser.gullet.consume_spaces();
336            if ctx.parser.gullet.future().text == "[" {
337                ctx.parser.gullet.pop_token();
338                ctx.parser.gullet.consume_spaces();
339                let align_tok = ctx.parser.gullet.pop_token();
340                if !"lcr".contains(align_tok.text.as_str()) {
341                    return Err(ParseError::new(
342                        "Expected l or c or r".to_string(),
343                        Some(&align_tok),
344                    ));
345                }
346                col_align = align_tok.text.clone();
347                ctx.parser.gullet.consume_spaces();
348                let close = ctx.parser.gullet.pop_token();
349                if close.text != "]" {
350                    return Err(ParseError::new(
351                        "Expected ]".to_string(),
352                        Some(&close),
353                    ));
354                }
355            }
356        }
357
358        let config = ArrayConfig {
359            hskip_before_and_after: Some(false),
360            cols: Some(vec![AlignSpec {
361                align_type: AlignType::Align,
362                align: Some(col_align.clone()),
363                pregap: None,
364                postgap: None,
365            }]),
366            ..Default::default()
367        };
368
369        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
370
371        // Fix cols to match actual number of columns
372        if let ParseNode::Array {
373            ref body,
374            ref mut cols,
375            ..
376        } = res
377        {
378            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
379            *cols = Some(
380                (0..num_cols)
381                    .map(|_| AlignSpec {
382                        align_type: AlignType::Align,
383                        align: Some(col_align.to_string()),
384                        pregap: None,
385                        postgap: None,
386                    })
387                    .collect(),
388            );
389        }
390
391        match delimiters {
392            Some((left, right)) => Ok(ParseNode::LeftRight {
393                mode: ctx.mode,
394                body: vec![res],
395                left: left.to_string(),
396                right: right.to_string(),
397                right_color: None,
398                loc: None,
399            }),
400            None => Ok(res),
401        }
402    }
403
404    for name in &[
405        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
406        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
407    ] {
408        map.insert(
409            name,
410            EnvSpec {
411                num_args: 0,
412                num_optional_args: 0,
413                handler: handle_matrix,
414            },
415        );
416    }
417}
418
419// ── cases / dcases / rcases / drcases ────────────────────────────────────
420
421fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
422    fn handle_cases(
423        ctx: &mut EnvContext,
424        _args: Vec<ParseNode>,
425        _opt_args: Vec<Option<ParseNode>>,
426    ) -> ParseResult<ParseNode> {
427        let config = ArrayConfig {
428            arraystretch: Some(1.2),
429            cols: Some(vec![
430                AlignSpec {
431                    align_type: AlignType::Align,
432                    align: Some("l".to_string()),
433                    pregap: Some(0.0),
434                    postgap: Some(1.0),
435                },
436                AlignSpec {
437                    align_type: AlignType::Align,
438                    align: Some("l".to_string()),
439                    pregap: Some(0.0),
440                    postgap: Some(0.0),
441                },
442            ]),
443            ..Default::default()
444        };
445
446        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
447
448        let (left, right) = if ctx.env_name.contains('r') {
449            (".", "\\}")
450        } else {
451            ("\\{", ".")
452        };
453
454        Ok(ParseNode::LeftRight {
455            mode: ctx.mode,
456            body: vec![res],
457            left: left.to_string(),
458            right: right.to_string(),
459            right_color: None,
460            loc: None,
461        })
462    }
463
464    for name in &["cases", "dcases", "rcases", "drcases"] {
465        map.insert(
466            name,
467            EnvSpec {
468                num_args: 0,
469                num_optional_args: 0,
470                handler: handle_cases,
471            },
472        );
473    }
474}
475
476// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
477
478fn handle_aligned(
479    ctx: &mut EnvContext,
480    args: Vec<ParseNode>,
481    _opt_args: Vec<Option<ParseNode>>,
482) -> ParseResult<ParseNode> {
483        let is_split = ctx.env_name == "split";
484        let is_alignat = ctx.env_name.contains("at");
485        let sep_type = if is_alignat { "alignat" } else { "align" };
486
487        let config = ArrayConfig {
488            add_jot: Some(true),
489            empty_single_row: true,
490            col_separation_type: Some(sep_type.to_string()),
491            max_num_cols: if is_split { Some(2) } else { None },
492            ..Default::default()
493        };
494
495        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
496
497        // Extract explicit column count from first arg (alignat only)
498        let mut num_maths = 0usize;
499        let mut explicit_cols = 0usize;
500        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
501            let mut arg_str = String::new();
502            for node in body {
503                if let Some(t) = node.symbol_text() {
504                    arg_str.push_str(t);
505                }
506            }
507            if let Ok(n) = arg_str.parse::<usize>() {
508                num_maths = n;
509                explicit_cols = n * 2;
510            }
511        }
512        let is_aligned = explicit_cols == 0;
513
514        // Determine actual number of columns
515        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
516            body.iter().map(|r| r.len()).max().unwrap_or(0)
517        } else {
518            0
519        };
520
521        if let ParseNode::Array {
522            body: ref mut array_body,
523            ..
524        } = res
525        {
526            for row in array_body.iter_mut() {
527                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
528                let mut i = 1;
529                while i < row.len() {
530                    if let ParseNode::Styling {
531                        body: ref mut styling_body,
532                        ..
533                    } = row[i]
534                    {
535                        if let Some(ParseNode::OrdGroup {
536                            body: ref mut og_body,
537                            ..
538                        }) = styling_body.first_mut()
539                        {
540                            og_body.insert(
541                                0,
542                                ParseNode::OrdGroup {
543                                    mode: ctx.mode,
544                                    body: vec![],
545                                    semisimple: None,
546                                    loc: None,
547                                },
548                            );
549                        }
550                    }
551                    i += 2;
552                }
553
554                if !is_aligned {
555                    let cur_maths = row.len() / 2;
556                    if num_maths < cur_maths {
557                        return Err(ParseError::msg(format!(
558                            "Too many math in a row: expected {}, but got {}",
559                            num_maths, cur_maths
560                        )));
561                    }
562                } else if num_cols < row.len() {
563                    num_cols = row.len();
564                }
565            }
566        }
567
568        if !is_aligned {
569            num_cols = explicit_cols;
570        }
571
572        let mut cols = Vec::new();
573        for i in 0..num_cols {
574            let (align, pregap) = if i % 2 == 1 {
575                ("l", 0.0)
576            } else if i > 0 && is_aligned {
577                ("r", 1.0)
578            } else {
579                ("r", 0.0)
580            };
581            cols.push(AlignSpec {
582                align_type: AlignType::Align,
583                align: Some(align.to_string()),
584                pregap: Some(pregap),
585                postgap: Some(0.0),
586            });
587        }
588
589        if let ParseNode::Array {
590            cols: ref mut array_cols,
591            col_separation_type: ref mut array_sep_type,
592            ..
593        } = res
594        {
595            *array_cols = Some(cols);
596            *array_sep_type = Some(
597                if is_aligned { "align" } else { "alignat" }.to_string(),
598            );
599        }
600
601    Ok(res)
602}
603
604fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
605    for name in &["align", "align*", "aligned", "split"] {
606        map.insert(
607            name,
608            EnvSpec {
609                num_args: 0,
610                num_optional_args: 0,
611                handler: handle_aligned,
612            },
613        );
614    }
615}
616
617// ── gathered / gather / gather* ──────────────────────────────────────────
618
619fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
620    fn handle_gathered(
621        ctx: &mut EnvContext,
622        _args: Vec<ParseNode>,
623        _opt_args: Vec<Option<ParseNode>>,
624    ) -> ParseResult<ParseNode> {
625        let config = ArrayConfig {
626            cols: Some(vec![AlignSpec {
627                align_type: AlignType::Align,
628                align: Some("c".to_string()),
629                pregap: None,
630                postgap: None,
631            }]),
632            add_jot: Some(true),
633            col_separation_type: Some("gather".to_string()),
634            empty_single_row: true,
635            ..Default::default()
636        };
637        parse_array(ctx.parser, config, Some(StyleStr::Display))
638    }
639
640    for name in &["gathered", "gather", "gather*"] {
641        map.insert(
642            name,
643            EnvSpec {
644                num_args: 0,
645                num_optional_args: 0,
646                handler: handle_gathered,
647            },
648        );
649    }
650}
651
652// ── equation / equation* ─────────────────────────────────────────────────
653
654fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
655    fn handle_equation(
656        ctx: &mut EnvContext,
657        _args: Vec<ParseNode>,
658        _opt_args: Vec<Option<ParseNode>>,
659    ) -> ParseResult<ParseNode> {
660        let config = ArrayConfig {
661            empty_single_row: true,
662            single_row: true,
663            max_num_cols: Some(1),
664            ..Default::default()
665        };
666        parse_array(ctx.parser, config, Some(StyleStr::Display))
667    }
668
669    for name in &["equation", "equation*"] {
670        map.insert(
671            name,
672            EnvSpec {
673                num_args: 0,
674                num_optional_args: 0,
675                handler: handle_equation,
676            },
677        );
678    }
679}
680
681// ── smallmatrix ──────────────────────────────────────────────────────────
682
683fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
684    fn handle_smallmatrix(
685        ctx: &mut EnvContext,
686        _args: Vec<ParseNode>,
687        _opt_args: Vec<Option<ParseNode>>,
688    ) -> ParseResult<ParseNode> {
689        let config = ArrayConfig {
690            arraystretch: Some(0.5),
691            ..Default::default()
692        };
693        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
694        if let ParseNode::Array {
695            ref mut col_separation_type,
696            ..
697        } = res
698        {
699            *col_separation_type = Some("small".to_string());
700        }
701        Ok(res)
702    }
703
704    map.insert(
705        "smallmatrix",
706        EnvSpec {
707            num_args: 0,
708            num_optional_args: 0,
709            handler: handle_smallmatrix,
710        },
711    );
712}
713
714// ── alignat / alignat* / alignedat ──────────────────────────────────────
715
716fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
717    for name in &["alignat", "alignat*", "alignedat"] {
718        map.insert(
719            name,
720            EnvSpec {
721                num_args: 1,
722                num_optional_args: 0,
723                handler: handle_aligned,
724            },
725        );
726    }
727}
728
729// ── subarray ────────────────────────────────────────────────────────────
730
731fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
732    fn handle_subarray(
733        ctx: &mut EnvContext,
734        args: Vec<ParseNode>,
735        _opt_args: Vec<Option<ParseNode>>,
736    ) -> ParseResult<ParseNode> {
737        let colalign = match &args[0] {
738            ParseNode::OrdGroup { body, .. } => body.clone(),
739            other if other.is_symbol_node() => vec![other.clone()],
740            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
741        };
742
743        let mut cols = Vec::new();
744        for nde in &colalign {
745            let ca = nde
746                .symbol_text()
747                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
748            match ca {
749                "l" | "c" => cols.push(AlignSpec {
750                    align_type: AlignType::Align,
751                    align: Some(ca.to_string()),
752                    pregap: None,
753                    postgap: None,
754                }),
755                _ => {
756                    return Err(ParseError::msg(format!(
757                        "Unknown column alignment: {}",
758                        ca
759                    )))
760                }
761            }
762        }
763
764        if cols.len() > 1 {
765            return Err(ParseError::msg("{subarray} can contain only one column"));
766        }
767
768        let config = ArrayConfig {
769            cols: Some(cols),
770            hskip_before_and_after: Some(false),
771            arraystretch: Some(0.5),
772            ..Default::default()
773        };
774
775        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
776
777        if let ParseNode::Array { ref body, .. } = res {
778            if !body.is_empty() && body[0].len() > 1 {
779                return Err(ParseError::msg("{subarray} can contain only one column"));
780            }
781        }
782
783        Ok(res)
784    }
785
786    map.insert(
787        "subarray",
788        EnvSpec {
789            num_args: 1,
790            num_optional_args: 0,
791            handler: handle_subarray,
792        },
793    );
794}