concrete_fftw/
array.rs

1//! Array with SIMD alignment
2
3use std::ops::{Deref, DerefMut};
4use std::os::raw::c_void;
5use std::slice::{from_raw_parts, from_raw_parts_mut};
6
7use num_traits::Zero;
8
9use crate::types::*;
10
11/// A RAII-wrapper of `fftw_alloc` and `fftw_free` with the [SIMD alignment].
12///
13/// [SIMD alignment]: http://www.fftw.org/fftw3_doc/SIMD-alignment-and-fftw_005fmalloc.html
14#[derive(Debug)]
15pub struct AlignedVec<T> {
16    n: usize,
17    data: *mut T,
18}
19
20/// Allocate SIMD-aligned memory of Real/Complex type
21pub trait AlignedAllocable: Zero + Clone + Copy + Sized {
22    /// Allocate SIMD-aligned memory
23    #[allow(clippy::missing_safety_doc)]
24    unsafe fn alloc(n: usize) -> *mut Self;
25}
26
27impl AlignedAllocable for f64 {
28    unsafe fn alloc(n: usize) -> *mut Self {
29        ffi::fftw_alloc_real(n as u64)
30    }
31}
32
33impl AlignedAllocable for f32 {
34    unsafe fn alloc(n: usize) -> *mut Self {
35        ffi::fftwf_alloc_real(n as u64)
36    }
37}
38
39impl AlignedAllocable for c64 {
40    unsafe fn alloc(n: usize) -> *mut Self {
41        ffi::fftw_alloc_complex(n as u64) as *mut _
42    }
43}
44
45impl AlignedAllocable for c32 {
46    unsafe fn alloc(n: usize) -> *mut Self {
47        ffi::fftwf_alloc_complex(n as u64) as *mut c32
48    }
49}
50
51impl<T> AlignedVec<T> {
52    pub fn as_slice(&self) -> &[T] {
53        unsafe { from_raw_parts(self.data, self.n) }
54    }
55
56    pub fn as_slice_mut(&mut self) -> &mut [T] {
57        unsafe { from_raw_parts_mut(self.data, self.n) }
58    }
59}
60
61impl<T> Deref for AlignedVec<T> {
62    type Target = [T];
63    fn deref(&self) -> &[T] {
64        self.as_slice()
65    }
66}
67
68impl<T> DerefMut for AlignedVec<T> {
69    fn deref_mut(&mut self) -> &mut [T] {
70        self.as_slice_mut()
71    }
72}
73
74impl<T> AlignedVec<T>
75where
76    T: AlignedAllocable,
77{
78    /// Create array with `fftw_malloc` (`fftw_free` will be automatically called by `Drop` trait)
79    pub fn new(n: usize) -> Self {
80        let ptr = excall! { T::alloc(n) };
81        let mut vec = AlignedVec { n, data: ptr };
82        for v in vec.iter_mut() {
83            *v = T::zero();
84        }
85        vec
86    }
87}
88
89impl<T> Drop for AlignedVec<T> {
90    fn drop(&mut self) {
91        excall! { ffi::fftw_free(self.data as *mut c_void) };
92    }
93}
94
95impl<T> Clone for AlignedVec<T>
96where
97    T: AlignedAllocable,
98{
99    fn clone(&self) -> Self {
100        let mut new_vec = Self::new(self.n);
101        new_vec.copy_from_slice(self);
102        new_vec
103    }
104}
105
106impl<T> PartialEq for AlignedVec<T>
107where
108    T: PartialEq,
109{
110    fn eq(&self, other: &Self) -> bool {
111        if self.len() != other.len() {
112            return false;
113        }
114        self.iter().zip(other.iter()).all(|(a, b)| a == b)
115    }
116}
117
118unsafe impl<T: Send> Send for AlignedVec<T> {}
119unsafe impl<T: Sync> Sync for AlignedVec<T> {}
120
121pub type Alignment = i32;
122
123/// Check the alignment of slice
124///
125/// ```
126/// # use concrete_fftw::array::*;
127/// let a = AlignedVec::<f32>::new(123);
128/// assert_eq!(alignment_of(&a), 0);  // aligned
129/// ```
130pub fn alignment_of<T>(a: &[T]) -> Alignment {
131    unsafe { ffi::fftw_alignment_of(a.as_ptr() as *mut _) }
132}
133
134#[cfg(feature = "serialize")]
135mod serde {
136    use std::fmt;
137    use std::marker::PhantomData;
138
139    use serde::de::{Error, SeqAccess, Visitor};
140    use serde::ser::{Serialize, SerializeSeq, Serializer};
141    use serde::{Deserialize, Deserializer};
142
143    use crate::array::AlignedAllocable;
144
145    use super::AlignedVec;
146
147    impl<T> Serialize for AlignedVec<T>
148    where
149        T: Serialize,
150    {
151        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
152        where
153            S: Serializer,
154        {
155            let mut seq = serializer.serialize_seq(Some(self.len()))?;
156            for e in self.iter() {
157                seq.serialize_element(e)?;
158            }
159            seq.end()
160        }
161    }
162
163    struct AlignedVecVisitor<T>(PhantomData<T>);
164
165    impl<'de, T> Visitor<'de> for AlignedVecVisitor<T>
166    where
167        T: AlignedAllocable + Deserialize<'de>,
168    {
169        type Value = AlignedVec<T>;
170
171        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
172            write!(formatter, "AlignedVec<T>")
173        }
174
175        fn visit_seq<A>(self, seq: A) -> Result<Self::Value, <A as SeqAccess<'de>>::Error>
176        where
177            A: SeqAccess<'de>,
178        {
179            let mut seq = seq;
180            let mut output = AlignedVec::new(seq.size_hint().ok_or(A::Error::custom(
181                "Failed to retrieve the size of the AlignedVec.",
182            ))?);
183            for val in output.iter_mut() {
184                *val = seq
185                    .next_element()?
186                    .ok_or(A::Error::custom("Failed to retrieve the next element"))?
187            }
188            Ok(output)
189        }
190    }
191
192    impl<'de, T> Deserialize<'de> for AlignedVec<T>
193    where
194        T: AlignedAllocable + Deserialize<'de>,
195    {
196        fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
197        where
198            D: Deserializer<'de>,
199        {
200            deserializer.deserialize_seq(AlignedVecVisitor(PhantomData))
201        }
202    }
203
204    #[cfg(test)]
205    mod test {
206        use serde_test::{assert_tokens, Token};
207
208        use crate::types::{c32, c64};
209
210        use super::AlignedVec;
211
212        #[test]
213        fn test_ser_de_empty_c32() {
214            let vec: AlignedVec<c32> = AlignedVec::new(0);
215
216            assert_tokens(&vec, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
217        }
218
219        #[test]
220        fn test_ser_de_empty_c64() {
221            let vec: AlignedVec<c64> = AlignedVec::new(0);
222
223            assert_tokens(&vec, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
224        }
225
226        #[test]
227        fn test_ser_de_c32() {
228            let mut vec = AlignedVec::new(3);
229            vec[0] = c32::new(1., 2.);
230            vec[1] = c32::new(3., 4.);
231            vec[2] = c32::new(5., 6.);
232
233            assert_tokens(
234                &vec,
235                &[
236                    Token::Seq { len: Some(3) },
237                    Token::Tuple { len: 2 },
238                    Token::F32(1.),
239                    Token::F32(2.),
240                    Token::TupleEnd,
241                    Token::Tuple { len: 2 },
242                    Token::F32(3.),
243                    Token::F32(4.),
244                    Token::TupleEnd,
245                    Token::Tuple { len: 2 },
246                    Token::F32(5.),
247                    Token::F32(6.),
248                    Token::TupleEnd,
249                    Token::SeqEnd,
250                ],
251            );
252        }
253
254        #[test]
255        fn test_ser_de_c64() {
256            let mut vec = AlignedVec::new(3);
257            vec[0] = c64::new(1., 2.);
258            vec[1] = c64::new(3., 4.);
259            vec[2] = c64::new(5., 6.);
260
261            assert_tokens(
262                &vec,
263                &[
264                    Token::Seq { len: Some(3) },
265                    Token::Tuple { len: 2 },
266                    Token::F64(1.),
267                    Token::F64(2.),
268                    Token::TupleEnd,
269                    Token::Tuple { len: 2 },
270                    Token::F64(3.),
271                    Token::F64(4.),
272                    Token::TupleEnd,
273                    Token::Tuple { len: 2 },
274                    Token::F64(5.),
275                    Token::F64(6.),
276                    Token::TupleEnd,
277                    Token::SeqEnd,
278                ],
279            );
280        }
281    }
282}