use std::collections::HashMap;
use std::f64::consts;
use std::rc::Rc;
use std::fmt;
use crate::term::Term;
use crate::func::Func;
use crate::num::Num;
#[derive(Clone)]
pub struct Context<N: Num> {
pub vars: HashMap<String, Term<N>>,
pub funcs: HashMap<String, Rc<dyn Func<N>>>,
pub cfg: Config,
}
#[derive(Debug, Clone)]
pub struct Config {
pub implicit_multiplication: bool,
pub precision: u32,
pub sqrt_both: bool,
}
impl<N: Num + 'static> Context<N> {
pub fn new() -> Self {
use self::funcs::*;
let mut ctx: Context<N> = Context::empty();
let empty = Context::empty();
ctx.set_var("pi", N::from_f64(consts::PI, &empty).unwrap());
ctx.set_var("e", N::from_f64(consts::E, &empty).unwrap());
ctx.set_var("i", N::from_f64_complex((0.0, 1.0), &empty).unwrap());
ctx.funcs.insert("sin".to_string(), Rc::new(Sin));
ctx.funcs.insert("cos".to_string(), Rc::new(Cos));
ctx.funcs.insert("max".to_string(), Rc::new(Max));
ctx.funcs.insert("min".to_string(), Rc::new(Min));
ctx.funcs.insert("sqrt".to_string(), Rc::new(Sqrt));
ctx.funcs.insert("nrt".to_string(), Rc::new(Nrt));
ctx.funcs.insert("tan".to_string(), Rc::new(Tan));
ctx.funcs.insert("abs".to_string(), Rc::new(Abs));
ctx.funcs.insert("asin".to_string(), Rc::new(Asin));
ctx.funcs.insert("acos".to_string(), Rc::new(Acos));
ctx.funcs.insert("atan".to_string(), Rc::new(Atan));
ctx.funcs.insert("atant".to_string(), Rc::new(Atan2));
ctx.funcs.insert("floor".to_string(), Rc::new(Floor));
ctx.funcs.insert("round".to_string(), Rc::new(Round));
ctx.funcs.insert("log".to_string(), Rc::new(Log));
ctx
}
pub fn set_var<T: Into<Term<N>>>(&mut self, name: &str, val: T) {
self.vars.insert(name.to_string(), val.into());
}
pub fn set_func<F: Func<N> + 'static>(&mut self, name: &str, func: F) {
self.funcs.insert(name.to_string(), Rc::new(func));
}
pub fn empty() -> Self {
Context {
vars: HashMap::new(),
funcs: HashMap::new(),
cfg: Config::new(),
}
}
}
impl Config {
pub fn new() -> Self {
Config {
implicit_multiplication: true,
precision: 53,
sqrt_both: true,
}
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
impl<N: Num + 'static> Default for Context<N> {
fn default() -> Self {
Self::new()
}
}
impl<N: Num> fmt::Debug for Context<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Context {{ vars: {:?}, funcs: {{{}}} }}", self.vars, {
let mut output = String::new();
for (i, key) in self.funcs.keys().enumerate() {
output.push_str(key);
if i + 1 < self.funcs.len() {
output.push_str(", ");
}
}
output
})
}
}
pub(in crate::context) mod funcs {
use std::cmp::Ordering;
use crate::context::Context;
use crate::term::Term;
use crate::errors::MathError;
use crate::func::Func;
use crate::opers::Calculation;
use crate::num::Num;
use crate::answer::Answer;
pub struct Sin;
impl<N: Num + 'static> Func<N> for Sin {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::sin(a, ctx))
}
}
pub struct Cos;
impl<N: Num + 'static> Func<N> for Cos {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::cos(a, ctx))
}
}
pub struct Max;
impl<N: Num + 'static> Func<N> for Max {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.is_empty() {
return Err(MathError::IncorrectArguments);
}
let mut extra = Vec::new();
let mut max = match args[0].eval_ctx(ctx)? {
Answer::Single(n) => n,
Answer::Multiple(mut ns) => {
let one = ns.pop().unwrap();
extra = ns;
one
}
};
let args: Vec<Answer<N>> = args.iter()
.map(|term| term.eval_ctx(ctx))
.collect::<Result<Vec<Answer<N>>, MathError>>()?;
let mut new_args = Vec::new();
for a in args {
match a {
Answer::Single(n) => new_args.push(n),
Answer::Multiple(mut ns) => new_args.append(&mut ns),
}
}
for arg in new_args[1..new_args.len()].iter().chain(extra.iter()) {
if Num::tryord(arg, &max, ctx)? == Ordering::Greater {
max = arg.clone();
}
}
Ok(Answer::Single(max))
}
}
pub struct Min;
impl<N: Num + 'static> Func<N> for Min {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.is_empty() {
return Err(MathError::IncorrectArguments);
}
let mut extra = Vec::new();
let mut min = match args[0].eval_ctx(ctx)? {
Answer::Single(n) => n,
Answer::Multiple(mut ns) => {
let one = ns.pop().unwrap();
extra = ns;
one
}
};
let args: Vec<Answer<N>> = args.iter()
.map(|term| term.eval_ctx(ctx))
.collect::<Result<Vec<Answer<N>>, MathError>>()?;
let mut new_args = Vec::new();
for a in args {
match a {
Answer::Single(n) => new_args.push(n),
Answer::Multiple(mut ns) => new_args.append(&mut ns),
}
}
for arg in new_args[1..new_args.len()].iter().chain(extra.iter()) {
if Num::tryord(arg, &min, ctx)? == Ordering::Less {
min = arg.clone();
}
}
Ok(Answer::Single(min))
}
}
pub struct Sqrt;
impl<N: Num + 'static> Func<N> for Sqrt {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::sqrt(a, ctx))
}
}
pub struct Nrt;
impl<N: Num + 'static> Func<N> for Nrt {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 2 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
let b = args[1].eval_ctx(ctx)?;
a.op(&b, |a, b| Num::nrt(a, b, ctx))
}
}
pub struct Abs;
impl<N: Num + 'static> Func<N> for Abs {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::abs(a, ctx))
}
}
pub struct Tan;
impl<N: Num + 'static> Func<N> for Tan {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::tan(a, ctx))
}
}
pub struct Asin;
impl<N: Num + 'static> Func<N> for Asin {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::asin(a, ctx))
}
}
pub struct Acos;
impl<N: Num + 'static> Func<N> for Acos {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::acos(a, ctx))
}
}
pub struct Atan;
impl<N: Num + 'static> Func<N> for Atan {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::atan(a, ctx))
}
}
pub struct Atan2;
impl<N: Num + 'static> Func<N> for Atan2 {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 2 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
let b = args[1].eval_ctx(ctx)?;
a.op(&b, |a, b| Num::atan2(a, b, ctx))
}
}
pub struct Floor;
impl<N: Num + 'static> Func<N> for Floor {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::floor(a, ctx))
}
}
pub struct Ceil;
impl<N: Num + 'static> Func<N> for Ceil {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::ceil(a, ctx))
}
}
pub struct Round;
impl<N: Num + 'static> Func<N> for Round {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 1 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
a.unop(|a| Num::round(a, ctx))
}
}
pub struct Log;
impl<N: Num + 'static> Func<N> for Log {
fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
if args.len() != 2 {
return Err(MathError::IncorrectArguments);
}
let a = args[0].eval_ctx(ctx)?;
let b = args[1].eval_ctx(ctx)?;
a.op(&b, |a, b| Num::log(a, b, ctx))
}
}
}