Skip to main content

provable_contracts/
latex.rs

1//! LaTeX conversion utilities for contract math notation.
2//!
3//! Converts Unicode math notation found in YAML contracts into
4//! LaTeX math mode suitable for rendering with `KaTeX` or full LaTeX.
5
6/// Escape special LaTeX characters in plain text.
7pub fn latex_escape(s: &str) -> String {
8    s.replace('\\', "\\textbackslash{}")
9        .replace('&', "\\&")
10        .replace('%', "\\%")
11        .replace('$', "\\$")
12        .replace('#', "\\#")
13        .replace('_', "\\_")
14        .replace('{', "\\{")
15        .replace('}', "\\}")
16        .replace('~', "\\textasciitilde{}")
17        .replace('^', "\\textasciicircum{}")
18}
19
20/// Convert contract math notation to LaTeX math mode.
21///
22/// Handles common patterns found in our YAML contracts:
23/// - Greek letters (ε, σ, α, etc.)
24/// - Subscripts (`x_i`, `A_{ij}`)
25/// - Superscripts (x^T, ℝ^n)
26/// - Operators (Σ, ∈, ≈, ≤, ≥, ∀, →)
27/// - Special sets (ℝ, ℤ)
28/// - Functions (sqrt, exp, log, softmax, etc.)
29pub fn math_to_latex(s: &str) -> String {
30    let mut out = s.to_string();
31
32    // Unicode → LaTeX replacements. Each entry: (unicode, latex_cmd, is_command).
33    // When is_command is true, a trailing space is inserted before the next
34    // alphabetic character to prevent `\foralli` instead of `\forall i`.
35    let replacements: &[(&str, &str, bool)] = &[
36        // Greek letters
37        ("α", "\\alpha", true),
38        ("β", "\\beta", true),
39        ("γ", "\\gamma", true),
40        ("δ", "\\delta", true),
41        ("ε", "\\varepsilon", true),
42        ("θ", "\\theta", true),
43        ("λ", "\\lambda", true),
44        ("σ", "\\sigma", true),
45        ("τ", "\\tau", true),
46        ("Σ", "\\sum", true),
47        ("Φ", "\\Phi", true),
48        ("π", "\\pi", true),
49        // Operators
50        ("∈", "\\in", true),
51        ("∉", "\\notin", true),
52        ("≈", "\\approx", true),
53        ("≤", "\\leq", true),
54        ("≥", "\\geq", true),
55        ("≠", "\\neq", true),
56        ("∀", "\\forall", true),
57        ("∃", "\\exists", true),
58        ("→", "\\to", true),
59        ("←", "\\leftarrow", true),
60        ("⊗", "\\otimes", true),
61        ("⁺", "^{+}", false),
62        // Special sets
63        ("ℝ", "\\mathbb{R}", false),
64        ("ℤ", "\\mathbb{Z}", false),
65    ];
66    for &(uni, tex, is_cmd) in replacements {
67        if is_cmd {
68            out = replace_unicode_cmd(&out, uni, tex);
69        } else {
70            out = out.replace(uni, tex);
71        }
72    }
73
74    // sqrt(...) → \sqrt{...}
75    out = replace_func(&out, "sqrt", "\\sqrt");
76
77    // exp(...) → \exp(...)
78    out = out.replace("exp(", "\\exp(");
79
80    // log(...) → \log(...)
81    out = out.replace("log(", "\\log(");
82
83    // Escape % and # in formulas
84    out = out.replace('%', "\\%");
85    out = out.replace('#', "\\#");
86
87    // Fix double superscripts: ^{+}^X → ^{+X}
88    // This arises when ⁺ (→ ^{+}) is followed by ^N in the source.
89    while let Some(pos) = out.find("^{+}^") {
90        let after = &out[pos + 5..]; // after "^{+}^"
91        // Collect the next superscript content: either {braced} or a single char
92        if after.starts_with('{') {
93            if let Some(close) = after.find('}') {
94                let inner = &after[1..close];
95                let replacement = format!("^{{+{inner}}}");
96                out = format!("{}{}{}", &out[..pos], replacement, &after[close + 1..]);
97            }
98        } else if let Some(ch) = after.chars().next() {
99            let replacement = format!("^{{+{ch}}}");
100            out = format!("{}{}{}", &out[..pos], replacement, &after[ch.len_utf8()..]);
101        }
102    }
103
104    out
105}
106
107/// Replace a Unicode symbol with a LaTeX command, inserting a trailing
108/// space when the next character is alphabetic (prevents `\foralli`).
109pub fn replace_unicode_cmd(s: &str, uni: &str, tex: &str) -> String {
110    let mut result = String::with_capacity(s.len());
111    let mut rest = s;
112    while let Some(pos) = rest.find(uni) {
113        result.push_str(&rest[..pos]);
114        result.push_str(tex);
115        let after = &rest[pos + uni.len()..];
116        // Insert space before next alphabetic char to keep LaTeX happy
117        if after.starts_with(|c: char| c.is_ascii_alphabetic()) {
118            result.push(' ');
119        }
120        rest = after;
121    }
122    result.push_str(rest);
123    result
124}
125
126/// Replace `func(...)` with `\cmd{...}` handling nested parens.
127/// Applies recursively so `sqrt(a + sqrt(b))` becomes `\sqrt{a + \sqrt{b}}`.
128pub fn replace_func(s: &str, func: &str, cmd: &str) -> String {
129    let pattern = format!("{func}(");
130    let mut result = String::with_capacity(s.len());
131    let mut rest = s;
132
133    while let Some(pos) = rest.find(&pattern) {
134        result.push_str(&rest[..pos]);
135        let after = &rest[pos + pattern.len()..];
136
137        // Find matching closing paren
138        let mut depth = 1;
139        let mut end = 0;
140        for (i, ch) in after.char_indices() {
141            match ch {
142                '(' => depth += 1,
143                ')' => {
144                    depth -= 1;
145                    if depth == 0 {
146                        end = i;
147                        break;
148                    }
149                }
150                _ => {}
151            }
152        }
153
154        if depth == 0 {
155            let inner = &after[..end];
156            // Recurse to handle nested calls like sqrt(a + sqrt(b))
157            let inner_replaced = replace_func(inner, func, cmd);
158            result.push_str(&format!("{cmd}{{{inner_replaced}}}"));
159            rest = &after[end + 1..];
160        } else {
161            // Unmatched paren — emit as-is
162            result.push_str(&pattern);
163            rest = after;
164        }
165    }
166
167    result.push_str(rest);
168    result
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_math_to_latex_greek() {
177        assert_eq!(math_to_latex("ε > 0"), "\\varepsilon > 0");
178        assert_eq!(math_to_latex("α_t"), "\\alpha_t");
179    }
180
181    #[test]
182    fn test_math_to_latex_operators() {
183        assert_eq!(math_to_latex("x ∈ ℝ^n"), "x \\in \\mathbb{R}^n");
184        assert_eq!(math_to_latex("a ≈ b"), "a \\approx b");
185        assert_eq!(math_to_latex("∀i: x_i ≥ 0"), "\\forall i: x_i \\geq 0");
186    }
187
188    #[test]
189    fn test_math_to_latex_sqrt() {
190        assert_eq!(
191            math_to_latex("Q / sqrt(mean(Q²) + ε)"),
192            "Q / \\sqrt{mean(Q²) + \\varepsilon}"
193        );
194    }
195
196    #[test]
197    fn test_math_to_latex_exp() {
198        assert_eq!(math_to_latex("exp(x_i - max(x))"), "\\exp(x_i - max(x))");
199    }
200
201    #[test]
202    fn test_replace_func_nested() {
203        assert_eq!(
204            replace_func("sqrt(a + sqrt(b))", "sqrt", "\\sqrt"),
205            "\\sqrt{a + \\sqrt{b}}"
206        );
207    }
208
209    #[test]
210    fn test_latex_escape() {
211        assert_eq!(latex_escape("a_b"), "a\\_b");
212        assert_eq!(latex_escape("100%"), "100\\%");
213    }
214}