#[cfg(test)]
mod tests {
use crate::core::{Expression, Number, BinaryOperator, UnaryOperator, MathConstant};
use crate::engine::calculus::CalculusEngine;
use num_bigint::BigInt;
use num_rational::BigRational;
fn create_engine() -> CalculusEngine {
CalculusEngine::new()
}
fn var(name: &str) -> Expression {
Expression::Variable(name.to_string())
}
fn int(value: i64) -> Expression {
Expression::Number(Number::Integer(BigInt::from(value)))
}
fn rational(num: i64, den: i64) -> Expression {
Expression::Number(Number::Rational(BigRational::new(
BigInt::from(num),
BigInt::from(den)
)))
}
fn binop(op: BinaryOperator, left: Expression, right: Expression) -> Expression {
Expression::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
fn unop(op: UnaryOperator, operand: Expression) -> Expression {
Expression::UnaryOp {
op,
operand: Box::new(operand),
}
}
#[test]
fn test_differentiate_constants() {
let engine = create_engine();
let result = engine.differentiate(&int(5), "x").unwrap();
assert_eq!(result, int(0));
let pi = Expression::Constant(MathConstant::Pi);
let result = engine.differentiate(&pi, "x").unwrap();
assert_eq!(result, int(0));
}
#[test]
fn test_differentiate_variables() {
let engine = create_engine();
let result = engine.differentiate(&var("x"), "x").unwrap();
assert_eq!(result, int(1));
let result = engine.differentiate(&var("y"), "x").unwrap();
assert_eq!(result, int(0));
}
#[test]
fn test_differentiate_addition() {
let engine = create_engine();
let expr = binop(BinaryOperator::Add, var("x"), int(5));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Add, left, right } => {
assert_eq!(*left, int(1));
assert_eq!(*right, int(0));
}
_ => panic!("期望得到加法表达式"),
}
}
#[test]
fn test_differentiate_subtraction() {
let engine = create_engine();
let expr = binop(BinaryOperator::Subtract, var("x"), int(3));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Subtract, left, right } => {
assert_eq!(*left, int(1));
assert_eq!(*right, int(0));
}
_ => panic!("期望得到减法表达式"),
}
}
#[test]
fn test_differentiate_multiplication() {
let engine = create_engine();
let expr = binop(BinaryOperator::Multiply, var("x"), int(3));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Add, .. } => {
}
_ => panic!("期望得到加法表达式(乘法法则)"),
}
}
#[test]
fn test_differentiate_division() {
let engine = create_engine();
let expr = binop(BinaryOperator::Divide, var("x"), int(2));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Divide, .. } => {
}
_ => panic!("期望得到除法表达式(除法法则)"),
}
}
#[test]
fn test_differentiate_power_constant_exponent() {
let engine = create_engine();
let expr = binop(BinaryOperator::Power, var("x"), int(2));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, .. } => {
}
_ => panic!("期望得到乘法表达式(幂函数法则)"),
}
}
#[test]
fn test_differentiate_power_variable_exponent() {
let engine = create_engine();
let expr = binop(BinaryOperator::Power, var("x"), var("x"));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, .. } => {
}
_ => panic!("期望得到乘法表达式(一般幂函数法则)"),
}
}
#[test]
fn test_differentiate_trigonometric_functions() {
let engine = create_engine();
let sin_x = unop(UnaryOperator::Sin, var("x"));
let result = engine.differentiate(&sin_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
match (&**left, &**right) {
(Expression::UnaryOp { op: UnaryOperator::Cos, .. },
Expression::Number(Number::Integer(n))) => {
assert_eq!(*n, BigInt::from(1));
}
_ => panic!("期望得到 cos(x) * 1"),
}
}
_ => panic!("期望得到乘法表达式"),
}
let cos_x = unop(UnaryOperator::Cos, var("x"));
let result = engine.differentiate(&cos_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
match (&**left, &**right) {
(Expression::UnaryOp { op: UnaryOperator::Negate, .. },
Expression::Number(Number::Integer(n))) => {
assert_eq!(*n, BigInt::from(1));
}
_ => panic!("期望得到 -sin(x) * 1"),
}
}
_ => panic!("期望得到乘法表达式"),
}
let tan_x = unop(UnaryOperator::Tan, var("x"));
let result = engine.differentiate(&tan_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Divide, .. } => {
}
_ => panic!("期望得到除法表达式"),
}
}
#[test]
fn test_differentiate_logarithmic_functions() {
let engine = create_engine();
let ln_x = unop(UnaryOperator::Ln, var("x"));
let result = engine.differentiate(&ln_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Divide, left, right } => {
assert_eq!(*left, int(1));
assert_eq!(*right, var("x"));
}
_ => panic!("期望得到 1/x"),
}
let log10_x = unop(UnaryOperator::Log10, var("x"));
let result = engine.differentiate(&log10_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Divide, .. } => {
}
_ => panic!("期望得到除法表达式"),
}
}
#[test]
fn test_differentiate_exponential_function() {
let engine = create_engine();
let exp_x = unop(UnaryOperator::Exp, var("x"));
let result = engine.differentiate(&exp_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
match (&**left, &**right) {
(Expression::UnaryOp { op: UnaryOperator::Exp, .. },
Expression::Number(Number::Integer(n))) => {
assert_eq!(*n, BigInt::from(1));
}
_ => panic!("期望得到 e^x * 1"),
}
}
_ => panic!("期望得到乘法表达式"),
}
}
#[test]
fn test_differentiate_sqrt() {
let engine = create_engine();
let sqrt_x = unop(UnaryOperator::Sqrt, var("x"));
let result = engine.differentiate(&sqrt_x, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Divide, .. } => {
}
_ => panic!("期望得到除法表达式"),
}
}
#[test]
fn test_differentiate_chain_rule() {
let engine = create_engine();
let x_squared = binop(BinaryOperator::Power, var("x"), int(2));
let sin_x_squared = unop(UnaryOperator::Sin, x_squared);
let result = engine.differentiate(&sin_x_squared, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Multiply, .. } => {
}
_ => panic!("期望得到乘法表达式(链式法则)"),
}
}
#[test]
fn test_differentiate_complex_expression() {
let engine = create_engine();
let x_squared = binop(BinaryOperator::Power, var("x"), int(2));
let three_x = binop(BinaryOperator::Multiply, int(3), var("x"));
let temp = binop(BinaryOperator::Add, x_squared, three_x);
let expr = binop(BinaryOperator::Add, temp, int(1));
let result = engine.differentiate(&expr, "x").unwrap();
match result {
Expression::BinaryOp { op: BinaryOperator::Add, .. } => {
}
_ => panic!("期望得到加法表达式"),
}
}
#[test]
fn test_is_constant_with_respect_to() {
let engine = create_engine();
assert!(engine.is_constant_with_respect_to(&int(5), "x"));
let pi = Expression::Constant(MathConstant::Pi);
assert!(engine.is_constant_with_respect_to(&pi, "x"));
assert!(!engine.is_constant_with_respect_to(&var("x"), "x"));
assert!(engine.is_constant_with_respect_to(&var("y"), "x"));
let expr = binop(BinaryOperator::Add, var("x"), int(1));
assert!(!engine.is_constant_with_respect_to(&expr, "x"));
let expr = binop(BinaryOperator::Add, var("y"), int(1));
assert!(engine.is_constant_with_respect_to(&expr, "x"));
}
}