#![no_std]
#![forbid(unsafe_code)]
#![warn(missing_docs)]
use core::fmt;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use embedded_f32_sqrt::sqrt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ComplexError {
DivisionByZero,
NegativeInput,
Undefined,
}
impl fmt::Display for ComplexError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ComplexError::DivisionByZero => write!(f, "ComplexError: division by zero"),
ComplexError::NegativeInput => write!(f, "ComplexError: negative input"),
ComplexError::Undefined => write!(f, "ComplexError: undefined (NaN)"),
}
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct Complex {
re: f32,
im: f32,
}
impl Complex {
#[inline]
pub const fn new(re: f32, im: f32) -> Self {
Self { re, im }
}
pub const ZERO: Self = Self::new(0.0, 0.0);
pub const ONE: Self = Self::new(1.0, 0.0);
pub const I: Self = Self::new(0.0, 1.0);
#[inline]
pub const fn re(self) -> f32 { self.re }
#[inline]
pub const fn im(self) -> f32 { self.im }
#[inline]
pub fn is_nan(self) -> bool {
self.re.is_nan() || self.im.is_nan()
}
#[inline]
pub fn is_infinite(self) -> bool {
self.re.is_infinite() || self.im.is_infinite()
}
#[inline]
pub fn is_finite(self) -> bool {
self.re.is_finite() && self.im.is_finite()
}
}
impl Complex {
#[inline]
pub fn conj(self) -> Self {
Self::new(self.re, -self.im)
}
pub fn norm(self) -> f32 {
sqrt(self.re * self.re + self.im * self.im).unwrap_or(f32::NAN)
}
#[inline]
pub fn norm_sq(self) -> f32 {
self.re * self.re + self.im * self.im
}
pub fn checked_div(self, rhs: Self) -> Result<Self, ComplexError> {
let denom = rhs.norm_sq();
if denom == 0.0 {
return Err(ComplexError::DivisionByZero);
}
Ok(Self::new(
(self.re * rhs.re + self.im * rhs.im) / denom,
(self.im * rhs.re - self.re * rhs.im) / denom,
))
}
pub fn inv(self) -> Result<Self, ComplexError> {
Self::ONE.checked_div(self)
}
pub fn csqrt(self) -> Result<Self, ComplexError> {
if self.is_nan() {
return Err(ComplexError::Undefined);
}
let r = self.norm();
let sqrt_r = sqrt(r).map_err(|_| ComplexError::NegativeInput)?;
if sqrt_r == 0.0 {
return Ok(Self::ZERO);
}
let cos_theta = self.re / r;
let half_cos = sqrt(((1.0 + cos_theta) / 2.0).max(0.0))
.map_err(|_| ComplexError::NegativeInput)?;
let half_sin_abs = sqrt(((1.0 - cos_theta) / 2.0).max(0.0))
.map_err(|_| ComplexError::NegativeInput)?;
let half_sin = if self.im < 0.0 { -half_sin_abs } else { half_sin_abs };
Ok(Self::new(sqrt_r * half_cos, sqrt_r * half_sin))
}
pub fn powi(self, n: i32) -> Result<Self, ComplexError> {
let base = if n < 0 { self.inv()? } else { self };
let mut exp = n.unsigned_abs();
let mut result = Self::ONE;
let mut b = base;
while exp > 0 {
if exp & 1 == 1 { result = result * b; }
b = b * b;
exp >>= 1;
}
Ok(result)
}
}
impl Add for Complex {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self { Self::new(self.re + rhs.re, self.im + rhs.im) }
}
impl AddAssign for Complex {
#[inline]
fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; }
}
impl Sub for Complex {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self { Self::new(self.re - rhs.re, self.im - rhs.im) }
}
impl SubAssign for Complex {
#[inline]
fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; }
}
impl Mul for Complex {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::new(
self.re * rhs.re - self.im * rhs.im,
self.re * rhs.im + self.im * rhs.re,
)
}
}
impl MulAssign for Complex {
#[inline]
fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; }
}
impl Div for Complex {
type Output = Self;
fn div(self, rhs: Self) -> Self {
self.checked_div(rhs).unwrap_or(Self::new(f32::NAN, f32::NAN))
}
}
impl DivAssign for Complex {
#[inline]
fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; }
}
impl Neg for Complex {
type Output = Self;
#[inline]
fn neg(self) -> Self { Self::new(-self.re, -self.im) }
}
impl Add<f32> for Complex {
type Output = Self;
#[inline]
fn add(self, rhs: f32) -> Self { Self::new(self.re + rhs, self.im) }
}
impl Sub<f32> for Complex {
type Output = Self;
#[inline]
fn sub(self, rhs: f32) -> Self { Self::new(self.re - rhs, self.im) }
}
impl Mul<f32> for Complex {
type Output = Self;
#[inline]
fn mul(self, rhs: f32) -> Self { Self::new(self.re * rhs, self.im * rhs) }
}
impl Div<f32> for Complex {
type Output = Self;
#[inline]
fn div(self, rhs: f32) -> Self { Self::new(self.re / rhs, self.im / rhs) }
}
impl fmt::Display for Complex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.im >= 0.0 || self.im.is_nan() {
write!(f, "{} + {}i", self.re, self.im)
} else {
write!(f, "{} - {}i", self.re, -self.im)
}
}
}
impl fmt::Debug for Complex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Complex {{ re: {}, im: {} }}", self.re, self.im)
}
}
impl From<f32> for Complex {
#[inline]
fn from(x: f32) -> Self { Self::new(x, 0.0) }
}
impl From<(f32, f32)> for Complex {
#[inline]
fn from((re, im): (f32, f32)) -> Self { Self::new(re, im) }
}
impl From<Complex> for (f32, f32) {
#[inline]
fn from(z: Complex) -> Self { (z.re, z.im) }
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-4;
fn approx_eq(a: f32, b: f32) -> bool { (a - b).abs() < EPS }
fn complex_approx_eq(a: Complex, b: Complex) -> bool {
approx_eq(a.re, b.re) && approx_eq(a.im, b.im)
}
#[test]
fn constants() {
assert_eq!(Complex::ZERO, Complex::new(0.0, 0.0));
assert_eq!(Complex::ONE, Complex::new(1.0, 0.0));
assert_eq!(Complex::I, Complex::new(0.0, 1.0));
}
#[test]
fn addition() {
assert_eq!(Complex::new(1.0, 2.0) + Complex::new(3.0, -1.0), Complex::new(4.0, 1.0));
}
#[test]
fn subtraction() {
assert_eq!(Complex::new(5.0, 3.0) - Complex::new(2.0, 1.0), Complex::new(3.0, 2.0));
}
#[test]
fn multiplication() {
let r = Complex::new(1.0, 1.0) * Complex::new(1.0, -1.0);
assert!(approx_eq(r.re, 2.0));
assert!(approx_eq(r.im, 0.0));
}
#[test]
fn i_squared_is_minus_one() {
let r = Complex::I * Complex::I;
assert!(approx_eq(r.re, -1.0));
assert!(approx_eq(r.im, 0.0));
}
#[test]
fn division() {
let r = Complex::new(4.0, 2.0) / Complex::new(1.0, 1.0);
assert!(approx_eq(r.re, 3.0));
assert!(approx_eq(r.im, -1.0));
}
#[test]
fn division_by_zero_returns_nan() {
let r = Complex::ONE / Complex::ZERO;
assert!(r.is_nan());
}
#[test]
fn checked_div_by_zero_returns_err() {
assert_eq!(Complex::ONE.checked_div(Complex::ZERO), Err(ComplexError::DivisionByZero));
}
#[test]
fn negation() {
assert_eq!(-Complex::new(1.0, -2.0), Complex::new(-1.0, 2.0));
}
#[test]
fn scalar_ops() {
let z = Complex::new(3.0, 4.0);
assert_eq!(z * 2.0, Complex::new(6.0, 8.0));
assert_eq!(z / 2.0, Complex::new(1.5, 2.0));
assert_eq!(z + 1.0, Complex::new(4.0, 4.0));
assert_eq!(z - 1.0, Complex::new(2.0, 4.0));
}
#[test]
fn assign_ops() {
let mut z = Complex::new(1.0, 2.0);
z += Complex::new(0.5, 0.5);
assert!(approx_eq(z.re, 1.5));
z *= Complex::new(2.0, 0.0);
assert!(approx_eq(z.re, 3.0));
}
#[test]
fn norm_pythagorean() {
assert!(approx_eq(Complex::new(3.0, 4.0).norm(), 5.0));
assert!(approx_eq(Complex::new(5.0, 12.0).norm(), 13.0));
}
#[test]
fn norm_sq() {
assert!(approx_eq(Complex::new(3.0, 4.0).norm_sq(), 25.0));
}
#[test]
fn conjugate() {
let z = Complex::new(3.0, -4.0);
let c = z.conj();
assert_eq!(c.im(), 4.0);
let prod = z * c;
assert!(approx_eq(prod.im, 0.0));
assert!(prod.re > 0.0);
}
#[test]
fn csqrt_real_positive() {
let r = Complex::new(9.0, 0.0).csqrt().unwrap();
assert!(approx_eq(r.re, 3.0));
assert!(approx_eq(r.im, 0.0));
}
#[test]
fn csqrt_minus_one_gives_i() {
let r = Complex::new(-1.0, 0.0).csqrt().unwrap();
assert!(approx_eq(r.norm(), 1.0));
}
#[test]
fn csqrt_general() {
let z = Complex::new(3.0, 4.0);
let back = z.csqrt().unwrap() * z.csqrt().unwrap();
assert!(approx_eq(back.re, 3.0));
assert!(approx_eq(back.im, 4.0));
}
#[test]
fn powi_zero_exp() {
assert!(complex_approx_eq(Complex::new(5.0, 3.0).powi(0).unwrap(), Complex::ONE));
}
#[test]
fn powi_i4_is_one() {
let r = Complex::I.powi(4).unwrap();
assert!(approx_eq(r.re, 1.0));
assert!(approx_eq(r.im, 0.0));
}
#[test]
fn powi_negative_exp() {
let r = Complex::new(2.0, 0.0).powi(-1).unwrap();
assert!(approx_eq(r.re, 0.5));
}
#[test]
fn inv_real() {
let r = Complex::new(4.0, 0.0).inv().unwrap();
assert!(approx_eq(r.re, 0.25));
}
#[test]
fn inv_zero_returns_err() {
assert_eq!(Complex::ZERO.inv(), Err(ComplexError::DivisionByZero));
}
#[test]
fn from_f32() {
assert_eq!(Complex::from(3.0f32), Complex::new(3.0, 0.0));
}
#[test]
fn from_tuple() {
let z: Complex = (2.0f32, -1.0f32).into();
assert_eq!(z, Complex::new(2.0, -1.0));
}
#[test]
fn into_tuple() {
let (re, im): (f32, f32) = Complex::new(7.0, -3.0).into();
assert_eq!(re, 7.0);
assert_eq!(im, -3.0);
}
#[test]
fn nan_propagation() {
let z = Complex::new(f32::NAN, 0.0);
assert!(z.is_nan());
assert!((z + Complex::ONE).is_nan());
}
#[test]
fn finite_check() {
assert!( Complex::new(1.0, 2.0).is_finite());
assert!(!Complex::new(f32::INFINITY, 0.0).is_finite());
}
}