use crate::core::{Expression, Number, MathConstant, BinaryOperator, UnaryOperator};
use super::{ParseError, Parser, lexer::{Lexer, Token}};
use num_bigint::BigInt;
use num_rational::BigRational;
use bigdecimal::BigDecimal;
use std::str::FromStr;
pub struct SyntaxParser {
lexer: Lexer,
current_token: Token,
}
impl SyntaxParser {
pub fn new(input: String) -> Result<Self, ParseError> {
let mut lexer = Lexer::new(input);
let current_token = lexer.next_token()?;
Ok(Self {
lexer,
current_token,
})
}
pub fn parse(&mut self) -> Result<Expression, ParseError> {
if matches!(self.current_token, Token::EndOfInput) {
return Err(ParseError::EmptyExpression);
}
let expr = self.parse_expression()?;
if !matches!(self.current_token, Token::EndOfInput) {
return Err(ParseError::syntax(
self.lexer.position(),
format!("意外的标记: {:?}", self.current_token)
));
}
Ok(expr)
}
fn parse_expression(&mut self) -> Result<Expression, ParseError> {
self.parse_logical_or()
}
fn parse_logical_or(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_logical_and()?;
while let Token::Operator(op) = &self.current_token {
if op == "||" {
self.advance()?;
let right = self.parse_logical_and()?;
left = Expression::binary_op(BinaryOperator::Or, left, right);
} else {
break;
}
}
Ok(left)
}
fn parse_logical_and(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_equality()?;
while let Token::Operator(op) = &self.current_token {
if op == "&&" {
self.advance()?;
let right = self.parse_equality()?;
left = Expression::binary_op(BinaryOperator::And, left, right);
} else {
break;
}
}
Ok(left)
}
fn parse_equality(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_comparison()?;
while let Token::Operator(op) = &self.current_token {
let binary_op = match op.as_str() {
"==" => BinaryOperator::Equal,
"!=" => BinaryOperator::NotEqual,
_ => break,
};
self.advance()?;
let right = self.parse_comparison()?;
left = Expression::binary_op(binary_op, left, right);
}
Ok(left)
}
fn parse_comparison(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_term()?;
while let Token::Operator(op) = &self.current_token {
let binary_op = match op.as_str() {
"<" => BinaryOperator::Less,
"<=" => BinaryOperator::LessEqual,
">" => BinaryOperator::Greater,
">=" => BinaryOperator::GreaterEqual,
_ => break,
};
self.advance()?;
let right = self.parse_term()?;
left = Expression::binary_op(binary_op, left, right);
}
Ok(left)
}
fn parse_term(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_factor()?;
while let Token::Operator(op) = &self.current_token {
let binary_op = match op.as_str() {
"+" => BinaryOperator::Add,
"-" => BinaryOperator::Subtract,
_ => break,
};
self.advance()?;
let right = self.parse_factor()?;
left = Expression::binary_op(binary_op, left, right);
}
Ok(left)
}
fn parse_factor(&mut self) -> Result<Expression, ParseError> {
let mut left = self.parse_power()?;
while let Token::Operator(op) = &self.current_token {
let binary_op = match op.as_str() {
"*" => BinaryOperator::Multiply,
"/" => BinaryOperator::Divide,
"%" => BinaryOperator::Modulo,
_ => break,
};
self.advance()?;
let right = self.parse_power()?;
left = Expression::binary_op(binary_op, left, right);
}
Ok(left)
}
fn parse_power(&mut self) -> Result<Expression, ParseError> {
let left = self.parse_unary()?;
if let Token::Operator(op) = &self.current_token {
if op == "^" || op == "**" {
self.advance()?;
let right = self.parse_power()?; return Ok(Expression::binary_op(BinaryOperator::Power, left, right));
}
}
Ok(left)
}
fn parse_unary(&mut self) -> Result<Expression, ParseError> {
match &self.current_token {
Token::Operator(op) => {
let unary_op = match op.as_str() {
"-" => UnaryOperator::Negate,
"+" => UnaryOperator::Plus,
"!" => UnaryOperator::Not,
_ => return self.parse_primary(),
};
self.advance()?;
let operand = self.parse_unary()?;
Ok(Expression::unary_op(unary_op, operand))
}
_ => self.parse_primary(),
}
}
fn parse_primary(&mut self) -> Result<Expression, ParseError> {
match &self.current_token.clone() {
Token::Number(num_str) => {
self.advance()?;
self.parse_number(num_str)
}
Token::Identifier(name) => {
let name = name.clone();
self.advance()?;
if matches!(self.current_token, Token::LeftParen) {
self.parse_function_call(name)
} else {
if let Some(constant) = MathConstant::from_str(&name) {
Ok(Expression::constant(constant))
} else {
Ok(Expression::variable(name))
}
}
}
Token::LeftParen => {
self.advance()?; let expr = self.parse_expression()?;
if !matches!(self.current_token, Token::RightParen) {
return Err(ParseError::unmatched_parenthesis(self.lexer.position()));
}
self.advance()?; Ok(expr)
}
Token::LeftBracket => {
self.parse_matrix_or_vector()
}
Token::EndOfInput => {
Err(ParseError::UnexpectedEndOfInput)
}
_ => {
Err(ParseError::syntax(
self.lexer.position(),
format!("意外的标记: {:?}", self.current_token)
))
}
}
}
fn parse_number(&self, num_str: &str) -> Result<Expression, ParseError> {
if !num_str.contains('.') && !num_str.contains('e') && !num_str.contains('E') {
if let Ok(int_val) = BigInt::from_str(num_str) {
return Ok(Expression::number(Number::Integer(int_val)));
}
}
if num_str.contains('.') && !num_str.contains('e') && !num_str.contains('E') {
if let Ok(decimal) = BigDecimal::from_str(num_str) {
if let Some(rational) = self.decimal_to_rational(&decimal) {
return Ok(Expression::number(Number::Rational(rational)));
} else {
return Ok(Expression::number(Number::Real(decimal)));
}
}
}
if let Ok(decimal) = BigDecimal::from_str(num_str) {
Ok(Expression::number(Number::Real(decimal)))
} else {
Err(ParseError::invalid_number(num_str.to_string()))
}
}
fn decimal_to_rational(&self, decimal: &BigDecimal) -> Option<BigRational> {
let decimal_str = decimal.to_string();
if let Some(dot_pos) = decimal_str.find('.') {
let integer_part = &decimal_str[..dot_pos];
let fractional_part = &decimal_str[dot_pos + 1..];
if let (Ok(int_part), Ok(frac_part)) = (
BigInt::from_str(integer_part),
BigInt::from_str(fractional_part)
) {
let denominator = BigInt::from(10).pow(fractional_part.len() as u32);
let numerator = int_part * &denominator + frac_part;
return Some(BigRational::new(numerator, denominator));
}
}
None
}
fn parse_function_call(&mut self, name: String) -> Result<Expression, ParseError> {
if !matches!(self.current_token, Token::LeftParen) {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 '(' 开始函数参数列表".to_string()
));
}
self.advance()?;
let mut args = Vec::new();
if matches!(self.current_token, Token::RightParen) {
self.advance()?;
return Ok(Expression::function(name, args));
}
loop {
args.push(self.parse_expression()?);
match &self.current_token {
Token::Comma => {
self.advance()?;
continue;
}
Token::RightParen => {
self.advance()?;
break;
}
_ => {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 ',' 或 ')' 在函数参数列表中".to_string()
));
}
}
}
Ok(Expression::function(name, args))
}
fn parse_matrix_or_vector(&mut self) -> Result<Expression, ParseError> {
self.advance()?;
if matches!(self.current_token, Token::RightBracket) {
self.advance()?;
return Ok(Expression::Vector(Vec::new()));
}
if matches!(self.current_token, Token::LeftBracket) {
let mut rows = Vec::new();
rows.push(self.parse_matrix_row()?);
while matches!(self.current_token, Token::Comma) {
self.advance()?;
if matches!(self.current_token, Token::RightBracket) {
break; }
rows.push(self.parse_matrix_row()?);
}
if !matches!(self.current_token, Token::RightBracket) {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 ']' 结束矩阵".to_string()
));
}
self.advance()?;
Expression::matrix(rows).map_err(|e| ParseError::syntax(self.lexer.position(), e))
} else {
let mut elements = Vec::new();
elements.push(self.parse_expression()?);
while matches!(self.current_token, Token::Comma) {
self.advance()?;
if matches!(self.current_token, Token::RightBracket) {
break; }
elements.push(self.parse_expression()?);
}
if !matches!(self.current_token, Token::RightBracket) {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 ']' 结束向量".to_string()
));
}
self.advance()?;
Ok(Expression::Vector(elements))
}
}
fn parse_matrix_row(&mut self) -> Result<Vec<Expression>, ParseError> {
if !matches!(self.current_token, Token::LeftBracket) {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 '[' 开始矩阵行".to_string()
));
}
self.advance()?;
let mut elements = Vec::new();
if matches!(self.current_token, Token::RightBracket) {
self.advance()?;
return Ok(elements);
}
loop {
elements.push(self.parse_expression()?);
match &self.current_token {
Token::Comma => {
self.advance()?;
continue;
}
Token::RightBracket => {
self.advance()?;
break;
}
_ => {
return Err(ParseError::syntax(
self.lexer.position(),
"期望 ',' 或 ']' 在矩阵行中".to_string()
));
}
}
}
Ok(elements)
}
fn advance(&mut self) -> Result<(), ParseError> {
self.current_token = self.lexer.next_token()?;
Ok(())
}
}
pub struct ExpressionParser;
impl ExpressionParser {
pub fn new() -> Self {
Self
}
}
impl Default for ExpressionParser {
fn default() -> Self {
Self::new()
}
}
impl Parser for ExpressionParser {
fn parse(&self, input: &str) -> Result<Expression, ParseError> {
let mut parser = SyntaxParser::new(input.to_string())?;
parser.parse()
}
fn validate(&self, input: &str) -> Result<(), ParseError> {
self.parse(input)?;
Ok(())
}
}