1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ops::{Add, Mul, Sub},
5 vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21 #[deref]
22 #[deref_mut]
23 #[as_ref(forward)]
24 #[as_mut(forward)]
25 pub(super) data: Box<[T]>,
26 #[into_iterator(ignore)]
27 pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33 pub(super) fn new(data: Box<[T]>) -> Self {
34 Self {
35 data,
36 _len: PhantomData,
37 }
38 }
39}
40
41impl<T: Pod, M: Positive> HeapArray<T, M> {
43 pub fn into_box_bytes(self) -> BoxBytes {
44 box_bytes_of(self.data)
45 }
46
47 pub fn from_box_bytes(buf: BoxBytes) -> Self {
48 Self {
49 data: from_box_bytes(buf),
50 _len: PhantomData,
51 }
52 }
53}
54
55impl<T: Sized, M: Positive> HeapArray<T, M> {
56 pub fn map<F, U>(self, f: F) -> HeapArray<U, M>
57 where
58 F: FnMut(T) -> U,
59 {
60 self.into_iter().map(f).collect()
61 }
62}
63
64impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
65 pub fn from_single_value(val: T) -> Self {
66 Self {
67 data: vec![val].into_boxed_slice(),
68 _len: PhantomData,
69 }
70 }
71}
72
73impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
74 fn default() -> Self {
75 Self {
76 data: (0..M::USIZE).map(|_| T::default()).collect(),
77 _len: PhantomData,
78 }
79 }
80}
81
82impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
85 .field("data", &self.data)
86 .finish()
87 }
88}
89
90impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
91 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
92 self.data.serialize(serializer)
93 }
94}
95
96impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
97 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
98 let data = Box::<[T]>::deserialize(deserializer)?;
99
100 if data.len() != M::USIZE {
101 return Err(serde::de::Error::custom(format!(
102 "Expected array of length {}, got {}",
103 M::USIZE,
104 data.len()
105 )));
106 }
107
108 Ok(Self {
109 data,
110 _len: PhantomData,
111 })
112 }
113}
114
115impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
116 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
117 let data = iter.into_iter().collect::<Box<_>>();
118 assert_eq!(data.len(), M::USIZE,);
119 Self {
120 data,
121 _len: PhantomData,
122 }
123 }
124}
125
126impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
127 type Item = T;
128 type Iter = rayon::vec::IntoIter<T>;
129
130 fn into_par_iter(self) -> Self::Iter {
131 self.data.into_par_iter()
132 }
133}
134
135impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
140 pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
141 where
142 M1: Positive,
143 M2: Positive + Add<M1, Output = M>,
144 {
145 let (m1, m2) = self.split_at(M1::USIZE);
146 (
147 HeapArray::<T, M1> {
148 data: m1.into(),
149 _len: PhantomData,
150 },
151 HeapArray::<T, M2> {
152 data: m2.into(),
153 _len: PhantomData,
154 },
155 )
156 }
157
158 pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
159 where
160 MDiv2: Positive + Mul<U2, Output = M>,
161 {
162 let (m1, m2) = self.split_at(MDiv2::USIZE);
163 (
164 HeapArray::<T, MDiv2> {
165 data: m1.into(),
166 _len: PhantomData,
167 },
168 HeapArray::<T, MDiv2> {
169 data: m2.into(),
170 _len: PhantomData,
171 },
172 )
173 }
174
175 pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
176 where
177 M: Mul<U2, Output: Positive>,
178 {
179 let mut vec = this.data.into_vec();
180 vec.extend(other.data.into_vec());
181 HeapArray::<T, Prod<M, U2>> {
182 data: vec.into_boxed_slice(),
183 _len: PhantomData,
184 }
185 }
186
187 pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
188 where
189 M1: Positive,
190 M2: Positive + Add<M1>,
191 M3: Positive + Add<Sum<M2, M1>, Output = M>,
192 {
193 let (m1, m_rest) = self.split_at(M1::USIZE);
194 let (m2, m3) = m_rest.split_at(M2::USIZE);
195 (
196 HeapArray::<T, M1> {
197 data: m1.into(),
198 _len: PhantomData,
199 },
200 HeapArray::<T, M2> {
201 data: m2.into(),
202 _len: PhantomData,
203 },
204 HeapArray::<T, M3> {
205 data: m3.into(),
206 _len: PhantomData,
207 },
208 )
209 }
210
211 pub fn split_thirds<MDiv3>(
212 &self,
213 ) -> (
214 HeapArray<T, MDiv3>,
215 HeapArray<T, MDiv3>,
216 HeapArray<T, MDiv3>,
217 )
218 where
219 MDiv3: Positive + Mul<U3, Output = M>,
220 {
221 let (m1, m_rest) = self.split_at(MDiv3::USIZE);
222 let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
223 (
224 HeapArray::<T, MDiv3> {
225 data: m1.into(),
226 _len: PhantomData,
227 },
228 HeapArray::<T, MDiv3> {
229 data: m2.into(),
230 _len: PhantomData,
231 },
232 HeapArray::<T, MDiv3> {
233 data: m3.into(),
234 _len: PhantomData,
235 },
236 )
237 }
238
239 pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
240 where
241 M: Mul<U3, Output: Positive>,
242 {
243 let mut vec = first.data.into_vec();
244 vec.extend(second.data.into_vec());
245 vec.extend(third.data.into_vec());
246 HeapArray::<T, Prod<M, U3>> {
247 data: vec.into_boxed_slice(),
248 _len: PhantomData,
249 }
250 }
251}
252
253pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
254 pub HeapArray<T1, M>,
255 pub HeapArray<T2, M>,
256);
257
258impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
259 fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
260 let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
261
262 assert_eq!(data1.len(), M::USIZE);
263 assert_eq!(data2.len(), M::USIZE);
264 HeapArrayTuple(
265 HeapArray::<T1, M> {
266 data: data1.into_boxed_slice(),
267 _len: PhantomData,
268 },
269 HeapArray::<T2, M> {
270 data: data2.into_boxed_slice(),
271 _len: PhantomData,
272 },
273 )
274 }
275}
276
277impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
278 type Error = PrimitiveError;
279
280 fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
281 if data.len() == M::USIZE {
282 Ok(Self {
283 data: data.into_boxed_slice(),
284 _len: PhantomData,
285 })
286 } else {
287 Err(PrimitiveError::InvalidSize(M::USIZE, data.len()))
288 }
289 }
290}
291
292impl<T: Sized> From<T> for HeapArray<T, U1> {
293 fn from(element: T) -> Self {
294 Self {
295 data: Box::new([element]),
296 _len: PhantomData,
297 }
298 }
299}
300
301impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
302 fn from(array: HeapArray<T, M>) -> Self {
303 array.data.into_vec()
304 }
305}
306
307impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
308 fn from(array: Array<T, M>) -> Self {
309 Self {
310 data: array.to_vec().into_boxed_slice(),
311 _len: PhantomData,
312 }
313 }
314}
315
316impl<T: Sized, M: Positive> HeapArray<T, M> {
317 pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
318 where
319 M: Sub<N, Output: Positive>,
320 {
321 let mut vec = self.data.into_vec();
322 let last_n = vec.split_off(M::USIZE - N::USIZE);
323
324 (
325 HeapArray {
326 data: vec.into_boxed_slice(),
327 _len: PhantomData,
328 },
329 HeapArray {
330 data: last_n.into_boxed_slice(),
331 _len: PhantomData,
332 },
333 )
334 }
335
336 pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
337 Self {
338 data: (0..M::USIZE).map(f).collect::<Box<_>>(),
339 _len: PhantomData,
340 }
341 }
342
343 pub fn from_constant(c: T) -> Self
344 where
345 T: Copy,
346 {
347 Self {
348 data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
349 _len: PhantomData,
350 }
351 }
352}
353
354impl<T: Sized, M: Positive> Display for HeapArray<T, M>
355where
356 T: Display,
357{
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "[")?;
360 for (i, item) in self.data.iter().enumerate() {
361 if i != 0 {
362 write!(f, ", ")?;
363 }
364 write!(f, "{item}")?;
365 }
366 write!(f, "]")
367 }
368}
369
370#[cfg(test)]
371pub mod tests {
372 use hybrid_array::sizes::{U2, U3, U6};
373 use typenum::{U1, U4};
374
375 use super::*;
376
377 #[test]
378 fn test_heap_array() {
379 let array = HeapArray::<_, U3>::from_fn(|i| i);
380 assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
381 }
382
383 #[test]
384 fn test_default() {
385 let array = HeapArray::<usize, U3>::default();
386 assert_eq!(array.len(), 3);
387 }
388
389 #[test]
390 fn test_heap_array_split_last() {
391 let array = HeapArray::<_, U6>::from_fn(|i| i);
392 let (first, last) = array.split_last::<U2>();
393 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
394 assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
395 }
396
397 #[test]
398 fn test_heap_array_from_array() {
399 let array = Array::<_, U3>::from_fn(|i| i);
400 let heap_array = HeapArray::<_, U3>::from(array);
401 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
402 }
403
404 #[test]
405 fn test_heap_array_from_vec() {
406 let vec = vec![0, 1, 2];
407 let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
408 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
409
410 let vec = vec![0, 1];
411 let heap_array = HeapArray::<_, U3>::try_from(vec);
412 assert!(heap_array.is_err());
413 }
414
415 #[test]
416 fn test_heap_array_from_iter() {
417 let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
418 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
419 }
420
421 #[test]
422 #[should_panic]
423 fn test_heap_array_from_iter_wrong_size() {
424 HeapArray::<_, U2>::from_iter(0..3);
425 }
426
427 #[test]
428 fn test_heap_array_deserialize() {
429 let array = HeapArray::<usize, U6>::from_fn(|i| i);
430 let serialized = bincode::serialize(&array).unwrap();
431 bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
432
433 let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
434 assert!(wrong_deserialize.is_err());
435 }
436
437 #[test]
438 fn test_heap_array_split() {
439 let array = HeapArray::<_, U6>::from_fn(|i| i);
440 let (first, second) = array.split::<U4, U2>();
441 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
442 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
443
444 let (first, second) = array.split_halves::<U3>();
445 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
446 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
447
448 }
452
453 #[test]
454 fn test_heap_array_split3() {
455 let array = HeapArray::<_, U6>::from_fn(|i| i);
456 let (first, second, third) = array.split3::<U3, U2, U1>();
457 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
458 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
459 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
460
461 let (first, second, third) = array.split_thirds::<U2>();
462 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
463 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
464 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
465
466 }
470}