use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use crate::input_reader::InputReader;
use crate::interpret::{interpret, interpret_with_definitions};
pub(crate) mod lex;
pub(crate) mod input_reader;
pub(crate) mod postfix;
pub(crate) mod interpret;
pub(crate) mod operator;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Error {
DivByZero,
NegativeExponent,
InvalidCharacter {
c: char
},
InvalidNumber {
found: String
},
Expected {
expected: String,
found: String
},
UnexpectedEOI,
InvalidOperand {
op: String
},
InvalidOperator {
op: String
},
InvalidExpression {
reason: String
},
UndefinedVariable {
name: String
},
UndefinedFunction {
name: String
},
InvalidArgumentCount {
name: String,
expected: usize,
got: usize
},
InvalidArgument {
name: String,
value: String
},
InvalidLeadingOperator {
op: String
},
MissingOperator,
MismatchedParentheses {
found: char,
missing: char
},
Other(String),
}
impl Error {
pub fn arg_count<S: Into<String>>(name: S, expected: usize, got: usize) -> Error {
Error::InvalidArgumentCount {
name: name.into(),
expected,
got
}
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Error::DivByZero => write!(f, "Can't divide by zero"),
Error::NegativeExponent => write!(f, "Can't raise a value to a negative power"),
Error::InvalidCharacter { c } => write!(f, "Invalid character: {}", c),
Error::InvalidNumber { found } => write!(f, "Invalid number: {}", found),
Error::Expected { expected, found } => write!(f, "Expected '{}', found '{}'", expected, found),
Error::UnexpectedEOI => write!(f, "Unexpected end of input"),
Error::InvalidOperand { op } => write!(f, "Invalid operand: {}", op),
Error::InvalidOperator { op } => write!(f, "Invalid operator: {}", op),
Error::InvalidExpression { reason } => write!(f, "Invalid expression: {}", reason),
Error::UndefinedVariable { name } => write!(f, "Undefined variable: {}", name),
Error::UndefinedFunction { name } => write!(f, "Undefined function: {}", name),
Error::InvalidArgumentCount { name, expected, got } => write!(f, "Invalid argument count for function '{}': expected {}, got {}", name, expected, got),
Error::InvalidArgument { name, value } => write!(f, "Invalid argument for function '{}': {}", name, value),
Error::InvalidLeadingOperator { op } => write!(f, "Invalid leading operator: {}", op),
Error::MissingOperator => write!(f, "Missing operator"),
Error::MismatchedParentheses { found, missing } => write!(f, "Mismatched parentheses: found '{}', missing '{}'", found, missing),
Error::Other(s) => write!(f, "{}", s),
}
}
}
pub struct Definitions {
pub(crate) map: HashMap<String, f64>,
}
impl Definitions {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn register<S: Into<String>, N: Into<f64>>(&mut self, name: S, value: N) {
self.map.insert(name.into(), value.into());
}
pub fn exists<S: Into<String>>(&self, ident: S) -> bool {
self.map.contains_key(ident.into().as_str())
}
pub(crate) fn get<S: Into<String>>(&self, ident: S) -> Option<&f64> {
self.map.get(ident.into().as_str())
}
}
pub struct Functions<'a> {
pub(crate) functions: HashMap<String, Box<dyn Fn(Vec<f64>) -> Result<f64, Error> + 'a>>,
}
impl<'a> Functions<'a> {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
}
}
pub fn register<S: Into<String>, F: Fn(Vec<f64>) -> Result<f64, Error> + 'a + Copy>(&mut self, name: S, f: F) {
self.functions.insert(name.into(), Box::new(f));
}
pub fn exists<S: Into<String>>(&self, ident: S) -> bool {
self.functions.contains_key(ident.into().as_str())
}
pub(crate) fn get<S: Into<String>>(&self, ident: S) -> Option<&Box<dyn Fn(Vec<f64>) -> Result<f64, Error> + 'a>> {
let ident = ident.into();
if !self.functions.contains_key(&ident) {
return None;
}
self.functions.get(&ident)
}
}
impl Default for Functions<'_> {
fn default() -> Self {
let mut funcs = Functions::new();
funcs.register("log", |args| {
if args.len() != 2 {
return Err(Error::arg_count("log", 2, args.len()));
}
Ok(args[1].log(args[0]))
});
funcs.register("sqrt", |args| {
if args.len() != 1 {
return Err(Error::arg_count("sqrt", 1, args.len()));
}
Ok(args[0].sqrt())
});
funcs.register("sin", |args| {
if args.len() != 1 {
return Err(Error::arg_count("sin", 1, args.len()));
}
Ok(args[0].sin())
});
funcs.register("cos", |args| {
if args.len() != 1 {
return Err(Error::arg_count("cos", 1, args.len()));
}
Ok(args[0].cos())
});
funcs.register("tan", |args| {
if args.len() != 1 {
return Err(Error::arg_count("tan", 1, args.len()));
}
Ok(args[0].tan())
});
funcs
}
}
pub fn evaluate<S: Into<String>>(input: S) -> Result<f64, Error> {
let mut input = InputReader::new(input.into());
let mut tokens = lex::lex(&mut input, false)?;
let mut shunted = postfix::shunting_yard(&mut tokens)?;
interpret(&mut shunted)
}
pub fn evaluate_with_defined<S: Into<String>>(input: S, definitions: Option<&Definitions>, functions: Option<&Functions>) -> Result<f64, Error> {
let mut input = InputReader::new(input.into());
let mut tokens = lex::lex(&mut input, definitions.is_some() || functions.is_some())?;
let mut shunted = postfix::shunting_yard(&mut tokens)?;
interpret_with_definitions(&mut shunted, definitions, functions)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test1() {
let expression = "(2 + 1) - 50 * 12 / 18 - (3 + 1) * 5";
let eval = evaluate(expression);
if eval.is_err() {
panic!("Encountered an error evaluating: {}", eval.unwrap_err());
}
println!("{}", eval.unwrap());
}
}