1#![allow(unused_imports)]
8
9use num_traits::{Float, FromPrimitive, Zero, One};
10use num_complex::Complex;
11
12#[cfg(feature = "alloc")]
13use alloc::{vec::Vec, boxed::Box};
14
15pub type ConstRealPointer<T> = *const T;
17pub type RealPointer<T> = *mut T;
18
19pub type ConstComplexPointer<T> = *const Complex<T>;
21pub type ComplexPointer<T> = *mut Complex<T>;
22
23#[derive(Copy, Clone, Debug)]
25pub struct ConstSplitPointer<T> {
26 pub real: ConstRealPointer<T>,
27 pub imag: ConstRealPointer<T>,
28}
29
30impl<T> ConstSplitPointer<T> {
31 pub fn new(real: ConstRealPointer<T>, imag: ConstRealPointer<T>) -> Self {
32 Self { real, imag }
33 }
34
35 pub unsafe fn get(&self, i: usize) -> Complex<T>
37 where
38 T: Copy,
39 {
40 Complex::new(*self.real.add(i), *self.imag.add(i))
41 }
42}
43
44#[derive(Copy, Clone, Debug)]
46pub struct SplitPointer<T> {
47 pub real: RealPointer<T>,
48 pub imag: RealPointer<T>,
49}
50
51impl<T> SplitPointer<T> {
52 pub fn new(real: RealPointer<T>, imag: RealPointer<T>) -> Self {
53 Self { real, imag }
54 }
55
56 pub fn as_const(&self) -> ConstSplitPointer<T> {
58 ConstSplitPointer::new(self.real, self.imag)
59 }
60
61 pub unsafe fn get(&self, i: usize) -> Complex<T>
63 where
64 T: Copy,
65 {
66 Complex::new(*self.real.add(i), *self.imag.add(i))
67 }
68
69 pub unsafe fn get_mut(&mut self, i: usize) -> SplitValue<T> {
71 SplitValue::new(self.real.add(i), self.imag.add(i))
72 }
73}
74
75pub struct SplitValue<T> {
77 real_ptr: *mut T,
78 imag_ptr: *mut T,
79}
80
81impl<T> SplitValue<T> {
82 unsafe fn new(real_ptr: *mut T, imag_ptr: *mut T) -> Self {
83 Self { real_ptr, imag_ptr }
84 }
85
86 pub fn real(&self) -> T
87 where
88 T: Copy,
89 {
90 unsafe { *self.real_ptr }
91 }
92
93 pub fn set_real(&mut self, value: T)
94 where
95 T: Copy,
96 {
97 unsafe { *self.real_ptr = value }
98 }
99
100 pub fn imag(&self) -> T
101 where
102 T: Copy,
103 {
104 unsafe { *self.imag_ptr }
105 }
106
107 pub fn set_imag(&mut self, value: T)
108 where
109 T: Copy,
110 {
111 unsafe { *self.imag_ptr = value }
112 }
113}
114
115impl<T> From<SplitValue<T>> for Complex<T>
116where
117 T: Copy,
118{
119 fn from(value: SplitValue<T>) -> Self {
120 Complex::new(value.real(), value.imag())
121 }
122}
123
124pub trait ExpressionBase {
126 type Output;
127 fn get(&self, i: usize) -> Self::Output;
128}
129
130pub struct ConstantExpr<T> {
132 pub value: T,
133}
134
135impl<T: Copy> ExpressionBase for ConstantExpr<T> {
136 type Output = T;
137 fn get(&self, _i: usize) -> T {
138 self.value
139 }
140}
141
142pub struct ReadableReal<T> {
144 pub pointer: ConstRealPointer<T>,
145}
146
147impl<T: Copy> ExpressionBase for ReadableReal<T> {
148 type Output = T;
149 fn get(&self, i: usize) -> T {
150 unsafe { *self.pointer.add(i) }
151 }
152}
153
154pub struct ReadableComplex<T> {
156 pub pointer: ConstComplexPointer<T>,
157}
158
159impl<T: Copy> ExpressionBase for ReadableComplex<T> {
160 type Output = Complex<T>;
161 fn get(&self, i: usize) -> Complex<T> {
162 unsafe { *self.pointer.add(i) }
163 }
164}
165
166pub struct ReadableSplit<T> {
168 pub pointer: ConstSplitPointer<T>,
169}
170
171impl<T: Copy> ExpressionBase for ReadableSplit<T> {
172 type Output = Complex<T>;
173 fn get(&self, i: usize) -> Complex<T> {
174 unsafe { self.pointer.get(i) }
175 }
176}
177
178pub struct Expression<E: ExpressionBase> {
180 expr: E,
181}
182
183impl<E: ExpressionBase> Expression<E> {
184 pub fn new(expr: E) -> Self {
185 Self { expr }
186 }
187
188 pub fn get(&self, i: usize) -> E::Output {
189 self.expr.get(i)
190 }
191}
192
193pub struct WritableExpression<E: ExpressionBase> {
195 expr: E,
196 pointer: *mut E::Output,
197}
198
199impl<E: ExpressionBase> WritableExpression<E> {
200 pub fn new(expr: E, pointer: *mut E::Output, _size: usize) -> Self {
201 Self { expr, pointer }
202 }
203
204 pub fn get(&self, i: usize) -> E::Output {
205 self.expr.get(i)
206 }
207
208 pub unsafe fn get_mut(&mut self, i: usize) -> *mut E::Output {
209 self.pointer.add(i)
210 }
211}
212
213pub struct Linear {
215 #[cfg(feature = "alloc")]
216 cached_results: Option<CachedResults>,
217}
218
219impl Linear {
220 pub fn new() -> Self {
221 Self {
222 #[cfg(feature = "alloc")]
223 cached_results: None,
224 }
225 }
226
227 pub fn wrap_real<T: Copy>(&self, pointer: ConstRealPointer<T>) -> Expression<ReadableReal<T>> {
229 Expression::new(ReadableReal { pointer })
230 }
231
232 pub fn wrap_complex<T: Copy>(&self, pointer: ConstComplexPointer<T>) -> Expression<ReadableComplex<T>> {
234 Expression::new(ReadableComplex { pointer })
235 }
236
237 pub fn wrap_split<T: Copy>(&self, pointer: ConstSplitPointer<T>) -> Expression<ReadableSplit<T>> {
239 Expression::new(ReadableSplit { pointer })
240 }
241
242 pub fn wrap_real_mut<T: Copy>(&self, pointer: RealPointer<T>, size: usize) -> WritableExpression<ReadableReal<T>> {
244 WritableExpression::new(ReadableReal { pointer }, pointer as *mut T, size)
245 }
246
247 pub fn wrap_complex_mut<T: Copy>(&self, pointer: ComplexPointer<T>, size: usize) -> WritableExpression<ReadableComplex<T>> {
249 WritableExpression::new(ReadableComplex { pointer }, pointer as *mut Complex<T>, size)
250 }
251
252 pub fn wrap_split_mut<T: Copy>(&self, pointer: SplitPointer<T>, size: usize) -> WritableExpression<ReadableSplit<T>> {
254 WritableExpression::new(ReadableSplit { pointer: pointer.as_const() }, pointer.real as *mut Complex<T>, size)
255 }
256
257 pub fn fill_real<T, E>(&self, pointer: RealPointer<T>, expr: &Expression<E>, size: usize)
259 where
260 E: ExpressionBase<Output = T>,
261 T: Copy,
262 {
263 for i in 0..size {
264 unsafe {
265 *pointer.add(i) = expr.get(i);
266 }
267 }
268 }
269
270 pub fn fill_complex<T, E>(&self, pointer: ComplexPointer<T>, expr: &Expression<E>, size: usize)
272 where
273 E: ExpressionBase<Output = Complex<T>>,
274 T: Copy,
275 {
276 for i in 0..size {
277 unsafe {
278 *pointer.add(i) = expr.get(i);
279 }
280 }
281 }
282
283 pub fn fill_split<T, E>(&self, pointer: SplitPointer<T>, expr: &Expression<E>, size: usize)
285 where
286 E: ExpressionBase<Output = Complex<T>>,
287 T: Copy,
288 {
289 for i in 0..size {
290 let value = expr.get(i);
291 unsafe {
292 *pointer.real.add(i) = value.re;
293 *pointer.imag.add(i) = value.im;
294 }
295 }
296 }
297
298 pub fn reserve<T>(&mut self, _size: usize) {
300 }
302}
303
304#[cfg(feature = "alloc")]
306pub struct Temporary<T> {
307 buffer: Vec<T>,
308 start: usize,
309 end: usize,
310}
311
312#[cfg(feature = "alloc")]
313impl<T> Temporary<T> {
314 pub fn new() -> Self {
315 Self {
316 buffer: Vec::new(),
317 start: 0,
318 end: 0,
319 }
320 }
321
322 pub fn reserve(&mut self, size: usize) {
323 self.buffer.resize(size, unsafe { std::mem::zeroed() });
324 self.start = 0;
325 self.end = size;
326 }
327
328 pub fn clear(&mut self) {
329 self.start = 0;
330 }
331
332 pub fn get_chunk(&mut self, size: usize) -> &mut [T] {
333 if self.start + size > self.end {
334 self.buffer.resize(self.end + size, unsafe { std::mem::zeroed() });
336 self.end += size;
337 }
338 let chunk = &mut self.buffer[self.start..self.start + size];
339 self.start += size;
340 chunk
341 }
342}
343
344#[cfg(feature = "alloc")]
346pub struct CachedResults {
347 floats: Temporary<f32>,
348 doubles: Temporary<f64>,
349}
350
351#[cfg(feature = "alloc")]
352impl CachedResults {
353 pub fn new() -> Self {
354 Self {
355 floats: Temporary::new(),
356 doubles: Temporary::new(),
357 }
358 }
359
360 pub fn reserve_floats(&mut self, size: usize) {
361 self.floats.reserve(size);
362 }
363
364 pub fn reserve_doubles(&mut self, size: usize) {
365 self.doubles.reserve(size);
366 }
367}
368
369pub trait MathOps<T> {
371 fn abs(&self) -> Self;
372 fn norm(&self) -> Self;
373 fn exp(&self) -> Self;
374 fn log(&self) -> Self;
375 fn sqrt(&self) -> Self;
376 fn conj(&self) -> Self;
377 fn real(&self) -> Self;
378 fn imag(&self) -> Self;
379}
380
381impl<T: Float + FromPrimitive> MathOps<T> for Expression<ConstantExpr<T>> {
382 fn abs(&self) -> Self {
383 Expression::new(ConstantExpr { value: self.expr.value.abs() })
384 }
385
386 fn norm(&self) -> Self {
387 Expression::new(ConstantExpr { value: self.expr.value * self.expr.value })
388 }
389
390 fn exp(&self) -> Self {
391 Expression::new(ConstantExpr { value: self.expr.value.exp() })
392 }
393
394 fn log(&self) -> Self {
395 Expression::new(ConstantExpr { value: self.expr.value.ln() })
396 }
397
398 fn sqrt(&self) -> Self {
399 Expression::new(ConstantExpr { value: self.expr.value.sqrt() })
400 }
401
402 fn conj(&self) -> Self {
403 Expression::new(ConstantExpr { value: self.expr.value })
404 }
405
406 fn real(&self) -> Self {
407 Expression::new(ConstantExpr { value: self.expr.value })
408 }
409
410 fn imag(&self) -> Self {
411 Expression::new(ConstantExpr { value: T::zero() })
412 }
413}
414
415pub trait BinaryOps<T> {
417 fn add(&self, other: &Self) -> Self;
418 fn sub(&self, other: &Self) -> Self;
419 fn mul(&self, other: &Self) -> Self;
420 fn div(&self, other: &Self) -> Self;
421}
422
423impl<T: Float + FromPrimitive> BinaryOps<T> for Expression<ConstantExpr<T>> {
424 fn add(&self, other: &Self) -> Self {
425 Expression::new(ConstantExpr { value: self.expr.value + other.expr.value })
426 }
427
428 fn sub(&self, other: &Self) -> Self {
429 Expression::new(ConstantExpr { value: self.expr.value - other.expr.value })
430 }
431
432 fn mul(&self, other: &Self) -> Self {
433 Expression::new(ConstantExpr { value: self.expr.value * other.expr.value })
434 }
435
436 fn div(&self, other: &Self) -> Self {
437 Expression::new(ConstantExpr { value: self.expr.value / other.expr.value })
438 }
439}
440
441pub fn cheap_energy_crossfade<T: Float + FromPrimitive>(x: T) -> (T, T) {
443 let to_coeff = x;
444 let from_coeff = T::one() - x;
445 (to_coeff, from_coeff)
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_constant_expression() {
454 let expr = Expression::new(ConstantExpr { value: 2.5f32 });
455 assert_eq!(expr.get(0), 2.5);
456 assert_eq!(expr.get(100), 2.5); }
458
459 #[test]
460 fn test_math_ops() {
461 let expr = Expression::new(ConstantExpr { value: 4.0f32 });
462
463 let abs_expr = expr.abs();
464 assert_eq!(abs_expr.get(0), 4.0);
465
466 let sqrt_expr = expr.sqrt();
467 assert_eq!(sqrt_expr.get(0), 2.0);
468
469 let norm_expr = expr.norm();
470 assert_eq!(norm_expr.get(0), 16.0);
471 }
472
473 #[test]
474 fn test_binary_ops() {
475 let expr1 = Expression::new(ConstantExpr { value: 3.0f32 });
476 let expr2 = Expression::new(ConstantExpr { value: 2.0f32 });
477
478 let add_expr = expr1.add(&expr2);
479 assert_eq!(add_expr.get(0), 5.0);
480
481 let mul_expr = expr1.mul(&expr2);
482 assert_eq!(mul_expr.get(0), 6.0);
483 }
484
485 #[test]
486 fn test_cheap_energy_crossfade() {
487 let (to_coeff, from_coeff) = cheap_energy_crossfade(0.5f32);
488 assert!((to_coeff - 0.5).abs() < 1e-6);
489 assert!((from_coeff - 0.5).abs() < 1e-6);
490 assert!((to_coeff + from_coeff - 1.0).abs() < 1e-6);
491 }
492
493 #[test]
494 fn test_split_pointer() {
495 let mut real_data = [1.0f32, 2.0, 3.0];
496 let mut imag_data = [4.0f32, 5.0, 6.0];
497
498 let split_ptr = SplitPointer::new(real_data.as_mut_ptr(), imag_data.as_mut_ptr());
499
500 unsafe {
501 let complex_val = split_ptr.get(1);
502 assert_eq!(complex_val.re, 2.0);
503 assert_eq!(complex_val.im, 5.0);
504 }
505 }
506
507 #[test]
508 fn test_linear_fill() {
509 let linear = Linear::new();
510 let expr = Expression::new(ConstantExpr { value: 2.5f32 });
511 let mut data = [0.0f32; 4];
512
513 linear.fill_real(data.as_mut_ptr(), &expr, 4);
514
515 for &value in &data {
516 assert_eq!(value, 2.5);
517 }
518 }
519}