use rustfft::num_complex::Complex;
use rustfft::num_traits::Zero;
use std::{cell::RefCell, rc::Rc};
use crate::float::Float;
pub enum ComplexComponent {
Re,
Im,
}
pub fn new_real_buffer<T: Float>(size: usize) -> Vec<T> {
vec![T::zero(); size]
}
pub fn new_complex_buffer<T: Float>(size: usize) -> Vec<Complex<T>> {
vec![Complex::zero(); size]
}
pub fn copy_real_to_complex<T: Float>(
input: &[T],
output: &mut [Complex<T>],
component: ComplexComponent,
) {
assert!(input.len() <= output.len());
match component {
ComplexComponent::Re => input.iter().zip(output.iter_mut()).for_each(|(i, o)| {
o.re = *i;
o.im = T::zero();
}),
ComplexComponent::Im => input.iter().zip(output.iter_mut()).for_each(|(i, o)| {
o.im = *i;
o.re = T::zero();
}),
}
output[input.len()..]
.iter_mut()
.for_each(|o| *o = Complex::zero())
}
pub fn copy_complex_to_real<T: Float>(
input: &[Complex<T>],
output: &mut [T],
component: ComplexComponent,
) {
assert!(input.len() <= output.len());
match component {
ComplexComponent::Re => input
.iter()
.map(|c| c.re)
.zip(output.iter_mut())
.for_each(|(i, o)| *o = i),
ComplexComponent::Im => input
.iter()
.map(|c| c.im)
.zip(output.iter_mut())
.for_each(|(i, o)| *o = i),
}
output[input.len()..]
.iter_mut()
.for_each(|o| *o = T::zero());
}
pub fn modulus_squared<'a, T: Float>(arr: &'a mut [Complex<T>]) {
for mut s in arr {
s.re = s.re * s.re + s.im * s.im;
s.im = T::zero();
}
}
pub fn square_sum<T>(arr: &[T]) -> T
where
T: Float + std::iter::Sum,
{
arr.iter().map(|&s| s * s).sum::<T>()
}
#[derive(Debug)]
pub struct BufferPool<T> {
real_buffers: Vec<Rc<RefCell<Vec<T>>>>,
complex_buffers: Vec<Rc<RefCell<Vec<Complex<T>>>>>,
pub buffer_size: usize,
}
impl<T: Float> BufferPool<T> {
pub fn new(buffer_size: usize) -> Self {
BufferPool {
real_buffers: vec![],
complex_buffers: vec![],
buffer_size,
}
}
fn add_real_buffer(&mut self) -> Rc<RefCell<Vec<T>>> {
self.real_buffers
.push(Rc::new(RefCell::new(new_real_buffer::<T>(
self.buffer_size,
))));
Rc::clone(&self.real_buffers.last().unwrap())
}
fn add_complex_buffer(&mut self) -> Rc<RefCell<Vec<Complex<T>>>> {
self.complex_buffers
.push(Rc::new(RefCell::new(new_complex_buffer::<T>(
self.buffer_size,
))));
Rc::clone(&self.complex_buffers.last().unwrap())
}
pub fn get_real_buffer(&mut self) -> Rc<RefCell<Vec<T>>> {
self.real_buffers
.iter()
.find(|&buf| Rc::strong_count(buf) == 1)
.map(|buf| Rc::clone(buf))
.unwrap_or_else(|| self.add_real_buffer())
}
pub fn get_complex_buffer(&mut self) -> Rc<RefCell<Vec<Complex<T>>>> {
self.complex_buffers
.iter()
.find(|&buf| Rc::strong_count(buf) == 1)
.map(|buf| Rc::clone(buf))
.unwrap_or_else(|| self.add_complex_buffer())
}
}
#[test]
fn test_buffers() {
let mut buffers = BufferPool::new(3);
let buf_cell1 = buffers.get_real_buffer();
{
let mut buf1 = buf_cell1.borrow_mut();
buf1[0] = 5.5;
}
{
let buf_cell2 = buffers.get_real_buffer();
let mut buf2 = buf_cell2.borrow_mut();
buf2[1] = 6.6;
}
{
let buf_cell3 = buffers.get_real_buffer();
let mut buf3 = buf_cell3.borrow_mut();
buf3[2] = 7.7;
}
assert_eq!(&buffers.real_buffers[0].borrow()[..], &[5.5, 0., 0.]);
assert_eq!(&buffers.real_buffers[1].borrow()[..], &[0.0, 6.6, 7.7]);
}