1#![no_std]
47#![forbid(unsafe_code)]
48#![warn(missing_docs)]
49
50use core::fmt;
51use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
52
53use embedded_f32_sqrt::sqrt;
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ComplexError {
60 DivisionByZero,
62 NegativeInput,
64 Undefined,
66}
67
68impl fmt::Display for ComplexError {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 match self {
71 ComplexError::DivisionByZero => write!(f, "ComplexError: division by zero"),
72 ComplexError::NegativeInput => write!(f, "ComplexError: negative input"),
73 ComplexError::Undefined => write!(f, "ComplexError: undefined (NaN)"),
74 }
75 }
76}
77
78#[derive(Clone, Copy, PartialEq)]
89pub struct Complex {
90 re: f32,
91 im: f32,
92}
93
94impl Complex {
96 #[inline]
105 pub const fn new(re: f32, im: f32) -> Self {
106 Self { re, im }
107 }
108
109 pub const ZERO: Self = Self::new(0.0, 0.0);
111
112 pub const ONE: Self = Self::new(1.0, 0.0);
114
115 pub const I: Self = Self::new(0.0, 1.0);
117
118 #[inline]
120 pub const fn re(self) -> f32 { self.re }
121
122 #[inline]
124 pub const fn im(self) -> f32 { self.im }
125
126 #[inline]
128 pub fn is_nan(self) -> bool {
129 self.re.is_nan() || self.im.is_nan()
130 }
131
132 #[inline]
134 pub fn is_infinite(self) -> bool {
135 self.re.is_infinite() || self.im.is_infinite()
136 }
137
138 #[inline]
140 pub fn is_finite(self) -> bool {
141 self.re.is_finite() && self.im.is_finite()
142 }
143}
144
145
146impl Complex {
149 #[inline]
157 pub fn conj(self) -> Self {
158 Self::new(self.re, -self.im)
159 }
160
161 pub fn norm(self) -> f32 {
171 sqrt(self.re * self.re + self.im * self.im).unwrap_or(f32::NAN)
172 }
173
174 #[inline]
182 pub fn norm_sq(self) -> f32 {
183 self.re * self.re + self.im * self.im
184 }
185
186 pub fn checked_div(self, rhs: Self) -> Result<Self, ComplexError> {
194 let denom = rhs.norm_sq();
195 if denom == 0.0 {
196 return Err(ComplexError::DivisionByZero);
197 }
198 Ok(Self::new(
199 (self.re * rhs.re + self.im * rhs.im) / denom,
200 (self.im * rhs.re - self.re * rhs.im) / denom,
201 ))
202 }
203
204 pub fn inv(self) -> Result<Self, ComplexError> {
214 Self::ONE.checked_div(self)
215 }
216
217 pub fn csqrt(self) -> Result<Self, ComplexError> {
231 if self.is_nan() {
232 return Err(ComplexError::Undefined);
233 }
234 let r = self.norm();
235 let sqrt_r = sqrt(r).map_err(|_| ComplexError::NegativeInput)?;
236
237 if sqrt_r == 0.0 {
238 return Ok(Self::ZERO);
239 }
240
241 let cos_theta = self.re / r;
243 let half_cos = sqrt(((1.0 + cos_theta) / 2.0).max(0.0))
244 .map_err(|_| ComplexError::NegativeInput)?;
245 let half_sin_abs = sqrt(((1.0 - cos_theta) / 2.0).max(0.0))
246 .map_err(|_| ComplexError::NegativeInput)?;
247 let half_sin = if self.im < 0.0 { -half_sin_abs } else { half_sin_abs };
248
249 Ok(Self::new(sqrt_r * half_cos, sqrt_r * half_sin))
250 }
251
252 pub fn powi(self, n: i32) -> Result<Self, ComplexError> {
263 let base = if n < 0 { self.inv()? } else { self };
264 let mut exp = n.unsigned_abs();
265 let mut result = Self::ONE;
266 let mut b = base;
267 while exp > 0 {
268 if exp & 1 == 1 { result = result * b; }
269 b = b * b;
270 exp >>= 1;
271 }
272 Ok(result)
273 }
274}
275
276impl Add for Complex {
279 type Output = Self;
280 #[inline]
281 fn add(self, rhs: Self) -> Self { Self::new(self.re + rhs.re, self.im + rhs.im) }
282}
283impl AddAssign for Complex {
284 #[inline]
285 fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; }
286}
287
288impl Sub for Complex {
289 type Output = Self;
290 #[inline]
291 fn sub(self, rhs: Self) -> Self { Self::new(self.re - rhs.re, self.im - rhs.im) }
292}
293impl SubAssign for Complex {
294 #[inline]
295 fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; }
296}
297
298impl Mul for Complex {
300 type Output = Self;
301 #[inline]
302 fn mul(self, rhs: Self) -> Self {
303 Self::new(
304 self.re * rhs.re - self.im * rhs.im,
305 self.re * rhs.im + self.im * rhs.re,
306 )
307 }
308}
309impl MulAssign for Complex {
310 #[inline]
311 fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; }
312}
313
314impl Div for Complex {
319 type Output = Self;
320 fn div(self, rhs: Self) -> Self {
321 self.checked_div(rhs).unwrap_or(Self::new(f32::NAN, f32::NAN))
322 }
323}
324impl DivAssign for Complex {
325 #[inline]
326 fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; }
327}
328
329impl Neg for Complex {
330 type Output = Self;
331 #[inline]
332 fn neg(self) -> Self { Self::new(-self.re, -self.im) }
333}
334
335impl Add<f32> for Complex {
337 type Output = Self;
338 #[inline]
339 fn add(self, rhs: f32) -> Self { Self::new(self.re + rhs, self.im) }
340}
341impl Sub<f32> for Complex {
342 type Output = Self;
343 #[inline]
344 fn sub(self, rhs: f32) -> Self { Self::new(self.re - rhs, self.im) }
345}
346impl Mul<f32> for Complex {
347 type Output = Self;
348 #[inline]
349 fn mul(self, rhs: f32) -> Self { Self::new(self.re * rhs, self.im * rhs) }
350}
351impl Div<f32> for Complex {
352 type Output = Self;
353 #[inline]
354 fn div(self, rhs: f32) -> Self { Self::new(self.re / rhs, self.im / rhs) }
355}
356
357impl fmt::Display for Complex {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 if self.im >= 0.0 || self.im.is_nan() {
362 write!(f, "{} + {}i", self.re, self.im)
363 } else {
364 write!(f, "{} - {}i", self.re, -self.im)
365 }
366 }
367}
368
369impl fmt::Debug for Complex {
370 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371 write!(f, "Complex {{ re: {}, im: {} }}", self.re, self.im)
372 }
373}
374
375impl From<f32> for Complex {
378 #[inline]
380 fn from(x: f32) -> Self { Self::new(x, 0.0) }
381}
382
383impl From<(f32, f32)> for Complex {
384 #[inline]
385 fn from((re, im): (f32, f32)) -> Self { Self::new(re, im) }
386}
387
388impl From<Complex> for (f32, f32) {
389 #[inline]
390 fn from(z: Complex) -> Self { (z.re, z.im) }
391}
392
393#[cfg(test)]
395mod tests {
396 use super::*;
397
398 const EPS: f32 = 1e-4;
399
400 fn approx_eq(a: f32, b: f32) -> bool { (a - b).abs() < EPS }
401 fn complex_approx_eq(a: Complex, b: Complex) -> bool {
402 approx_eq(a.re, b.re) && approx_eq(a.im, b.im)
403 }
404
405 #[test]
408 fn constants() {
409 assert_eq!(Complex::ZERO, Complex::new(0.0, 0.0));
410 assert_eq!(Complex::ONE, Complex::new(1.0, 0.0));
411 assert_eq!(Complex::I, Complex::new(0.0, 1.0));
412 }
413
414 #[test]
417 fn addition() {
418 assert_eq!(Complex::new(1.0, 2.0) + Complex::new(3.0, -1.0), Complex::new(4.0, 1.0));
419 }
420
421 #[test]
422 fn subtraction() {
423 assert_eq!(Complex::new(5.0, 3.0) - Complex::new(2.0, 1.0), Complex::new(3.0, 2.0));
424 }
425
426 #[test]
427 fn multiplication() {
428 let r = Complex::new(1.0, 1.0) * Complex::new(1.0, -1.0);
430 assert!(approx_eq(r.re, 2.0));
431 assert!(approx_eq(r.im, 0.0));
432 }
433
434 #[test]
435 fn i_squared_is_minus_one() {
436 let r = Complex::I * Complex::I;
437 assert!(approx_eq(r.re, -1.0));
438 assert!(approx_eq(r.im, 0.0));
439 }
440
441 #[test]
442 fn division() {
443 let r = Complex::new(4.0, 2.0) / Complex::new(1.0, 1.0);
445 assert!(approx_eq(r.re, 3.0));
446 assert!(approx_eq(r.im, -1.0));
447 }
448
449 #[test]
450 fn division_by_zero_returns_nan() {
451 let r = Complex::ONE / Complex::ZERO;
452 assert!(r.is_nan());
453 }
454
455 #[test]
456 fn checked_div_by_zero_returns_err() {
457 assert_eq!(Complex::ONE.checked_div(Complex::ZERO), Err(ComplexError::DivisionByZero));
458 }
459
460 #[test]
461 fn negation() {
462 assert_eq!(-Complex::new(1.0, -2.0), Complex::new(-1.0, 2.0));
463 }
464
465 #[test]
466 fn scalar_ops() {
467 let z = Complex::new(3.0, 4.0);
468 assert_eq!(z * 2.0, Complex::new(6.0, 8.0));
469 assert_eq!(z / 2.0, Complex::new(1.5, 2.0));
470 assert_eq!(z + 1.0, Complex::new(4.0, 4.0));
471 assert_eq!(z - 1.0, Complex::new(2.0, 4.0));
472 }
473
474 #[test]
475 fn assign_ops() {
476 let mut z = Complex::new(1.0, 2.0);
477 z += Complex::new(0.5, 0.5);
478 assert!(approx_eq(z.re, 1.5));
479 z *= Complex::new(2.0, 0.0);
480 assert!(approx_eq(z.re, 3.0));
481 }
482
483 #[test]
486 fn norm_pythagorean() {
487 assert!(approx_eq(Complex::new(3.0, 4.0).norm(), 5.0));
488 assert!(approx_eq(Complex::new(5.0, 12.0).norm(), 13.0));
489 }
490
491 #[test]
492 fn norm_sq() {
493 assert!(approx_eq(Complex::new(3.0, 4.0).norm_sq(), 25.0));
494 }
495
496 #[test]
497 fn conjugate() {
498 let z = Complex::new(3.0, -4.0);
499 let c = z.conj();
500 assert_eq!(c.im(), 4.0);
501 let prod = z * c;
502 assert!(approx_eq(prod.im, 0.0));
503 assert!(prod.re > 0.0);
504 }
505
506 #[test]
509 fn csqrt_real_positive() {
510 let r = Complex::new(9.0, 0.0).csqrt().unwrap();
511 assert!(approx_eq(r.re, 3.0));
512 assert!(approx_eq(r.im, 0.0));
513 }
514
515 #[test]
516 fn csqrt_minus_one_gives_i() {
517 let r = Complex::new(-1.0, 0.0).csqrt().unwrap();
518 assert!(approx_eq(r.norm(), 1.0));
519 }
520
521 #[test]
522 fn csqrt_general() {
523 let z = Complex::new(3.0, 4.0);
525 let back = z.csqrt().unwrap() * z.csqrt().unwrap();
526 assert!(approx_eq(back.re, 3.0));
527 assert!(approx_eq(back.im, 4.0));
528 }
529
530 #[test]
533 fn powi_zero_exp() {
534 assert!(complex_approx_eq(Complex::new(5.0, 3.0).powi(0).unwrap(), Complex::ONE));
535 }
536
537 #[test]
538 fn powi_i4_is_one() {
539 let r = Complex::I.powi(4).unwrap();
540 assert!(approx_eq(r.re, 1.0));
541 assert!(approx_eq(r.im, 0.0));
542 }
543
544 #[test]
545 fn powi_negative_exp() {
546 let r = Complex::new(2.0, 0.0).powi(-1).unwrap();
547 assert!(approx_eq(r.re, 0.5));
548 }
549
550 #[test]
553 fn inv_real() {
554 let r = Complex::new(4.0, 0.0).inv().unwrap();
555 assert!(approx_eq(r.re, 0.25));
556 }
557
558 #[test]
559 fn inv_zero_returns_err() {
560 assert_eq!(Complex::ZERO.inv(), Err(ComplexError::DivisionByZero));
561 }
562
563 #[test]
566 fn from_f32() {
567 assert_eq!(Complex::from(3.0f32), Complex::new(3.0, 0.0));
568 }
569
570 #[test]
571 fn from_tuple() {
572 let z: Complex = (2.0f32, -1.0f32).into();
573 assert_eq!(z, Complex::new(2.0, -1.0));
574 }
575
576 #[test]
577 fn into_tuple() {
578 let (re, im): (f32, f32) = Complex::new(7.0, -3.0).into();
579 assert_eq!(re, 7.0);
580 assert_eq!(im, -3.0);
581 }
582
583 #[test]
586 fn nan_propagation() {
587 let z = Complex::new(f32::NAN, 0.0);
588 assert!(z.is_nan());
589 assert!((z + Complex::ONE).is_nan());
590 }
591
592 #[test]
593 fn finite_check() {
594 assert!( Complex::new(1.0, 2.0).is_finite());
595 assert!(!Complex::new(f32::INFINITY, 0.0).is_finite());
596 }
597}