mathlex 0.4.1

Mathematical expression parser for LaTeX and plain text notation, producing a language-agnostic AST
Documentation
use super::trait_def::{ToLatex, KNOWN_FUNCTIONS};
use crate::ast::precedence::needs_parens;
use crate::ast::{ExprKind, Expression};

const GREEK_LETTERS: &[&str] = &[
    "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota", "kappa",
    "lambda", "mu", "nu", "xi", "omicron", "pi", "rho", "sigma", "tau", "upsilon", "phi", "chi",
    "psi", "omega", "Gamma", "Delta", "Theta", "Lambda", "Xi", "Pi", "Sigma", "Upsilon", "Phi",
    "Psi", "Omega",
];

pub(super) fn variable_to_latex(name: &str) -> String {
    if let Some(underscore_pos) = name.find('_') {
        let (base, subscript_with_underscore) = name.split_at(underscore_pos);
        let subscript = &subscript_with_underscore[1..];
        let base_latex = if GREEK_LETTERS.contains(&base) {
            format!(r"\{}", base)
        } else {
            base.to_string()
        };
        if subscript.len() == 1 {
            format!("{}_{}", base_latex, subscript)
        } else {
            format!("{}_{{{}}}", base_latex, subscript)
        }
    } else if GREEK_LETTERS.contains(&name) {
        format!(r"\{}", name)
    } else {
        name.to_string()
    }
}

pub(super) fn to_latex_literal(expr: &Expression) -> String {
    match &expr.kind {
        ExprKind::Integer(n) => format!("{}", n),
        ExprKind::Float(x) => format!("{}", x),
        ExprKind::Rational {
            numerator,
            denominator,
        } => format!(
            r"\frac{{{}}}{{{}}}",
            numerator.to_latex(),
            denominator.to_latex()
        ),
        ExprKind::Complex { real, imaginary } => {
            format!("{} + {}i", real.to_latex(), imaginary.to_latex())
        }
        ExprKind::Quaternion { real, i, j, k } => format!(
            "{} + {}\\mathbf{{i}} + {}\\mathbf{{j}} + {}\\mathbf{{k}}",
            real.to_latex(),
            i.to_latex(),
            j.to_latex(),
            k.to_latex()
        ),
        ExprKind::Variable(name) => variable_to_latex(name),
        ExprKind::Constant(c) => c.to_latex(),
        _ => unreachable!("to_latex_literal called on non-literal"),
    }
}

pub(super) fn to_latex_binary(
    op: &crate::ast::BinaryOp,
    left: &Expression,
    right: &Expression,
) -> String {
    match op {
        crate::ast::BinaryOp::Div => {
            format!(r"\frac{{{}}}{{{}}}", left.to_latex(), right.to_latex())
        }
        crate::ast::BinaryOp::Pow => {
            let left_str = if needs_parens(left, *op, false) {
                format!(r"\left({}\right)", left.to_latex())
            } else {
                left.to_latex()
            };
            let right_str = if needs_parens(right, *op, true) {
                format!(r"\left({}\right)", right.to_latex())
            } else {
                right.to_latex()
            };
            format!("{}^{{{}}}", left_str, right_str)
        }
        crate::ast::BinaryOp::Mod => {
            format!("{} \\bmod {}", left.to_latex(), right.to_latex())
        }
        _ => {
            let left_str = if needs_parens(left, *op, false) {
                format!(r"\left({}\right)", left.to_latex())
            } else {
                left.to_latex()
            };
            let right_str = if needs_parens(right, *op, true) {
                format!(r"\left({}\right)", right.to_latex())
            } else {
                right.to_latex()
            };
            format!("{} {} {}", left_str, op.to_latex(), right_str)
        }
    }
}

pub(super) fn to_latex_unary(op: &crate::ast::UnaryOp, operand: &Expression) -> String {
    let is_binary = matches!(operand.kind, ExprKind::Binary { .. });
    match op {
        crate::ast::UnaryOp::Neg => {
            if is_binary {
                format!(r"-\left({}\right)", operand.to_latex())
            } else {
                format!("-{}", operand.to_latex())
            }
        }
        crate::ast::UnaryOp::Pos => {
            if is_binary {
                format!(r"+\left({}\right)", operand.to_latex())
            } else {
                format!("+{}", operand.to_latex())
            }
        }
        crate::ast::UnaryOp::Factorial => {
            if is_binary {
                format!(r"\left({}\right)!", operand.to_latex())
            } else {
                format!("{}!", operand.to_latex())
            }
        }
        crate::ast::UnaryOp::Transpose => {
            if is_binary {
                format!(r"\left({}\right)^T", operand.to_latex())
            } else {
                format!("{}^T", operand.to_latex())
            }
        }
    }
}

