1use std::fmt::{Display, Formatter, Result};
2use super::{ast::expr::Expr, token::op::BinOpKind};
3
4pub trait Latex {
6 fn fmt_latex(&self, f: &mut Formatter) -> Result;
8
9 fn as_display(&self) -> LatexFormatter<'_, Self> {
11 LatexFormatter(self)
12 }
13}
14
15pub struct LatexFormatter<'a, T: ?Sized>(&'a T);
17
18impl<T: ?Sized> Display for LatexFormatter<'_, T>
19where
20 T: Latex,
21{
22 fn fmt(&self, f: &mut Formatter) -> Result {
23 self.0.fmt_latex(f)
24 }
25}
26
27pub fn fmt_pow(f: &mut Formatter, left: Option<&Expr>, right: Option<&Expr>) -> Result {
29 if let Some(left) = left {
30 let left = left.innermost();
31 let mut insert_with_paren = || {
32 write!(f, "\\left(")?;
33 left.fmt_latex(f)?;
34 write!(f, "\\right)")
35 };
36
37 match left {
40 Expr::Unary(unary)
41 if unary.op.precedence() <= BinOpKind::Exp.precedence() => insert_with_paren(),
42 Expr::Binary(binary)
45 if binary.op.precedence() <= BinOpKind::Exp.precedence() => insert_with_paren(),
46 Expr::Call(call) if call.name.name == "pow" => insert_with_paren(),
47 _ => left.fmt_latex(f),
48 }?
49 }
50 write!(f, "^{{")?;
51 if let Some(right) = right {
52 right.innermost().fmt_latex(f)?;
53 }
54 write!(f, "}}")
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 use crate::parser::Parser;
62
63 #[test]
64 fn fmt_display() {
65 let mut parser = Parser::new("3x + 6");
66 let expr = parser.try_parse_full::<Expr>().unwrap();
67 let fmt = format!("{}", expr);
68
69 assert_eq!(fmt, "3x+6");
70 }
71
72 #[test]
73 fn fmt_display_2() {
74 let mut parser = Parser::new("f(x) = x^2 + 5x + 6");
75 let expr = parser.try_parse_full::<Expr>().unwrap();
76 let fmt = format!("{}", expr);
77
78 assert_eq!(fmt, "f(x) = x^2+5x+6");
79 }
80
81 #[test]
82 fn fmt_display_3() {
83 let mut parser = Parser::new("x^(3(x + 6))^9");
84 let expr = parser.try_parse_full::<Expr>().unwrap();
85 let fmt = format!("{}", expr);
86
87 assert_eq!(fmt, "x^(3(x+6))^9");
88 }
89
90 #[test]
91 fn fmt_latex() {
92 let mut parser = Parser::new("sqrt(3x)^2");
93 let expr = parser.try_parse_full::<Expr>().unwrap();
94 let fmt = format!("{}", expr.as_display());
95
96 assert_eq!(fmt, "\\sqrt{3x}^{2}");
97 }
98
99 #[test]
100 fn fmt_latex_2() {
101 let mut parser = Parser::new("f(x) = 1/x + 5/x^2 + 6/x^3");
102 let expr = parser.try_parse_full::<Expr>().unwrap();
103 let fmt = format!("{}", expr.as_display());
104
105 assert_eq!(fmt, "\\mathrm{ f } \\left(x\\right) = \\frac{1}{x}+\\frac{5}{x^{2}}+\\frac{6}{x^{3}}");
106 }
107}