use std::collections::HashMap;
use mathlex::{BinaryOp, Expression, MathConstant};
use num_complex::Complex;
use mathlex_eval::{EvalInput, NumericResult, compile, eval};
fn main() {
println!("=== sqrt(-1) ===");
let ast = Expression::Function {
name: "sqrt".into(),
args: vec![Expression::Unary {
op: mathlex::UnaryOp::Neg,
operand: Box::new(Expression::Integer(1)),
}],
};
let compiled = compile(&ast, &HashMap::new()).expect("compile failed");
let result = eval(&compiled, HashMap::new())
.expect("eval failed")
.scalar()
.expect("scalar failed");
println!("sqrt(-1) = {:?}", result);
println!("\n=== 1 + 2i ===");
let ast = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Integer(1)),
right: Box::new(Expression::Binary {
op: BinaryOp::Mul,
left: Box::new(Expression::Integer(2)),
right: Box::new(Expression::Constant(MathConstant::I)),
}),
};
let compiled = compile(&ast, &HashMap::new()).expect("compile failed");
println!("is_complex: {}", compiled.is_complex());
let result = eval(&compiled, HashMap::new())
.expect("eval failed")
.scalar()
.expect("scalar failed");
println!("1 + 2i = {:?}", result);
println!("\n=== x^2 with x = 1+i ===");
let ast = Expression::Binary {
op: BinaryOp::Pow,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Integer(2)),
};
let compiled = compile(&ast, &HashMap::new()).expect("compile failed");
let mut args = HashMap::new();
args.insert("x", EvalInput::Complex(Complex::new(1.0, 1.0)));
let result = eval(&compiled, args)
.expect("eval failed")
.scalar()
.expect("scalar failed");
println!("(1+i)^2 = {:?}", result);
println!("\n=== ln(-1) ===");
let ast = Expression::Function {
name: "ln".into(),
args: vec![Expression::Unary {
op: mathlex::UnaryOp::Neg,
operand: Box::new(Expression::Integer(1)),
}],
};
let compiled = compile(&ast, &HashMap::new()).expect("compile failed");
let result = eval(&compiled, HashMap::new())
.expect("eval failed")
.scalar()
.expect("scalar failed");
println!("ln(-1) = {:?}", result);
if let NumericResult::Complex(c) = result {
println!(
" (re ≈ 0: {}, im ≈ π: {})",
c.re.abs() < 1e-10,
(c.im - std::f64::consts::PI).abs() < 1e-10
);
}
}