use crate::CalculatorError;
use crate::CalculatorFloat;
use num_complex::Complex;
#[cfg(feature = "json_schema")]
use schemars::schema::*;
use serde::de::Deserialize;
use serde::de::Error;
use serde::de::{SeqAccess, Visitor};
use serde::ser::SerializeTuple;
use serde::Serialize;
use std::convert::TryFrom;
use std::fmt;
use std::ops;
#[derive(Debug, Clone, PartialEq)]
pub struct CalculatorComplex {
pub re: CalculatorFloat,
pub im: CalculatorFloat,
}
#[cfg(feature = "json_schema")]
impl schemars::JsonSchema for CalculatorComplex {
fn schema_name() -> String {
"CalculatorComplex".to_string()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
<(CalculatorFloat, CalculatorFloat)>::json_schema(gen)
}
}
impl Serialize for CalculatorComplex {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut tuple = serializer.serialize_tuple(2)?;
tuple.serialize_element(&self.re)?;
tuple.serialize_element(&self.im)?;
tuple.end()
}
}
impl<'de> Deserialize<'de> for CalculatorComplex {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ComplexVisitor;
impl<'de> Visitor<'de> for ComplexVisitor {
type Value = CalculatorComplex;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
std::fmt::Formatter::write_str(
formatter,
"Tuple of two CalculatorFloat values (float or string)",
)
}
fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: SeqAccess<'de>,
{
let real: CalculatorFloat = match access.next_element()? {
Some(x) => x,
None => {
return Err(M::Error::custom("Missing real part".to_string()));
}
};
let imaginary: CalculatorFloat = match access.next_element()? {
Some(x) => x,
None => {
return Err(M::Error::custom("Missing imaginary part".to_string()));
}
};
Ok(CalculatorComplex::new(real, imaginary))
}
}
let pp_visitor = ComplexVisitor;
deserializer.deserialize_tuple(2, pp_visitor)
}
}
impl Default for CalculatorComplex {
fn default() -> Self {
CalculatorComplex {
re: CalculatorFloat::Float(0.0),
im: CalculatorFloat::Float(0.0),
}
}
}
impl<'a> From<&'a CalculatorComplex> for CalculatorComplex {
fn from(item: &'a CalculatorComplex) -> Self {
(*item).clone()
}
}
impl<T1, T2> From<(T1, T2)> for CalculatorComplex
where
T1: Into<CalculatorFloat>,
T2: Into<CalculatorFloat>,
{
fn from(input: (T1, T2)) -> Self {
CalculatorComplex {
re: input.0.into(),
im: input.1.into(),
}
}
}
impl<T> From<T> for CalculatorComplex
where
CalculatorFloat: From<T>,
{
fn from(item: T) -> Self {
Self {
re: CalculatorFloat::from(item),
im: CalculatorFloat::Float(0.0),
}
}
}
impl From<Complex<f64>> for CalculatorComplex {
fn from(item: Complex<f64>) -> Self {
Self {
re: CalculatorFloat::from(item.re),
im: CalculatorFloat::from(item.im),
}
}
}
impl TryFrom<CalculatorComplex> for f64 {
type Error = CalculatorError;
fn try_from(value: CalculatorComplex) -> Result<Self, Self::Error> {
match value.im {
CalculatorFloat::Float(x) => {
if x != 0.0 {
return Err(CalculatorError::ComplexCanNotBeConvertedToFloat { val: value });
}
}
_ => return Err(CalculatorError::ComplexSymbolicNotConvertable { val: value }),
}
match value.re {
CalculatorFloat::Float(x) => Ok(x),
CalculatorFloat::Str(_) => {
Err(CalculatorError::ComplexSymbolicNotConvertable { val: value })
}
}
}
}
impl TryFrom<CalculatorComplex> for Complex<f64> {
type Error = CalculatorError;
fn try_from(value: CalculatorComplex) -> Result<Self, CalculatorError> {
let im = match value.im {
CalculatorFloat::Float(x) => x,
_ => return Err(CalculatorError::ComplexSymbolicNotConvertable { val: value }),
};
let re = match value.re {
CalculatorFloat::Float(x) => x,
CalculatorFloat::Str(_) => {
return Err(CalculatorError::ComplexSymbolicNotConvertable { val: value })
}
};
Ok(Complex::new(re, im))
}
}
impl fmt::Display for CalculatorComplex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "({} + i * {})", self.re, self.im)
}
}
impl CalculatorComplex {
pub const ZERO: CalculatorComplex = CalculatorComplex {
re: CalculatorFloat::Float(0.0),
im: CalculatorFloat::Float(0.0),
};
pub const ONE: CalculatorComplex = CalculatorComplex {
re: CalculatorFloat::Float(1.0),
im: CalculatorFloat::Float(0.0),
};
pub const I: CalculatorComplex = CalculatorComplex {
re: CalculatorFloat::Float(0.0),
im: CalculatorFloat::Float(1.0),
};
pub fn new<T1, T2>(re: T1, im: T2) -> Self
where
T1: Into<CalculatorFloat>,
T2: Into<CalculatorFloat>,
{
Self {
re: re.into(),
im: im.into(),
}
}
pub fn arg(&self) -> CalculatorFloat {
self.im.atan2(&self.re)
}
pub fn norm_sqr(&self) -> CalculatorFloat {
(self.re.clone() * &self.re) + (self.im.clone() * &self.im)
}
pub fn norm(&self) -> CalculatorFloat {
((self.re.clone() * &self.re) + (self.im.clone() * &self.im)).sqrt()
}
pub fn abs(&self) -> CalculatorFloat {
self.norm()
}
pub fn conj(&self) -> CalculatorComplex {
Self {
re: self.re.clone(),
im: -self.im.clone(),
}
}
pub fn isclose<T>(&self, other: T) -> bool
where
T: Into<CalculatorComplex>,
{
let other_from: CalculatorComplex = other.into();
self.re.isclose(other_from.re) && self.im.isclose(other_from.im)
}
}
impl<T> ops::Add<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
type Output = Self;
fn add(self, other: T) -> Self {
let other_from = other.into();
CalculatorComplex {
re: self.re + other_from.re,
im: self.im + other_from.im,
}
}
}
impl std::iter::Sum for CalculatorComplex {
fn sum<I: Iterator<Item = CalculatorComplex>>(iter: I) -> Self {
let mut sum = CalculatorComplex::new(0, 0);
for i in iter {
sum += i;
}
sum
}
}
impl<T> ops::AddAssign<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
fn add_assign(&mut self, other: T) {
let other_from: CalculatorComplex = other.into();
*self = CalculatorComplex {
re: &self.re + other_from.re,
im: &self.im + other_from.im,
}
}
}
impl<T> ops::Sub<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
type Output = Self;
fn sub(self, other: T) -> Self {
let other_from: CalculatorComplex = other.into();
CalculatorComplex {
re: self.re - other_from.re,
im: self.im - other_from.im,
}
}
}
impl<T> ops::SubAssign<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
fn sub_assign(&mut self, other: T) {
let other_from: CalculatorComplex = other.into();
*self = CalculatorComplex {
re: self.re.clone() - other_from.re,
im: self.im.clone() - other_from.im,
}
}
}
impl ops::Neg for CalculatorComplex {
type Output = CalculatorComplex;
fn neg(self) -> Self {
CalculatorComplex {
re: -self.re,
im: -self.im,
}
}
}
impl<T> ops::Mul<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
type Output = Self;
fn mul(self, other: T) -> Self {
let other_from: CalculatorComplex = other.into();
CalculatorComplex {
re: self.re.clone() * &other_from.re - (self.im.clone() * &other_from.im),
im: self.re * &other_from.im + (self.im * &other_from.re),
}
}
}
impl<T> ops::MulAssign<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
fn mul_assign(&mut self, other: T) {
let other_from: CalculatorComplex = other.into();
*self = CalculatorComplex {
re: self.re.clone() * &other_from.re - (self.im.clone() * &other_from.im),
im: self.re.clone() * &other_from.im + (self.im.clone() * &other_from.re),
}
}
}
impl<T> ops::Div<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
type Output = Self;
fn div(self, other: T) -> Self {
let other_from: CalculatorComplex = other.into();
let norm = other_from.norm_sqr();
CalculatorComplex {
re: (self.re.clone() * &other_from.re + (self.im.clone() * &other_from.im)) / &norm,
im: (-self.re * &other_from.im + (self.im * &other_from.re)) / &norm,
}
}
}
impl<T> ops::DivAssign<T> for CalculatorComplex
where
T: Into<CalculatorComplex>,
{
fn div_assign(&mut self, other: T) {
let other_from: CalculatorComplex = other.into();
let norm = other_from.norm_sqr();
*self = CalculatorComplex {
re: (self.re.clone() * &other_from.re + (self.im.clone() * &other_from.im)) / &norm,
im: (-self.re.clone() * &other_from.im + (self.im.clone() * &other_from.re)) / &norm,
}
}
}
impl CalculatorComplex {
pub fn recip(&self) -> CalculatorComplex {
let norm = self.norm_sqr();
CalculatorComplex {
re: self.re.clone() / &norm,
im: -self.im.clone() / &norm,
}
}
}
#[cfg(test)]
mod tests {
use super::CalculatorComplex;
use super::CalculatorFloat;
use num_complex::Complex;
#[cfg(feature = "json_schema")]
use schemars::schema_for;
use serde_test::assert_tokens;
use serde_test::Configure;
use serde_test::Token;
use std::convert::TryFrom;
use std::ops::Neg;
#[test]
fn from_int() {
let x = CalculatorComplex::from(3);
assert_eq!(x.re, CalculatorFloat::from(3));
assert_eq!(x.im, CalculatorFloat::from(0));
assert_eq!(
x,
CalculatorComplex {
re: CalculatorFloat::from(3),
im: CalculatorFloat::from(0)
}
);
assert_eq!(f64::try_from(x).unwrap(), 3.0)
}
#[test]
fn serde_readable() {
let complex_float = CalculatorComplex::new(0.1, 0.3);
assert_tokens(
&complex_float.readable(),
&[
Token::Tuple { len: 2 },
Token::F64(0.1),
Token::F64(0.3),
Token::TupleEnd,
],
);
let complex_str = CalculatorComplex::new("a", "b");
assert_tokens(
&complex_str.readable(),
&[
Token::Tuple { len: 2 },
Token::Str("a"),
Token::Str("b"),
Token::TupleEnd,
],
);
let complex_mixed = CalculatorComplex::new("a", -0.3);
assert_tokens(
&complex_mixed.readable(),
&[
Token::Tuple { len: 2 },
Token::Str("a"),
Token::F64(-0.3),
Token::TupleEnd,
],
);
let complex_num_complex = Complex::<f64>::new(0.0, -3.0);
assert_tokens(
&complex_num_complex.readable(),
&[
Token::Tuple { len: 2 },
Token::F64(0.0),
Token::F64(-3.0),
Token::TupleEnd,
],
);
}
#[cfg(feature = "json_schema")]
#[test]
fn test_json_schema_support() {
let schema = schema_for!(CalculatorComplex);
let serialized = serde_json::to_string(&schema).unwrap();
assert_eq!(serialized.as_str(), "{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"title\":\"CalculatorComplex\",\"type\":\"array\",\"items\":[{\"$ref\":\"#/definitions/CalculatorFloat\"},{\"$ref\":\"#/definitions/CalculatorFloat\"}],\"maxItems\":2,\"minItems\":2,\"definitions\":{\"CalculatorFloat\":{\"oneOf\":[{\"type\":\"number\",\"format\":\"double\"},{\"type\":\"string\"}]}}}");
}
#[test]
fn from_float() {
let x = CalculatorComplex::from(3.1);
assert_eq!(x.re, CalculatorFloat::from(3.1));
assert_eq!(x.im, CalculatorFloat::from(0));
assert_eq!(
x,
CalculatorComplex {
re: CalculatorFloat::from(3.1),
im: CalculatorFloat::from(0)
}
);
}
#[test]
fn from_str() {
let x = CalculatorComplex::from("3.1");
assert_eq!(x.re, CalculatorFloat::from("3.1"));
assert_eq!(x.im, CalculatorFloat::from(0));
assert_eq!(
x,
CalculatorComplex {
re: CalculatorFloat::from("3.1"),
im: CalculatorFloat::from(0)
}
);
}
#[test]
fn from_complex() {
let x = CalculatorComplex::from(Complex::new(1.0, 2.0));
assert_eq!(x.re, CalculatorFloat::from(1.0));
assert_eq!(x.im, CalculatorFloat::from(2.00));
assert_eq!(
x,
CalculatorComplex {
re: CalculatorFloat::from(1.0),
im: CalculatorFloat::from(2.0)
}
);
}
#[test]
fn from_calculator_complex() {
let x = CalculatorComplex::new(1, 1);
assert_eq!(CalculatorComplex::from(&x), x);
}
#[test]
fn default() {
let x = CalculatorComplex::default();
assert_eq!(x.re, CalculatorFloat::from(0.0));
assert_eq!(x.im, CalculatorFloat::from(0.0));
assert_eq!(x, CalculatorComplex::new(0, 0));
}
#[test]
fn try_from_float() {
let x = CalculatorComplex::new(1.0, 0.0);
assert_eq!(<f64>::try_from(x).unwrap(), 1.0);
let x = CalculatorComplex::new(0.0, 1.0);
assert!(f64::try_from(x).is_err());
let x = CalculatorComplex::new("x", 0.0);
assert!(f64::try_from(x).is_err());
let x = CalculatorComplex::new(1.0, "x");
assert!(f64::try_from(x).is_err());
}
#[test]
fn try_from_complex() {
let x = CalculatorComplex::new(1, 1);
assert_eq!(Complex::<f64>::try_from(x).unwrap(), Complex::new(1.0, 1.0));
let x = CalculatorComplex::new(0.0, "x");
assert!(Complex::<f64>::try_from(x).is_err());
let x = CalculatorComplex::new("x", 0.0);
assert!(Complex::<f64>::try_from(x).is_err());
}
#[test]
fn display() {
let x = CalculatorComplex::new(-3, 2);
let x_formatted = format!("{x}");
assert_eq!(x_formatted, "(-3e0 + i * 2e0)");
}
#[test]
fn try_add() {
let x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, "test");
assert_eq!(x + y, CalculatorComplex::new(3.0, "(1e0 + test)"));
}
#[test]
fn try_iadd() {
let mut x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, "test");
x += y;
assert_eq!(x, CalculatorComplex::new(3.0, "(1e0 + test)"));
}
#[test]
fn try_sub() {
let x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, "test");
assert_eq!(x - y, CalculatorComplex::new(-1.0, "(1e0 - test)"));
}
#[test]
fn try_isub() {
let mut x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, "test");
x -= y;
assert_eq!(x, CalculatorComplex::new(-1.0, "(1e0 - test)"));
}
#[test]
fn try_mul() {
let x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, 2);
assert_eq!(x * y, CalculatorComplex::new(0.0, 4.0));
}
#[test]
fn try_imul() {
let mut x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(2, 2);
x *= y;
assert_eq!(x, CalculatorComplex::new(0.0, 4.0));
}
#[test]
fn try_div() {
let x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(3, 4);
assert_eq!(x / y, CalculatorComplex::new(7.0 / 25.0, -1.0 / 25.0));
}
#[test]
fn try_idiv() {
let mut x = CalculatorComplex::new(1, 1);
let y = CalculatorComplex::new(3, 4);
x /= y;
assert_eq!(x, CalculatorComplex::new(7.0 / 25.0, -1.0 / 25.0));
}
#[test]
fn arg() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert_eq!(x.arg(), CalculatorFloat::from(y.arg()));
let x = CalculatorComplex::new("x", 2);
assert_eq!(x.arg(), CalculatorFloat::from("atan2(2e0, x)"));
let x = CalculatorComplex::new(1, "2x");
assert_eq!(x.arg(), CalculatorFloat::from("atan2(2x, 1e0)"));
let x = CalculatorComplex::new("x", "2t");
assert_eq!(x.arg(), CalculatorFloat::from("atan2(2t, x)"));
}
#[test]
fn norm_sqr() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert_eq!(x.norm_sqr(), CalculatorFloat::from(y.norm_sqr()));
}
#[test]
fn norm() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert_eq!(x.norm(), CalculatorFloat::from(y.norm()));
}
#[test]
fn abs() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert_eq!(x.abs(), CalculatorFloat::from(y.norm()));
}
#[test]
fn conj() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert_eq!(x.conj(), CalculatorComplex::new(y.conj().re, y.conj().im));
}
#[test]
fn is_close() {
let x = CalculatorComplex::new(1, 2);
let y = Complex::new(1.0, 2.0);
assert!(x.isclose(y));
let y = 1.0;
assert!(!x.isclose(y));
}
#[test]
fn neg() {
let x = CalculatorComplex::new(1, 2);
assert_eq!(x.neg(), CalculatorComplex::new(-1, -2));
}
#[test]
fn inv() {
let x = CalculatorComplex::new(3, 4);
assert_eq!(x.recip(), CalculatorComplex::new(0.12, -0.16));
}
#[test]
fn debug() {
let x = CalculatorComplex::from(3.0);
assert_eq!(
format!("{x:?}"),
"CalculatorComplex { re: Float(3.0), im: Float(0.0) }"
);
let xs = CalculatorComplex::from("3x");
assert_eq!(
format!("{xs:?}"),
"CalculatorComplex { re: Str(\"3x\"), im: Float(0.0) }"
);
}
#[test]
fn clone_trait() {
let x = CalculatorComplex::from(3.0);
assert_eq!(x.clone(), x);
let xs = CalculatorComplex::from("3x");
assert_eq!(xs.clone(), xs);
}
#[test]
fn partial_eq() {
let x1 = CalculatorComplex::from(3.0);
let x2 = CalculatorComplex::from(3.0);
assert!(x1 == x2);
assert!(x2 == x1);
let x1s = CalculatorComplex::from("3x");
let x2s = CalculatorComplex::from("3x");
assert!(x1s == x2s);
assert!(x2s == x1s);
}
}