1use std::collections::HashMap;
2
3use crate::error::{ParseError, ParseResult};
4use crate::parse_node::{AlignSpec, AlignType, Measurement, Mode, ParseNode, StyleStr};
5use crate::parser::Parser;
6
7pub struct EnvContext<'a, 'b> {
10 pub mode: Mode,
11 pub env_name: String,
12 pub parser: &'a mut Parser<'b>,
13}
14
15pub type EnvHandler = fn(
16 ctx: &mut EnvContext,
17 args: Vec<ParseNode>,
18 opt_args: Vec<Option<ParseNode>>,
19) -> ParseResult<ParseNode>;
20
21pub struct EnvSpec {
22 pub num_args: usize,
23 pub num_optional_args: usize,
24 pub handler: EnvHandler,
25}
26
27pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
28 std::sync::LazyLock::new(|| {
29 let mut map = HashMap::new();
30 register_array(&mut map);
31 register_matrix(&mut map);
32 register_cases(&mut map);
33 register_align(&mut map);
34 register_gathered(&mut map);
35 register_equation(&mut map);
36 register_smallmatrix(&mut map);
37 register_alignat(&mut map);
38 register_subarray(&mut map);
39 map
40 });
41
42#[derive(Default)]
45pub struct ArrayConfig {
46 pub hskip_before_and_after: Option<bool>,
47 pub add_jot: Option<bool>,
48 pub cols: Option<Vec<AlignSpec>>,
49 pub arraystretch: Option<f64>,
50 pub col_separation_type: Option<String>,
51 pub single_row: bool,
52 pub empty_single_row: bool,
53 pub max_num_cols: Option<usize>,
54 pub leqno: Option<bool>,
55}
56
57
58fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
61 let mut hline_info = Vec::new();
62 parser.consume_spaces()?;
63
64 let mut nxt = parser.fetch()?.text.clone();
65 if nxt == "\\relax" {
66 parser.consume();
67 parser.consume_spaces()?;
68 nxt = parser.fetch()?.text.clone();
69 }
70 while nxt == "\\hline" || nxt == "\\hdashline" {
71 parser.consume();
72 hline_info.push(nxt == "\\hdashline");
73 parser.consume_spaces()?;
74 nxt = parser.fetch()?.text.clone();
75 }
76 Ok(hline_info)
77}
78
79fn d_cell_style(env_name: &str) -> Option<StyleStr> {
80 if env_name.starts_with('d') {
81 Some(StyleStr::Display)
82 } else {
83 Some(StyleStr::Text)
84 }
85}
86
87pub fn parse_array(
88 parser: &mut Parser,
89 config: ArrayConfig,
90 style: Option<StyleStr>,
91) -> ParseResult<ParseNode> {
92 parser.gullet.begin_group();
93
94 if !config.single_row {
95 parser
96 .gullet
97 .set_text_macro("\\cr", "\\\\\\relax");
98 }
99
100 let arraystretch = config.arraystretch.unwrap_or(1.0);
101
102 parser.gullet.begin_group();
103
104 let mut row: Vec<ParseNode> = Vec::new();
105 let mut body: Vec<Vec<ParseNode>> = Vec::new();
106 let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
107 let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
108
109 hlines_before_row.push(get_hlines(parser)?);
110
111 loop {
112 let break_token = if config.single_row { "\\end" } else { "\\\\" };
113 let cell_body = parser.parse_expression(false, Some(break_token))?;
114 parser.gullet.end_group();
115 parser.gullet.begin_group();
116
117 let mut cell = ParseNode::OrdGroup {
118 mode: parser.mode,
119 body: cell_body,
120 semisimple: None,
121 loc: None,
122 };
123
124 if let Some(s) = style {
125 cell = ParseNode::Styling {
126 mode: parser.mode,
127 style: s,
128 body: vec![cell],
129 loc: None,
130 };
131 }
132
133 row.push(cell.clone());
134 let next = parser.fetch()?.text.clone();
135
136 if next == "&" {
137 if let Some(max) = config.max_num_cols {
138 if row.len() >= max {
139 return Err(ParseError::msg("Too many tab characters: &"));
140 }
141 }
142 parser.consume();
143 } else if next == "\\end" {
144 let is_empty_trailing = if let Some(s) = style {
146 if s == StyleStr::Text || s == StyleStr::Display {
147 if let ParseNode::Styling { body: ref sb, .. } = cell {
148 if let Some(ParseNode::OrdGroup {
149 body: ref ob, ..
150 }) = sb.first()
151 {
152 ob.is_empty()
153 } else {
154 false
155 }
156 } else {
157 false
158 }
159 } else {
160 false
161 }
162 } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
163 ob.is_empty()
164 } else {
165 false
166 };
167
168 body.push(row);
169
170 if is_empty_trailing
171 && (body.len() > 1 || !config.empty_single_row)
172 {
173 body.pop();
174 }
175
176 if hlines_before_row.len() < body.len() + 1 {
177 hlines_before_row.push(vec![]);
178 }
179 break;
180 } else if next == "\\\\" {
181 parser.consume();
182 let size = if parser.gullet.future().text != " " {
183 parser.parse_size_group(true)?
184 } else {
185 None
186 };
187 let gap = size.and_then(|s| {
188 if let ParseNode::Size { value, .. } = s {
189 Some(value)
190 } else {
191 None
192 }
193 });
194 row_gaps.push(gap);
195
196 body.push(row);
197 hlines_before_row.push(get_hlines(parser)?);
198 row = Vec::new();
199 } else {
200 return Err(ParseError::msg(format!(
201 "Expected & or \\\\ or \\cr or \\end, got '{}'",
202 next
203 )));
204 }
205 }
206
207 parser.gullet.end_group();
208 parser.gullet.end_group();
209
210 Ok(ParseNode::Array {
211 mode: parser.mode,
212 body,
213 row_gaps,
214 hlines_before_row,
215 cols: config.cols,
216 col_separation_type: config.col_separation_type,
217 hskip_before_and_after: config.hskip_before_and_after,
218 add_jot: config.add_jot,
219 arraystretch,
220 tags: None,
221 leqno: config.leqno,
222 is_cd: None,
223 loc: None,
224 })
225}
226
227fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
230 fn handle_array(
231 ctx: &mut EnvContext,
232 args: Vec<ParseNode>,
233 _opt_args: Vec<Option<ParseNode>>,
234 ) -> ParseResult<ParseNode> {
235 let colalign = match &args[0] {
236 ParseNode::OrdGroup { body, .. } => body.clone(),
237 other if other.is_symbol_node() => vec![other.clone()],
238 _ => return Err(ParseError::msg("Invalid column alignment for array")),
239 };
240
241 let mut cols = Vec::new();
242 for nde in &colalign {
243 let ca = nde
244 .symbol_text()
245 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
246 match ca {
247 "l" | "c" | "r" => cols.push(AlignSpec {
248 align_type: AlignType::Align,
249 align: Some(ca.to_string()),
250 pregap: None,
251 postgap: None,
252 }),
253 "|" => cols.push(AlignSpec {
254 align_type: AlignType::Separator,
255 align: Some("|".to_string()),
256 pregap: None,
257 postgap: None,
258 }),
259 ":" => cols.push(AlignSpec {
260 align_type: AlignType::Separator,
261 align: Some(":".to_string()),
262 pregap: None,
263 postgap: None,
264 }),
265 _ => {
266 return Err(ParseError::msg(format!(
267 "Unknown column alignment: {}",
268 ca
269 )))
270 }
271 }
272 }
273
274 let max_num_cols = cols.len();
275 let config = ArrayConfig {
276 cols: Some(cols),
277 hskip_before_and_after: Some(true),
278 max_num_cols: Some(max_num_cols),
279 ..Default::default()
280 };
281 parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
282 }
283
284 for name in &["array", "darray"] {
285 map.insert(
286 name,
287 EnvSpec {
288 num_args: 1,
289 num_optional_args: 0,
290 handler: handle_array,
291 },
292 );
293 }
294}
295
296fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
299 fn handle_matrix(
300 ctx: &mut EnvContext,
301 _args: Vec<ParseNode>,
302 _opt_args: Vec<Option<ParseNode>>,
303 ) -> ParseResult<ParseNode> {
304 let base_name = ctx.env_name.replace('*', "");
305 let delimiters: Option<(&str, &str)> = match base_name.as_str() {
306 "matrix" => None,
307 "pmatrix" => Some(("(", ")")),
308 "bmatrix" => Some(("[", "]")),
309 "Bmatrix" => Some(("\\{", "\\}")),
310 "vmatrix" => Some(("|", "|")),
311 "Vmatrix" => Some(("\\Vert", "\\Vert")),
312 _ => None,
313 };
314
315 let mut col_align = "c".to_string();
316
317 if ctx.env_name.ends_with('*') {
319 ctx.parser.gullet.consume_spaces();
320 if ctx.parser.gullet.future().text == "[" {
321 ctx.parser.gullet.pop_token();
322 ctx.parser.gullet.consume_spaces();
323 let align_tok = ctx.parser.gullet.pop_token();
324 if !"lcr".contains(align_tok.text.as_str()) {
325 return Err(ParseError::new(
326 "Expected l or c or r".to_string(),
327 Some(&align_tok),
328 ));
329 }
330 col_align = align_tok.text.clone();
331 ctx.parser.gullet.consume_spaces();
332 let close = ctx.parser.gullet.pop_token();
333 if close.text != "]" {
334 return Err(ParseError::new(
335 "Expected ]".to_string(),
336 Some(&close),
337 ));
338 }
339 }
340 }
341
342 let config = ArrayConfig {
343 hskip_before_and_after: Some(false),
344 cols: Some(vec![AlignSpec {
345 align_type: AlignType::Align,
346 align: Some(col_align.clone()),
347 pregap: None,
348 postgap: None,
349 }]),
350 ..Default::default()
351 };
352
353 let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
354
355 if let ParseNode::Array {
357 ref body,
358 ref mut cols,
359 ..
360 } = res
361 {
362 let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
363 *cols = Some(
364 (0..num_cols)
365 .map(|_| AlignSpec {
366 align_type: AlignType::Align,
367 align: Some(col_align.to_string()),
368 pregap: None,
369 postgap: None,
370 })
371 .collect(),
372 );
373 }
374
375 match delimiters {
376 Some((left, right)) => Ok(ParseNode::LeftRight {
377 mode: ctx.mode,
378 body: vec![res],
379 left: left.to_string(),
380 right: right.to_string(),
381 right_color: None,
382 loc: None,
383 }),
384 None => Ok(res),
385 }
386 }
387
388 for name in &[
389 "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
390 "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
391 ] {
392 map.insert(
393 name,
394 EnvSpec {
395 num_args: 0,
396 num_optional_args: 0,
397 handler: handle_matrix,
398 },
399 );
400 }
401}
402
403fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
406 fn handle_cases(
407 ctx: &mut EnvContext,
408 _args: Vec<ParseNode>,
409 _opt_args: Vec<Option<ParseNode>>,
410 ) -> ParseResult<ParseNode> {
411 let config = ArrayConfig {
412 arraystretch: Some(1.2),
413 cols: Some(vec![
414 AlignSpec {
415 align_type: AlignType::Align,
416 align: Some("l".to_string()),
417 pregap: Some(0.0),
418 postgap: Some(1.0),
419 },
420 AlignSpec {
421 align_type: AlignType::Align,
422 align: Some("l".to_string()),
423 pregap: Some(0.0),
424 postgap: Some(0.0),
425 },
426 ]),
427 ..Default::default()
428 };
429
430 let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
431
432 let (left, right) = if ctx.env_name.contains('r') {
433 (".", "\\}")
434 } else {
435 ("\\{", ".")
436 };
437
438 Ok(ParseNode::LeftRight {
439 mode: ctx.mode,
440 body: vec![res],
441 left: left.to_string(),
442 right: right.to_string(),
443 right_color: None,
444 loc: None,
445 })
446 }
447
448 for name in &["cases", "dcases", "rcases", "drcases"] {
449 map.insert(
450 name,
451 EnvSpec {
452 num_args: 0,
453 num_optional_args: 0,
454 handler: handle_cases,
455 },
456 );
457 }
458}
459
460fn handle_aligned(
463 ctx: &mut EnvContext,
464 args: Vec<ParseNode>,
465 _opt_args: Vec<Option<ParseNode>>,
466) -> ParseResult<ParseNode> {
467 let is_split = ctx.env_name == "split";
468 let is_alignat = ctx.env_name.contains("at");
469 let sep_type = if is_alignat { "alignat" } else { "align" };
470
471 let config = ArrayConfig {
472 add_jot: Some(true),
473 empty_single_row: true,
474 col_separation_type: Some(sep_type.to_string()),
475 max_num_cols: if is_split { Some(2) } else { None },
476 ..Default::default()
477 };
478
479 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
480
481 let mut num_maths = 0usize;
483 let mut explicit_cols = 0usize;
484 if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
485 let mut arg_str = String::new();
486 for node in body {
487 if let Some(t) = node.symbol_text() {
488 arg_str.push_str(t);
489 }
490 }
491 if let Ok(n) = arg_str.parse::<usize>() {
492 num_maths = n;
493 explicit_cols = n * 2;
494 }
495 }
496 let is_aligned = explicit_cols == 0;
497
498 let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
500 body.iter().map(|r| r.len()).max().unwrap_or(0)
501 } else {
502 0
503 };
504
505 if let ParseNode::Array {
506 body: ref mut array_body,
507 ..
508 } = res
509 {
510 for row in array_body.iter_mut() {
511 let mut i = 1;
513 while i < row.len() {
514 if let ParseNode::Styling {
515 body: ref mut styling_body,
516 ..
517 } = row[i]
518 {
519 if let Some(ParseNode::OrdGroup {
520 body: ref mut og_body,
521 ..
522 }) = styling_body.first_mut()
523 {
524 og_body.insert(
525 0,
526 ParseNode::OrdGroup {
527 mode: ctx.mode,
528 body: vec![],
529 semisimple: None,
530 loc: None,
531 },
532 );
533 }
534 }
535 i += 2;
536 }
537
538 if !is_aligned {
539 let cur_maths = row.len() / 2;
540 if num_maths < cur_maths {
541 return Err(ParseError::msg(format!(
542 "Too many math in a row: expected {}, but got {}",
543 num_maths, cur_maths
544 )));
545 }
546 } else if num_cols < row.len() {
547 num_cols = row.len();
548 }
549 }
550 }
551
552 if !is_aligned {
553 num_cols = explicit_cols;
554 }
555
556 let mut cols = Vec::new();
557 for i in 0..num_cols {
558 let (align, pregap) = if i % 2 == 1 {
559 ("l", 0.0)
560 } else if i > 0 && is_aligned {
561 ("r", 1.0)
562 } else {
563 ("r", 0.0)
564 };
565 cols.push(AlignSpec {
566 align_type: AlignType::Align,
567 align: Some(align.to_string()),
568 pregap: Some(pregap),
569 postgap: Some(0.0),
570 });
571 }
572
573 if let ParseNode::Array {
574 cols: ref mut array_cols,
575 col_separation_type: ref mut array_sep_type,
576 ..
577 } = res
578 {
579 *array_cols = Some(cols);
580 *array_sep_type = Some(
581 if is_aligned { "align" } else { "alignat" }.to_string(),
582 );
583 }
584
585 Ok(res)
586}
587
588fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
589 for name in &["align", "align*", "aligned", "split"] {
590 map.insert(
591 name,
592 EnvSpec {
593 num_args: 0,
594 num_optional_args: 0,
595 handler: handle_aligned,
596 },
597 );
598 }
599}
600
601fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
604 fn handle_gathered(
605 ctx: &mut EnvContext,
606 _args: Vec<ParseNode>,
607 _opt_args: Vec<Option<ParseNode>>,
608 ) -> ParseResult<ParseNode> {
609 let config = ArrayConfig {
610 cols: Some(vec![AlignSpec {
611 align_type: AlignType::Align,
612 align: Some("c".to_string()),
613 pregap: None,
614 postgap: None,
615 }]),
616 add_jot: Some(true),
617 col_separation_type: Some("gather".to_string()),
618 empty_single_row: true,
619 ..Default::default()
620 };
621 parse_array(ctx.parser, config, Some(StyleStr::Display))
622 }
623
624 for name in &["gathered", "gather", "gather*"] {
625 map.insert(
626 name,
627 EnvSpec {
628 num_args: 0,
629 num_optional_args: 0,
630 handler: handle_gathered,
631 },
632 );
633 }
634}
635
636fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
639 fn handle_equation(
640 ctx: &mut EnvContext,
641 _args: Vec<ParseNode>,
642 _opt_args: Vec<Option<ParseNode>>,
643 ) -> ParseResult<ParseNode> {
644 let config = ArrayConfig {
645 empty_single_row: true,
646 single_row: true,
647 max_num_cols: Some(1),
648 ..Default::default()
649 };
650 parse_array(ctx.parser, config, Some(StyleStr::Display))
651 }
652
653 for name in &["equation", "equation*"] {
654 map.insert(
655 name,
656 EnvSpec {
657 num_args: 0,
658 num_optional_args: 0,
659 handler: handle_equation,
660 },
661 );
662 }
663}
664
665fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
668 fn handle_smallmatrix(
669 ctx: &mut EnvContext,
670 _args: Vec<ParseNode>,
671 _opt_args: Vec<Option<ParseNode>>,
672 ) -> ParseResult<ParseNode> {
673 let config = ArrayConfig {
674 arraystretch: Some(0.5),
675 ..Default::default()
676 };
677 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
678 if let ParseNode::Array {
679 ref mut col_separation_type,
680 ..
681 } = res
682 {
683 *col_separation_type = Some("small".to_string());
684 }
685 Ok(res)
686 }
687
688 map.insert(
689 "smallmatrix",
690 EnvSpec {
691 num_args: 0,
692 num_optional_args: 0,
693 handler: handle_smallmatrix,
694 },
695 );
696}
697
698fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
701 for name in &["alignat", "alignat*", "alignedat"] {
702 map.insert(
703 name,
704 EnvSpec {
705 num_args: 1,
706 num_optional_args: 0,
707 handler: handle_aligned,
708 },
709 );
710 }
711}
712
713fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
716 fn handle_subarray(
717 ctx: &mut EnvContext,
718 args: Vec<ParseNode>,
719 _opt_args: Vec<Option<ParseNode>>,
720 ) -> ParseResult<ParseNode> {
721 let colalign = match &args[0] {
722 ParseNode::OrdGroup { body, .. } => body.clone(),
723 other if other.is_symbol_node() => vec![other.clone()],
724 _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
725 };
726
727 let mut cols = Vec::new();
728 for nde in &colalign {
729 let ca = nde
730 .symbol_text()
731 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
732 match ca {
733 "l" | "c" => cols.push(AlignSpec {
734 align_type: AlignType::Align,
735 align: Some(ca.to_string()),
736 pregap: None,
737 postgap: None,
738 }),
739 _ => {
740 return Err(ParseError::msg(format!(
741 "Unknown column alignment: {}",
742 ca
743 )))
744 }
745 }
746 }
747
748 if cols.len() > 1 {
749 return Err(ParseError::msg("{subarray} can contain only one column"));
750 }
751
752 let config = ArrayConfig {
753 cols: Some(cols),
754 hskip_before_and_after: Some(false),
755 arraystretch: Some(0.5),
756 ..Default::default()
757 };
758
759 let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
760
761 if let ParseNode::Array { ref body, .. } = res {
762 if !body.is_empty() && body[0].len() > 1 {
763 return Err(ParseError::msg("{subarray} can contain only one column"));
764 }
765 }
766
767 Ok(res)
768 }
769
770 map.insert(
771 "subarray",
772 EnvSpec {
773 num_args: 1,
774 num_optional_args: 0,
775 handler: handle_subarray,
776 },
777 );
778}