1use std::collections::HashMap;
2
3use crate::error::{ParseError, ParseResult};
4use crate::macro_expander::MacroDefinition;
5use crate::parse_node::{AlignSpec, AlignType, Measurement, Mode, ParseNode, StyleStr};
6use crate::parser::Parser;
7
8pub struct EnvContext<'a, 'b> {
11 pub mode: Mode,
12 pub env_name: String,
13 pub parser: &'a mut Parser<'b>,
14}
15
16pub type EnvHandler = fn(
17 ctx: &mut EnvContext,
18 args: Vec<ParseNode>,
19 opt_args: Vec<Option<ParseNode>>,
20) -> ParseResult<ParseNode>;
21
22pub struct EnvSpec {
23 pub num_args: usize,
24 pub num_optional_args: usize,
25 pub handler: EnvHandler,
26}
27
28pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
29 std::sync::LazyLock::new(|| {
30 let mut map = HashMap::new();
31 register_array(&mut map);
32 register_matrix(&mut map);
33 register_cases(&mut map);
34 register_align(&mut map);
35 register_gathered(&mut map);
36 register_equation(&mut map);
37 register_smallmatrix(&mut map);
38 register_alignat(&mut map);
39 register_subarray(&mut map);
40 map
41 });
42
43#[derive(Default)]
46pub struct ArrayConfig {
47 pub hskip_before_and_after: Option<bool>,
48 pub add_jot: Option<bool>,
49 pub cols: Option<Vec<AlignSpec>>,
50 pub arraystretch: Option<f64>,
51 pub col_separation_type: Option<String>,
52 pub single_row: bool,
53 pub empty_single_row: bool,
54 pub max_num_cols: Option<usize>,
55 pub leqno: Option<bool>,
56}
57
58
59fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
62 let mut hline_info = Vec::new();
63 parser.consume_spaces()?;
64
65 let mut nxt = parser.fetch()?.text.clone();
66 if nxt == "\\relax" {
67 parser.consume();
68 parser.consume_spaces()?;
69 nxt = parser.fetch()?.text.clone();
70 }
71 while nxt == "\\hline" || nxt == "\\hdashline" {
72 parser.consume();
73 hline_info.push(nxt == "\\hdashline");
74 parser.consume_spaces()?;
75 nxt = parser.fetch()?.text.clone();
76 }
77 Ok(hline_info)
78}
79
80fn d_cell_style(env_name: &str) -> Option<StyleStr> {
81 if env_name.starts_with('d') {
82 Some(StyleStr::Display)
83 } else {
84 Some(StyleStr::Text)
85 }
86}
87
88pub fn parse_array(
89 parser: &mut Parser,
90 config: ArrayConfig,
91 style: Option<StyleStr>,
92) -> ParseResult<ParseNode> {
93 parser.gullet.begin_group();
94
95 if !config.single_row {
96 parser
97 .gullet
98 .set_text_macro("\\cr", "\\\\\\relax");
99 }
100
101 let arraystretch = config.arraystretch.unwrap_or_else(|| {
102 if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
104 let s = match def {
105 MacroDefinition::Text(s) => s.clone(),
106 MacroDefinition::Tokens { tokens, .. } => {
107 tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
109 }
110 MacroDefinition::Function(_) => String::new(),
111 };
112 s.parse::<f64>().unwrap_or(1.0)
113 } else {
114 1.0
115 }
116 });
117
118 parser.gullet.begin_group();
119
120 let mut row: Vec<ParseNode> = Vec::new();
121 let mut body: Vec<Vec<ParseNode>> = Vec::new();
122 let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
123 let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
124
125 hlines_before_row.push(get_hlines(parser)?);
126
127 loop {
128 let break_token = if config.single_row { "\\end" } else { "\\\\" };
129 let cell_body = parser.parse_expression(false, Some(break_token))?;
130 parser.gullet.end_group();
131 parser.gullet.begin_group();
132
133 let mut cell = ParseNode::OrdGroup {
134 mode: parser.mode,
135 body: cell_body,
136 semisimple: None,
137 loc: None,
138 };
139
140 if let Some(s) = style {
141 cell = ParseNode::Styling {
142 mode: parser.mode,
143 style: s,
144 body: vec![cell],
145 loc: None,
146 };
147 }
148
149 row.push(cell.clone());
150 let next = parser.fetch()?.text.clone();
151
152 if next == "&" {
153 if let Some(max) = config.max_num_cols {
154 if row.len() >= max {
155 return Err(ParseError::msg("Too many tab characters: &"));
156 }
157 }
158 parser.consume();
159 } else if next == "\\end" {
160 let is_empty_trailing = if let Some(s) = style {
162 if s == StyleStr::Text || s == StyleStr::Display {
163 if let ParseNode::Styling { body: ref sb, .. } = cell {
164 if let Some(ParseNode::OrdGroup {
165 body: ref ob, ..
166 }) = sb.first()
167 {
168 ob.is_empty()
169 } else {
170 false
171 }
172 } else {
173 false
174 }
175 } else {
176 false
177 }
178 } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
179 ob.is_empty()
180 } else {
181 false
182 };
183
184 body.push(row);
185
186 if is_empty_trailing
187 && (body.len() > 1 || !config.empty_single_row)
188 {
189 body.pop();
190 }
191
192 if hlines_before_row.len() < body.len() + 1 {
193 hlines_before_row.push(vec![]);
194 }
195 break;
196 } else if next == "\\\\" {
197 parser.consume();
198 let size = if parser.gullet.future().text != " " {
199 parser.parse_size_group(true)?
200 } else {
201 None
202 };
203 let gap = size.and_then(|s| {
204 if let ParseNode::Size { value, .. } = s {
205 Some(value)
206 } else {
207 None
208 }
209 });
210 row_gaps.push(gap);
211
212 body.push(row);
213 hlines_before_row.push(get_hlines(parser)?);
214 row = Vec::new();
215 } else {
216 return Err(ParseError::msg(format!(
217 "Expected & or \\\\ or \\cr or \\end, got '{}'",
218 next
219 )));
220 }
221 }
222
223 parser.gullet.end_group();
224 parser.gullet.end_group();
225
226 Ok(ParseNode::Array {
227 mode: parser.mode,
228 body,
229 row_gaps,
230 hlines_before_row,
231 cols: config.cols,
232 col_separation_type: config.col_separation_type,
233 hskip_before_and_after: config.hskip_before_and_after,
234 add_jot: config.add_jot,
235 arraystretch,
236 tags: None,
237 leqno: config.leqno,
238 is_cd: None,
239 loc: None,
240 })
241}
242
243fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
246 fn handle_array(
247 ctx: &mut EnvContext,
248 args: Vec<ParseNode>,
249 _opt_args: Vec<Option<ParseNode>>,
250 ) -> ParseResult<ParseNode> {
251 let colalign = match &args[0] {
252 ParseNode::OrdGroup { body, .. } => body.clone(),
253 other if other.is_symbol_node() => vec![other.clone()],
254 _ => return Err(ParseError::msg("Invalid column alignment for array")),
255 };
256
257 let mut cols = Vec::new();
258 for nde in &colalign {
259 let ca = nde
260 .symbol_text()
261 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
262 match ca {
263 "l" | "c" | "r" => cols.push(AlignSpec {
264 align_type: AlignType::Align,
265 align: Some(ca.to_string()),
266 pregap: None,
267 postgap: None,
268 }),
269 "|" => cols.push(AlignSpec {
270 align_type: AlignType::Separator,
271 align: Some("|".to_string()),
272 pregap: None,
273 postgap: None,
274 }),
275 ":" => cols.push(AlignSpec {
276 align_type: AlignType::Separator,
277 align: Some(":".to_string()),
278 pregap: None,
279 postgap: None,
280 }),
281 _ => {
282 return Err(ParseError::msg(format!(
283 "Unknown column alignment: {}",
284 ca
285 )))
286 }
287 }
288 }
289
290 let max_num_cols = cols.len();
291 let config = ArrayConfig {
292 cols: Some(cols),
293 hskip_before_and_after: Some(true),
294 max_num_cols: Some(max_num_cols),
295 ..Default::default()
296 };
297 parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
298 }
299
300 for name in &["array", "darray"] {
301 map.insert(
302 name,
303 EnvSpec {
304 num_args: 1,
305 num_optional_args: 0,
306 handler: handle_array,
307 },
308 );
309 }
310}
311
312fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
315 fn handle_matrix(
316 ctx: &mut EnvContext,
317 _args: Vec<ParseNode>,
318 _opt_args: Vec<Option<ParseNode>>,
319 ) -> ParseResult<ParseNode> {
320 let base_name = ctx.env_name.replace('*', "");
321 let delimiters: Option<(&str, &str)> = match base_name.as_str() {
322 "matrix" => None,
323 "pmatrix" => Some(("(", ")")),
324 "bmatrix" => Some(("[", "]")),
325 "Bmatrix" => Some(("\\{", "\\}")),
326 "vmatrix" => Some(("|", "|")),
327 "Vmatrix" => Some(("\\Vert", "\\Vert")),
328 _ => None,
329 };
330
331 let mut col_align = "c".to_string();
332
333 if ctx.env_name.ends_with('*') {
335 ctx.parser.gullet.consume_spaces();
336 if ctx.parser.gullet.future().text == "[" {
337 ctx.parser.gullet.pop_token();
338 ctx.parser.gullet.consume_spaces();
339 let align_tok = ctx.parser.gullet.pop_token();
340 if !"lcr".contains(align_tok.text.as_str()) {
341 return Err(ParseError::new(
342 "Expected l or c or r".to_string(),
343 Some(&align_tok),
344 ));
345 }
346 col_align = align_tok.text.clone();
347 ctx.parser.gullet.consume_spaces();
348 let close = ctx.parser.gullet.pop_token();
349 if close.text != "]" {
350 return Err(ParseError::new(
351 "Expected ]".to_string(),
352 Some(&close),
353 ));
354 }
355 }
356 }
357
358 let config = ArrayConfig {
359 hskip_before_and_after: Some(false),
360 cols: Some(vec![AlignSpec {
361 align_type: AlignType::Align,
362 align: Some(col_align.clone()),
363 pregap: None,
364 postgap: None,
365 }]),
366 ..Default::default()
367 };
368
369 let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
370
371 if let ParseNode::Array {
373 ref body,
374 ref mut cols,
375 ..
376 } = res
377 {
378 let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
379 *cols = Some(
380 (0..num_cols)
381 .map(|_| AlignSpec {
382 align_type: AlignType::Align,
383 align: Some(col_align.to_string()),
384 pregap: None,
385 postgap: None,
386 })
387 .collect(),
388 );
389 }
390
391 match delimiters {
392 Some((left, right)) => Ok(ParseNode::LeftRight {
393 mode: ctx.mode,
394 body: vec![res],
395 left: left.to_string(),
396 right: right.to_string(),
397 right_color: None,
398 loc: None,
399 }),
400 None => Ok(res),
401 }
402 }
403
404 for name in &[
405 "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
406 "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
407 ] {
408 map.insert(
409 name,
410 EnvSpec {
411 num_args: 0,
412 num_optional_args: 0,
413 handler: handle_matrix,
414 },
415 );
416 }
417}
418
419fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
422 fn handle_cases(
423 ctx: &mut EnvContext,
424 _args: Vec<ParseNode>,
425 _opt_args: Vec<Option<ParseNode>>,
426 ) -> ParseResult<ParseNode> {
427 let config = ArrayConfig {
428 arraystretch: Some(1.2),
429 cols: Some(vec![
430 AlignSpec {
431 align_type: AlignType::Align,
432 align: Some("l".to_string()),
433 pregap: Some(0.0),
434 postgap: Some(1.0),
435 },
436 AlignSpec {
437 align_type: AlignType::Align,
438 align: Some("l".to_string()),
439 pregap: Some(0.0),
440 postgap: Some(0.0),
441 },
442 ]),
443 ..Default::default()
444 };
445
446 let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
447
448 let (left, right) = if ctx.env_name.contains('r') {
449 (".", "\\}")
450 } else {
451 ("\\{", ".")
452 };
453
454 Ok(ParseNode::LeftRight {
455 mode: ctx.mode,
456 body: vec![res],
457 left: left.to_string(),
458 right: right.to_string(),
459 right_color: None,
460 loc: None,
461 })
462 }
463
464 for name in &["cases", "dcases", "rcases", "drcases"] {
465 map.insert(
466 name,
467 EnvSpec {
468 num_args: 0,
469 num_optional_args: 0,
470 handler: handle_cases,
471 },
472 );
473 }
474}
475
476fn handle_aligned(
479 ctx: &mut EnvContext,
480 args: Vec<ParseNode>,
481 _opt_args: Vec<Option<ParseNode>>,
482) -> ParseResult<ParseNode> {
483 let is_split = ctx.env_name == "split";
484 let is_alignat = ctx.env_name.contains("at");
485 let sep_type = if is_alignat { "alignat" } else { "align" };
486
487 let config = ArrayConfig {
488 add_jot: Some(true),
489 empty_single_row: true,
490 col_separation_type: Some(sep_type.to_string()),
491 max_num_cols: if is_split { Some(2) } else { None },
492 ..Default::default()
493 };
494
495 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
496
497 let mut num_maths = 0usize;
499 let mut explicit_cols = 0usize;
500 if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
501 let mut arg_str = String::new();
502 for node in body {
503 if let Some(t) = node.symbol_text() {
504 arg_str.push_str(t);
505 }
506 }
507 if let Ok(n) = arg_str.parse::<usize>() {
508 num_maths = n;
509 explicit_cols = n * 2;
510 }
511 }
512 let is_aligned = explicit_cols == 0;
513
514 let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
516 body.iter().map(|r| r.len()).max().unwrap_or(0)
517 } else {
518 0
519 };
520
521 if let ParseNode::Array {
522 body: ref mut array_body,
523 ..
524 } = res
525 {
526 for row in array_body.iter_mut() {
527 let mut i = 1;
529 while i < row.len() {
530 if let ParseNode::Styling {
531 body: ref mut styling_body,
532 ..
533 } = row[i]
534 {
535 if let Some(ParseNode::OrdGroup {
536 body: ref mut og_body,
537 ..
538 }) = styling_body.first_mut()
539 {
540 og_body.insert(
541 0,
542 ParseNode::OrdGroup {
543 mode: ctx.mode,
544 body: vec![],
545 semisimple: None,
546 loc: None,
547 },
548 );
549 }
550 }
551 i += 2;
552 }
553
554 if !is_aligned {
555 let cur_maths = row.len() / 2;
556 if num_maths < cur_maths {
557 return Err(ParseError::msg(format!(
558 "Too many math in a row: expected {}, but got {}",
559 num_maths, cur_maths
560 )));
561 }
562 } else if num_cols < row.len() {
563 num_cols = row.len();
564 }
565 }
566 }
567
568 if !is_aligned {
569 num_cols = explicit_cols;
570 }
571
572 let mut cols = Vec::new();
573 for i in 0..num_cols {
574 let (align, pregap) = if i % 2 == 1 {
575 ("l", 0.0)
576 } else if i > 0 && is_aligned {
577 ("r", 1.0)
578 } else {
579 ("r", 0.0)
580 };
581 cols.push(AlignSpec {
582 align_type: AlignType::Align,
583 align: Some(align.to_string()),
584 pregap: Some(pregap),
585 postgap: Some(0.0),
586 });
587 }
588
589 if let ParseNode::Array {
590 cols: ref mut array_cols,
591 col_separation_type: ref mut array_sep_type,
592 ..
593 } = res
594 {
595 *array_cols = Some(cols);
596 *array_sep_type = Some(
597 if is_aligned { "align" } else { "alignat" }.to_string(),
598 );
599 }
600
601 Ok(res)
602}
603
604fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
605 for name in &["align", "align*", "aligned", "split"] {
606 map.insert(
607 name,
608 EnvSpec {
609 num_args: 0,
610 num_optional_args: 0,
611 handler: handle_aligned,
612 },
613 );
614 }
615}
616
617fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
620 fn handle_gathered(
621 ctx: &mut EnvContext,
622 _args: Vec<ParseNode>,
623 _opt_args: Vec<Option<ParseNode>>,
624 ) -> ParseResult<ParseNode> {
625 let config = ArrayConfig {
626 cols: Some(vec![AlignSpec {
627 align_type: AlignType::Align,
628 align: Some("c".to_string()),
629 pregap: None,
630 postgap: None,
631 }]),
632 add_jot: Some(true),
633 col_separation_type: Some("gather".to_string()),
634 empty_single_row: true,
635 ..Default::default()
636 };
637 parse_array(ctx.parser, config, Some(StyleStr::Display))
638 }
639
640 for name in &["gathered", "gather", "gather*"] {
641 map.insert(
642 name,
643 EnvSpec {
644 num_args: 0,
645 num_optional_args: 0,
646 handler: handle_gathered,
647 },
648 );
649 }
650}
651
652fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
655 fn handle_equation(
656 ctx: &mut EnvContext,
657 _args: Vec<ParseNode>,
658 _opt_args: Vec<Option<ParseNode>>,
659 ) -> ParseResult<ParseNode> {
660 let config = ArrayConfig {
661 empty_single_row: true,
662 single_row: true,
663 max_num_cols: Some(1),
664 ..Default::default()
665 };
666 parse_array(ctx.parser, config, Some(StyleStr::Display))
667 }
668
669 for name in &["equation", "equation*"] {
670 map.insert(
671 name,
672 EnvSpec {
673 num_args: 0,
674 num_optional_args: 0,
675 handler: handle_equation,
676 },
677 );
678 }
679}
680
681fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
684 fn handle_smallmatrix(
685 ctx: &mut EnvContext,
686 _args: Vec<ParseNode>,
687 _opt_args: Vec<Option<ParseNode>>,
688 ) -> ParseResult<ParseNode> {
689 let config = ArrayConfig {
690 arraystretch: Some(0.5),
691 ..Default::default()
692 };
693 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
694 if let ParseNode::Array {
695 ref mut col_separation_type,
696 ..
697 } = res
698 {
699 *col_separation_type = Some("small".to_string());
700 }
701 Ok(res)
702 }
703
704 map.insert(
705 "smallmatrix",
706 EnvSpec {
707 num_args: 0,
708 num_optional_args: 0,
709 handler: handle_smallmatrix,
710 },
711 );
712}
713
714fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
717 for name in &["alignat", "alignat*", "alignedat"] {
718 map.insert(
719 name,
720 EnvSpec {
721 num_args: 1,
722 num_optional_args: 0,
723 handler: handle_aligned,
724 },
725 );
726 }
727}
728
729fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
732 fn handle_subarray(
733 ctx: &mut EnvContext,
734 args: Vec<ParseNode>,
735 _opt_args: Vec<Option<ParseNode>>,
736 ) -> ParseResult<ParseNode> {
737 let colalign = match &args[0] {
738 ParseNode::OrdGroup { body, .. } => body.clone(),
739 other if other.is_symbol_node() => vec![other.clone()],
740 _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
741 };
742
743 let mut cols = Vec::new();
744 for nde in &colalign {
745 let ca = nde
746 .symbol_text()
747 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
748 match ca {
749 "l" | "c" => cols.push(AlignSpec {
750 align_type: AlignType::Align,
751 align: Some(ca.to_string()),
752 pregap: None,
753 postgap: None,
754 }),
755 _ => {
756 return Err(ParseError::msg(format!(
757 "Unknown column alignment: {}",
758 ca
759 )))
760 }
761 }
762 }
763
764 if cols.len() > 1 {
765 return Err(ParseError::msg("{subarray} can contain only one column"));
766 }
767
768 let config = ArrayConfig {
769 cols: Some(cols),
770 hskip_before_and_after: Some(false),
771 arraystretch: Some(0.5),
772 ..Default::default()
773 };
774
775 let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
776
777 if let ParseNode::Array { ref body, .. } = res {
778 if !body.is_empty() && body[0].len() > 1 {
779 return Err(ParseError::msg("{subarray} can contain only one column"));
780 }
781 }
782
783 Ok(res)
784 }
785
786 map.insert(
787 "subarray",
788 EnvSpec {
789 num_args: 1,
790 num_optional_args: 0,
791 handler: handle_subarray,
792 },
793 );
794}