pub(super) fn to_latex_function(name: &str, args: &[Expression]) -> String {
    if name == "sqrt" {
        return match args.len() {
            1 => format!(r"\sqrt{{{}}}", args[0].to_latex()),
            2 => format!(r"\sqrt[{}]{{{}}}", args[0].to_latex(), args[1].to_latex()),
            _ => format!(r"\operatorname{{{}}}", name),
        };
    }
    // log with base: args[0] = argument, args[1] = base → \log_{base}{arg}
    if name == "log" && args.len() == 2 {
        return format!(r"\log_{{{}}}{{{}}}", args[1].to_latex(), args[0].to_latex());
    }
    // floor/ceil with delimiters
    if name == "floor" && args.len() == 1 {
        return format!(r"\lfloor {} \rfloor", args[0].to_latex());
    }
    if name == "ceil" && args.len() == 1 {
        return format!(r"\lceil {} \rceil", args[0].to_latex());
    }
    let func_prefix = if KNOWN_FUNCTIONS.contains(&name) {
        format!(r"\{}", name)
    } else {
        format!(r"\operatorname{{{}}}", name)
    };
    if args.is_empty() {
        func_prefix
    } else {
        let args_str = args
            .iter()
            .map(|arg| arg.to_latex())
            .collect::<Vec<_>>()
            .join(", ");
        format!(r"{}\left({}\right)", func_prefix, args_str)
    }
}

fn to_latex_derivative(expr: &Expression) -> String {
    match &expr.kind {
        ExprKind::Derivative { expr, var, order } => {
            if *order == 1 {
                format!(r"\frac{{d}}{{d{}}}{}", var, expr.to_latex())
            } else {
                format!(
                    r"\frac{{d^{{{}}}}}{{d{}^{{{}}}}}{}",
                    order,
                    var,
                    order,
                    expr.to_latex()
                )
            }
        }
        ExprKind::PartialDerivative { expr, var, order } => {
            if *order == 1 {
                format!(r"\frac{{\partial}}{{\partial {}}}{}", var, expr.to_latex())
            } else {
                format!(
                    r"\frac{{\partial^{{{}}}}}{{\partial {}^{{{}}}}}{}",
                    order,
                    var,
                    order,
                    expr.to_latex()
                )
            }
        }
        _ => unreachable!("to_latex_derivative called on non-derivative expression"),
    }
}

fn to_latex_multiple_integral(
    dimension: u8,
    integrand: &Expression,
    bounds: Option<&crate::ast::MultipleBounds>,
    vars: &[String],
) -> String {
    let int_cmd = match dimension {
        2 => r"\iint",
        3 => r"\iiint",
        4 => r"\iiiint",
        _ => r"\int\cdots\int",
    };
    let vars_str = vars
        .iter()
        .map(|v| format!("d{}", v))
        .collect::<Vec<_>>()
        .join(" \\, ");
    if let Some(b) = bounds {
        let bounds_latex: Vec<String> = b
            .bounds
            .iter()
            .map(|ib| format!("_{{{}}}^{{{}}}", ib.lower.to_latex(), ib.upper.to_latex()))
            .collect();
        format!(
            "{}{} {} \\, {}",
            int_cmd,
            bounds_latex.join(""),
            integrand.to_latex(),
            vars_str
        )
    } else {
        format!("{} {} \\, {}", int_cmd, integrand.to_latex(), vars_str)
    }
}

fn to_latex_integral(expr: &Expression) -> String {
    match &expr.kind {
        ExprKind::Integral {
            integrand,
            var,
            bounds,
        } => {
            if let Some(b) = bounds {
                format!(
                    r"\int_{{{}}}^{{{}}} {} d{}",
                    b.lower.to_latex(),
                    b.upper.to_latex(),
                    integrand.to_latex(),
                    var
                )
            } else {
                format!(r"\int {} d{}", integrand.to_latex(), var)
            }
        }
        ExprKind::MultipleIntegral {
            dimension,
            integrand,
            bounds,
            vars,
        } => to_latex_multiple_integral(*dimension, integrand, bounds.as_ref(), vars),
        ExprKind::ClosedIntegral {
            dimension,
            integrand,
            surface,
            var,
        } => {
            let int_cmd = match dimension {
                1 => r"\oint",
                2 => r"\oiint",
                3 => r"\oiiint",
                _ => r"\oint",
            };
            if let Some(s) = surface {
                format!(
                    "{}_{{{}}} {} \\, d{}",
                    int_cmd,
                    s,
                    integrand.to_latex(),
                    var
                )
            } else {
                format!("{} {} \\, d{}", int_cmd, integrand.to_latex(), var)
            }
        }
        _ => unreachable!("to_latex_integral called on non-integral expression"),
    }
}

pub(super) fn to_latex_calculus(expr: &Expression) -> String {
    match &expr.kind {
        ExprKind::Derivative { .. } | ExprKind::PartialDerivative { .. } => {
            to_latex_derivative(expr)
        }
        ExprKind::Integral { .. }
        | ExprKind::MultipleIntegral { .. }
        | ExprKind::ClosedIntegral { .. } => to_latex_integral(expr),
        ExprKind::Limit {
            expr,
            var,
            to,
            direction,
        } => format!(
            r"\lim_{{{} \to {}{}}}{}",
            var,
            to.to_latex(),
            direction.to_latex(),
            expr.to_latex()
        ),
        ExprKind::Sum {
            index,
            lower,
            upper,
            body,
        } => format!(
            r"\sum_{{{}={}}}^{{{}}}{}",
            index,
            lower.to_latex(),
            upper.to_latex(),
            body.to_latex()
        ),
        ExprKind::Product {
            index,
            lower,
            upper,
            body,
        } => format!(
            r"\prod_{{{}={}}}^{{{}}}{}",
            index,
            lower.to_latex(),
            upper.to_latex(),
            body.to_latex()
        ),
        _ => unreachable!("to_latex_calculus called on non-calculus expression"),
    }
}