1use ndarray::{Array1, ArrayView, ArrayViewMut, Dimension};
2use num_complex::Complex;
3use std::any::TypeId;
4use std::ffi::c_void;
5use std::mem::size_of;
6use std::cell::UnsafeCell;
7
8#[repr(C)]
26struct RustArrayDescriptor {
27 shape: [u64; 10], stride: [i64; 10],
29 data: *mut c_void,
30 ndim: u8,
31 dtype: u8,
32}
33
34fn format_shape(ndinp: &[usize]) -> [u64; 10] {
35 let mut res = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
36 for (i, elem) in ndinp.iter().enumerate() {
37 res[i] = *elem as u64;
38 }
39 return res;
40}
41
42fn format_stride(ndinp: &[isize]) -> [i64; 10] {
43 let mut res = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
44 for (i, elem) in ndinp.iter().enumerate() {
45 res[i] = *elem as i64;
46 }
47 return res;
48}
49
50fn type2typeid<A: 'static>() -> u8 {
51 if TypeId::of::<A>() == TypeId::of::<f64>() {
52 7
53 } else if TypeId::of::<A>() == TypeId::of::<f32>() {
54 3
55 } else if TypeId::of::<A>() == TypeId::of::<Complex<f64>>() {
56 7 + 64
57 } else if TypeId::of::<A>() == TypeId::of::<Complex<f32>>() {
58 3 + 64
59 } else if TypeId::of::<A>() == TypeId::of::<usize>() {
60 (size_of::<A>() - 1 + 32) as u8
61 } else {
62 println!("{}", std::any::type_name::<A>());
63 panic!("typeid not supported");
64 }
65}
66
67fn mutslice2arrdesc<'a, A: 'static, D: Dimension>(
69 slc: ArrayViewMut<'a, A, D>,
70) -> RustArrayDescriptor {
71 RustArrayDescriptor {
72 ndim: slc.shape().len() as u8,
73 dtype: type2typeid::<A>(),
74 shape: format_shape(slc.shape()),
75 stride: format_stride(slc.strides()),
76 data: slc.as_ptr() as *mut c_void,
77 }
78}
79
80fn slice2arrdesc<'a, A: 'static, D: Dimension>(slc: ArrayView<'a, A, D>) -> RustArrayDescriptor {
81 RustArrayDescriptor {
82 ndim: slc.shape().len() as u8,
83 dtype: type2typeid::<A>(),
84 shape: format_shape(slc.shape()),
85 stride: format_stride(slc.strides()),
86 data: slc.as_ptr() as *mut c_void,
87 }
88}
89extern "C" {
93 fn fft_c2c_(
94 inp: &RustArrayDescriptor,
95 out: &mut RustArrayDescriptor,
96 axes: &RustArrayDescriptor,
97 forward: bool,
98 fct: f64,
99 nthreads: usize,
100 );
101}
102
103pub fn fft_c2c<A: 'static, D: ndarray::Dimension>(
118 inp: ArrayView<Complex<A>, D>,
119 out: ArrayViewMut<Complex<A>, D>,
120 axes: &Vec<usize>,
121 forward: bool,
122 fct: f64,
123 nthreads: usize,
124) {
125 let inp2 = slice2arrdesc(inp);
126 let mut out2 = mutslice2arrdesc(out);
127 let axes2 = Array1::from_vec(axes.to_vec());
128 let axes3 = slice2arrdesc(axes2.view());
129 unsafe {
130 fft_c2c_(&inp2, &mut out2, &axes3, forward, fct, nthreads);
131 }
132}
133
134pub fn fft_c2c_inplace<A: 'static, D: ndarray::Dimension>(
138 inpout: ArrayViewMut<Complex<A>, D>,
139 axes: &Vec<usize>,
140 forward: bool,
141 fct: f64,
142 nthreads: usize,
143) {
144 let inpout2 = UnsafeCell::new(mutslice2arrdesc(inpout));
145 let axes2 = Array1::from_vec(axes.to_vec());
146 let axes3 = slice2arrdesc(axes2.view());
147 unsafe {
148 fft_c2c_(
149 &*inpout2.get(),
150 &mut *inpout2.get(),
151 &axes3,
152 forward,
153 fct,
154 nthreads,
155 );
156 }
157}
158#[cfg(test)]
161mod tests {
162 use super::*;
163 use ndarray::Array;
164 #[test]
172 fn fft_test() {
173 let shape = (2, 3, 3);
174
175 let b = Array::from_elem(shape, Complex::<f64>::new(12., 0.));
176 let mut c = Array::from_elem(shape, Complex::<f64>::new(0., 0.));
177 println!("{:8.4}", b);
178 let axes = vec![0, 2];
179 fft_c2c(b.view(), c.view_mut(), &axes, true, 1., 1);
180 println!("{:8.4}", c);
181
182 fft_c2c_inplace(c.view_mut(), &axes, true, 1., 1);
183 }
184}