cas_parser/parser/
fmt.rs

1use std::fmt::{Display, Formatter, Result};
2use super::{ast::expr::Expr, token::op::BinOpKind};
3
4/// A trait for types that can be formatted as LaTeX.
5pub trait Latex {
6    /// Format the value as LaTeX.
7    fn fmt_latex(&self, f: &mut Formatter) -> Result;
8
9    /// Wraps the value in a [`LatexFormatter`], which implements [`Display`].
10    fn as_display(&self) -> LatexFormatter<'_, Self> {
11        LatexFormatter(self)
12    }
13}
14
15/// A wrapper type that implements [`Display`] for any type that implements [`Latex`].
16pub struct LatexFormatter<'a, T: ?Sized>(&'a T);
17
18impl<T: ?Sized> Display for LatexFormatter<'_, T>
19where
20    T: Latex,
21{
22    fn fmt(&self, f: &mut Formatter) -> Result {
23        self.0.fmt_latex(f)
24    }
25}
26
27/// Helper to format powers.
28pub fn fmt_pow(f: &mut Formatter, left: Option<&Expr>, right: Option<&Expr>) -> Result {
29    if let Some(left) = left {
30        let left = left.innermost();
31        let mut insert_with_paren = || {
32            write!(f, "\\left(")?;
33            left.fmt_latex(f)?;
34            write!(f, "\\right)")
35        };
36
37        // all of these are separate match arms instead of a single match arm with multiple
38        // patterns, because apparently that can't be parsed correctly
39        match left {
40            Expr::Unary(unary)
41                if unary.op.precedence() <= BinOpKind::Exp.precedence() => insert_with_paren(),
42            // NOTE: exp is the highest precedence binary operator, so this check is not necessary,
43            // but is just here for completeness
44            Expr::Binary(binary)
45                if binary.op.precedence() <= BinOpKind::Exp.precedence() => insert_with_paren(),
46            Expr::Call(call) if call.name.name == "pow" => insert_with_paren(),
47            _ => left.fmt_latex(f),
48        }?
49    }
50    write!(f, "^{{")?;
51    if let Some(right) = right {
52        right.innermost().fmt_latex(f)?;
53    }
54    write!(f, "}}")
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    use crate::parser::Parser;
62
63    #[test]
64    fn fmt_display() {
65        let mut parser = Parser::new("3x + 6");
66        let expr = parser.try_parse_full::<Expr>().unwrap();
67        let fmt = format!("{}", expr);
68
69        assert_eq!(fmt, "3x+6");
70    }
71
72    #[test]
73    fn fmt_display_2() {
74        let mut parser = Parser::new("f(x) = x^2 + 5x + 6");
75        let expr = parser.try_parse_full::<Expr>().unwrap();
76        let fmt = format!("{}", expr);
77
78        assert_eq!(fmt, "f(x) = x^2+5x+6");
79    }
80
81    #[test]
82    fn fmt_display_3() {
83        let mut parser = Parser::new("x^(3(x + 6))^9");
84        let expr = parser.try_parse_full::<Expr>().unwrap();
85        let fmt = format!("{}", expr);
86
87        assert_eq!(fmt, "x^(3(x+6))^9");
88    }
89
90    #[test]
91    fn fmt_latex() {
92        let mut parser = Parser::new("sqrt(3x)^2");
93        let expr = parser.try_parse_full::<Expr>().unwrap();
94        let fmt = format!("{}", expr.as_display());
95
96        assert_eq!(fmt, "\\sqrt{3x}^{2}");
97    }
98
99    #[test]
100    fn fmt_latex_2() {
101        let mut parser = Parser::new("f(x) = 1/x + 5/x^2 + 6/x^3");
102        let expr = parser.try_parse_full::<Expr>().unwrap();
103        let fmt = format!("{}", expr.as_display());
104
105        assert_eq!(fmt, "\\mathrm{ f } \\left(x\\right) = \\frac{1}{x}+\\frac{5}{x^{2}}+\\frac{6}{x^{3}}");
106    }
107}