#![allow(unused_imports)]
use num_traits::{Float, FromPrimitive, Zero, One};
use num_complex::Complex;
#[cfg(feature = "alloc")]
use alloc::{vec::Vec, boxed::Box};
pub type ConstRealPointer<T> = *const T;
pub type RealPointer<T> = *mut T;
pub type ConstComplexPointer<T> = *const Complex<T>;
pub type ComplexPointer<T> = *mut Complex<T>;
#[derive(Copy, Clone, Debug)]
pub struct ConstSplitPointer<T> {
pub real: ConstRealPointer<T>,
pub imag: ConstRealPointer<T>,
}
impl<T> ConstSplitPointer<T> {
pub fn new(real: ConstRealPointer<T>, imag: ConstRealPointer<T>) -> Self {
Self { real, imag }
}
pub unsafe fn get(&self, i: usize) -> Complex<T>
where
T: Copy,
{
Complex::new(*self.real.add(i), *self.imag.add(i))
}
}
#[derive(Copy, Clone, Debug)]
pub struct SplitPointer<T> {
pub real: RealPointer<T>,
pub imag: RealPointer<T>,
}
impl<T> SplitPointer<T> {
pub fn new(real: RealPointer<T>, imag: RealPointer<T>) -> Self {
Self { real, imag }
}
pub fn as_const(&self) -> ConstSplitPointer<T> {
ConstSplitPointer::new(self.real, self.imag)
}
pub unsafe fn get(&self, i: usize) -> Complex<T>
where
T: Copy,
{
Complex::new(*self.real.add(i), *self.imag.add(i))
}
pub unsafe fn get_mut(&mut self, i: usize) -> SplitValue<T> {
SplitValue::new(self.real.add(i), self.imag.add(i))
}
}
pub struct SplitValue<T> {
real_ptr: *mut T,
imag_ptr: *mut T,
}
impl<T> SplitValue<T> {
unsafe fn new(real_ptr: *mut T, imag_ptr: *mut T) -> Self {
Self { real_ptr, imag_ptr }
}
pub fn real(&self) -> T
where
T: Copy,
{
unsafe { *self.real_ptr }
}
pub fn set_real(&mut self, value: T)
where
T: Copy,
{
unsafe { *self.real_ptr = value }
}
pub fn imag(&self) -> T
where
T: Copy,
{
unsafe { *self.imag_ptr }
}
pub fn set_imag(&mut self, value: T)
where
T: Copy,
{
unsafe { *self.imag_ptr = value }
}
}
impl<T> From<SplitValue<T>> for Complex<T>
where
T: Copy,
{
fn from(value: SplitValue<T>) -> Self {
Complex::new(value.real(), value.imag())
}
}
pub trait ExpressionBase {
type Output;
fn get(&self, i: usize) -> Self::Output;
}
pub struct ConstantExpr<T> {
pub value: T,
}
impl<T: Copy> ExpressionBase for ConstantExpr<T> {
type Output = T;
fn get(&self, _i: usize) -> T {
self.value
}
}
pub struct ReadableReal<T> {
pub pointer: ConstRealPointer<T>,
}
impl<T: Copy> ExpressionBase for ReadableReal<T> {
type Output = T;
fn get(&self, i: usize) -> T {
unsafe { *self.pointer.add(i) }
}
}
pub struct ReadableComplex<T> {
pub pointer: ConstComplexPointer<T>,
}
impl<T: Copy> ExpressionBase for ReadableComplex<T> {
type Output = Complex<T>;
fn get(&self, i: usize) -> Complex<T> {
unsafe { *self.pointer.add(i) }
}
}
pub struct ReadableSplit<T> {
pub pointer: ConstSplitPointer<T>,
}
impl<T: Copy> ExpressionBase for ReadableSplit<T> {
type Output = Complex<T>;
fn get(&self, i: usize) -> Complex<T> {
unsafe { self.pointer.get(i) }
}
}
pub struct Expression<E: ExpressionBase> {
expr: E,
}
impl<E: ExpressionBase> Expression<E> {
pub fn new(expr: E) -> Self {
Self { expr }
}
pub fn get(&self, i: usize) -> E::Output {
self.expr.get(i)
}
}
pub struct WritableExpression<E: ExpressionBase> {
expr: E,
pointer: *mut E::Output,
}
impl<E: ExpressionBase> WritableExpression<E> {
pub fn new(expr: E, pointer: *mut E::Output, _size: usize) -> Self {
Self { expr, pointer }
}
pub fn get(&self, i: usize) -> E::Output {
self.expr.get(i)
}
pub unsafe fn get_mut(&mut self, i: usize) -> *mut E::Output {
self.pointer.add(i)
}
}
pub struct Linear {
#[cfg(feature = "alloc")]
cached_results: Option<CachedResults>,
}
impl Linear {
pub fn new() -> Self {
Self {
#[cfg(feature = "alloc")]
cached_results: None,
}
}
pub fn wrap_real<T: Copy>(&self, pointer: ConstRealPointer<T>) -> Expression<ReadableReal<T>> {
Expression::new(ReadableReal { pointer })
}
pub fn wrap_complex<T: Copy>(&self, pointer: ConstComplexPointer<T>) -> Expression<ReadableComplex<T>> {
Expression::new(ReadableComplex { pointer })
}
pub fn wrap_split<T: Copy>(&self, pointer: ConstSplitPointer<T>) -> Expression<ReadableSplit<T>> {
Expression::new(ReadableSplit { pointer })
}
pub fn wrap_real_mut<T: Copy>(&self, pointer: RealPointer<T>, size: usize) -> WritableExpression<ReadableReal<T>> {
WritableExpression::new(ReadableReal { pointer }, pointer as *mut T, size)
}
pub fn wrap_complex_mut<T: Copy>(&self, pointer: ComplexPointer<T>, size: usize) -> WritableExpression<ReadableComplex<T>> {
WritableExpression::new(ReadableComplex { pointer }, pointer as *mut Complex<T>, size)
}
pub fn wrap_split_mut<T: Copy>(&self, pointer: SplitPointer<T>, size: usize) -> WritableExpression<ReadableSplit<T>> {
WritableExpression::new(ReadableSplit { pointer: pointer.as_const() }, pointer.real as *mut Complex<T>, size)
}
pub fn fill_real<T, E>(&self, pointer: RealPointer<T>, expr: &Expression<E>, size: usize)
where
E: ExpressionBase<Output = T>,
T: Copy,
{
for i in 0..size {
unsafe {
*pointer.add(i) = expr.get(i);
}
}
}
pub fn fill_complex<T, E>(&self, pointer: ComplexPointer<T>, expr: &Expression<E>, size: usize)
where
E: ExpressionBase<Output = Complex<T>>,
T: Copy,
{
for i in 0..size {
unsafe {
*pointer.add(i) = expr.get(i);
}
}
}
pub fn fill_split<T, E>(&self, pointer: SplitPointer<T>, expr: &Expression<E>, size: usize)
where
E: ExpressionBase<Output = Complex<T>>,
T: Copy,
{
for i in 0..size {
let value = expr.get(i);
unsafe {
*pointer.real.add(i) = value.re;
*pointer.imag.add(i) = value.im;
}
}
}
pub fn reserve<T>(&mut self, _size: usize) {
}
}
#[cfg(feature = "alloc")]
pub struct Temporary<T> {
buffer: Vec<T>,
start: usize,
end: usize,
}
#[cfg(feature = "alloc")]
impl<T> Temporary<T> {
pub fn new() -> Self {
Self {
buffer: Vec::new(),
start: 0,
end: 0,
}
}
pub fn reserve(&mut self, size: usize) {
self.buffer.resize(size, unsafe { std::mem::zeroed() });
self.start = 0;
self.end = size;
}
pub fn clear(&mut self) {
self.start = 0;
}
pub fn get_chunk(&mut self, size: usize) -> &mut [T] {
if self.start + size > self.end {
self.buffer.resize(self.end + size, unsafe { std::mem::zeroed() });
self.end += size;
}
let chunk = &mut self.buffer[self.start..self.start + size];
self.start += size;
chunk
}
}
#[cfg(feature = "alloc")]
pub struct CachedResults {
floats: Temporary<f32>,
doubles: Temporary<f64>,
}
#[cfg(feature = "alloc")]
impl CachedResults {
pub fn new() -> Self {
Self {
floats: Temporary::new(),
doubles: Temporary::new(),
}
}
pub fn reserve_floats(&mut self, size: usize) {
self.floats.reserve(size);
}
pub fn reserve_doubles(&mut self, size: usize) {
self.doubles.reserve(size);
}
}
pub trait MathOps<T> {
fn abs(&self) -> Self;
fn norm(&self) -> Self;
fn exp(&self) -> Self;
fn log(&self) -> Self;
fn sqrt(&self) -> Self;
fn conj(&self) -> Self;
fn real(&self) -> Self;
fn imag(&self) -> Self;
}
impl<T: Float + FromPrimitive> MathOps<T> for Expression<ConstantExpr<T>> {
fn abs(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value.abs() })
}
fn norm(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value * self.expr.value })
}
fn exp(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value.exp() })
}
fn log(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value.ln() })
}
fn sqrt(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value.sqrt() })
}
fn conj(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value })
}
fn real(&self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value })
}
fn imag(&self) -> Self {
Expression::new(ConstantExpr { value: T::zero() })
}
}
pub trait BinaryOps<T> {
fn add(&self, other: &Self) -> Self;
fn sub(&self, other: &Self) -> Self;
fn mul(&self, other: &Self) -> Self;
fn div(&self, other: &Self) -> Self;
}
impl<T: Float + FromPrimitive> BinaryOps<T> for Expression<ConstantExpr<T>> {
fn add(&self, other: &Self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value + other.expr.value })
}
fn sub(&self, other: &Self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value - other.expr.value })
}
fn mul(&self, other: &Self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value * other.expr.value })
}
fn div(&self, other: &Self) -> Self {
Expression::new(ConstantExpr { value: self.expr.value / other.expr.value })
}
}
pub fn cheap_energy_crossfade<T: Float + FromPrimitive>(x: T) -> (T, T) {
let to_coeff = x;
let from_coeff = T::one() - x;
(to_coeff, from_coeff)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_expression() {
let expr = Expression::new(ConstantExpr { value: 2.5f32 });
assert_eq!(expr.get(0), 2.5);
assert_eq!(expr.get(100), 2.5); }
#[test]
fn test_math_ops() {
let expr = Expression::new(ConstantExpr { value: 4.0f32 });
let abs_expr = expr.abs();
assert_eq!(abs_expr.get(0), 4.0);
let sqrt_expr = expr.sqrt();
assert_eq!(sqrt_expr.get(0), 2.0);
let norm_expr = expr.norm();
assert_eq!(norm_expr.get(0), 16.0);
}
#[test]
fn test_binary_ops() {
let expr1 = Expression::new(ConstantExpr { value: 3.0f32 });
let expr2 = Expression::new(ConstantExpr { value: 2.0f32 });
let add_expr = expr1.add(&expr2);
assert_eq!(add_expr.get(0), 5.0);
let mul_expr = expr1.mul(&expr2);
assert_eq!(mul_expr.get(0), 6.0);
}
#[test]
fn test_cheap_energy_crossfade() {
let (to_coeff, from_coeff) = cheap_energy_crossfade(0.5f32);
assert!((to_coeff - 0.5).abs() < 1e-6);
assert!((from_coeff - 0.5).abs() < 1e-6);
assert!((to_coeff + from_coeff - 1.0).abs() < 1e-6);
}
#[test]
fn test_split_pointer() {
let mut real_data = [1.0f32, 2.0, 3.0];
let mut imag_data = [4.0f32, 5.0, 6.0];
let split_ptr = SplitPointer::new(real_data.as_mut_ptr(), imag_data.as_mut_ptr());
unsafe {
let complex_val = split_ptr.get(1);
assert_eq!(complex_val.re, 2.0);
assert_eq!(complex_val.im, 5.0);
}
}
#[test]
fn test_linear_fill() {
let linear = Linear::new();
let expr = Expression::new(ConstantExpr { value: 2.5f32 });
let mut data = [0.0f32; 4];
linear.fill_real(data.as_mut_ptr(), &expr, 4);
for &value in &data {
assert_eq!(value, 2.5);
}
}
}