Skip to main content

term_maths/
latex_renderer.rs

1//! LaTeX renderer — serialises an `EqNode` AST back to a LaTeX string.
2
3use rust_latex_parser::{AccentKind, EqNode, MathFontKind, MatrixKind};
4
5use crate::renderer::MathRenderer;
6
7/// Serialises an `EqNode` back to a LaTeX math string.
8pub struct LatexRenderer;
9
10impl MathRenderer for LatexRenderer {
11    type Output = String;
12
13    fn render(&self, node: &EqNode) -> String {
14        node_to_latex(node)
15    }
16}
17
18fn node_to_latex(node: &EqNode) -> String {
19    match node {
20        EqNode::Text(s) => latex_escape_text(s),
21        EqNode::Space(pts) => space_to_latex(*pts),
22        EqNode::Seq(children) => children.iter().map(node_to_latex).collect(),
23        EqNode::Frac(num, den) => {
24            format!(r"\frac{{{}}}{{{}}}", node_to_latex(num), node_to_latex(den))
25        }
26        EqNode::Sup(base, sup) => {
27            format!("{}^{{{}}}", node_to_latex(base), node_to_latex(sup))
28        }
29        EqNode::Sub(base, sub) => {
30            format!("{}_{{{}}} ", node_to_latex(base), node_to_latex(sub))
31        }
32        EqNode::SupSub(base, sup, sub) => {
33            format!(
34                "{}^{{{}}}_{{{}}}",
35                node_to_latex(base),
36                node_to_latex(sup),
37                node_to_latex(sub)
38            )
39        }
40        EqNode::Sqrt(body) => format!(r"\sqrt{{{}}}", node_to_latex(body)),
41        EqNode::BigOp {
42            symbol,
43            lower,
44            upper,
45        } => {
46            let sym = unicode_to_latex_op(symbol);
47            let mut s = sym;
48            if let Some(lo) = lower {
49                s.push_str(&format!("_{{{}}}", node_to_latex(lo)));
50            }
51            if let Some(up) = upper {
52                s.push_str(&format!("^{{{}}}", node_to_latex(up)));
53            }
54            s
55        }
56        EqNode::Accent(body, kind) => {
57            let cmd = match kind {
58                AccentKind::Hat => r"\hat",
59                AccentKind::Bar => r"\overline",
60                AccentKind::Dot => r"\dot",
61                AccentKind::DoubleDot => r"\ddot",
62                AccentKind::Tilde => r"\tilde",
63                AccentKind::Vec => r"\vec",
64            };
65            format!("{}{{{}}}", cmd, node_to_latex(body))
66        }
67        EqNode::Limit { name, lower } => {
68            let latex_name = format!(r"\{}", name);
69            if let Some(lo) = lower {
70                format!("{}_{{{}}}", latex_name, node_to_latex(lo))
71            } else {
72                latex_name
73            }
74        }
75        EqNode::TextBlock(s) => format!(r"\text{{{}}}", s),
76        EqNode::MathFont { kind, content } => {
77            let cmd = match kind {
78                MathFontKind::Bold => r"\mathbf",
79                MathFontKind::Blackboard => r"\mathbb",
80                MathFontKind::Calligraphic => r"\mathcal",
81                MathFontKind::Roman => r"\mathrm",
82                MathFontKind::Fraktur => r"\mathfrak",
83                MathFontKind::SansSerif => r"\mathsf",
84                MathFontKind::Monospace => r"\mathtt",
85            };
86            format!("{}{{{}}}", cmd, node_to_latex(content))
87        }
88        EqNode::Delimited {
89            left,
90            right,
91            content,
92        } => {
93            format!(
94                r"\left{} {} \right{}",
95                latex_delim(left),
96                node_to_latex(content),
97                latex_delim(right)
98            )
99        }
100        EqNode::Matrix { kind, rows } => {
101            let env = match kind {
102                MatrixKind::Plain => "matrix",
103                MatrixKind::Paren => "pmatrix",
104                MatrixKind::Bracket => "bmatrix",
105                MatrixKind::Brace => "Bmatrix",
106                MatrixKind::VBar => "vmatrix",
107                MatrixKind::DoubleVBar => "Vmatrix",
108            };
109            let rows_str: Vec<String> = rows
110                .iter()
111                .map(|row| {
112                    row.iter()
113                        .map(node_to_latex)
114                        .collect::<Vec<_>>()
115                        .join(" & ")
116                })
117                .collect();
118            format!(
119                r"\begin{{{}}} {} \end{{{}}}",
120                env,
121                rows_str.join(r" \\ "),
122                env
123            )
124        }
125        EqNode::Cases { rows } => {
126            let rows_str: Vec<String> = rows
127                .iter()
128                .map(|(val, cond)| {
129                    if let Some(c) = cond {
130                        format!("{} & {}", node_to_latex(val), node_to_latex(c))
131                    } else {
132                        node_to_latex(val)
133                    }
134                })
135                .collect();
136            format!(r"\begin{{cases}} {} \end{{cases}}", rows_str.join(r" \\ "))
137        }
138        EqNode::Binom(top, bottom) => {
139            format!(
140                r"\binom{{{}}}{{{}}}",
141                node_to_latex(top),
142                node_to_latex(bottom)
143            )
144        }
145        EqNode::Brace {
146            content,
147            label,
148            over,
149        } => {
150            let cmd = if *over { r"\overbrace" } else { r"\underbrace" };
151            let mut s = format!("{}{{{}}}", cmd, node_to_latex(content));
152            if let Some(lbl) = label {
153                if *over {
154                    s.push_str(&format!("^{{{}}}", node_to_latex(lbl)));
155                } else {
156                    s.push_str(&format!("_{{{}}}", node_to_latex(lbl)));
157                }
158            }
159            s
160        }
161        EqNode::StackRel {
162            base,
163            annotation,
164            over,
165        } => {
166            let cmd = if *over { r"\overset" } else { r"\underset" };
167            format!(
168                "{}{{{}}}{{{}}}",
169                cmd,
170                node_to_latex(annotation),
171                node_to_latex(base)
172            )
173        }
174    }
175}
176
177/// Escape special LaTeX characters in text content.
178fn latex_escape_text(s: &str) -> String {
179    // Map common Unicode back to LaTeX commands
180    let mut result = String::new();
181    for ch in s.chars() {
182        match ch {
183            'α' => result.push_str(r"\alpha "),
184            'β' => result.push_str(r"\beta "),
185            'γ' => result.push_str(r"\gamma "),
186            'δ' => result.push_str(r"\delta "),
187            'ε' => result.push_str(r"\epsilon "),
188            'ζ' => result.push_str(r"\zeta "),
189            'η' => result.push_str(r"\eta "),
190            'θ' => result.push_str(r"\theta "),
191            'ι' => result.push_str(r"\iota "),
192            'κ' => result.push_str(r"\kappa "),
193            'λ' => result.push_str(r"\lambda "),
194            'μ' => result.push_str(r"\mu "),
195            'ν' => result.push_str(r"\nu "),
196            'ξ' => result.push_str(r"\xi "),
197            'π' => result.push_str(r"\pi "),
198            'ρ' => result.push_str(r"\rho "),
199            'σ' => result.push_str(r"\sigma "),
200            'τ' => result.push_str(r"\tau "),
201            'υ' => result.push_str(r"\upsilon "),
202            'φ' => result.push_str(r"\phi "),
203            'χ' => result.push_str(r"\chi "),
204            'ψ' => result.push_str(r"\psi "),
205            'ω' => result.push_str(r"\omega "),
206            '∞' => result.push_str(r"\infty "),
207            '∑' => result.push_str(r"\sum "),
208            '∏' => result.push_str(r"\prod "),
209            '∫' => result.push_str(r"\int "),
210            '±' => result.push_str(r"\pm "),
211            '·' => result.push_str(r"\cdot "),
212            '→' => result.push_str(r"\rightarrow "),
213            '←' => result.push_str(r"\leftarrow "),
214            '≤' => result.push_str(r"\leq "),
215            '≥' => result.push_str(r"\geq "),
216            '≠' => result.push_str(r"\neq "),
217            '∈' => result.push_str(r"\in "),
218            '∀' => result.push_str(r"\forall "),
219            '∃' => result.push_str(r"\exists "),
220            '∂' => result.push_str(r"\partial "),
221            '∇' => result.push_str(r"\nabla "),
222            _ => result.push(ch),
223        }
224    }
225    result
226}
227
228fn space_to_latex(pts: f32) -> String {
229    if pts < 0.0 {
230        r"\!".to_string()
231    } else if pts < 3.0 {
232        r"\,".to_string()
233    } else if pts < 5.0 {
234        r"\;".to_string()
235    } else if pts >= 18.0 {
236        r"\quad ".to_string()
237    } else {
238        " ".to_string()
239    }
240}
241
242fn unicode_to_latex_op(symbol: &str) -> String {
243    match symbol {
244        "∑" => r"\sum".to_string(),
245        "∏" => r"\prod".to_string(),
246        "∫" => r"\int".to_string(),
247        "∬" => r"\iint".to_string(),
248        "∮" => r"\oint".to_string(),
249        "⋃" => r"\bigcup".to_string(),
250        "⋂" => r"\bigcap".to_string(),
251        "⊕" => r"\bigoplus".to_string(),
252        "⊗" => r"\bigotimes".to_string(),
253        _ => symbol.to_string(),
254    }
255}
256
257fn latex_delim(d: &str) -> String {
258    match d {
259        "." => ".".to_string(),
260        "(" | ")" | "[" | "]" | "|" => d.to_string(),
261        "{" => r"\{".to_string(),
262        "}" => r"\}".to_string(),
263        "‖" => r"\|".to_string(),
264        _ => d.to_string(),
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::renderer::MathRenderer;
272    use rust_latex_parser::parse_equation;
273
274    #[test]
275    fn test_simple_fraction_roundtrip() {
276        let renderer = LatexRenderer;
277        let ast = parse_equation(r"\frac{a}{b}");
278        let latex = renderer.render(&ast);
279        assert!(latex.contains(r"\frac"));
280        assert!(latex.contains('a'));
281        assert!(latex.contains('b'));
282    }
283
284    #[test]
285    fn test_superscript_roundtrip() {
286        let renderer = LatexRenderer;
287        let ast = parse_equation(r"x^2");
288        let latex = renderer.render(&ast);
289        assert!(latex.contains("x^"));
290        assert!(latex.contains('2'));
291    }
292
293    #[test]
294    fn test_matrix_roundtrip() {
295        let renderer = LatexRenderer;
296        let ast = parse_equation(r"\begin{pmatrix} a & b \\ c & d \end{pmatrix}");
297        let latex = renderer.render(&ast);
298        assert!(latex.contains("pmatrix"));
299        assert!(latex.contains('&'));
300    }
301}