1use std::fmt;
2use super::{Expr, E};
3
4fn precedence(e: &Expr) -> u8 {
6 match e {
7 Expr::Add(..) | Expr::Sub(..) => 1,
8 Expr::Mul(..) | Expr::Div(..) => 2,
9 Expr::Neg(..) => 3,
10 Expr::Pow(..) => 4,
11 _ => 10, }
13}
14
15fn fmt_child(f: &mut fmt::Formatter<'_>, child: &Expr, parent_prec: u8, right_assoc: bool) -> fmt::Result {
16 let child_prec = precedence(child);
17 let needs_parens = if right_assoc {
18 child_prec < parent_prec || (child_prec == parent_prec && parent_prec <= 2)
19 } else {
20 child_prec < parent_prec
21 };
22 if needs_parens {
23 write!(f, "(")?;
24 fmt::Display::fmt(child, f)?;
25 write!(f, ")")
26 } else {
27 fmt::Display::fmt(child, f)
28 }
29}
30
31fn fmt_unary(f: &mut fmt::Formatter<'_>, name: &str, arg: &Expr) -> fmt::Result {
32 write!(f, "{name}(")?;
33 fmt::Display::fmt(arg, f)?;
34 write!(f, ")")
35}
36
37fn fmt_binary_fn(f: &mut fmt::Formatter<'_>, name: &str, a: &Expr, b: &Expr) -> fmt::Result {
38 write!(f, "{name}(")?;
39 fmt::Display::fmt(a, f)?;
40 write!(f, ", ")?;
41 fmt::Display::fmt(b, f)?;
42 write!(f, ")")
43}
44
45fn fmt_const(f: &mut fmt::Formatter<'_>, v: f64) -> fmt::Result {
46 if v == v.floor() && v.abs() < 1e15 {
47 write!(f, "{}", v as i64)
48 } else {
49 write!(f, "{v}")
50 }
51}
52
53impl fmt::Display for Expr {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 match self {
56 Expr::Sym(name) => write!(f, "{name}"),
57 Expr::Const(v) => fmt_const(f, *v),
58 Expr::NamedConst { name, .. } => write!(f, "{name}"),
59 Expr::Neg(a) => {
60 write!(f, "-")?;
61 let needs_parens = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..) | Expr::Neg(_));
62 if needs_parens {
63 write!(f, "(")?;
64 fmt::Display::fmt(a.as_ref(), f)?;
65 write!(f, ")")
66 } else {
67 fmt::Display::fmt(a.as_ref(), f)
68 }
69 }
70 Expr::Add(a, b) => {
71 let p = precedence(self);
72 fmt_child(f, a, p, false)?;
73 write!(f, " + ")?;
74 fmt_child(f, b, p, false)
75 }
76 Expr::Sub(a, b) => {
77 let p = precedence(self);
78 fmt_child(f, a, p, false)?;
79 write!(f, " - ")?;
80 fmt_child(f, b, p, true)
81 }
82 Expr::Mul(a, b) => {
83 let p = precedence(self);
84 fmt_child(f, a, p, false)?;
85 write!(f, " * ")?;
86 fmt_child(f, b, p, false)
87 }
88 Expr::Div(a, b) => {
89 let p = precedence(self);
90 fmt_child(f, a, p, false)?;
91 write!(f, " / ")?;
92 fmt_child(f, b, p, true)
93 }
94 Expr::Pow(a, b) => {
95 let base_needs = precedence(a) < precedence(self);
96 if base_needs {
97 write!(f, "(")?;
98 fmt::Display::fmt(a.as_ref(), f)?;
99 write!(f, ")")?;
100 } else {
101 fmt::Display::fmt(a.as_ref(), f)?;
102 }
103 write!(f, "^")?;
104 let exp_needs = precedence(b) < 10;
105 if exp_needs {
106 write!(f, "(")?;
107 fmt::Display::fmt(b.as_ref(), f)?;
108 write!(f, ")")
109 } else {
110 fmt::Display::fmt(b.as_ref(), f)
111 }
112 }
113 Expr::Sin(a) => fmt_unary(f, "sin", a),
114 Expr::Cos(a) => fmt_unary(f, "cos", a),
115 Expr::Tan(a) => fmt_unary(f, "tan", a),
116 Expr::Asin(a) => fmt_unary(f, "asin", a),
117 Expr::Acos(a) => fmt_unary(f, "acos", a),
118 Expr::Atan(a) => fmt_unary(f, "atan", a),
119 Expr::Atan2(y, x) => fmt_binary_fn(f, "atan2", y, x),
120 Expr::Sinh(a) => fmt_unary(f, "sinh", a),
121 Expr::Cosh(a) => fmt_unary(f, "cosh", a),
122 Expr::Tanh(a) => fmt_unary(f, "tanh", a),
123 Expr::Exp(a) => fmt_unary(f, "exp", a),
124 Expr::Ln(a) => fmt_unary(f, "ln", a),
125 Expr::Log2(a) => fmt_unary(f, "log2", a),
126 Expr::Log10(a) => fmt_unary(f, "log10", a),
127 Expr::Sqrt(a) => fmt_unary(f, "sqrt", a),
128 Expr::Abs(a) => fmt_unary(f, "abs", a),
129 Expr::Heaviside(a) => fmt_unary(f, "H", a),
130 Expr::Clamp(val, lo, hi) => {
131 write!(f, "clamp(")?;
132 fmt::Display::fmt(val.as_ref(), f)?;
133 write!(f, ", ")?;
134 fmt::Display::fmt(lo.as_ref(), f)?;
135 write!(f, ", ")?;
136 fmt::Display::fmt(hi.as_ref(), f)?;
137 write!(f, ")")
138 }
139 Expr::Func { name, args, .. } => {
140 write!(f, "{name}(")?;
141 for (i, arg) in args.iter().enumerate() {
142 if i > 0 { write!(f, ", ")?; }
143 fmt::Display::fmt(arg.as_ref(), f)?;
144 }
145 write!(f, ")")
146 }
147 }
148 }
149}
150
151impl fmt::Display for E {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 fmt::Display::fmt(self.as_ref(), f)
154 }
155}
156
157impl fmt::Debug for E {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 fmt::Display::fmt(self.as_ref(), f)
161 }
162}
163
164impl Expr {
167 pub fn to_latex(&self) -> String {
172 let mut buf = String::new();
173 self.write_latex(&mut buf);
174 buf
175 }
176
177 fn write_latex(&self, buf: &mut String) {
178 match self {
179 Expr::Sym(name) => buf.push_str(name),
180 Expr::Const(v) => {
181 if *v == v.floor() && v.abs() < 1e15 {
182 buf.push_str(&format!("{}", *v as i64));
183 } else {
184 buf.push_str(&format!("{v}"));
185 }
186 }
187 Expr::NamedConst { latex, .. } => buf.push_str(latex),
188 Expr::Neg(a) => {
189 buf.push('-');
190 let needs_parens = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..));
191 if needs_parens {
192 buf.push_str("\\left(");
193 a.write_latex(buf);
194 buf.push_str("\\right)");
195 } else {
196 a.write_latex(buf);
197 }
198 }
199 Expr::Add(a, b) => {
200 a.write_latex(buf);
201 buf.push_str(" + ");
202 b.write_latex(buf);
203 }
204 Expr::Sub(a, b) => {
205 a.write_latex(buf);
206 buf.push_str(" - ");
207 let needs_parens = matches!(b.as_ref(), Expr::Add(..) | Expr::Sub(..));
208 if needs_parens {
209 buf.push_str("\\left(");
210 b.write_latex(buf);
211 buf.push_str("\\right)");
212 } else {
213 b.write_latex(buf);
214 }
215 }
216 Expr::Mul(a, b) => {
217 let a_needs = matches!(a.as_ref(), Expr::Add(..) | Expr::Sub(..));
218 if a_needs {
219 buf.push_str("\\left(");
220 a.write_latex(buf);
221 buf.push_str("\\right)");
222 } else {
223 a.write_latex(buf);
224 }
225 buf.push_str(" \\cdot ");
226 let b_needs = matches!(b.as_ref(), Expr::Add(..) | Expr::Sub(..));
227 if b_needs {
228 buf.push_str("\\left(");
229 b.write_latex(buf);
230 buf.push_str("\\right)");
231 } else {
232 b.write_latex(buf);
233 }
234 }
235 Expr::Div(a, b) => {
236 buf.push_str("\\frac{");
237 a.write_latex(buf);
238 buf.push_str("}{");
239 b.write_latex(buf);
240 buf.push('}');
241 }
242 Expr::Pow(a, b) => {
243 let needs_parens = matches!(
244 a.as_ref(),
245 Expr::Add(..) | Expr::Sub(..) | Expr::Mul(..) | Expr::Div(..) | Expr::Neg(..)
246 );
247 if needs_parens {
248 buf.push_str("\\left(");
249 a.write_latex(buf);
250 buf.push_str("\\right)");
251 } else {
252 a.write_latex(buf);
253 }
254 buf.push_str("^{");
255 b.write_latex(buf);
256 buf.push('}');
257 }
258 Expr::Sin(a) => Self::write_latex_fn(buf, "\\sin", a),
259 Expr::Cos(a) => Self::write_latex_fn(buf, "\\cos", a),
260 Expr::Tan(a) => Self::write_latex_fn(buf, "\\tan", a),
261 Expr::Asin(a) => Self::write_latex_fn(buf, "\\arcsin", a),
262 Expr::Acos(a) => Self::write_latex_fn(buf, "\\arccos", a),
263 Expr::Atan(a) => Self::write_latex_fn(buf, "\\arctan", a),
264 Expr::Atan2(y, x) => {
265 buf.push_str("\\operatorname{atan2}\\left(");
266 y.write_latex(buf);
267 buf.push_str(", ");
268 x.write_latex(buf);
269 buf.push_str("\\right)");
270 }
271 Expr::Sinh(a) => Self::write_latex_fn(buf, "\\sinh", a),
272 Expr::Cosh(a) => Self::write_latex_fn(buf, "\\cosh", a),
273 Expr::Tanh(a) => Self::write_latex_fn(buf, "\\tanh", a),
274 Expr::Exp(a) => {
275 buf.push_str("e^{");
276 a.write_latex(buf);
277 buf.push('}');
278 }
279 Expr::Ln(a) => Self::write_latex_fn(buf, "\\ln", a),
280 Expr::Log2(a) => Self::write_latex_fn(buf, "\\log_2", a),
281 Expr::Log10(a) => Self::write_latex_fn(buf, "\\log_{10}", a),
282 Expr::Sqrt(a) => {
283 buf.push_str("\\sqrt{");
284 a.write_latex(buf);
285 buf.push('}');
286 }
287 Expr::Abs(a) => {
288 buf.push_str("\\left|");
289 a.write_latex(buf);
290 buf.push_str("\\right|");
291 }
292 Expr::Heaviside(a) => Self::write_latex_fn(buf, "H", a),
293 Expr::Clamp(val, lo, hi) => {
294 buf.push_str("\\operatorname{clamp}\\left(");
295 val.write_latex(buf);
296 buf.push_str(", ");
297 lo.write_latex(buf);
298 buf.push_str(", ");
299 hi.write_latex(buf);
300 buf.push_str("\\right)");
301 }
302 Expr::Func { name, args, .. } => {
303 let escaped = name.replace('_', "\\_");
304 buf.push_str(&format!("\\operatorname{{{escaped}}}\\left("));
305 for (i, arg) in args.iter().enumerate() {
306 if i > 0 { buf.push_str(", "); }
307 arg.write_latex(buf);
308 }
309 buf.push_str("\\right)");
310 }
311 }
312 }
313
314 fn write_latex_fn(buf: &mut String, name: &str, arg: &Expr) {
315 buf.push_str(name);
316 buf.push_str("\\left(");
317 arg.write_latex(buf);
318 buf.push_str("\\right)");
319 }
320
321 pub fn to_rust(&self, float_type: &str) -> String {
328 let mut buf = String::new();
329 self.write_rust(&mut buf, float_type, 0);
330 buf
331 }
332
333 fn prec(&self) -> u8 {
340 match self {
341 Expr::Add(_, _) | Expr::Sub(_, _) => 5,
342 Expr::Mul(_, _) | Expr::Div(_, _) => 6,
343 Expr::Neg(_) => 7,
344 _ => 8,
345 }
346 }
347
348 fn write_rust(&self, buf: &mut String, ft: &str, parent_prec: u8) {
349 let my_prec = self.prec();
350 let need_parens = my_prec < parent_prec;
353 if need_parens { buf.push('('); }
354
355 match self {
356 Expr::Sym(name) => buf.push_str(name),
357 Expr::Const(v) => {
358 if ft.is_empty() {
359 if *v == v.floor() && v.abs() < 1e15 {
360 buf.push_str(&format!("{}.0", *v as i64));
361 } else {
362 buf.push_str(&format!("{v}"));
363 }
364 } else if *v == v.floor() && v.abs() < 1e15 {
365 buf.push_str(&format!("{}.0_{ft}", *v as i64));
366 } else {
367 buf.push_str(&format!("{v}_{ft}"));
368 }
369 }
370 Expr::NamedConst { rust_f32, rust_f64, .. } => {
371 buf.push_str(if ft == "f32" { rust_f32 } else { rust_f64 });
372 }
373 Expr::Neg(a) => {
374 buf.push('-');
375 a.write_rust(buf, ft, 7);
376 }
377 Expr::Add(a, b) => {
378 a.write_rust(buf, ft, 5);
379 buf.push_str(" + ");
380 b.write_rust(buf, ft, 6); }
382 Expr::Sub(a, b) => {
383 a.write_rust(buf, ft, 5);
384 buf.push_str(" - ");
385 b.write_rust(buf, ft, 6); }
387 Expr::Mul(a, b) => {
388 a.write_rust(buf, ft, 6);
389 buf.push_str(" * ");
390 b.write_rust(buf, ft, 7); }
392 Expr::Div(a, b) => {
393 a.write_rust(buf, ft, 6);
394 buf.push_str(" / ");
395 b.write_rust(buf, ft, 7); }
397 Expr::Pow(a, b) => {
398 a.write_rust(buf, ft, 8);
399 buf.push_str(".powf(");
400 b.write_rust(buf, ft, 0);
401 buf.push(')');
402 }
403 Expr::Sin(a) => Self::write_rust_method(buf, ft, a, "sin"),
404 Expr::Cos(a) => Self::write_rust_method(buf, ft, a, "cos"),
405 Expr::Tan(a) => Self::write_rust_method(buf, ft, a, "tan"),
406 Expr::Asin(a) => Self::write_rust_method(buf, ft, a, "asin"),
407 Expr::Acos(a) => Self::write_rust_method(buf, ft, a, "acos"),
408 Expr::Atan(a) => Self::write_rust_method(buf, ft, a, "atan"),
409 Expr::Atan2(y, x) => {
410 y.write_rust(buf, ft, 8);
411 buf.push_str(".atan2(");
412 x.write_rust(buf, ft, 0);
413 buf.push(')');
414 }
415 Expr::Sinh(a) => Self::write_rust_method(buf, ft, a, "sinh"),
416 Expr::Cosh(a) => Self::write_rust_method(buf, ft, a, "cosh"),
417 Expr::Tanh(a) => Self::write_rust_method(buf, ft, a, "tanh"),
418 Expr::Exp(a) => Self::write_rust_method(buf, ft, a, "exp"),
419 Expr::Ln(a) => Self::write_rust_method(buf, ft, a, "ln"),
420 Expr::Log2(a) => Self::write_rust_method(buf, ft, a, "log2"),
421 Expr::Log10(a) => Self::write_rust_method(buf, ft, a, "log10"),
422 Expr::Sqrt(a) => Self::write_rust_method(buf, ft, a, "sqrt"),
423 Expr::Abs(a) => Self::write_rust_method(buf, ft, a, "abs"),
424 Expr::Heaviside(a) => Self::write_rust_method(buf, ft, a, "heaviside"),
425 Expr::Clamp(val, lo, hi) => {
426 val.write_rust(buf, ft, 8);
427 buf.push_str(".clamp(");
428 lo.write_rust(buf, ft, 0);
429 buf.push_str(", ");
430 hi.write_rust(buf, ft, 0);
431 buf.push(')');
432 }
433 Expr::Func { name, params, kind, args } => {
434 if let Some(body) = kind.body() {
435 let prec = if name == "identity" { 8 } else { parent_prec };
438 crate::expand_func(params, body, args).write_rust(buf, ft, prec);
439 } else if let crate::FuncKind::Extern { call_path, .. } = kind {
440 buf.push_str(call_path);
442 buf.push('(');
443 for (i, arg) in args.iter().enumerate() {
444 if i > 0 { buf.push_str(", "); }
445 arg.write_rust(buf, ft, 0);
446 }
447 buf.push(')');
448 }
449 return; }
451 }
452
453 if need_parens { buf.push(')'); }
454 }
455
456 fn write_rust_method(buf: &mut String, ft: &str, arg: &Expr, method: &str) {
457 arg.write_rust(buf, ft, 8);
458 buf.push('.');
459 buf.push_str(method);
460 buf.push_str("()");
461 }
462}