1use std::collections::HashMap;
2use std::f64::consts;
3use std::rc::Rc;
4use std::fmt;
5
6use crate::term::Term;
7use crate::func::Func;
8use crate::num::Num;
9
10#[derive(Clone)]
93pub struct Context<N: Num> {
94 pub vars: HashMap<String, Term<N>>,
96 pub funcs: HashMap<String, Rc<dyn Func<N>>>,
98 pub cfg: Config,
100}
101
102#[derive(Debug, Clone)]
104pub struct Config {
105 pub implicit_multiplication: bool,
107 pub precision: u32,
109 pub sqrt_both: bool,
111}
112
113impl<N: Num + 'static> Context<N> {
114 pub fn new() -> Self {
116 use self::funcs::*;
117
118 let mut ctx: Context<N> = Context::empty();
119
120 let empty = Context::empty();
121
122 ctx.set_var("pi", N::from_f64(consts::PI, &empty).unwrap());
123 ctx.set_var("e", N::from_f64(consts::E, &empty).unwrap());
124 ctx.set_var("i", N::from_f64_complex((0.0, 1.0), &empty).unwrap());
125
126 ctx.funcs.insert("sin".to_string(), Rc::new(Sin));
127 ctx.funcs.insert("cos".to_string(), Rc::new(Cos));
128 ctx.funcs.insert("max".to_string(), Rc::new(Max));
129 ctx.funcs.insert("min".to_string(), Rc::new(Min));
130 ctx.funcs.insert("sqrt".to_string(), Rc::new(Sqrt));
131 ctx.funcs.insert("nrt".to_string(), Rc::new(Nrt));
132 ctx.funcs.insert("tan".to_string(), Rc::new(Tan));
133 ctx.funcs.insert("abs".to_string(), Rc::new(Abs));
134 ctx.funcs.insert("asin".to_string(), Rc::new(Asin));
135 ctx.funcs.insert("acos".to_string(), Rc::new(Acos));
136 ctx.funcs.insert("atan".to_string(), Rc::new(Atan));
137 ctx.funcs.insert("atant".to_string(), Rc::new(Atan2));
138 ctx.funcs.insert("floor".to_string(), Rc::new(Floor));
139 ctx.funcs.insert("round".to_string(), Rc::new(Round));
140 ctx.funcs.insert("log".to_string(), Rc::new(Log));
141
142 ctx
143 }
144
145 pub fn set_var<T: Into<Term<N>>>(&mut self, name: &str, val: T) {
147 self.vars.insert(name.to_string(), val.into());
148 }
149
150 pub fn set_func<F: Func<N> + 'static>(&mut self, name: &str, func: F) {
152 self.funcs.insert(name.to_string(), Rc::new(func));
153 }
154
155 pub fn empty() -> Self {
157 Context {
158 vars: HashMap::new(),
159 funcs: HashMap::new(),
160 cfg: Config::new(),
161 }
162 }
163}
164
165impl Config {
166 pub fn new() -> Self {
168 Config {
169 implicit_multiplication: true,
170 precision: 53,
171 sqrt_both: true,
172 }
173 }
174}
175
176impl Default for Config {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl<N: Num + 'static> Default for Context<N> {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl<N: Num> fmt::Debug for Context<N> {
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 write!(f, "Context {{ vars: {:?}, funcs: {{{}}} }}", self.vars, {
191 let mut output = String::new();
192 for (i, key) in self.funcs.keys().enumerate() {
193 output.push_str(key);
194 if i + 1 < self.funcs.len() {
195 output.push_str(", ");
196 }
197 }
198 output
199 })
200 }
201}
202
203pub(in crate::context) mod funcs {
204 use std::cmp::Ordering;
205
206 use crate::context::Context;
207 use crate::term::Term;
208 use crate::errors::MathError;
209 use crate::func::Func;
210 use crate::opers::Calculation;
211 use crate::num::Num;
212 use crate::answer::Answer;
213
214 pub struct Sin;
215 impl<N: Num + 'static> Func<N> for Sin {
216 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
217 if args.len() != 1 {
218 return Err(MathError::IncorrectArguments);
219 }
220
221 let a = args[0].eval_ctx(ctx)?;
222
223 a.unop(|a| Num::sin(a, ctx))
224 }
225 }
226
227 pub struct Cos;
228 impl<N: Num + 'static> Func<N> for Cos {
229 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
230 if args.len() != 1 {
231 return Err(MathError::IncorrectArguments);
232 }
233
234 let a = args[0].eval_ctx(ctx)?;
235
236 a.unop(|a| Num::cos(a, ctx))
237 }
238 }
239
240 pub struct Max;
241 impl<N: Num + 'static> Func<N> for Max {
242 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
243 if args.is_empty() {
244 return Err(MathError::IncorrectArguments);
245 }
246 let mut extra = Vec::new();
247 let mut max = match args[0].eval_ctx(ctx)? {
248 Answer::Single(n) => n,
249 Answer::Multiple(mut ns) => {
250 let one = ns.pop().unwrap();
251 extra = ns;
252 one
253 }
254 };
255
256 let args: Vec<Answer<N>> = args.iter()
258 .map(|term| term.eval_ctx(ctx))
259 .collect::<Result<Vec<Answer<N>>, MathError>>()?;
260 let mut new_args = Vec::new();
261 for a in args {
263 match a {
264 Answer::Single(n) => new_args.push(n),
265 Answer::Multiple(mut ns) => new_args.append(&mut ns),
266 }
267 }
268 for arg in new_args[1..new_args.len()].iter().chain(extra.iter()) {
270 if Num::tryord(arg, &max, ctx)? == Ordering::Greater {
271 max = arg.clone();
272 }
273 }
274 Ok(Answer::Single(max))
275 }
276 }
277
278 pub struct Min;
279 impl<N: Num + 'static> Func<N> for Min {
280 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
281 if args.is_empty() {
282 return Err(MathError::IncorrectArguments);
283 }
284 let mut extra = Vec::new();
285 let mut min = match args[0].eval_ctx(ctx)? {
286 Answer::Single(n) => n,
287 Answer::Multiple(mut ns) => {
288 let one = ns.pop().unwrap();
289 extra = ns;
290 one
291 }
292 };
293
294 let args: Vec<Answer<N>> = args.iter()
296 .map(|term| term.eval_ctx(ctx))
297 .collect::<Result<Vec<Answer<N>>, MathError>>()?;
298 let mut new_args = Vec::new();
299 for a in args {
301 match a {
302 Answer::Single(n) => new_args.push(n),
303 Answer::Multiple(mut ns) => new_args.append(&mut ns),
304 }
305 }
306 for arg in new_args[1..new_args.len()].iter().chain(extra.iter()) {
308 if Num::tryord(arg, &min, ctx)? == Ordering::Less {
309 min = arg.clone();
310 }
311 }
312 Ok(Answer::Single(min))
313 }
314 }
315
316 pub struct Sqrt;
317 impl<N: Num + 'static> Func<N> for Sqrt {
318 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
319 if args.len() != 1 {
320 return Err(MathError::IncorrectArguments);
321 }
322
323 let a = args[0].eval_ctx(ctx)?;
324
325 a.unop(|a| Num::sqrt(a, ctx))
326 }
327 }
328
329 pub struct Nrt;
330 impl<N: Num + 'static> Func<N> for Nrt {
331 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
332 if args.len() != 2 {
333 return Err(MathError::IncorrectArguments);
334 }
335
336 let a = args[0].eval_ctx(ctx)?;
337 let b = args[1].eval_ctx(ctx)?;
338
339 a.op(&b, |a, b| Num::nrt(a, b, ctx))
340 }
341 }
342
343 pub struct Abs;
344 impl<N: Num + 'static> Func<N> for Abs {
345 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
346 if args.len() != 1 {
347 return Err(MathError::IncorrectArguments);
348 }
349
350 let a = args[0].eval_ctx(ctx)?;
351
352 a.unop(|a| Num::abs(a, ctx))
353 }
354 }
355
356 pub struct Tan;
357 impl<N: Num + 'static> Func<N> for Tan {
358 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
359 if args.len() != 1 {
360 return Err(MathError::IncorrectArguments);
361 }
362
363 let a = args[0].eval_ctx(ctx)?;
364
365 a.unop(|a| Num::tan(a, ctx))
366 }
367 }
368
369 pub struct Asin;
370 impl<N: Num + 'static> Func<N> for Asin {
371 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
372 if args.len() != 1 {
373 return Err(MathError::IncorrectArguments);
374 }
375
376 let a = args[0].eval_ctx(ctx)?;
377
378 a.unop(|a| Num::asin(a, ctx))
379 }
380 }
381
382 pub struct Acos;
383 impl<N: Num + 'static> Func<N> for Acos {
384 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
385 if args.len() != 1 {
386 return Err(MathError::IncorrectArguments);
387 }
388
389 let a = args[0].eval_ctx(ctx)?;
390
391 a.unop(|a| Num::acos(a, ctx))
392 }
393 }
394
395 pub struct Atan;
396 impl<N: Num + 'static> Func<N> for Atan {
397 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
398 if args.len() != 1 {
399 return Err(MathError::IncorrectArguments);
400 }
401
402 let a = args[0].eval_ctx(ctx)?;
403
404 a.unop(|a| Num::atan(a, ctx))
405 }
406 }
407
408 pub struct Atan2;
409 impl<N: Num + 'static> Func<N> for Atan2 {
410 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
411 if args.len() != 2 {
412 return Err(MathError::IncorrectArguments);
413 }
414
415 let a = args[0].eval_ctx(ctx)?;
416 let b = args[1].eval_ctx(ctx)?;
417
418 a.op(&b, |a, b| Num::atan2(a, b, ctx))
419 }
420 }
421
422 pub struct Floor;
423 impl<N: Num + 'static> Func<N> for Floor {
424 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
425 if args.len() != 1 {
426 return Err(MathError::IncorrectArguments);
427 }
428
429 let a = args[0].eval_ctx(ctx)?;
430
431 a.unop(|a| Num::floor(a, ctx))
432 }
433 }
434
435 pub struct Ceil;
436 impl<N: Num + 'static> Func<N> for Ceil {
437 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
438 if args.len() != 1 {
439 return Err(MathError::IncorrectArguments);
440 }
441
442 let a = args[0].eval_ctx(ctx)?;
443
444 a.unop(|a| Num::ceil(a, ctx))
445 }
446 }
447
448 pub struct Round;
449 impl<N: Num + 'static> Func<N> for Round {
450 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
451 if args.len() != 1 {
452 return Err(MathError::IncorrectArguments);
453 }
454
455 let a = args[0].eval_ctx(ctx)?;
456
457 a.unop(|a| Num::round(a, ctx))
458 }
459 }
460
461 pub struct Log;
462 impl<N: Num + 'static> Func<N> for Log {
463 fn eval(&self, args: &[Term<N>], ctx: &Context<N>) -> Calculation<N> {
464 if args.len() != 2 {
465 return Err(MathError::IncorrectArguments);
466 }
467
468 let a = args[0].eval_ctx(ctx)?;
469 let b = args[1].eval_ctx(ctx)?;
470
471 a.op(&b, |a, b| Num::log(a, b, ctx))
472 }
473 }
474}