use std::marker::PhantomData;
use std::ops::{Add, Mul, Sub};
use num_traits::{CheckedDiv, NumOps, Zero};
use crate::Node;
use crate::error::RoplatError;
pub struct OpAdd<T>(PhantomData<T>);
impl<T> Default for OpAdd<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpAdd<T>
where
T: Add<Output = T> + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0 + input.1)
}
}
pub struct OpSub<T>(PhantomData<T>);
impl<T> Default for OpSub<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpSub<T>
where
T: Sub<Output = T> + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0 - input.1)
}
}
pub struct OpMul<T>(PhantomData<T>);
impl<T> Default for OpMul<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpMul<T>
where
T: Mul<Output = T> + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0 * input.1)
}
}
pub struct OpDiv<T>(PhantomData<T>);
impl<T> Default for OpDiv<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpDiv<T>
where
T: CheckedDiv + Zero + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
input
.0
.checked_div(&input.1)
.ok_or_else(|| RoplatError::Arithmetic("除零错误".into()))
}
}
pub struct OpRem<T>(PhantomData<T>);
impl<T> Default for OpRem<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpRem<T>
where
T: NumOps + Zero + PartialEq + Copy + Send + Sync + 'static,
for<'a> &'a T: std::ops::Rem<Output = T>,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
if input.1 == T::zero() {
return Err(RoplatError::Arithmetic("模零错误".into()));
}
Ok(input.0 % input.1)
}
}
pub struct OpNeg<T>(PhantomData<T>);
impl<T> Default for OpNeg<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpNeg<T>
where
T: std::ops::Neg<Output = T> + Send + Sync + Copy + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(-input)
}
}
pub struct OpAbs<T>(PhantomData<T>);
impl<T> Default for OpAbs<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpAbs<T>
where
T: num_traits::Signed + Send + Sync + Copy + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.abs())
}
}
pub struct OpMax<T>(PhantomData<T>);
impl<T> Default for OpMax<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpMax<T>
where
T: Ord + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0.max(input.1))
}
}
pub struct OpMin<T>(PhantomData<T>);
impl<T> Default for OpMin<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpMin<T>
where
T: Ord + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0.min(input.1))
}
}
pub struct OpPow<T>(PhantomData<T>);
impl<T> Default for OpPow<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpPow<T>
where
T: num_traits::Float + Send + Sync + Copy + 'static,
{
type Input = (T, T);
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
Ok(input.0.powf(input.1))
}
}
pub struct OpSqrt<T>(PhantomData<T>);
impl<T> Default for OpSqrt<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Node for OpSqrt<T>
where
T: num_traits::Float + Send + Sync + Copy + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: Self::Input) -> Result<T, RoplatError> {
if input < T::zero() {
return Err(RoplatError::Arithmetic("负数无法开平方根".into()));
}
Ok(input.sqrt())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_op_add_integer() {
let mut node = OpAdd::<i32>::default();
assert_eq!(node.process((5, 3)).await.unwrap(), 8);
assert_eq!(node.process((-5, 3)).await.unwrap(), -2);
assert_eq!(node.process((0, 0)).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_add_float() {
let mut node = OpAdd::<f64>::default();
let result = node.process((1.5, 2.3)).await.unwrap();
assert!((result - 3.8).abs() < 1e-10);
}
#[tokio::test]
async fn test_op_sub_integer() {
let mut node = OpSub::<i32>::default();
assert_eq!(node.process((5, 3)).await.unwrap(), 2);
assert_eq!(node.process((3, 5)).await.unwrap(), -2);
assert_eq!(node.process((0, 0)).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_mul_integer() {
let mut node = OpMul::<i32>::default();
assert_eq!(node.process((5, 3)).await.unwrap(), 15);
assert_eq!(node.process((-5, 3)).await.unwrap(), -15);
assert_eq!(node.process((5, 0)).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_div_integer_success() {
let mut node = OpDiv::<i32>::default();
assert_eq!(node.process((10, 2)).await.unwrap(), 5);
assert_eq!(node.process((-10, 2)).await.unwrap(), -5);
assert_eq!(node.process((0, 5)).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_div_integer_zero_division() {
let mut node = OpDiv::<i32>::default();
let result = node.process((10, 0)).await;
assert!(result.is_err());
if let Err(RoplatError::Arithmetic(msg)) = result {
assert!(msg.contains("除零"));
} else {
panic!("Expected Arithmetic error");
}
}
#[tokio::test]
async fn test_op_rem_integer() {
let mut node = OpRem::<i32>::default();
assert_eq!(node.process((10, 3)).await.unwrap(), 1);
assert_eq!(node.process((10, 2)).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_rem_zero_division() {
let mut node = OpRem::<i32>::default();
let result = node.process((10, 0)).await;
assert!(result.is_err());
if let Err(RoplatError::Arithmetic(msg)) = result {
assert!(msg.contains("模零"));
} else {
panic!("Expected Arithmetic error");
}
}
#[tokio::test]
async fn test_op_neg_integer() {
let mut node = OpNeg::<i32>::default();
assert_eq!(node.process(5).await.unwrap(), -5);
assert_eq!(node.process(-5).await.unwrap(), 5);
assert_eq!(node.process(0).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_neg_float() {
let mut node = OpNeg::<f64>::default();
let result = node.process(std::f64::consts::PI).await.unwrap();
assert!((result + std::f64::consts::PI).abs() < 1e-10);
}
#[tokio::test]
async fn test_op_abs_integer() {
let mut node = OpAbs::<i32>::default();
assert_eq!(node.process(5).await.unwrap(), 5);
assert_eq!(node.process(-5).await.unwrap(), 5);
assert_eq!(node.process(0).await.unwrap(), 0);
}
#[tokio::test]
async fn test_op_abs_float() {
let mut node = OpAbs::<f64>::default();
assert_eq!(
node.process(std::f64::consts::PI).await.unwrap(),
std::f64::consts::PI
);
assert_eq!(
node.process(-std::f64::consts::PI).await.unwrap(),
std::f64::consts::PI
);
}
#[tokio::test]
async fn test_op_max_integer() {
let mut node = OpMax::<i32>::default();
assert_eq!(node.process((5, 3)).await.unwrap(), 5);
assert_eq!(node.process((3, 5)).await.unwrap(), 5);
assert_eq!(node.process((5, 5)).await.unwrap(), 5);
}
#[tokio::test]
async fn test_op_min_integer() {
let mut node = OpMin::<i32>::default();
assert_eq!(node.process((5, 3)).await.unwrap(), 3);
assert_eq!(node.process((3, 5)).await.unwrap(), 3);
assert_eq!(node.process((5, 5)).await.unwrap(), 5);
}
#[tokio::test]
async fn test_op_pow_float() {
let mut node = OpPow::<f64>::default();
let result = node.process((2.0, 3.0)).await.unwrap();
assert!((result - 8.0).abs() < 1e-10);
let result = node.process((10.0, 2.0)).await.unwrap();
assert!((result - 100.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_op_pow_fractional() {
let mut node = OpPow::<f64>::default();
let result = node.process((4.0, 0.5)).await.unwrap();
assert!((result - 2.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_op_sqrt_positive() {
let mut node = OpSqrt::<f64>::default();
let result = node.process(9.0).await.unwrap();
assert!((result - 3.0).abs() < 1e-10);
let result = node.process(2.0).await.unwrap();
assert!((result - std::f64::consts::SQRT_2).abs() < 1e-8);
}
#[tokio::test]
async fn test_op_sqrt_zero() {
let mut node = OpSqrt::<f64>::default();
let result = node.process(0.0).await.unwrap();
assert!((result - 0.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_op_sqrt_negative() {
let mut node = OpSqrt::<f64>::default();
let result = node.process(-1.0).await;
assert!(result.is_err());
if let Err(RoplatError::Arithmetic(msg)) = result {
assert!(msg.contains("负数") || msg.contains("平方根"));
} else {
panic!("Expected Arithmetic error");
}
}
#[tokio::test]
async fn test_arithmetic_with_large_numbers() {
let mut add = OpAdd::<i64>::default();
let mut mul = OpMul::<i64>::default();
assert_eq!(
add.process((1_000_000, 2_000_000)).await.unwrap(),
3_000_000
);
assert_eq!(mul.process((1000, 1000)).await.unwrap(), 1_000_000);
}
#[tokio::test]
async fn test_arithmetic_with_negative_numbers() {
let mut add = OpAdd::<i32>::default();
let mut sub = OpSub::<i32>::default();
let mut mul = OpMul::<i32>::default();
assert_eq!(add.process((-5, -3)).await.unwrap(), -8);
assert_eq!(add.process((-5, 3)).await.unwrap(), -2);
assert_eq!(sub.process((-5, -3)).await.unwrap(), -2);
assert_eq!(sub.process((5, -3)).await.unwrap(), 8);
assert_eq!(mul.process((-5, 3)).await.unwrap(), -15);
assert_eq!(mul.process((-5, -3)).await.unwrap(), 15);
}
#[tokio::test]
async fn test_arithmetic_identity() {
let mut add = OpAdd::<i32>::default();
let mut mul = OpMul::<i32>::default();
assert_eq!(add.process((5, 0)).await.unwrap(), 5);
assert_eq!(mul.process((5, 1)).await.unwrap(), 5);
}
#[tokio::test]
async fn test_arithmetic_commutativity() {
let mut add = OpAdd::<i32>::default();
let mut mul = OpMul::<i32>::default();
assert_eq!(
add.process((5, 3)).await.unwrap(),
add.process((3, 5)).await.unwrap()
);
assert_eq!(
mul.process((5, 3)).await.unwrap(),
mul.process((3, 5)).await.unwrap()
);
}
#[tokio::test]
async fn test_float_precision() {
let mut node = OpAdd::<f32>::default();
let result = node.process((0.1, 0.2)).await.unwrap();
assert!((result - 0.3).abs() < 0.0001);
}
#[tokio::test]
async fn test_chained_operations() {
let mut add = OpAdd::<i32>::default();
let mut mul = OpMul::<i32>::default();
let sum = add.process((5, 3)).await.unwrap();
let result = mul.process((sum, 2)).await.unwrap();
assert_eq!(result, 16);
}
#[tokio::test]
async fn test_unsigned_operations() {
let mut add = OpAdd::<u32>::default();
let mut mul = OpMul::<u32>::default();
assert_eq!(add.process((5, 3)).await.unwrap(), 8);
assert_eq!(mul.process((5, 3)).await.unwrap(), 15);
}
}