1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ops::{Add, Mul, Sub},
5 vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21 #[deref]
22 #[deref_mut]
23 #[as_ref(forward)]
24 #[as_mut(forward)]
25 pub(super) data: Box<[T]>,
26 #[into_iterator(ignore)]
27 pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33 pub(super) fn new(data: Box<[T]>) -> Self {
34 Self {
35 data,
36 _len: PhantomData,
37 }
38 }
39}
40
41impl<T: Pod, M: Positive> HeapArray<T, M> {
43 pub fn into_box_bytes(self) -> BoxBytes {
44 box_bytes_of(self.data)
45 }
46
47 pub fn from_box_bytes(buf: BoxBytes) -> Self {
48 Self {
49 data: from_box_bytes(buf),
50 _len: PhantomData,
51 }
52 }
53}
54
55impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
56 pub fn from_single_value(val: T) -> Self {
57 Self {
58 data: vec![val].into_boxed_slice(),
59 _len: PhantomData,
60 }
61 }
62}
63
64impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
65 fn default() -> Self {
66 Self {
67 data: (0..M::USIZE).map(|_| T::default()).collect(),
68 _len: PhantomData,
69 }
70 }
71}
72
73impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
76 .field("data", &self.data)
77 .finish()
78 }
79}
80
81impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
82 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
83 self.data.serialize(serializer)
84 }
85}
86
87impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
88 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
89 let data = Box::<[T]>::deserialize(deserializer)?;
90
91 if data.len() != M::USIZE {
92 return Err(serde::de::Error::custom(format!(
93 "Expected array of length {}, got {}",
94 M::USIZE,
95 data.len()
96 )));
97 }
98
99 Ok(Self {
100 data,
101 _len: PhantomData,
102 })
103 }
104}
105
106impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
107 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
108 let data = iter.into_iter().collect::<Box<_>>();
109 assert_eq!(data.len(), M::USIZE,);
110 Self {
111 data,
112 _len: PhantomData,
113 }
114 }
115}
116
117impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
118 type Item = T;
119 type Iter = rayon::vec::IntoIter<T>;
120
121 fn into_par_iter(self) -> Self::Iter {
122 self.data.into_par_iter()
123 }
124}
125
126impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
131 pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
132 where
133 M1: Positive,
134 M2: Positive + Add<M1, Output = M>,
135 {
136 let (m1, m2) = self.split_at(M1::USIZE);
137 (
138 HeapArray::<T, M1> {
139 data: m1.into(),
140 _len: PhantomData,
141 },
142 HeapArray::<T, M2> {
143 data: m2.into(),
144 _len: PhantomData,
145 },
146 )
147 }
148
149 pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
150 where
151 MDiv2: Positive + Mul<U2, Output = M>,
152 {
153 let (m1, m2) = self.split_at(MDiv2::USIZE);
154 (
155 HeapArray::<T, MDiv2> {
156 data: m1.into(),
157 _len: PhantomData,
158 },
159 HeapArray::<T, MDiv2> {
160 data: m2.into(),
161 _len: PhantomData,
162 },
163 )
164 }
165
166 pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
167 where
168 M: Mul<U2, Output: Positive>,
169 {
170 let mut vec = this.data.into_vec();
171 vec.extend(other.data.into_vec());
172 HeapArray::<T, Prod<M, U2>> {
173 data: vec.into_boxed_slice(),
174 _len: PhantomData,
175 }
176 }
177
178 pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
179 where
180 M1: Positive,
181 M2: Positive + Add<M1>,
182 M3: Positive + Add<Sum<M2, M1>, Output = M>,
183 {
184 let (m1, m_rest) = self.split_at(M1::USIZE);
185 let (m2, m3) = m_rest.split_at(M2::USIZE);
186 (
187 HeapArray::<T, M1> {
188 data: m1.into(),
189 _len: PhantomData,
190 },
191 HeapArray::<T, M2> {
192 data: m2.into(),
193 _len: PhantomData,
194 },
195 HeapArray::<T, M3> {
196 data: m3.into(),
197 _len: PhantomData,
198 },
199 )
200 }
201
202 pub fn split_thirds<MDiv3>(
203 &self,
204 ) -> (
205 HeapArray<T, MDiv3>,
206 HeapArray<T, MDiv3>,
207 HeapArray<T, MDiv3>,
208 )
209 where
210 MDiv3: Positive + Mul<U3, Output = M>,
211 {
212 let (m1, m_rest) = self.split_at(MDiv3::USIZE);
213 let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
214 (
215 HeapArray::<T, MDiv3> {
216 data: m1.into(),
217 _len: PhantomData,
218 },
219 HeapArray::<T, MDiv3> {
220 data: m2.into(),
221 _len: PhantomData,
222 },
223 HeapArray::<T, MDiv3> {
224 data: m3.into(),
225 _len: PhantomData,
226 },
227 )
228 }
229
230 pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
231 where
232 M: Mul<U3, Output: Positive>,
233 {
234 let mut vec = first.data.into_vec();
235 vec.extend(second.data.into_vec());
236 vec.extend(third.data.into_vec());
237 HeapArray::<T, Prod<M, U3>> {
238 data: vec.into_boxed_slice(),
239 _len: PhantomData,
240 }
241 }
242}
243
244pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
245 pub HeapArray<T1, M>,
246 pub HeapArray<T2, M>,
247);
248
249impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
250 fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
251 let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
252
253 assert_eq!(data1.len(), M::USIZE);
254 assert_eq!(data2.len(), M::USIZE);
255 HeapArrayTuple(
256 HeapArray::<T1, M> {
257 data: data1.into_boxed_slice(),
258 _len: PhantomData,
259 },
260 HeapArray::<T2, M> {
261 data: data2.into_boxed_slice(),
262 _len: PhantomData,
263 },
264 )
265 }
266}
267
268impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
269 type Error = PrimitiveError;
270
271 fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
272 if data.len() == M::USIZE {
273 Ok(Self {
274 data: data.into_boxed_slice(),
275 _len: PhantomData,
276 })
277 } else {
278 Err(PrimitiveError::SizeError(M::USIZE, data.len()))
279 }
280 }
281}
282
283impl<T: Sized> From<T> for HeapArray<T, U1> {
284 fn from(element: T) -> Self {
285 Self {
286 data: Box::new([element]),
287 _len: PhantomData,
288 }
289 }
290}
291
292impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
293 fn from(array: HeapArray<T, M>) -> Self {
294 array.data.into_vec()
295 }
296}
297
298impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
299 fn from(array: Array<T, M>) -> Self {
300 Self {
301 data: array.to_vec().into_boxed_slice(),
302 _len: PhantomData,
303 }
304 }
305}
306
307impl<T: Sized, M: Positive> HeapArray<T, M> {
308 pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
309 where
310 M: Sub<N, Output: Positive>,
311 {
312 let mut vec = self.data.into_vec();
313 let last_n = vec.split_off(M::USIZE - N::USIZE);
314
315 (
316 HeapArray {
317 data: vec.into_boxed_slice(),
318 _len: PhantomData,
319 },
320 HeapArray {
321 data: last_n.into_boxed_slice(),
322 _len: PhantomData,
323 },
324 )
325 }
326
327 pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
328 Self {
329 data: (0..M::USIZE).map(f).collect::<Box<_>>(),
330 _len: PhantomData,
331 }
332 }
333
334 pub fn from_constant(c: T) -> Self
335 where
336 T: Copy,
337 {
338 Self {
339 data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
340 _len: PhantomData,
341 }
342 }
343}
344
345impl<T: Sized, M: Positive> Display for HeapArray<T, M>
346where
347 T: Display,
348{
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 write!(f, "[")?;
351 for (i, item) in self.data.iter().enumerate() {
352 if i != 0 {
353 write!(f, ", ")?;
354 }
355 write!(f, "{item}")?;
356 }
357 write!(f, "]")
358 }
359}
360
361#[cfg(test)]
362pub mod tests {
363 use hybrid_array::sizes::{U2, U3, U6};
364 use typenum::{U1, U4};
365
366 use super::*;
367
368 #[test]
369 fn test_heap_array() {
370 let array = HeapArray::<_, U3>::from_fn(|i| i);
371 assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
372 }
373
374 #[test]
375 fn test_default() {
376 let array = HeapArray::<usize, U3>::default();
377 assert_eq!(array.len(), 3);
378 }
379
380 #[test]
381 fn test_heap_array_split_last() {
382 let array = HeapArray::<_, U6>::from_fn(|i| i);
383 let (first, last) = array.split_last::<U2>();
384 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
385 assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
386 }
387
388 #[test]
389 fn test_heap_array_from_array() {
390 let array = Array::<_, U3>::from_fn(|i| i);
391 let heap_array = HeapArray::<_, U3>::from(array);
392 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
393 }
394
395 #[test]
396 fn test_heap_array_from_vec() {
397 let vec = vec![0, 1, 2];
398 let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
399 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
400
401 let vec = vec![0, 1];
402 let heap_array = HeapArray::<_, U3>::try_from(vec);
403 assert!(heap_array.is_err());
404 }
405
406 #[test]
407 fn test_heap_array_from_iter() {
408 let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
409 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
410 }
411
412 #[test]
413 #[should_panic]
414 fn test_heap_array_from_iter_wrong_size() {
415 HeapArray::<_, U2>::from_iter(0..3);
416 }
417
418 #[test]
419 fn test_heap_array_deserialize() {
420 let array = HeapArray::<usize, U6>::from_fn(|i| i);
421 let serialized = bincode::serialize(&array).unwrap();
422 bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
423
424 let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
425 assert!(wrong_deserialize.is_err());
426 }
427
428 #[test]
429 fn test_heap_array_split() {
430 let array = HeapArray::<_, U6>::from_fn(|i| i);
431 let (first, second) = array.split::<U4, U2>();
432 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
433 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
434
435 let (first, second) = array.split_halves::<U3>();
436 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
437 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
438
439 }
443
444 #[test]
445 fn test_heap_array_split3() {
446 let array = HeapArray::<_, U6>::from_fn(|i| i);
447 let (first, second, third) = array.split3::<U3, U2, U1>();
448 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
449 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
450 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
451
452 let (first, second, third) = array.split_thirds::<U2>();
453 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
454 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
455 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
456
457 }
461}