1use std::collections::HashMap;
2
3use ratex_lexer::token::Token;
4
5use crate::error::{ParseError, ParseResult};
6use crate::macro_expander::MacroDefinition;
7use crate::parse_node::{
8 AlignSpec, AlignType, ArrayTag, Measurement, Mode, ParseNode, ProofBranch, ProofLineStyle,
9 StyleStr,
10};
11use crate::parser::Parser;
12
13pub struct EnvContext<'a, 'b> {
16 pub mode: Mode,
17 pub env_name: String,
18 pub parser: &'a mut Parser<'b>,
19}
20
21pub type EnvHandler = fn(
22 ctx: &mut EnvContext,
23 args: Vec<ParseNode>,
24 opt_args: Vec<Option<ParseNode>>,
25) -> ParseResult<ParseNode>;
26
27pub struct EnvSpec {
28 pub num_args: usize,
29 pub num_optional_args: usize,
30 pub handler: EnvHandler,
31}
32
33pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
34 std::sync::LazyLock::new(|| {
35 let mut map = HashMap::new();
36 register_array(&mut map);
37 register_matrix(&mut map);
38 register_cases(&mut map);
39 register_align(&mut map);
40 register_gathered(&mut map);
41 register_equation(&mut map);
42 register_smallmatrix(&mut map);
43 register_alignat(&mut map);
44 register_subarray(&mut map);
45 register_cd(&mut map);
46 register_prooftree(&mut map);
47 map
48 });
49
50#[derive(Default)]
53pub struct ArrayConfig {
54 pub hskip_before_and_after: Option<bool>,
55 pub add_jot: Option<bool>,
56 pub cols: Option<Vec<AlignSpec>>,
57 pub arraystretch: Option<f64>,
58 pub col_separation_type: Option<String>,
59 pub single_row: bool,
60 pub empty_single_row: bool,
61 pub max_num_cols: Option<usize>,
62 pub leqno: Option<bool>,
63 pub auto_number: bool,
64}
65
66
67fn extract_trailing_tag_from_last_cell(row: &mut [ParseNode], auto_number: bool) -> ParseResult<ArrayTag> {
73 let default_tag = if auto_number { ArrayTag::Auto(true) } else { ArrayTag::Auto(false) };
74 let Some(last) = row.last_mut() else {
75 return Ok(default_tag);
76 };
77
78 let inner: &mut ParseNode = match last {
79 ParseNode::Styling { body, .. } => {
80 if body.len() != 1 {
81 return Ok(default_tag);
82 }
83 &mut body[0]
84 }
85 _ => last,
86 };
87
88 let obody = match inner {
89 ParseNode::OrdGroup { body, .. } => body,
90 _ => return Ok(default_tag),
91 };
92
93 let tag_indices: Vec<usize> = obody
95 .iter()
96 .enumerate()
97 .filter(|(_, n)| matches!(n, ParseNode::Tag { .. }))
98 .map(|(i, _)| i)
99 .collect();
100
101 let nonumber_indices: Vec<usize> = obody
103 .iter()
104 .enumerate()
105 .filter(|(_, n)| matches!(n, ParseNode::NoNumber { .. }))
106 .map(|(i, _)| i)
107 .collect();
108
109 if !tag_indices.is_empty() && !nonumber_indices.is_empty() {
111 return Err(ParseError::msg(
112 "Cannot use both \\tag and \\nonumber in the same row",
113 ));
114 }
115
116 if !tag_indices.is_empty() {
118 if tag_indices.len() > 1 {
119 return Err(ParseError::msg("Multiple \\tag in a row"));
120 }
121 let idx = tag_indices[0];
122 if idx != obody.len() - 1 {
123 return Err(ParseError::msg(
124 "\\tag must appear at the end of the row after the equation body",
125 ));
126 }
127 match obody.pop() {
128 Some(ParseNode::Tag { tag, .. }) => {
129 if tag.is_empty() {
130 Ok(ArrayTag::Auto(false))
131 } else {
132 Ok(ArrayTag::Explicit(tag))
133 }
134 }
135 _ => Ok(default_tag),
136 }
137 } else if !nonumber_indices.is_empty() {
138 if nonumber_indices.len() > 1 {
140 return Err(ParseError::msg("Multiple \\nonumber in a row"));
141 }
142 let idx = nonumber_indices[0];
143 if idx != obody.len() - 1 {
144 return Err(ParseError::msg(
145 "\\nonumber must appear at the end of the row",
146 ));
147 }
148 obody.pop(); Ok(ArrayTag::Auto(false))
150 } else {
151 Ok(default_tag)
153 }
154}
155
156fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
157 let mut hline_info = Vec::new();
158 parser.consume_spaces()?;
159
160 let mut nxt = parser.fetch()?.text.clone();
161 if nxt == "\\relax" {
162 parser.consume();
163 parser.consume_spaces()?;
164 nxt = parser.fetch()?.text.clone();
165 }
166 while nxt == "\\hline" || nxt == "\\hdashline" {
167 parser.consume();
168 hline_info.push(nxt == "\\hdashline");
169 parser.consume_spaces()?;
170 nxt = parser.fetch()?.text.clone();
171 }
172 Ok(hline_info)
173}
174
175fn d_cell_style(env_name: &str) -> Option<StyleStr> {
176 if env_name.starts_with('d') {
177 Some(StyleStr::Display)
178 } else {
179 Some(StyleStr::Text)
180 }
181}
182
183pub fn parse_array(
184 parser: &mut Parser,
185 config: ArrayConfig,
186 style: Option<StyleStr>,
187) -> ParseResult<ParseNode> {
188 parser.gullet.begin_group();
189
190 if !config.single_row {
191 parser
192 .gullet
193 .set_text_macro("\\cr", "\\\\\\relax");
194 }
195
196 let arraystretch = config.arraystretch.unwrap_or_else(|| {
197 if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
199 let s = match def {
200 MacroDefinition::Text(s) => s.clone(),
201 MacroDefinition::Tokens { tokens, .. } => {
202 tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
204 }
205 MacroDefinition::Function(_) => String::new(),
206 };
207 s.parse::<f64>().unwrap_or(1.0)
208 } else {
209 1.0
210 }
211 });
212
213 parser.gullet.begin_group();
214
215 let mut row: Vec<ParseNode> = Vec::new();
216 let mut body: Vec<Vec<ParseNode>> = Vec::new();
217 let mut row_tags: Vec<ArrayTag> = Vec::new();
218 let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
219 let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
220
221 hlines_before_row.push(get_hlines(parser)?);
222
223 loop {
224 let break_token = if config.single_row { "\\end" } else { "\\\\" };
225 let cell_body = parser.parse_expression(false, Some(break_token))?;
226 parser.gullet.end_group();
227 parser.gullet.begin_group();
228
229 let mut cell = ParseNode::OrdGroup {
230 mode: parser.mode,
231 body: cell_body,
232 semisimple: None,
233 loc: None,
234 };
235
236 if let Some(s) = style {
237 cell = ParseNode::Styling {
238 mode: parser.mode,
239 style: s,
240 body: vec![cell],
241 loc: None,
242 };
243 }
244
245 row.push(cell.clone());
246 let next = parser.fetch()?.text.clone();
247
248 if next == "&" {
249 if let Some(max) = config.max_num_cols {
250 if row.len() >= max {
251 return Err(ParseError::msg("Too many tab characters: &"));
252 }
253 }
254 parser.consume();
255 } else if next == "\\end" {
256 let is_empty_trailing = if let Some(s) = style {
258 if s == StyleStr::Text || s == StyleStr::Display {
259 if let ParseNode::Styling { body: ref sb, .. } = cell {
260 if let Some(ParseNode::OrdGroup {
261 body: ref ob, ..
262 }) = sb.first()
263 {
264 ob.is_empty()
265 } else {
266 false
267 }
268 } else {
269 false
270 }
271 } else {
272 false
273 }
274 } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
275 ob.is_empty()
276 } else {
277 false
278 };
279
280 let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
281 row_tags.push(row_tag);
282 body.push(row);
283
284 if is_empty_trailing
285 && (body.len() > 1 || !config.empty_single_row)
286 {
287 body.pop();
288 row_tags.pop();
289 }
290
291 if hlines_before_row.len() < body.len() + 1 {
292 hlines_before_row.push(vec![]);
293 }
294 break;
295 } else if next == "\\\\" {
296 parser.consume();
297 let size = if parser.gullet.future().text != " " {
298 parser.parse_size_group(true)?
299 } else {
300 None
301 };
302 let gap = size.and_then(|s| {
303 if let ParseNode::Size { value, .. } = s {
304 Some(value)
305 } else {
306 None
307 }
308 });
309 row_gaps.push(gap);
310
311 let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
312 row_tags.push(row_tag);
313 body.push(row);
314 hlines_before_row.push(get_hlines(parser)?);
315 row = Vec::new();
316 } else {
317 return Err(ParseError::msg(format!(
318 "Expected & or \\\\ or \\cr or \\end, got '{}'",
319 next
320 )));
321 }
322 }
323
324 parser.gullet.end_group();
325 parser.gullet.end_group();
326
327 let tags = if config.auto_number {
329 let mut processed: Vec<ArrayTag> = Vec::with_capacity(row_tags.len());
330 let mut any_visible = false;
331 for raw_tag in &row_tags {
332 match raw_tag {
333 ArrayTag::Explicit(nodes) if !nodes.is_empty() => {
334 parser.equation_counter += 1;
336 processed.push(ArrayTag::Explicit(nodes.clone()));
337 any_visible = true;
338 }
339 ArrayTag::Explicit(_) => {
340 processed.push(ArrayTag::Auto(false));
342 }
343 ArrayTag::Auto(true) => {
344 parser.equation_counter += 1;
346 let num_str = parser.equation_counter.to_string();
347 let tag_nodes = vec![
348 ParseNode::MathOrd {
349 mode: Mode::Math,
350 text: "(".to_string(),
351 loc: None,
352 },
353 ParseNode::MathOrd {
354 mode: Mode::Math,
355 text: num_str,
356 loc: None,
357 },
358 ParseNode::MathOrd {
359 mode: Mode::Math,
360 text: ")".to_string(),
361 loc: None,
362 },
363 ];
364 processed.push(ArrayTag::Explicit(tag_nodes));
365 any_visible = true;
366 }
367 ArrayTag::Auto(false) => {
368 processed.push(ArrayTag::Auto(false));
370 }
371 }
372 }
373 if any_visible { Some(processed) } else { None }
374 } else {
375 if row_tags.iter().any(|t| {
377 matches!(t, ArrayTag::Explicit(nodes) if !nodes.is_empty())
378 }) {
379 Some(row_tags)
380 } else {
381 None
382 }
383 };
384
385 Ok(ParseNode::Array {
386 mode: parser.mode,
387 body,
388 row_gaps,
389 hlines_before_row,
390 cols: config.cols,
391 col_separation_type: config.col_separation_type,
392 hskip_before_and_after: config.hskip_before_and_after,
393 add_jot: config.add_jot,
394 arraystretch,
395 tags,
396 leqno: config.leqno,
397 is_cd: None,
398 loc: None,
399 })
400}
401
402fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
405 fn handle_array(
406 ctx: &mut EnvContext,
407 args: Vec<ParseNode>,
408 _opt_args: Vec<Option<ParseNode>>,
409 ) -> ParseResult<ParseNode> {
410 let colalign = match &args[0] {
411 ParseNode::OrdGroup { body, .. } => body.clone(),
412 other if other.is_symbol_node() => vec![other.clone()],
413 _ => return Err(ParseError::msg("Invalid column alignment for array")),
414 };
415
416 let mut cols = Vec::new();
417 for nde in &colalign {
418 let ca = nde
419 .symbol_text()
420 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
421 match ca {
422 "l" | "c" | "r" => cols.push(AlignSpec {
423 align_type: AlignType::Align,
424 align: Some(ca.to_string()),
425 pregap: None,
426 postgap: None,
427 }),
428 "|" => cols.push(AlignSpec {
429 align_type: AlignType::Separator,
430 align: Some("|".to_string()),
431 pregap: None,
432 postgap: None,
433 }),
434 ":" => cols.push(AlignSpec {
435 align_type: AlignType::Separator,
436 align: Some(":".to_string()),
437 pregap: None,
438 postgap: None,
439 }),
440 _ => {
441 return Err(ParseError::msg(format!(
442 "Unknown column alignment: {}",
443 ca
444 )))
445 }
446 }
447 }
448
449 let max_num_cols = cols.len();
450 let config = ArrayConfig {
451 cols: Some(cols),
452 hskip_before_and_after: Some(true),
453 max_num_cols: Some(max_num_cols),
454 ..Default::default()
455 };
456 parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
457 }
458
459 for name in &["array", "darray"] {
460 map.insert(
461 name,
462 EnvSpec {
463 num_args: 1,
464 num_optional_args: 0,
465 handler: handle_array,
466 },
467 );
468 }
469}
470
471fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
474 fn handle_matrix(
475 ctx: &mut EnvContext,
476 _args: Vec<ParseNode>,
477 _opt_args: Vec<Option<ParseNode>>,
478 ) -> ParseResult<ParseNode> {
479 let base_name = ctx.env_name.replace('*', "");
480 let delimiters: Option<(&str, &str)> = match base_name.as_str() {
481 "matrix" => None,
482 "pmatrix" => Some(("(", ")")),
483 "bmatrix" => Some(("[", "]")),
484 "Bmatrix" => Some(("\\{", "\\}")),
485 "vmatrix" => Some(("|", "|")),
486 "Vmatrix" => Some(("\\Vert", "\\Vert")),
487 _ => None,
488 };
489
490 let mut col_align = "c".to_string();
491
492 if ctx.env_name.ends_with('*') {
494 ctx.parser.gullet.consume_spaces();
495 if ctx.parser.gullet.future().text == "[" {
496 ctx.parser.gullet.pop_token();
497 ctx.parser.gullet.consume_spaces();
498 let align_tok = ctx.parser.gullet.pop_token();
499 if !"lcr".contains(align_tok.text.as_str()) {
500 return Err(ParseError::new(
501 "Expected l or c or r".to_string(),
502 Some(&align_tok),
503 ));
504 }
505 col_align = align_tok.text.clone();
506 ctx.parser.gullet.consume_spaces();
507 let close = ctx.parser.gullet.pop_token();
508 if close.text != "]" {
509 return Err(ParseError::new(
510 "Expected ]".to_string(),
511 Some(&close),
512 ));
513 }
514 }
515 }
516
517 let config = ArrayConfig {
518 hskip_before_and_after: Some(false),
519 cols: Some(vec![AlignSpec {
520 align_type: AlignType::Align,
521 align: Some(col_align.clone()),
522 pregap: None,
523 postgap: None,
524 }]),
525 ..Default::default()
526 };
527
528 let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
529
530 if let ParseNode::Array {
532 ref body,
533 ref mut cols,
534 ..
535 } = res
536 {
537 let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
538 *cols = Some(
539 (0..num_cols)
540 .map(|_| AlignSpec {
541 align_type: AlignType::Align,
542 align: Some(col_align.to_string()),
543 pregap: None,
544 postgap: None,
545 })
546 .collect(),
547 );
548 }
549
550 match delimiters {
551 Some((left, right)) => Ok(ParseNode::LeftRight {
552 mode: ctx.mode,
553 body: vec![res],
554 left: left.to_string(),
555 right: right.to_string(),
556 right_color: None,
557 loc: None,
558 }),
559 None => Ok(res),
560 }
561 }
562
563 for name in &[
564 "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
565 "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
566 ] {
567 map.insert(
568 name,
569 EnvSpec {
570 num_args: 0,
571 num_optional_args: 0,
572 handler: handle_matrix,
573 },
574 );
575 }
576}
577
578fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
581 fn handle_cases(
582 ctx: &mut EnvContext,
583 _args: Vec<ParseNode>,
584 _opt_args: Vec<Option<ParseNode>>,
585 ) -> ParseResult<ParseNode> {
586 let config = ArrayConfig {
587 arraystretch: Some(1.2),
588 cols: Some(vec![
589 AlignSpec {
590 align_type: AlignType::Align,
591 align: Some("l".to_string()),
592 pregap: Some(0.0),
593 postgap: Some(1.0),
594 },
595 AlignSpec {
596 align_type: AlignType::Align,
597 align: Some("l".to_string()),
598 pregap: Some(0.0),
599 postgap: Some(0.0),
600 },
601 ]),
602 ..Default::default()
603 };
604
605 let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
606
607 let (left, right) = if ctx.env_name.contains('r') {
608 (".", "\\}")
609 } else {
610 ("\\{", ".")
611 };
612
613 Ok(ParseNode::LeftRight {
614 mode: ctx.mode,
615 body: vec![res],
616 left: left.to_string(),
617 right: right.to_string(),
618 right_color: None,
619 loc: None,
620 })
621 }
622
623 for name in &["cases", "dcases", "rcases", "drcases"] {
624 map.insert(
625 name,
626 EnvSpec {
627 num_args: 0,
628 num_optional_args: 0,
629 handler: handle_cases,
630 },
631 );
632 }
633}
634
635fn handle_aligned(
638 ctx: &mut EnvContext,
639 args: Vec<ParseNode>,
640 _opt_args: Vec<Option<ParseNode>>,
641) -> ParseResult<ParseNode> {
642 let is_split = ctx.env_name == "split";
643 let is_alignat = ctx.env_name.contains("at");
644 let sep_type = if is_alignat { "alignat" } else { "align" };
645 let auto_number = !ctx.env_name.ends_with('*')
646 && !is_split
647 && ctx.env_name != "aligned"
648 && ctx.env_name != "alignedat";
649
650 let config = ArrayConfig {
651 add_jot: Some(true),
652 empty_single_row: true,
653 col_separation_type: Some(sep_type.to_string()),
654 max_num_cols: if is_split { Some(2) } else { None },
655 auto_number,
656 ..Default::default()
657 };
658
659 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
660
661 let mut num_maths = 0usize;
663 let mut explicit_cols = 0usize;
664 if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
665 let mut arg_str = String::new();
666 for node in body {
667 if let Some(t) = node.symbol_text() {
668 arg_str.push_str(t);
669 }
670 }
671 if let Ok(n) = arg_str.parse::<usize>() {
672 num_maths = n;
673 explicit_cols = n * 2;
674 }
675 }
676 let is_aligned = explicit_cols == 0;
677
678 let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
680 body.iter().map(|r| r.len()).max().unwrap_or(0)
681 } else {
682 0
683 };
684
685 if let ParseNode::Array {
686 body: ref mut array_body,
687 ..
688 } = res
689 {
690 for row in array_body.iter_mut() {
691 let mut i = 1;
693 while i < row.len() {
694 if let ParseNode::Styling {
695 body: ref mut styling_body,
696 ..
697 } = row[i]
698 {
699 if let Some(ParseNode::OrdGroup {
700 body: ref mut og_body,
701 ..
702 }) = styling_body.first_mut()
703 {
704 og_body.insert(
705 0,
706 ParseNode::OrdGroup {
707 mode: ctx.mode,
708 body: vec![],
709 semisimple: None,
710 loc: None,
711 },
712 );
713 }
714 }
715 i += 2;
716 }
717
718 if !is_aligned {
719 let cur_maths = row.len() / 2;
720 if num_maths < cur_maths {
721 return Err(ParseError::msg(format!(
722 "Too many math in a row: expected {}, but got {}",
723 num_maths, cur_maths
724 )));
725 }
726 } else if num_cols < row.len() {
727 num_cols = row.len();
728 }
729 }
730 }
731
732 if !is_aligned {
733 num_cols = explicit_cols;
734 }
735
736 let mut cols = Vec::new();
737 for i in 0..num_cols {
738 let (align, pregap) = if i % 2 == 1 {
739 ("l", 0.0)
740 } else if i > 0 && is_aligned {
741 ("r", 1.0)
742 } else {
743 ("r", 0.0)
744 };
745 cols.push(AlignSpec {
746 align_type: AlignType::Align,
747 align: Some(align.to_string()),
748 pregap: Some(pregap),
749 postgap: Some(0.0),
750 });
751 }
752
753 if let ParseNode::Array {
754 cols: ref mut array_cols,
755 col_separation_type: ref mut array_sep_type,
756 ..
757 } = res
758 {
759 *array_cols = Some(cols);
760 *array_sep_type = Some(
761 if is_aligned { "align" } else { "alignat" }.to_string(),
762 );
763 }
764
765 Ok(res)
766}
767
768fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
769 for name in &["align", "align*", "aligned", "split"] {
770 map.insert(
771 name,
772 EnvSpec {
773 num_args: 0,
774 num_optional_args: 0,
775 handler: handle_aligned,
776 },
777 );
778 }
779}
780
781fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
784 fn handle_gathered(
785 ctx: &mut EnvContext,
786 _args: Vec<ParseNode>,
787 _opt_args: Vec<Option<ParseNode>>,
788 ) -> ParseResult<ParseNode> {
789 let auto_number = !ctx.env_name.ends_with('*') && ctx.env_name != "gathered";
790 let config = ArrayConfig {
791 cols: Some(vec![AlignSpec {
792 align_type: AlignType::Align,
793 align: Some("c".to_string()),
794 pregap: None,
795 postgap: None,
796 }]),
797 add_jot: Some(true),
798 col_separation_type: Some("gather".to_string()),
799 empty_single_row: true,
800 auto_number,
801 ..Default::default()
802 };
803 parse_array(ctx.parser, config, Some(StyleStr::Display))
804 }
805
806 for name in &["gathered", "gather", "gather*"] {
807 map.insert(
808 name,
809 EnvSpec {
810 num_args: 0,
811 num_optional_args: 0,
812 handler: handle_gathered,
813 },
814 );
815 }
816}
817
818fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
821 fn handle_equation(
822 ctx: &mut EnvContext,
823 _args: Vec<ParseNode>,
824 _opt_args: Vec<Option<ParseNode>>,
825 ) -> ParseResult<ParseNode> {
826 let auto_number = !ctx.env_name.ends_with('*');
827 let config = ArrayConfig {
828 empty_single_row: true,
829 single_row: true,
830 max_num_cols: Some(1),
831 auto_number,
832 ..Default::default()
833 };
834 parse_array(ctx.parser, config, Some(StyleStr::Display))
835 }
836
837 for name in &["equation", "equation*"] {
838 map.insert(
839 name,
840 EnvSpec {
841 num_args: 0,
842 num_optional_args: 0,
843 handler: handle_equation,
844 },
845 );
846 }
847}
848
849fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
852 fn handle_smallmatrix(
853 ctx: &mut EnvContext,
854 _args: Vec<ParseNode>,
855 _opt_args: Vec<Option<ParseNode>>,
856 ) -> ParseResult<ParseNode> {
857 let config = ArrayConfig {
858 arraystretch: Some(0.5),
859 ..Default::default()
860 };
861 let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
862 if let ParseNode::Array {
863 ref mut col_separation_type,
864 ..
865 } = res
866 {
867 *col_separation_type = Some("small".to_string());
868 }
869 Ok(res)
870 }
871
872 map.insert(
873 "smallmatrix",
874 EnvSpec {
875 num_args: 0,
876 num_optional_args: 0,
877 handler: handle_smallmatrix,
878 },
879 );
880}
881
882fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
885 for name in &["alignat", "alignat*", "alignedat"] {
886 map.insert(
887 name,
888 EnvSpec {
889 num_args: 1,
890 num_optional_args: 0,
891 handler: handle_aligned,
892 },
893 );
894 }
895}
896
897fn register_cd(map: &mut HashMap<&'static str, EnvSpec>) {
900 fn handle_cd(
901 ctx: &mut EnvContext,
902 _args: Vec<ParseNode>,
903 _opt_args: Vec<Option<ParseNode>>,
904 ) -> ParseResult<ParseNode> {
905 let mut raw: Vec<Token> = Vec::new();
907 loop {
908 let tok = ctx.parser.gullet.future().clone();
909 if tok.text == "\\end" || tok.text == "EOF" {
910 break;
911 }
912 ctx.parser.gullet.pop_token();
913 raw.push(tok);
914 }
915
916 let rows = cd_split_rows(raw);
918
919 let mut body: Vec<Vec<ParseNode>> = Vec::new();
920 let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
921 let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
922 hlines_before_row.push(vec![]);
923
924 for row_toks in rows {
925 if row_toks.iter().all(|t| t.text == " ") {
927 continue;
928 }
929 let cells = cd_parse_row(ctx.parser, row_toks)?;
930 if !cells.is_empty() {
931 body.push(cells);
932 row_gaps.push(None);
933 hlines_before_row.push(vec![]);
934 }
935 }
936
937 if body.is_empty() {
938 body.push(vec![]);
939 hlines_before_row.push(vec![]);
940 }
941
942 Ok(ParseNode::Array {
943 mode: ctx.mode,
944 body,
945 row_gaps,
946 hlines_before_row,
947 cols: None,
948 col_separation_type: Some("CD".to_string()),
949 hskip_before_and_after: Some(false),
950 add_jot: None,
951 arraystretch: 1.0,
952 tags: None,
953 leqno: None,
954 is_cd: Some(true),
955 loc: None,
956 })
957 }
958
959 map.insert(
960 "CD",
961 EnvSpec {
962 num_args: 0,
963 num_optional_args: 0,
964 handler: handle_cd,
965 },
966 );
967}
968
969fn cd_split_rows(tokens: Vec<Token>) -> Vec<Vec<Token>> {
971 let mut rows: Vec<Vec<Token>> = Vec::new();
972 let mut current: Vec<Token> = Vec::new();
973 for tok in tokens {
974 if tok.text == "\\\\" || tok.text == "\\cr" {
975 rows.push(current);
976 current = Vec::new();
977 } else {
978 current.push(tok);
979 }
980 }
981 if !current.is_empty() {
982 rows.push(current);
983 }
984 rows
985}
986
987fn cd_collect_until(tokens: &[Token], start: usize, delimiter: &str) -> (Vec<Token>, usize) {
991 let mut result = Vec::new();
992 let mut i = start;
993 while i < tokens.len() {
994 if tokens[i].text == delimiter {
995 i += 1; break;
997 }
998 result.push(tokens[i].clone());
999 i += 1;
1000 }
1001 (result, i - start)
1002}
1003
1004fn cd_collect_until_at(tokens: &[Token], start: usize) -> (Vec<Token>, usize) {
1006 let mut result = Vec::new();
1007 let mut i = start;
1008 while i < tokens.len() && tokens[i].text != "@" {
1009 result.push(tokens[i].clone());
1010 i += 1;
1011 }
1012 (result, i - start)
1013}
1014
1015fn cd_parse_tokens(parser: &mut Parser, tokens: Vec<Token>) -> ParseResult<ParseNode> {
1018 let has_content = tokens.iter().any(|t| t.text != " ");
1020 if !has_content {
1021 return Ok(ParseNode::OrdGroup {
1022 mode: parser.mode,
1023 body: vec![],
1024 semisimple: None,
1025 loc: None,
1026 });
1027 }
1028 let mut rev = tokens;
1030 rev.reverse();
1031 let body = parser.subparse(rev)?;
1032 Ok(ParseNode::OrdGroup {
1033 mode: parser.mode,
1034 body,
1035 semisimple: None,
1036 loc: None,
1037 })
1038}
1039
1040fn cd_parse_row(parser: &mut Parser, row_tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1043 let toks = &row_tokens;
1044 let n = toks.len();
1045 let mut cells: Vec<ParseNode> = Vec::new();
1046 let mut i = 0usize;
1047
1048 while i < n {
1049 while i < n && toks[i].text == " " {
1051 i += 1;
1052 }
1053 if i >= n {
1054 break;
1055 }
1056
1057 if toks[i].text == "@" {
1058 i += 1; if i >= n {
1060 return Err(ParseError::msg("Unexpected end of CD row after @"));
1061 }
1062 let dir = toks[i].text.clone();
1063 i += 1; let mode = parser.mode;
1066 let arrow = match dir.as_str() {
1067 ">" | "<" => {
1068 let (above_toks, c1) = cd_collect_until(toks, i, &dir);
1069 i += c1;
1070 let (below_toks, c2) = cd_collect_until(toks, i, &dir);
1071 i += c2;
1072 let label_above = cd_parse_tokens(parser, above_toks)?;
1073 let label_below = cd_parse_tokens(parser, below_toks)?;
1074 ParseNode::CdArrow {
1075 mode,
1076 direction: if dir == ">" { "right" } else { "left" }.to_string(),
1077 label_above: Some(Box::new(label_above)),
1078 label_below: Some(Box::new(label_below)),
1079 loc: None,
1080 }
1081 }
1082 "V" | "A" => {
1083 let (left_toks, c1) = cd_collect_until(toks, i, &dir);
1084 i += c1;
1085 let (right_toks, c2) = cd_collect_until(toks, i, &dir);
1086 i += c2;
1087 let label_above = cd_parse_tokens(parser, left_toks)?;
1088 let label_below = cd_parse_tokens(parser, right_toks)?;
1089 ParseNode::CdArrow {
1090 mode,
1091 direction: if dir == "V" { "down" } else { "up" }.to_string(),
1092 label_above: Some(Box::new(label_above)),
1093 label_below: Some(Box::new(label_below)),
1094 loc: None,
1095 }
1096 }
1097 "=" => ParseNode::CdArrow {
1098 mode,
1099 direction: "horiz_eq".to_string(),
1100 label_above: None,
1101 label_below: None,
1102 loc: None,
1103 },
1104 "|" => ParseNode::CdArrow {
1105 mode,
1106 direction: "vert_eq".to_string(),
1107 label_above: None,
1108 label_below: None,
1109 loc: None,
1110 },
1111 "." => ParseNode::CdArrow {
1112 mode,
1113 direction: "none".to_string(),
1114 label_above: None,
1115 label_below: None,
1116 loc: None,
1117 },
1118 _ => return Err(ParseError::msg(format!("Unknown CD directive: @{}", dir))),
1119 };
1120 cells.push(arrow);
1121 } else {
1122 let (obj_toks, consumed) = cd_collect_until_at(toks, i);
1124 i += consumed;
1125 let obj = cd_parse_tokens(parser, obj_toks)?;
1126 cells.push(obj);
1127 }
1128 }
1129
1130 Ok(cd_structure_row(cells, parser.mode))
1132}
1133
1134fn cd_structure_row(cells: Vec<ParseNode>, mode: Mode) -> Vec<ParseNode> {
1140 let is_arrow_row = cells.iter().all(|c| match c {
1142 ParseNode::CdArrow { .. } => true,
1143 ParseNode::OrdGroup { body, .. } => body.is_empty(),
1144 _ => false,
1145 }) && cells.iter().any(|c| matches!(c, ParseNode::CdArrow { .. }));
1146
1147 if is_arrow_row {
1148 let arrows: Vec<ParseNode> = cells
1149 .into_iter()
1150 .filter(|c| matches!(c, ParseNode::CdArrow { .. }))
1151 .collect();
1152
1153 if arrows.is_empty() {
1154 return vec![];
1155 }
1156
1157 let empty = || ParseNode::OrdGroup {
1158 mode,
1159 body: vec![],
1160 semisimple: None,
1161 loc: None,
1162 };
1163
1164 let mut result = Vec::with_capacity(arrows.len() * 2 - 1);
1165 for (idx, arrow) in arrows.into_iter().enumerate() {
1166 if idx > 0 {
1167 result.push(empty());
1168 }
1169 result.push(arrow);
1170 }
1171 result
1172 } else {
1173 cells
1175 }
1176}
1177
1178fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
1181 fn handle_subarray(
1182 ctx: &mut EnvContext,
1183 args: Vec<ParseNode>,
1184 _opt_args: Vec<Option<ParseNode>>,
1185 ) -> ParseResult<ParseNode> {
1186 let colalign = match &args[0] {
1187 ParseNode::OrdGroup { body, .. } => body.clone(),
1188 other if other.is_symbol_node() => vec![other.clone()],
1189 _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
1190 };
1191
1192 let mut cols = Vec::new();
1193 for nde in &colalign {
1194 let ca = nde
1195 .symbol_text()
1196 .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
1197 match ca {
1198 "l" | "c" => cols.push(AlignSpec {
1199 align_type: AlignType::Align,
1200 align: Some(ca.to_string()),
1201 pregap: None,
1202 postgap: None,
1203 }),
1204 _ => {
1205 return Err(ParseError::msg(format!(
1206 "Unknown column alignment: {}",
1207 ca
1208 )))
1209 }
1210 }
1211 }
1212
1213 if cols.len() > 1 {
1214 return Err(ParseError::msg("{subarray} can contain only one column"));
1215 }
1216
1217 let config = ArrayConfig {
1218 cols: Some(cols),
1219 hskip_before_and_after: Some(false),
1220 arraystretch: Some(0.5),
1221 ..Default::default()
1222 };
1223
1224 let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
1225
1226 if let ParseNode::Array { ref body, .. } = res {
1227 if !body.is_empty() && body[0].len() > 1 {
1228 return Err(ParseError::msg("{subarray} can contain only one column"));
1229 }
1230 }
1231
1232 Ok(res)
1233 }
1234
1235 map.insert(
1236 "subarray",
1237 EnvSpec {
1238 num_args: 1,
1239 num_optional_args: 0,
1240 handler: handle_subarray,
1241 },
1242 );
1243}
1244
1245fn register_prooftree(map: &mut HashMap<&'static str, EnvSpec>) {
1256 fn handle_prooftree(
1257 ctx: &mut EnvContext,
1258 _args: Vec<ParseNode>,
1259 _opt_args: Vec<Option<ParseNode>>,
1260 ) -> ParseResult<ParseNode> {
1261 parse_prooftree(ctx.parser)
1262 }
1263
1264 map.insert(
1265 "prooftree",
1266 EnvSpec {
1267 num_args: 0,
1268 num_optional_args: 0,
1269 handler: handle_prooftree,
1270 },
1271 );
1272}
1273
1274fn proof_command_arity(name: &str) -> Option<usize> {
1275 match name {
1276 "\\UnaryInfC" | "\\UnaryInf" | "\\UIC" => Some(1),
1277 "\\BinaryInfC" | "\\BinaryInf" | "\\BIC" => Some(2),
1278 "\\TrinaryInfC" | "\\TrinaryInf" | "\\TIC" => Some(3),
1279 "\\QuaternaryInfC" | "\\QuaternaryInf" => Some(4),
1280 "\\QuinaryInfC" | "\\QuinaryInf" => Some(5),
1281 _ => None,
1282 }
1283}
1284
1285fn parse_prooftree_arg(parser: &mut Parser, command: &str) -> ParseResult<Vec<ParseNode>> {
1286 let arg = parser.parse_argument_group(false, None)?.ok_or_else(|| {
1287 ParseError::msg(format!("Expected argument for {}", command))
1288 })?;
1289 Ok(ParseNode::ord_argument(arg))
1290}
1291
1292fn parse_prooftree(parser: &mut Parser) -> ParseResult<ParseNode> {
1293 let mut stack: Vec<ProofBranch> = Vec::new();
1294 let mut left_label: Option<Vec<ParseNode>> = None;
1295 let mut right_label: Option<Vec<ParseNode>> = None;
1296 let mut next_line_style = ProofLineStyle::Solid;
1297 let mut default_line_style = ProofLineStyle::Solid;
1298 let mut next_root_at_top = false;
1299 let mut default_root_at_top = false;
1300
1301 loop {
1302 parser.consume_spaces()?;
1303 let token = parser.fetch()?;
1304 let command = token.text.clone();
1305
1306 if command == "\\end" {
1307 break;
1308 }
1309 parser.consume();
1310
1311 match command.as_str() {
1312 "\\AxiomC" | "\\Axiom" | "\\AXC" => {
1313 let conclusion = parse_prooftree_arg(parser, &command)?;
1314 stack.push(ProofBranch {
1315 conclusion,
1316 premises: Vec::new(),
1317 left_label: None,
1318 right_label: None,
1319 line_style: ProofLineStyle::None,
1320 root_at_top: false,
1321 });
1322 }
1323 "\\LeftLabel" | "\\LL" => {
1324 left_label = Some(parse_prooftree_arg(parser, &command)?);
1325 }
1326 "\\RightLabel" | "\\RL" => {
1327 right_label = Some(parse_prooftree_arg(parser, &command)?);
1328 }
1329 "\\singleLine" | "\\solidLine" => {
1330 next_line_style = ProofLineStyle::Solid;
1331 }
1332 "\\dashedLine" => {
1333 next_line_style = ProofLineStyle::Dashed;
1334 }
1335 "\\noLine" => {
1336 next_line_style = ProofLineStyle::None;
1337 }
1338 "\\alwaysSingleLine" | "\\alwaysSolidLine" => {
1339 default_line_style = ProofLineStyle::Solid;
1340 next_line_style = ProofLineStyle::Solid;
1341 }
1342 "\\alwaysDashedLine" => {
1343 default_line_style = ProofLineStyle::Dashed;
1344 next_line_style = ProofLineStyle::Dashed;
1345 }
1346 "\\alwaysNoLine" => {
1347 default_line_style = ProofLineStyle::None;
1348 next_line_style = ProofLineStyle::None;
1349 }
1350 "\\rootAtTop" => {
1351 next_root_at_top = true;
1352 }
1353 "\\rootAtBottom" => {
1354 next_root_at_top = false;
1355 }
1356 "\\alwaysRootAtTop" => {
1357 default_root_at_top = true;
1358 next_root_at_top = true;
1359 }
1360 "\\alwaysRootAtBottom" => {
1361 default_root_at_top = false;
1362 next_root_at_top = false;
1363 }
1364 name if proof_command_arity(name).is_some() => {
1365 let arity = proof_command_arity(name).unwrap();
1366 if stack.len() < arity {
1367 return Err(ParseError::msg(format!(
1368 "{} needs {} premise(s), but only {} available",
1369 name,
1370 arity,
1371 stack.len()
1372 )));
1373 }
1374 let conclusion = parse_prooftree_arg(parser, name)?;
1375 let start = stack.len() - arity;
1376 let premises = stack.split_off(start);
1377 stack.push(ProofBranch {
1378 conclusion,
1379 premises,
1380 left_label: left_label.take(),
1381 right_label: right_label.take(),
1382 line_style: next_line_style.clone(),
1383 root_at_top: next_root_at_top,
1384 });
1385 next_line_style = default_line_style.clone();
1386 next_root_at_top = default_root_at_top;
1387 }
1388 _ => {
1389 return Err(ParseError::msg(format!(
1390 "{} valid only as a supported bussproofs command within prooftree",
1391 command
1392 )));
1393 }
1394 }
1395 }
1396
1397 if stack.len() != 1 {
1398 return Err(ParseError::msg(format!(
1399 "prooftree ended with {} proof stack item(s), expected 1",
1400 stack.len()
1401 )));
1402 }
1403
1404 if left_label.is_some() || right_label.is_some() {
1405 return Err(ParseError::msg(
1406 "prooftree has a \\LeftLabel or \\RightLabel without a following inference command",
1407 ));
1408 }
1409
1410 Ok(ParseNode::ProofTree {
1411 mode: parser.mode,
1412 tree: stack.pop().unwrap(),
1413 loc: None,
1414 })
1415}