#![allow(unused_imports)]
use std::arch::asm;
#[cfg(feature = "std")]
use std::marker::PhantomData;
#[cfg(not(feature = "std"))]
use core::marker::PhantomData;
use num_complex::Complex;
#[inline]
pub fn mul<T>(a: Complex<T>, b: Complex<T>) -> Complex<T>
where
T: num_traits::Float,
{
Complex::new(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
}
#[inline]
pub fn mul_conj<T>(a: Complex<T>, b: Complex<T>) -> Complex<T>
where
T: num_traits::Float,
{
Complex::new(b.re * a.re + b.im * a.im, b.re * a.im - b.im * a.re)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub struct StopDenormals {
#[cfg(feature = "std")]
control_status_register: u32,
#[cfg(not(feature = "std"))]
_phantom: PhantomData<()>,
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl StopDenormals {
#[cfg(feature = "std")]
pub fn new() -> Self {
#[cfg(target_arch = "x86")]
use std::arch::x86::_mm_getcsr;
#[cfg(target_arch = "x86")]
use std::arch::x86::_mm_setcsr;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::_mm_getcsr;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::_mm_setcsr;
unsafe {
let csr = _mm_getcsr();
_mm_setcsr(csr | 0x8040); Self {
control_status_register: csr,
}
}
}
#[cfg(not(feature = "std"))]
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl Drop for StopDenormals {
#[cfg(feature = "std")]
fn drop(&mut self) {
#[cfg(target_arch = "x86")]
use std::arch::x86::_mm_setcsr;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::_mm_setcsr;
unsafe {
_mm_setcsr(self.control_status_register);
}
}
#[cfg(not(feature = "std"))]
fn drop(&mut self) {
}
}
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
pub struct StopDenormals {
#[cfg(feature = "std")]
status: usize,
#[cfg(not(feature = "std"))]
_phantom: PhantomData<()>,
}
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
impl StopDenormals {
#[cfg(feature = "std")]
pub fn new() -> Self {
let mut status: usize;
unsafe {
asm!(
"mrs {0}, fpcr",
out(reg) status
);
let new_status = status | 0x01000000; asm!(
"msr fpcr, {0}",
in(reg) new_status
);
}
Self { status }
}
#[cfg(not(feature = "std"))]
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
impl Drop for StopDenormals {
#[cfg(feature = "std")]
fn drop(&mut self) {
unsafe {
asm!(
"msr fpcr, {0}",
in(reg) self.status
);
}
}
#[cfg(not(feature = "std"))]
fn drop(&mut self) {
}
}
#[cfg(not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64"
)))]
pub struct StopDenormals {
_phantom: PhantomData<()>,
}
#[cfg(not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64"
)))]
impl StopDenormals {
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complex_mul() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
let result = mul(a, b);
let expected = a * b;
assert_eq!(result, expected);
}
#[test]
fn test_complex_mul_conj() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
let result = mul_conj(a, b);
let expected = a * b.conj();
assert_eq!(result, expected);
}
}