use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MathExpr {
pub root: MathNode,
pub confidence: f32,
}
impl MathExpr {
pub fn new(root: MathNode, confidence: f32) -> Self {
Self { root, confidence }
}
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
self.root.accept(visitor);
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MathNode {
Symbol {
value: String,
unicode: Option<char>,
},
Number {
value: String,
is_decimal: bool,
},
Binary {
op: BinaryOp,
left: Box<MathNode>,
right: Box<MathNode>,
},
Unary { op: UnaryOp, operand: Box<MathNode> },
Fraction {
numerator: Box<MathNode>,
denominator: Box<MathNode>,
},
Radical {
index: Option<Box<MathNode>>,
radicand: Box<MathNode>,
},
Script {
base: Box<MathNode>,
subscript: Option<Box<MathNode>>,
superscript: Option<Box<MathNode>>,
},
Function {
name: String,
argument: Box<MathNode>,
},
Matrix {
rows: Vec<Vec<MathNode>>,
bracket_type: BracketType,
},
Group {
content: Box<MathNode>,
bracket_type: BracketType,
},
LargeOp {
op_type: LargeOpType,
lower: Option<Box<MathNode>>,
upper: Option<Box<MathNode>>,
content: Box<MathNode>,
},
Sequence { elements: Vec<MathNode> },
Text { content: String },
Empty,
}
impl MathNode {
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
visitor.visit(self);
match self {
MathNode::Binary { left, right, .. } => {
left.accept(visitor);
right.accept(visitor);
}
MathNode::Unary { operand, .. } => {
operand.accept(visitor);
}
MathNode::Fraction {
numerator,
denominator,
} => {
numerator.accept(visitor);
denominator.accept(visitor);
}
MathNode::Radical { index, radicand } => {
if let Some(idx) = index {
idx.accept(visitor);
}
radicand.accept(visitor);
}
MathNode::Script {
base,
subscript,
superscript,
} => {
base.accept(visitor);
if let Some(sub) = subscript {
sub.accept(visitor);
}
if let Some(sup) = superscript {
sup.accept(visitor);
}
}
MathNode::Function { argument, .. } => {
argument.accept(visitor);
}
MathNode::Matrix { rows, .. } => {
for row in rows {
for elem in row {
elem.accept(visitor);
}
}
}
MathNode::Group { content, .. } => {
content.accept(visitor);
}
MathNode::LargeOp {
lower,
upper,
content,
..
} => {
if let Some(l) = lower {
l.accept(visitor);
}
if let Some(u) = upper {
u.accept(visitor);
}
content.accept(visitor);
}
MathNode::Sequence { elements } => {
for elem in elements {
elem.accept(visitor);
}
}
_ => {}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryOp {
Add,
Subtract,
Multiply,
Divide,
Power,
Equal,
NotEqual,
Less,
Greater,
LessEqual,
GreaterEqual,
ApproxEqual,
Equivalent,
Similar,
Congruent,
Proportional,
Custom(String),
}
impl BinaryOp {
pub fn precedence(&self) -> u8 {
match self {
BinaryOp::Power => 60,
BinaryOp::Multiply | BinaryOp::Divide => 50,
BinaryOp::Add | BinaryOp::Subtract => 40,
BinaryOp::Equal
| BinaryOp::NotEqual
| BinaryOp::Less
| BinaryOp::Greater
| BinaryOp::LessEqual
| BinaryOp::GreaterEqual
| BinaryOp::ApproxEqual
| BinaryOp::Equivalent
| BinaryOp::Similar
| BinaryOp::Congruent
| BinaryOp::Proportional => 30,
BinaryOp::Custom(_) => 35,
}
}
pub fn is_left_associative(&self) -> bool {
!matches!(self, BinaryOp::Power)
}
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Subtract => write!(f, "-"),
BinaryOp::Multiply => write!(f, "×"),
BinaryOp::Divide => write!(f, "÷"),
BinaryOp::Power => write!(f, "^"),
BinaryOp::Equal => write!(f, "="),
BinaryOp::NotEqual => write!(f, "≠"),
BinaryOp::Less => write!(f, "<"),
BinaryOp::Greater => write!(f, ">"),
BinaryOp::LessEqual => write!(f, "≤"),
BinaryOp::GreaterEqual => write!(f, "≥"),
BinaryOp::ApproxEqual => write!(f, "≈"),
BinaryOp::Equivalent => write!(f, "≡"),
BinaryOp::Similar => write!(f, "∼"),
BinaryOp::Congruent => write!(f, "≅"),
BinaryOp::Proportional => write!(f, "∝"),
BinaryOp::Custom(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UnaryOp {
Plus,
Minus,
Not,
Custom(String),
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOp::Plus => write!(f, "+"),
UnaryOp::Minus => write!(f, "-"),
UnaryOp::Not => write!(f, "¬"),
UnaryOp::Custom(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum LargeOpType {
Sum, Product, Integral, DoubleIntegral, TripleIntegral, ContourIntegral, Union, Intersection, Coproduct, DirectSum, Custom(String),
}
impl fmt::Display for LargeOpType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LargeOpType::Sum => write!(f, "∑"),
LargeOpType::Product => write!(f, "∏"),
LargeOpType::Integral => write!(f, "∫"),
LargeOpType::DoubleIntegral => write!(f, "∬"),
LargeOpType::TripleIntegral => write!(f, "∭"),
LargeOpType::ContourIntegral => write!(f, "∮"),
LargeOpType::Union => write!(f, "⋃"),
LargeOpType::Intersection => write!(f, "⋂"),
LargeOpType::Coproduct => write!(f, "∐"),
LargeOpType::DirectSum => write!(f, "⊕"),
LargeOpType::Custom(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BracketType {
Parentheses, Brackets, Braces, AngleBrackets, Vertical, DoubleVertical, Floor, Ceiling, None, }
impl BracketType {
pub fn opening(&self) -> &str {
match self {
BracketType::Parentheses => "(",
BracketType::Brackets => "[",
BracketType::Braces => "{",
BracketType::AngleBrackets => "⟨",
BracketType::Vertical => "|",
BracketType::DoubleVertical => "‖",
BracketType::Floor => "⌊",
BracketType::Ceiling => "⌈",
BracketType::None => "",
}
}
pub fn closing(&self) -> &str {
match self {
BracketType::Parentheses => ")",
BracketType::Brackets => "]",
BracketType::Braces => "}",
BracketType::AngleBrackets => "⟩",
BracketType::Vertical => "|",
BracketType::DoubleVertical => "‖",
BracketType::Floor => "⌋",
BracketType::Ceiling => "⌉",
BracketType::None => "",
}
}
}
pub trait MathVisitor {
fn visit(&mut self, node: &MathNode);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binary_op_precedence() {
assert!(BinaryOp::Power.precedence() > BinaryOp::Multiply.precedence());
assert!(BinaryOp::Multiply.precedence() > BinaryOp::Add.precedence());
assert!(BinaryOp::Add.precedence() > BinaryOp::Equal.precedence());
}
#[test]
fn test_binary_op_associativity() {
assert!(BinaryOp::Add.is_left_associative());
assert!(BinaryOp::Multiply.is_left_associative());
assert!(!BinaryOp::Power.is_left_associative());
}
#[test]
fn test_bracket_delimiters() {
assert_eq!(BracketType::Parentheses.opening(), "(");
assert_eq!(BracketType::Parentheses.closing(), ")");
assert_eq!(BracketType::Brackets.opening(), "[");
assert_eq!(BracketType::Braces.closing(), "}");
}
#[test]
fn test_math_expr_creation() {
let expr = MathExpr::new(
MathNode::Number {
value: "42".to_string(),
is_decimal: false,
},
0.95,
);
assert_eq!(expr.confidence, 0.95);
}
#[test]
fn test_visitor_pattern() {
struct CountVisitor {
count: usize,
}
impl MathVisitor for CountVisitor {
fn visit(&mut self, _node: &MathNode) {
self.count += 1;
}
}
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let mut visitor = CountVisitor { count: 0 };
expr.accept(&mut visitor);
assert_eq!(visitor.count, 3); }
}