ducc0/
lib.rs

1use ndarray::{Array1, ArrayView, ArrayViewMut, Dimension};
2use num_complex::Complex;
3use std::any::TypeId;
4use std::ffi::c_void;
5use std::mem::size_of;
6use std::cell::UnsafeCell;
7
8// Support -march=native
9// Support -ffast-math (only for building ducc)
10
11// 1) c2c, genuine_hartley
12// 2a) c2r, r2c, seperable_hartley, dct, dst
13// 2b) nfft
14// n-1) sht
15// n) radio response
16
17
18// Debugging
19// fn print_type_of<T>(_: &T) {
20//     println!("{}", std::any::type_name::<T>())
21// }
22// /Debugging
23
24// Related to RustArrayDescriptor
25#[repr(C)]
26struct RustArrayDescriptor {
27    shape: [u64; 10], // TODO Make the "10" variable
28    stride: [i64; 10],
29    data: *mut c_void,
30    ndim: u8,
31    dtype: u8,
32}
33
34fn format_shape(ndinp: &[usize]) -> [u64; 10] {
35    let mut res = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
36    for (i, elem) in ndinp.iter().enumerate() {
37        res[i] = *elem as u64;
38    }
39    return res;
40}
41
42fn format_stride(ndinp: &[isize]) -> [i64; 10] {
43    let mut res = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
44    for (i, elem) in ndinp.iter().enumerate() {
45        res[i] = *elem as i64;
46    }
47    return res;
48}
49
50fn type2typeid<A: 'static>() -> u8 {
51    if TypeId::of::<A>() == TypeId::of::<f64>() {
52        7
53    } else if TypeId::of::<A>() == TypeId::of::<f32>() {
54        3
55    } else if TypeId::of::<A>() == TypeId::of::<Complex<f64>>() {
56        7 + 64
57    } else if TypeId::of::<A>() == TypeId::of::<Complex<f32>>() {
58        3 + 64
59    } else if TypeId::of::<A>() == TypeId::of::<usize>() {
60        (size_of::<A>() - 1 + 32) as u8
61    } else {
62        println!("{}", std::any::type_name::<A>());
63        panic!("typeid not supported");
64    }
65}
66
67// TODO unify mutslice2arrdesc and slice2arrdesc?
68fn mutslice2arrdesc<'a, A: 'static, D: Dimension>(
69    slc: ArrayViewMut<'a, A, D>,
70) -> RustArrayDescriptor {
71    RustArrayDescriptor {
72        ndim: slc.shape().len() as u8,
73        dtype: type2typeid::<A>(),
74        shape: format_shape(slc.shape()),
75        stride: format_stride(slc.strides()),
76        data: slc.as_ptr() as *mut c_void,
77    }
78}
79
80fn slice2arrdesc<'a, A: 'static, D: Dimension>(slc: ArrayView<'a, A, D>) -> RustArrayDescriptor {
81    RustArrayDescriptor {
82        ndim: slc.shape().len() as u8,
83        dtype: type2typeid::<A>(),
84        shape: format_shape(slc.shape()),
85        stride: format_stride(slc.strides()),
86        data: slc.as_ptr() as *mut c_void,
87    }
88}
89// /Related to RustArrayDescriptor
90
91// Interface
92extern "C" {
93    fn fft_c2c_(
94        inp: &RustArrayDescriptor,
95        out: &mut RustArrayDescriptor,
96        axes: &RustArrayDescriptor,
97        forward: bool,
98        fct: f64,
99        nthreads: usize,
100    );
101}
102
103/// Complex-to-complex Fast Fourier Transform
104///
105/// This executes a Fast Fourier Transform on `inp` and stores the result in `out`.
106///
107/// # Arguments
108///
109/// * `inp` - View to the input array
110/// * `out` - Mutable view to the output array
111/// * `axes` - Specifies the axes over which the transform is carried out
112/// * `forward` - If `true`, a minus sign will be used in the exponent
113/// * `fct` - No normalization factors will be applied by default; if multiplication by a constant
114/// is desired, it can be supplied here.
115/// * `nthreads` - If the underlying array has more than one dimension, the computation will be
116/// distributed over `nthreads` threads.
117pub fn fft_c2c<A: 'static, D: ndarray::Dimension>(
118    inp: ArrayView<Complex<A>, D>,
119    out: ArrayViewMut<Complex<A>, D>,
120    axes: &Vec<usize>,
121    forward: bool,
122    fct: f64,
123    nthreads: usize,
124) {
125    let inp2 = slice2arrdesc(inp);
126    let mut out2 = mutslice2arrdesc(out);
127    let axes2 = Array1::from_vec(axes.to_vec());
128    let axes3 = slice2arrdesc(axes2.view());
129    unsafe {
130        fft_c2c_(&inp2, &mut out2, &axes3, forward, fct, nthreads);
131    }
132}
133
134/// Inplace complex-to-complex Fast Fourier Transform
135///
136/// Usage analogous to [`fft_c2c`].
137pub fn fft_c2c_inplace<A: 'static, D: ndarray::Dimension>(
138    inpout: ArrayViewMut<Complex<A>, D>,
139    axes: &Vec<usize>,
140    forward: bool,
141    fct: f64,
142    nthreads: usize,
143) {
144    let inpout2 = UnsafeCell::new(mutslice2arrdesc(inpout));
145    let axes2 = Array1::from_vec(axes.to_vec());
146    let axes3 = slice2arrdesc(axes2.view());
147    unsafe {
148        fft_c2c_(
149            &*inpout2.get(),
150            &mut *inpout2.get(),
151            &axes3,
152            forward,
153            fct,
154            nthreads,
155        );
156    }
157}
158// /Interface
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use ndarray::Array;
164    // use ndarray::prelude::*;
165
166    // TODO Write tests that go through all combinations of axes for 1d-3d, do FFT of arrays that
167    // contain only ones, check if sums are consistent
168
169    // TODO FFT back and forth with correct normalization and check that equal
170
171    #[test]
172    fn fft_test() {
173        let shape = (2, 3, 3);
174
175        let b = Array::from_elem(shape, Complex::<f64>::new(12., 0.));
176        let mut c = Array::from_elem(shape, Complex::<f64>::new(0., 0.));
177        println!("{:8.4}", b);
178        let axes = vec![0, 2];
179        fft_c2c(b.view(), c.view_mut(), &axes, true, 1., 1);
180        println!("{:8.4}", c);
181
182        fft_c2c_inplace(c.view_mut(), &axes, true, 1., 1);
183    }
184}