Skip to main content

arael_sym/
fmt.rs

1use std::fmt;
2use super::{Expr, E};
3
4// Operator precedence for minimal parenthesization
5fn precedence(e: &Expr) -> u8 {
6    match e {
7        Expr::Add(..) | Expr::Sub(..) => 1,
8        Expr::Mul(..) | Expr::Div(..) => 2,
9        Expr::Neg(..) => 3,
10        Expr::Pow(..) => 4,
11        _ => 10, // atoms and functions
12    }
13}
14
15fn fmt_child(f: &mut fmt::Formatter<'_>, child: &Expr, parent_prec: u8, right_assoc: bool) -> fmt::Result {
16    let child_prec = precedence(child);
17    let needs_parens = if right_assoc {
18        child_prec < parent_prec || (child_prec == parent_prec && parent_prec <= 2)
19    } else {
20        child_prec < parent_prec
21    };
22    if needs_parens {
23        write!(f, "(")?;
24        fmt::Display::fmt(child, f)?;
25        write!(f, ")")
26    } else {
27        fmt::Display::fmt(child, f)
28    }
29}
30
31fn fmt_unary(f: &mut fmt::Formatter<'_>, name: &str, arg: &Expr) -> fmt::Result {
32    write!(f, "{name}(")?;
33    fmt::Display::fmt(arg, f)?;
34    write!(f, ")")
35}
36
37fn fmt_binary_fn(f: &mut fmt::Formatter<'_>, name: &str, a: &Expr, b: &Expr) -> fmt::Result {
38    write!(f, "{name}(")?;
39    fmt::Display::fmt(a, f)?;
40    write!(f, ", ")?;
41    fmt::Display::fmt(b, f)?;
42    write!(f, ")")
43}
44
45fn fmt_const(f: &mut fmt::Formatter<'_>, v: f64) -> fmt::Result {
46    if v == v.floor() && v.abs() < 1e15 {
47        write!(f, "{}", v as i64)
48    } else {
49        write!(f, "{v}")
50    }
51}
52
53impl fmt::Display for Expr {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        match self {
56            Expr::Sym(name) => write!(f, "{name}"),
57            Expr::Const(v) => fmt_const(f, *v),
58            Expr::NamedConst { name, .. } => write!(f, "{name}"),
59            Expr::Neg(a) => {
60                write!(f, "-")?;
61                let needs_parens = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..) | Expr::Neg(_));
62                if needs_parens {
63                    write!(f, "(")?;
64                    fmt::Display::fmt(a.as_ref(), f)?;
65                    write!(f, ")")
66                } else {
67                    fmt::Display::fmt(a.as_ref(), f)
68                }
69            }
70            Expr::Add(a, b) => {
71                let p = precedence(self);
72                fmt_child(f, a, p, false)?;
73                write!(f, " + ")?;
74                fmt_child(f, b, p, false)
75            }
76            Expr::Sub(a, b) => {
77                let p = precedence(self);
78                fmt_child(f, a, p, false)?;
79                write!(f, " - ")?;
80                fmt_child(f, b, p, true)
81            }
82            Expr::Mul(a, b) => {
83                let p = precedence(self);
84                fmt_child(f, a, p, false)?;
85                write!(f, " * ")?;
86                fmt_child(f, b, p, false)
87            }
88            Expr::Div(a, b) => {
89                let p = precedence(self);
90                fmt_child(f, a, p, false)?;
91                write!(f, " / ")?;
92                fmt_child(f, b, p, true)
93            }
94            Expr::Pow(a, b) => {
95                let base_needs = precedence(a) < precedence(self);
96                if base_needs {
97                    write!(f, "(")?;
98                    fmt::Display::fmt(a.as_ref(), f)?;
99                    write!(f, ")")?;
100                } else {
101                    fmt::Display::fmt(a.as_ref(), f)?;
102                }
103                write!(f, "^")?;
104                let exp_needs = precedence(b) < 10;
105                if exp_needs {
106                    write!(f, "(")?;
107                    fmt::Display::fmt(b.as_ref(), f)?;
108                    write!(f, ")")
109                } else {
110                    fmt::Display::fmt(b.as_ref(), f)
111                }
112            }
113            Expr::Sin(a) => fmt_unary(f, "sin", a),
114            Expr::Cos(a) => fmt_unary(f, "cos", a),
115            Expr::Tan(a) => fmt_unary(f, "tan", a),
116            Expr::Asin(a) => fmt_unary(f, "asin", a),
117            Expr::Acos(a) => fmt_unary(f, "acos", a),
118            Expr::Atan(a) => fmt_unary(f, "atan", a),
119            Expr::Atan2(y, x) => fmt_binary_fn(f, "atan2", y, x),
120            Expr::Sinh(a) => fmt_unary(f, "sinh", a),
121            Expr::Cosh(a) => fmt_unary(f, "cosh", a),
122            Expr::Tanh(a) => fmt_unary(f, "tanh", a),
123            Expr::Exp(a) => fmt_unary(f, "exp", a),
124            Expr::Ln(a) => fmt_unary(f, "ln", a),
125            Expr::Log2(a) => fmt_unary(f, "log2", a),
126            Expr::Log10(a) => fmt_unary(f, "log10", a),
127            Expr::Sqrt(a) => fmt_unary(f, "sqrt", a),
128            Expr::Abs(a) => fmt_unary(f, "abs", a),
129            Expr::Heaviside(a) => fmt_unary(f, "H", a),
130            Expr::Clamp(val, lo, hi) => {
131                write!(f, "clamp(")?;
132                fmt::Display::fmt(val.as_ref(), f)?;
133                write!(f, ", ")?;
134                fmt::Display::fmt(lo.as_ref(), f)?;
135                write!(f, ", ")?;
136                fmt::Display::fmt(hi.as_ref(), f)?;
137                write!(f, ")")
138            }
139            Expr::Func { name, args, .. } => {
140                write!(f, "{name}(")?;
141                for (i, arg) in args.iter().enumerate() {
142                    if i > 0 { write!(f, ", ")?; }
143                    fmt::Display::fmt(arg.as_ref(), f)?;
144                }
145                write!(f, ")")
146            }
147        }
148    }
149}
150
151impl fmt::Display for E {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        fmt::Display::fmt(self.as_ref(), f)
154    }
155}
156
157impl fmt::Debug for E {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        // Use Display for Debug too — more readable
160        fmt::Display::fmt(self.as_ref(), f)
161    }
162}
163
164// --- LaTeX output ---
165
166impl Expr {
167    /// Format the expression as LaTeX math notation.
168    ///
169    /// Produces a string suitable for embedding in LaTeX documents, using
170    /// `\frac`, `\sqrt`, `\sin`, etc.
171    pub fn to_latex(&self) -> String {
172        let mut buf = String::new();
173        self.write_latex(&mut buf);
174        buf
175    }
176
177    fn write_latex(&self, buf: &mut String) {
178        match self {
179            Expr::Sym(name) => buf.push_str(name),
180            Expr::Const(v) => {
181                if *v == v.floor() && v.abs() < 1e15 {
182                    buf.push_str(&format!("{}", *v as i64));
183                } else {
184                    buf.push_str(&format!("{v}"));
185                }
186            }
187            Expr::NamedConst { latex, .. } => buf.push_str(latex),
188            Expr::Neg(a) => {
189                buf.push('-');
190                let needs_parens = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..));
191                if needs_parens {
192                    buf.push_str("\\left(");
193                    a.write_latex(buf);
194                    buf.push_str("\\right)");
195                } else {
196                    a.write_latex(buf);
197                }
198            }
199            Expr::Add(a, b) => {
200                a.write_latex(buf);
201                buf.push_str(" + ");
202                b.write_latex(buf);
203            }
204            Expr::Sub(a, b) => {
205                a.write_latex(buf);
206                buf.push_str(" - ");
207                let needs_parens = matches!(b.as_ref(), Expr::Add(..) | Expr::Sub(..));
208                if needs_parens {
209                    buf.push_str("\\left(");
210                    b.write_latex(buf);
211                    buf.push_str("\\right)");
212                } else {
213                    b.write_latex(buf);
214                }
215            }
216            Expr::Mul(a, b) => {
217                let a_needs = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..));
218                if a_needs {
219                    buf.push_str("\\left(");
220                    a.write_latex(buf);
221                    buf.push_str("\\right)");
222                } else {
223                    a.write_latex(buf);
224                }
225                buf.push_str(" \\cdot ");
226                let b_needs = matches!(b.as_ref(), Expr::Add(..) | Expr::Sub(..));
227                if b_needs {
228                    buf.push_str("\\left(");
229                    b.write_latex(buf);
230                    buf.push_str("\\right)");
231                } else {
232                    b.write_latex(buf);
233                }
234            }
235            Expr::Div(a, b) => {
236                buf.push_str("\\frac{");
237                a.write_latex(buf);
238                buf.push_str("}{");
239                b.write_latex(buf);
240                buf.push('}');
241            }
242            Expr::Pow(a, b) => {
243                let needs_parens = matches!(
244                    a.as_ref(),
245                    Expr::Add(..) | Expr::Sub(..) | Expr::Mul(..) | Expr::Div(..) | Expr::Neg(..)
246                );
247                if needs_parens {
248                    buf.push_str("\\left(");
249                    a.write_latex(buf);
250                    buf.push_str("\\right)");
251                } else {
252                    a.write_latex(buf);
253                }
254                buf.push_str("^{");
255                b.write_latex(buf);
256                buf.push('}');
257            }
258            Expr::Sin(a) => Self::write_latex_fn(buf, "\\sin", a),
259            Expr::Cos(a) => Self::write_latex_fn(buf, "\\cos", a),
260            Expr::Tan(a) => Self::write_latex_fn(buf, "\\tan", a),
261            Expr::Asin(a) => Self::write_latex_fn(buf, "\\arcsin", a),
262            Expr::Acos(a) => Self::write_latex_fn(buf, "\\arccos", a),
263            Expr::Atan(a) => Self::write_latex_fn(buf, "\\arctan", a),
264            Expr::Atan2(y, x) => {
265                buf.push_str("\\operatorname{atan2}\\left(");
266                y.write_latex(buf);
267                buf.push_str(", ");
268                x.write_latex(buf);
269                buf.push_str("\\right)");
270            }
271            Expr::Sinh(a) => Self::write_latex_fn(buf, "\\sinh", a),
272            Expr::Cosh(a) => Self::write_latex_fn(buf, "\\cosh", a),
273            Expr::Tanh(a) => Self::write_latex_fn(buf, "\\tanh", a),
274            Expr::Exp(a) => {
275                buf.push_str("e^{");
276                a.write_latex(buf);
277                buf.push('}');
278            }
279            Expr::Ln(a) => Self::write_latex_fn(buf, "\\ln", a),
280            Expr::Log2(a) => Self::write_latex_fn(buf, "\\log_2", a),
281            Expr::Log10(a) => Self::write_latex_fn(buf, "\\log_{10}", a),
282            Expr::Sqrt(a) => {
283                buf.push_str("\\sqrt{");
284                a.write_latex(buf);
285                buf.push('}');
286            }
287            Expr::Abs(a) => {
288                buf.push_str("\\left|");
289                a.write_latex(buf);
290                buf.push_str("\\right|");
291            }
292            Expr::Heaviside(a) => Self::write_latex_fn(buf, "H", a),
293            Expr::Clamp(val, lo, hi) => {
294                buf.push_str("\\operatorname{clamp}\\left(");
295                val.write_latex(buf);
296                buf.push_str(", ");
297                lo.write_latex(buf);
298                buf.push_str(", ");
299                hi.write_latex(buf);
300                buf.push_str("\\right)");
301            }
302            Expr::Func { name, args, .. } => {
303                let escaped = name.replace('_', "\\_");
304                buf.push_str(&format!("\\operatorname{{{escaped}}}\\left("));
305                for (i, arg) in args.iter().enumerate() {
306                    if i > 0 { buf.push_str(", "); }
307                    arg.write_latex(buf);
308                }
309                buf.push_str("\\right)");
310            }
311        }
312    }
313
314    fn write_latex_fn(buf: &mut String, name: &str, arg: &Expr) {
315        buf.push_str(name);
316        buf.push_str("\\left(");
317        arg.write_latex(buf);
318        buf.push_str("\\right)");
319    }
320
321    // --- Rust code output ---
322
323    /// Generate Rust source code for this expression.
324    ///
325    /// The `float_type` parameter (e.g. `"f64"`) controls numeric literal
326    /// suffixes. Pass an empty string to omit type suffixes.
327    pub fn to_rust(&self, float_type: &str) -> String {
328        let mut buf = String::new();
329        self.write_rust(&mut buf, float_type, 0);
330        buf
331    }
332
333    // Precedence levels (matching Rust):
334    // 0 = top level
335    // 5 = Add, Sub
336    // 6 = Mul, Div
337    // 7 = Unary Neg
338    // 8 = Atoms, method calls (never need parens)
339    fn prec(&self) -> u8 {
340        match self {
341            Expr::Add(_, _) | Expr::Sub(_, _) => 5,
342            Expr::Mul(_, _) | Expr::Div(_, _) => 6,
343            Expr::Neg(_) => 7,
344            _ => 8,
345        }
346    }
347
348    fn write_rust(&self, buf: &mut String, ft: &str, parent_prec: u8) {
349        let my_prec = self.prec();
350        // Need parens when our precedence is lower than parent's,
351        // or for Sub/Div right-hand side at same precedence (non-associative)
352        let need_parens = my_prec < parent_prec;
353        if need_parens { buf.push('('); }
354
355        match self {
356            Expr::Sym(name) => buf.push_str(name),
357            Expr::Const(v) => {
358                if ft.is_empty() {
359                    if *v == v.floor() && v.abs() < 1e15 {
360                        buf.push_str(&format!("{}.0", *v as i64));
361                    } else {
362                        buf.push_str(&format!("{v}"));
363                    }
364                } else if *v == v.floor() && v.abs() < 1e15 {
365                    buf.push_str(&format!("{}.0_{ft}", *v as i64));
366                } else {
367                    buf.push_str(&format!("{v}_{ft}"));
368                }
369            }
370            Expr::NamedConst { rust_f32, rust_f64, .. } => {
371                buf.push_str(if ft == "f32" { rust_f32 } else { rust_f64 });
372            }
373            Expr::Neg(a) => {
374                buf.push('-');
375                a.write_rust(buf, ft, 7);
376            }
377            Expr::Add(a, b) => {
378                a.write_rust(buf, ft, 5);
379                buf.push_str(" + ");
380                b.write_rust(buf, ft, 6); // right side: need parens for sub at same level
381            }
382            Expr::Sub(a, b) => {
383                a.write_rust(buf, ft, 5);
384                buf.push_str(" - ");
385                b.write_rust(buf, ft, 6); // right side of sub: parens for add/sub
386            }
387            Expr::Mul(a, b) => {
388                a.write_rust(buf, ft, 6);
389                buf.push_str(" * ");
390                b.write_rust(buf, ft, 7); // right side: need parens for div at same level
391            }
392            Expr::Div(a, b) => {
393                a.write_rust(buf, ft, 6);
394                buf.push_str(" / ");
395                b.write_rust(buf, ft, 7); // right side of div: parens for mul/div
396            }
397            Expr::Pow(a, b) => {
398                a.write_rust(buf, ft, 8);
399                buf.push_str(".powf(");
400                b.write_rust(buf, ft, 0);
401                buf.push(')');
402            }
403            Expr::Sin(a) => Self::write_rust_method(buf, ft, a, "sin"),
404            Expr::Cos(a) => Self::write_rust_method(buf, ft, a, "cos"),
405            Expr::Tan(a) => Self::write_rust_method(buf, ft, a, "tan"),
406            Expr::Asin(a) => Self::write_rust_method(buf, ft, a, "asin"),
407            Expr::Acos(a) => Self::write_rust_method(buf, ft, a, "acos"),
408            Expr::Atan(a) => Self::write_rust_method(buf, ft, a, "atan"),
409            Expr::Atan2(y, x) => {
410                y.write_rust(buf, ft, 8);
411                buf.push_str(".atan2(");
412                x.write_rust(buf, ft, 0);
413                buf.push(')');
414            }
415            Expr::Sinh(a) => Self::write_rust_method(buf, ft, a, "sinh"),
416            Expr::Cosh(a) => Self::write_rust_method(buf, ft, a, "cosh"),
417            Expr::Tanh(a) => Self::write_rust_method(buf, ft, a, "tanh"),
418            Expr::Exp(a) => Self::write_rust_method(buf, ft, a, "exp"),
419            Expr::Ln(a) => Self::write_rust_method(buf, ft, a, "ln"),
420            Expr::Log2(a) => Self::write_rust_method(buf, ft, a, "log2"),
421            Expr::Log10(a) => Self::write_rust_method(buf, ft, a, "log10"),
422            Expr::Sqrt(a) => Self::write_rust_method(buf, ft, a, "sqrt"),
423            Expr::Abs(a) => Self::write_rust_method(buf, ft, a, "abs"),
424            Expr::Heaviside(a) => Self::write_rust_method(buf, ft, a, "heaviside"),
425            Expr::Clamp(val, lo, hi) => {
426                val.write_rust(buf, ft, 8);
427                buf.push_str(".clamp(");
428                lo.write_rust(buf, ft, 0);
429                buf.push_str(", ");
430                hi.write_rust(buf, ft, 0);
431                buf.push(')');
432            }
433            Expr::Func { name, params, kind, args } => {
434                if let Some(body) = kind.body() {
435                    // Symbolic: inline the expanded body.
436                    // Identity functions force parentheses to preserve eval order.
437                    let prec = if name == "identity" { 8 } else { parent_prec };
438                    crate::expand_func(params, body, args).write_rust(buf, ft, prec);
439                } else if let crate::FuncKind::Extern { call_path, .. } = kind {
440                    // Extern: emit function call
441                    buf.push_str(call_path);
442                    buf.push('(');
443                    for (i, arg) in args.iter().enumerate() {
444                        if i > 0 { buf.push_str(", "); }
445                        arg.write_rust(buf, ft, 0);
446                    }
447                    buf.push(')');
448                }
449                return; // already wrote, skip trailing paren logic
450            }
451        }
452
453        if need_parens { buf.push(')'); }
454    }
455
456    fn write_rust_method(buf: &mut String, ft: &str, arg: &Expr, method: &str) {
457        arg.write_rust(buf, ft, 8);
458        buf.push('.');
459        buf.push_str(method);
460        buf.push_str("()");
461    }
462}