1use std::ops::{Deref, DerefMut};
4use std::os::raw::c_void;
5use std::slice::{from_raw_parts, from_raw_parts_mut};
6
7use num_traits::Zero;
8
9use crate::types::*;
10
11#[derive(Debug)]
15pub struct AlignedVec<T> {
16 n: usize,
17 data: *mut T,
18}
19
20pub trait AlignedAllocable: Zero + Clone + Copy + Sized {
22 #[allow(clippy::missing_safety_doc)]
24 unsafe fn alloc(n: usize) -> *mut Self;
25}
26
27impl AlignedAllocable for f64 {
28 unsafe fn alloc(n: usize) -> *mut Self {
29 ffi::fftw_alloc_real(n as u64)
30 }
31}
32
33impl AlignedAllocable for f32 {
34 unsafe fn alloc(n: usize) -> *mut Self {
35 ffi::fftwf_alloc_real(n as u64)
36 }
37}
38
39impl AlignedAllocable for c64 {
40 unsafe fn alloc(n: usize) -> *mut Self {
41 ffi::fftw_alloc_complex(n as u64) as *mut _
42 }
43}
44
45impl AlignedAllocable for c32 {
46 unsafe fn alloc(n: usize) -> *mut Self {
47 ffi::fftwf_alloc_complex(n as u64) as *mut c32
48 }
49}
50
51impl<T> AlignedVec<T> {
52 pub fn as_slice(&self) -> &[T] {
53 unsafe { from_raw_parts(self.data, self.n) }
54 }
55
56 pub fn as_slice_mut(&mut self) -> &mut [T] {
57 unsafe { from_raw_parts_mut(self.data, self.n) }
58 }
59}
60
61impl<T> Deref for AlignedVec<T> {
62 type Target = [T];
63 fn deref(&self) -> &[T] {
64 self.as_slice()
65 }
66}
67
68impl<T> DerefMut for AlignedVec<T> {
69 fn deref_mut(&mut self) -> &mut [T] {
70 self.as_slice_mut()
71 }
72}
73
74impl<T> AlignedVec<T>
75where
76 T: AlignedAllocable,
77{
78 pub fn new(n: usize) -> Self {
80 let ptr = excall! { T::alloc(n) };
81 let mut vec = AlignedVec { n, data: ptr };
82 for v in vec.iter_mut() {
83 *v = T::zero();
84 }
85 vec
86 }
87}
88
89impl<T> Drop for AlignedVec<T> {
90 fn drop(&mut self) {
91 excall! { ffi::fftw_free(self.data as *mut c_void) };
92 }
93}
94
95impl<T> Clone for AlignedVec<T>
96where
97 T: AlignedAllocable,
98{
99 fn clone(&self) -> Self {
100 let mut new_vec = Self::new(self.n);
101 new_vec.copy_from_slice(self);
102 new_vec
103 }
104}
105
106impl<T> PartialEq for AlignedVec<T>
107where
108 T: PartialEq,
109{
110 fn eq(&self, other: &Self) -> bool {
111 if self.len() != other.len() {
112 return false;
113 }
114 self.iter().zip(other.iter()).all(|(a, b)| a == b)
115 }
116}
117
118unsafe impl<T: Send> Send for AlignedVec<T> {}
119unsafe impl<T: Sync> Sync for AlignedVec<T> {}
120
121pub type Alignment = i32;
122
123pub fn alignment_of<T>(a: &[T]) -> Alignment {
131 unsafe { ffi::fftw_alignment_of(a.as_ptr() as *mut _) }
132}
133
134#[cfg(feature = "serialize")]
135mod serde {
136 use std::fmt;
137 use std::marker::PhantomData;
138
139 use serde::de::{Error, SeqAccess, Visitor};
140 use serde::ser::{Serialize, SerializeSeq, Serializer};
141 use serde::{Deserialize, Deserializer};
142
143 use crate::array::AlignedAllocable;
144
145 use super::AlignedVec;
146
147 impl<T> Serialize for AlignedVec<T>
148 where
149 T: Serialize,
150 {
151 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
152 where
153 S: Serializer,
154 {
155 let mut seq = serializer.serialize_seq(Some(self.len()))?;
156 for e in self.iter() {
157 seq.serialize_element(e)?;
158 }
159 seq.end()
160 }
161 }
162
163 struct AlignedVecVisitor<T>(PhantomData<T>);
164
165 impl<'de, T> Visitor<'de> for AlignedVecVisitor<T>
166 where
167 T: AlignedAllocable + Deserialize<'de>,
168 {
169 type Value = AlignedVec<T>;
170
171 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
172 write!(formatter, "AlignedVec<T>")
173 }
174
175 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, <A as SeqAccess<'de>>::Error>
176 where
177 A: SeqAccess<'de>,
178 {
179 let mut seq = seq;
180 let mut output = AlignedVec::new(seq.size_hint().ok_or(A::Error::custom(
181 "Failed to retrieve the size of the AlignedVec.",
182 ))?);
183 for val in output.iter_mut() {
184 *val = seq
185 .next_element()?
186 .ok_or(A::Error::custom("Failed to retrieve the next element"))?
187 }
188 Ok(output)
189 }
190 }
191
192 impl<'de, T> Deserialize<'de> for AlignedVec<T>
193 where
194 T: AlignedAllocable + Deserialize<'de>,
195 {
196 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
197 where
198 D: Deserializer<'de>,
199 {
200 deserializer.deserialize_seq(AlignedVecVisitor(PhantomData))
201 }
202 }
203
204 #[cfg(test)]
205 mod test {
206 use serde_test::{assert_tokens, Token};
207
208 use crate::types::{c32, c64};
209
210 use super::AlignedVec;
211
212 #[test]
213 fn test_ser_de_empty_c32() {
214 let vec: AlignedVec<c32> = AlignedVec::new(0);
215
216 assert_tokens(&vec, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
217 }
218
219 #[test]
220 fn test_ser_de_empty_c64() {
221 let vec: AlignedVec<c64> = AlignedVec::new(0);
222
223 assert_tokens(&vec, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
224 }
225
226 #[test]
227 fn test_ser_de_c32() {
228 let mut vec = AlignedVec::new(3);
229 vec[0] = c32::new(1., 2.);
230 vec[1] = c32::new(3., 4.);
231 vec[2] = c32::new(5., 6.);
232
233 assert_tokens(
234 &vec,
235 &[
236 Token::Seq { len: Some(3) },
237 Token::Tuple { len: 2 },
238 Token::F32(1.),
239 Token::F32(2.),
240 Token::TupleEnd,
241 Token::Tuple { len: 2 },
242 Token::F32(3.),
243 Token::F32(4.),
244 Token::TupleEnd,
245 Token::Tuple { len: 2 },
246 Token::F32(5.),
247 Token::F32(6.),
248 Token::TupleEnd,
249 Token::SeqEnd,
250 ],
251 );
252 }
253
254 #[test]
255 fn test_ser_de_c64() {
256 let mut vec = AlignedVec::new(3);
257 vec[0] = c64::new(1., 2.);
258 vec[1] = c64::new(3., 4.);
259 vec[2] = c64::new(5., 6.);
260
261 assert_tokens(
262 &vec,
263 &[
264 Token::Seq { len: Some(3) },
265 Token::Tuple { len: 2 },
266 Token::F64(1.),
267 Token::F64(2.),
268 Token::TupleEnd,
269 Token::Tuple { len: 2 },
270 Token::F64(3.),
271 Token::F64(4.),
272 Token::TupleEnd,
273 Token::Tuple { len: 2 },
274 Token::F64(5.),
275 Token::F64(6.),
276 Token::TupleEnd,
277 Token::SeqEnd,
278 ],
279 );
280 }
281 }
282}