cute_dsp/
linear.rs

1//! Linear algebra and expression template system
2//!
3//! This module provides a flexible expression template system for efficient vector operations,
4//! supporting both real and complex numbers with automatic optimization and SIMD acceleration
5//! where available.
6
7#![allow(unused_imports)]
8
9use num_traits::{Float, FromPrimitive, Zero, One};
10use num_complex::Complex;
11
12#[cfg(feature = "alloc")]
13use alloc::{vec::Vec, boxed::Box};
14
15/// Pointer types for real data
16pub type ConstRealPointer<T> = *const T;
17pub type RealPointer<T> = *mut T;
18
19/// Pointer types for complex data
20pub type ConstComplexPointer<T> = *const Complex<T>;
21pub type ComplexPointer<T> = *mut Complex<T>;
22
23/// Split pointer for separate real/imaginary parts
24#[derive(Copy, Clone, Debug)]
25pub struct ConstSplitPointer<T> {
26    pub real: ConstRealPointer<T>,
27    pub imag: ConstRealPointer<T>,
28}
29
30impl<T> ConstSplitPointer<T> {
31    pub fn new(real: ConstRealPointer<T>, imag: ConstRealPointer<T>) -> Self {
32        Self { real, imag }
33    }
34
35    /// Array-like access for convenience
36    pub unsafe fn get(&self, i: usize) -> Complex<T>
37    where
38        T: Copy,
39    {
40        Complex::new(*self.real.add(i), *self.imag.add(i))
41    }
42}
43
44/// Mutable split pointer for separate real/imaginary parts
45#[derive(Copy, Clone, Debug)]
46pub struct SplitPointer<T> {
47    pub real: RealPointer<T>,
48    pub imag: RealPointer<T>,
49}
50
51impl<T> SplitPointer<T> {
52    pub fn new(real: RealPointer<T>, imag: RealPointer<T>) -> Self {
53        Self { real, imag }
54    }
55
56    /// Convert to const split pointer
57    pub fn as_const(&self) -> ConstSplitPointer<T> {
58        ConstSplitPointer::new(self.real, self.imag)
59    }
60
61    /// Array-like access for convenience
62    pub unsafe fn get(&self, i: usize) -> Complex<T>
63    where
64        T: Copy,
65    {
66        Complex::new(*self.real.add(i), *self.imag.add(i))
67    }
68
69    /// Mutable array-like access
70    pub unsafe fn get_mut(&mut self, i: usize) -> SplitValue<T> {
71        SplitValue::new(self.real.add(i), self.imag.add(i))
72    }
73}
74
75/// Mutable value that can be assigned to split pointer elements
76pub struct SplitValue<T> {
77    real_ptr: *mut T,
78    imag_ptr: *mut T,
79}
80
81impl<T> SplitValue<T> {
82    unsafe fn new(real_ptr: *mut T, imag_ptr: *mut T) -> Self {
83        Self { real_ptr, imag_ptr }
84    }
85
86    pub fn real(&self) -> T
87    where
88        T: Copy,
89    {
90        unsafe { *self.real_ptr }
91    }
92
93    pub fn set_real(&mut self, value: T)
94    where
95        T: Copy,
96    {
97        unsafe { *self.real_ptr = value }
98    }
99
100    pub fn imag(&self) -> T
101    where
102        T: Copy,
103    {
104        unsafe { *self.imag_ptr }
105    }
106
107    pub fn set_imag(&mut self, value: T)
108    where
109        T: Copy,
110    {
111        unsafe { *self.imag_ptr = value }
112    }
113}
114
115impl<T> From<SplitValue<T>> for Complex<T>
116where
117    T: Copy,
118{
119    fn from(value: SplitValue<T>) -> Self {
120        Complex::new(value.real(), value.imag())
121    }
122}
123
124/// Base trait for all expressions
125pub trait ExpressionBase {
126    type Output;
127    fn get(&self, i: usize) -> Self::Output;
128}
129
130/// Constant expression
131pub struct ConstantExpr<T> {
132    pub value: T,
133}
134
135impl<T: Copy> ExpressionBase for ConstantExpr<T> {
136    type Output = T;
137    fn get(&self, _i: usize) -> T {
138        self.value
139    }
140}
141
142/// Readable real expression
143pub struct ReadableReal<T> {
144    pub pointer: ConstRealPointer<T>,
145}
146
147impl<T: Copy> ExpressionBase for ReadableReal<T> {
148    type Output = T;
149    fn get(&self, i: usize) -> T {
150        unsafe { *self.pointer.add(i) }
151    }
152}
153
154/// Readable complex expression
155pub struct ReadableComplex<T> {
156    pub pointer: ConstComplexPointer<T>,
157}
158
159impl<T: Copy> ExpressionBase for ReadableComplex<T> {
160    type Output = Complex<T>;
161    fn get(&self, i: usize) -> Complex<T> {
162        unsafe { *self.pointer.add(i) }
163    }
164}
165
166/// Readable split expression
167pub struct ReadableSplit<T> {
168    pub pointer: ConstSplitPointer<T>,
169}
170
171impl<T: Copy> ExpressionBase for ReadableSplit<T> {
172    type Output = Complex<T>;
173    fn get(&self, i: usize) -> Complex<T> {
174        unsafe { self.pointer.get(i) }
175    }
176}
177
178/// Expression wrapper
179pub struct Expression<E: ExpressionBase> {
180    expr: E,
181}
182
183impl<E: ExpressionBase> Expression<E> {
184    pub fn new(expr: E) -> Self {
185        Self { expr }
186    }
187
188    pub fn get(&self, i: usize) -> E::Output {
189        self.expr.get(i)
190    }
191}
192
193/// Writable expression wrapper
194pub struct WritableExpression<E: ExpressionBase> {
195    expr: E,
196    pointer: *mut E::Output,
197}
198
199impl<E: ExpressionBase> WritableExpression<E> {
200    pub fn new(expr: E, pointer: *mut E::Output, _size: usize) -> Self {
201        Self { expr, pointer }
202    }
203
204    pub fn get(&self, i: usize) -> E::Output {
205        self.expr.get(i)
206    }
207
208    pub unsafe fn get_mut(&mut self, i: usize) -> *mut E::Output {
209        self.pointer.add(i)
210    }
211}
212
213/// Linear algebra implementation
214pub struct Linear {
215    #[cfg(feature = "alloc")]
216    cached_results: Option<CachedResults>,
217}
218
219impl Linear {
220    pub fn new() -> Self {
221        Self {
222            #[cfg(feature = "alloc")]
223            cached_results: None,
224        }
225    }
226
227    /// Wrap a real pointer as an expression
228    pub fn wrap_real<T: Copy>(&self, pointer: ConstRealPointer<T>) -> Expression<ReadableReal<T>> {
229        Expression::new(ReadableReal { pointer })
230    }
231
232    /// Wrap a complex pointer as an expression
233    pub fn wrap_complex<T: Copy>(&self, pointer: ConstComplexPointer<T>) -> Expression<ReadableComplex<T>> {
234        Expression::new(ReadableComplex { pointer })
235    }
236
237    /// Wrap a split pointer as an expression
238    pub fn wrap_split<T: Copy>(&self, pointer: ConstSplitPointer<T>) -> Expression<ReadableSplit<T>> {
239        Expression::new(ReadableSplit { pointer })
240    }
241
242    /// Wrap a mutable real pointer as a writable expression
243    pub fn wrap_real_mut<T: Copy>(&self, pointer: RealPointer<T>, size: usize) -> WritableExpression<ReadableReal<T>> {
244        WritableExpression::new(ReadableReal { pointer }, pointer as *mut T, size)
245    }
246
247    /// Wrap a mutable complex pointer as a writable expression
248    pub fn wrap_complex_mut<T: Copy>(&self, pointer: ComplexPointer<T>, size: usize) -> WritableExpression<ReadableComplex<T>> {
249        WritableExpression::new(ReadableComplex { pointer }, pointer as *mut Complex<T>, size)
250    }
251
252    /// Wrap a mutable split pointer as a writable expression
253    pub fn wrap_split_mut<T: Copy>(&self, pointer: SplitPointer<T>, size: usize) -> WritableExpression<ReadableSplit<T>> {
254        WritableExpression::new(ReadableSplit { pointer: pointer.as_const() }, pointer.real as *mut Complex<T>, size)
255    }
256
257    /// Fill a real array with values from an expression
258    pub fn fill_real<T, E>(&self, pointer: RealPointer<T>, expr: &Expression<E>, size: usize)
259    where
260        E: ExpressionBase<Output = T>,
261        T: Copy,
262    {
263        for i in 0..size {
264            unsafe {
265                *pointer.add(i) = expr.get(i);
266            }
267        }
268    }
269
270    /// Fill a complex array with values from an expression
271    pub fn fill_complex<T, E>(&self, pointer: ComplexPointer<T>, expr: &Expression<E>, size: usize)
272    where
273        E: ExpressionBase<Output = Complex<T>>,
274        T: Copy,
275    {
276        for i in 0..size {
277            unsafe {
278                *pointer.add(i) = expr.get(i);
279            }
280        }
281    }
282
283    /// Fill a split array with values from an expression
284    pub fn fill_split<T, E>(&self, pointer: SplitPointer<T>, expr: &Expression<E>, size: usize)
285    where
286        E: ExpressionBase<Output = Complex<T>>,
287        T: Copy,
288    {
289        for i in 0..size {
290            let value = expr.get(i);
291            unsafe {
292                *pointer.real.add(i) = value.re;
293                *pointer.imag.add(i) = value.im;
294            }
295        }
296    }
297
298    /// Reserve temporary storage
299    pub fn reserve<T>(&mut self, _size: usize) {
300        // Implementation would depend on cached results
301    }
302}
303
304/// Temporary storage for intermediate calculations
305#[cfg(feature = "alloc")]
306pub struct Temporary<T> {
307    buffer: Vec<T>,
308    start: usize,
309    end: usize,
310}
311
312#[cfg(feature = "alloc")]
313impl<T> Temporary<T> {
314    pub fn new() -> Self {
315        Self {
316            buffer: Vec::new(),
317            start: 0,
318            end: 0,
319        }
320    }
321
322    pub fn reserve(&mut self, size: usize) {
323        self.buffer.resize(size, unsafe { std::mem::zeroed() });
324        self.start = 0;
325        self.end = size;
326    }
327
328    pub fn clear(&mut self) {
329        self.start = 0;
330    }
331
332    pub fn get_chunk(&mut self, size: usize) -> &mut [T] {
333        if self.start + size > self.end {
334            // Need to allocate more space
335            self.buffer.resize(self.end + size, unsafe { std::mem::zeroed() });
336            self.end += size;
337        }
338        let chunk = &mut self.buffer[self.start..self.start + size];
339        self.start += size;
340        chunk
341    }
342}
343
344/// Cached results for optimization
345#[cfg(feature = "alloc")]
346pub struct CachedResults {
347    floats: Temporary<f32>,
348    doubles: Temporary<f64>,
349}
350
351#[cfg(feature = "alloc")]
352impl CachedResults {
353    pub fn new() -> Self {
354        Self {
355            floats: Temporary::new(),
356            doubles: Temporary::new(),
357        }
358    }
359
360    pub fn reserve_floats(&mut self, size: usize) {
361        self.floats.reserve(size);
362    }
363
364    pub fn reserve_doubles(&mut self, size: usize) {
365        self.doubles.reserve(size);
366    }
367}
368
369/// Mathematical functions for expressions
370pub trait MathOps<T> {
371    fn abs(&self) -> Self;
372    fn norm(&self) -> Self;
373    fn exp(&self) -> Self;
374    fn log(&self) -> Self;
375    fn sqrt(&self) -> Self;
376    fn conj(&self) -> Self;
377    fn real(&self) -> Self;
378    fn imag(&self) -> Self;
379}
380
381impl<T: Float + FromPrimitive> MathOps<T> for Expression<ConstantExpr<T>> {
382    fn abs(&self) -> Self {
383        Expression::new(ConstantExpr { value: self.expr.value.abs() })
384    }
385
386    fn norm(&self) -> Self {
387        Expression::new(ConstantExpr { value: self.expr.value * self.expr.value })
388    }
389
390    fn exp(&self) -> Self {
391        Expression::new(ConstantExpr { value: self.expr.value.exp() })
392    }
393
394    fn log(&self) -> Self {
395        Expression::new(ConstantExpr { value: self.expr.value.ln() })
396    }
397
398    fn sqrt(&self) -> Self {
399        Expression::new(ConstantExpr { value: self.expr.value.sqrt() })
400    }
401
402    fn conj(&self) -> Self {
403        Expression::new(ConstantExpr { value: self.expr.value })
404    }
405
406    fn real(&self) -> Self {
407        Expression::new(ConstantExpr { value: self.expr.value })
408    }
409
410    fn imag(&self) -> Self {
411        Expression::new(ConstantExpr { value: T::zero() })
412    }
413}
414
415/// Binary operations for expressions
416pub trait BinaryOps<T> {
417    fn add(&self, other: &Self) -> Self;
418    fn sub(&self, other: &Self) -> Self;
419    fn mul(&self, other: &Self) -> Self;
420    fn div(&self, other: &Self) -> Self;
421}
422
423impl<T: Float + FromPrimitive> BinaryOps<T> for Expression<ConstantExpr<T>> {
424    fn add(&self, other: &Self) -> Self {
425        Expression::new(ConstantExpr { value: self.expr.value + other.expr.value })
426    }
427
428    fn sub(&self, other: &Self) -> Self {
429        Expression::new(ConstantExpr { value: self.expr.value - other.expr.value })
430    }
431
432    fn mul(&self, other: &Self) -> Self {
433        Expression::new(ConstantExpr { value: self.expr.value * other.expr.value })
434    }
435
436    fn div(&self, other: &Self) -> Self {
437        Expression::new(ConstantExpr { value: self.expr.value / other.expr.value })
438    }
439}
440
441/// Cheap energy-preserving crossfade
442pub fn cheap_energy_crossfade<T: Float + FromPrimitive>(x: T) -> (T, T) {
443    let to_coeff = x;
444    let from_coeff = T::one() - x;
445    (to_coeff, from_coeff)
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_constant_expression() {
454        let expr = Expression::new(ConstantExpr { value: 2.5f32 });
455        assert_eq!(expr.get(0), 2.5);
456        assert_eq!(expr.get(100), 2.5); // Same value for all indices
457    }
458
459    #[test]
460    fn test_math_ops() {
461        let expr = Expression::new(ConstantExpr { value: 4.0f32 });
462        
463        let abs_expr = expr.abs();
464        assert_eq!(abs_expr.get(0), 4.0);
465        
466        let sqrt_expr = expr.sqrt();
467        assert_eq!(sqrt_expr.get(0), 2.0);
468        
469        let norm_expr = expr.norm();
470        assert_eq!(norm_expr.get(0), 16.0);
471    }
472
473    #[test]
474    fn test_binary_ops() {
475        let expr1 = Expression::new(ConstantExpr { value: 3.0f32 });
476        let expr2 = Expression::new(ConstantExpr { value: 2.0f32 });
477        
478        let add_expr = expr1.add(&expr2);
479        assert_eq!(add_expr.get(0), 5.0);
480        
481        let mul_expr = expr1.mul(&expr2);
482        assert_eq!(mul_expr.get(0), 6.0);
483    }
484
485    #[test]
486    fn test_cheap_energy_crossfade() {
487        let (to_coeff, from_coeff) = cheap_energy_crossfade(0.5f32);
488        assert!((to_coeff - 0.5).abs() < 1e-6);
489        assert!((from_coeff - 0.5).abs() < 1e-6);
490        assert!((to_coeff + from_coeff - 1.0).abs() < 1e-6);
491    }
492
493    #[test]
494    fn test_split_pointer() {
495        let mut real_data = [1.0f32, 2.0, 3.0];
496        let mut imag_data = [4.0f32, 5.0, 6.0];
497        
498        let split_ptr = SplitPointer::new(real_data.as_mut_ptr(), imag_data.as_mut_ptr());
499        
500        unsafe {
501            let complex_val = split_ptr.get(1);
502            assert_eq!(complex_val.re, 2.0);
503            assert_eq!(complex_val.im, 5.0);
504        }
505    }
506
507    #[test]
508    fn test_linear_fill() {
509        let linear = Linear::new();
510        let expr = Expression::new(ConstantExpr { value: 2.5f32 });
511        let mut data = [0.0f32; 4];
512        
513        linear.fill_real(data.as_mut_ptr(), &expr, 4);
514        
515        for &value in &data {
516            assert_eq!(value, 2.5);
517        }
518    }
519}