1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ops::{Add, Mul, Sub},
5 vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21 #[deref]
22 #[deref_mut]
23 #[as_ref(forward)]
24 #[as_mut(forward)]
25 pub(super) data: Box<[T]>,
26 #[into_iterator(ignore)]
27 pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33 pub(super) fn new(data: Box<[T]>) -> Self {
34 Self {
35 data,
36 _len: PhantomData,
37 }
38 }
39}
40
41impl<T: Pod, M: Positive> HeapArray<T, M> {
43 pub fn into_box_bytes(self) -> BoxBytes {
44 box_bytes_of(self.data)
45 }
46
47 pub fn from_box_bytes(buf: BoxBytes) -> Self {
48 Self {
49 data: from_box_bytes(buf),
50 _len: PhantomData,
51 }
52 }
53}
54
55impl<T: Sized, M: Positive> HeapArray<T, M> {
56 pub fn map<F, U>(self, f: F) -> HeapArray<U, M>
57 where
58 F: FnMut(T) -> U,
59 {
60 self.into_iter().map(f).collect()
61 }
62}
63
64impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
65 pub fn from_single_value(val: T) -> Self {
66 Self {
67 data: vec![val].into_boxed_slice(),
68 _len: PhantomData,
69 }
70 }
71}
72
73impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
74 fn default() -> Self {
75 Self {
76 data: (0..M::USIZE).map(|_| T::default()).collect(),
77 _len: PhantomData,
78 }
79 }
80}
81
82impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
85 .field("data", &self.data)
86 .finish()
87 }
88}
89
90impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
91 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
92 use serde::ser::SerializeTuple;
93 let mut tuple = serializer.serialize_tuple(M::USIZE)?;
94 for element in self.data.iter() {
95 tuple.serialize_element(element)?;
96 }
97 tuple.end()
98 }
99}
100
101impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
102 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
103 struct HeapArrayVisitor<T, M: Positive> {
104 _phantom: PhantomData<(T, M)>,
105 }
106
107 impl<'de, T: Deserialize<'de>, M: Positive> serde::de::Visitor<'de> for HeapArrayVisitor<T, M> {
108 type Value = HeapArray<T, M>;
109
110 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
111 write!(formatter, "a tuple of {} elements", M::USIZE)
112 }
113
114 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
115 where
116 A: serde::de::SeqAccess<'de>,
117 {
118 let mut data = Vec::with_capacity(M::USIZE);
119 for i in 0..M::USIZE {
120 let element = seq
121 .next_element()?
122 .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
123 data.push(element);
124 }
125
126 Ok(HeapArray {
127 data: data.into_boxed_slice(),
128 _len: PhantomData,
129 })
130 }
131 }
132
133 deserializer.deserialize_tuple(
134 M::USIZE,
135 HeapArrayVisitor {
136 _phantom: PhantomData,
137 },
138 )
139 }
140}
141
142impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
143 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
144 let data = iter.into_iter().collect::<Box<_>>();
145 assert_eq!(data.len(), M::USIZE,);
146 Self {
147 data,
148 _len: PhantomData,
149 }
150 }
151}
152
153impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
154 type Item = T;
155 type Iter = rayon::vec::IntoIter<T>;
156
157 fn into_par_iter(self) -> Self::Iter {
158 self.data.into_par_iter()
159 }
160}
161
162impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
167 pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
168 where
169 M1: Positive,
170 M2: Positive + Add<M1, Output = M>,
171 {
172 let (m1, m2) = self.split_at(M1::USIZE);
173 (
174 HeapArray::<T, M1> {
175 data: m1.into(),
176 _len: PhantomData,
177 },
178 HeapArray::<T, M2> {
179 data: m2.into(),
180 _len: PhantomData,
181 },
182 )
183 }
184
185 pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
186 where
187 MDiv2: Positive + Mul<U2, Output = M>,
188 {
189 let (m1, m2) = self.split_at(MDiv2::USIZE);
190 (
191 HeapArray::<T, MDiv2> {
192 data: m1.into(),
193 _len: PhantomData,
194 },
195 HeapArray::<T, MDiv2> {
196 data: m2.into(),
197 _len: PhantomData,
198 },
199 )
200 }
201
202 pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
203 where
204 M: Mul<U2, Output: Positive>,
205 {
206 let mut vec = this.data.into_vec();
207 vec.extend(other.data.into_vec());
208 HeapArray::<T, Prod<M, U2>> {
209 data: vec.into_boxed_slice(),
210 _len: PhantomData,
211 }
212 }
213
214 pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
215 where
216 M1: Positive,
217 M2: Positive + Add<M1>,
218 M3: Positive + Add<Sum<M2, M1>, Output = M>,
219 {
220 let (m1, m_rest) = self.split_at(M1::USIZE);
221 let (m2, m3) = m_rest.split_at(M2::USIZE);
222 (
223 HeapArray::<T, M1> {
224 data: m1.into(),
225 _len: PhantomData,
226 },
227 HeapArray::<T, M2> {
228 data: m2.into(),
229 _len: PhantomData,
230 },
231 HeapArray::<T, M3> {
232 data: m3.into(),
233 _len: PhantomData,
234 },
235 )
236 }
237
238 pub fn split_thirds<MDiv3>(
239 &self,
240 ) -> (
241 HeapArray<T, MDiv3>,
242 HeapArray<T, MDiv3>,
243 HeapArray<T, MDiv3>,
244 )
245 where
246 MDiv3: Positive + Mul<U3, Output = M>,
247 {
248 let (m1, m_rest) = self.split_at(MDiv3::USIZE);
249 let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
250 (
251 HeapArray::<T, MDiv3> {
252 data: m1.into(),
253 _len: PhantomData,
254 },
255 HeapArray::<T, MDiv3> {
256 data: m2.into(),
257 _len: PhantomData,
258 },
259 HeapArray::<T, MDiv3> {
260 data: m3.into(),
261 _len: PhantomData,
262 },
263 )
264 }
265
266 pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
267 where
268 M: Mul<U3, Output: Positive>,
269 {
270 let mut vec = first.data.into_vec();
271 vec.extend(second.data.into_vec());
272 vec.extend(third.data.into_vec());
273 HeapArray::<T, Prod<M, U3>> {
274 data: vec.into_boxed_slice(),
275 _len: PhantomData,
276 }
277 }
278}
279
280pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
281 pub HeapArray<T1, M>,
282 pub HeapArray<T2, M>,
283);
284
285impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
286 fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
287 let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
288
289 assert_eq!(data1.len(), M::USIZE);
290 assert_eq!(data2.len(), M::USIZE);
291 HeapArrayTuple(
292 HeapArray::<T1, M> {
293 data: data1.into_boxed_slice(),
294 _len: PhantomData,
295 },
296 HeapArray::<T2, M> {
297 data: data2.into_boxed_slice(),
298 _len: PhantomData,
299 },
300 )
301 }
302}
303
304impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
305 type Error = PrimitiveError;
306
307 fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
308 if data.len() == M::USIZE {
309 Ok(Self {
310 data: data.into_boxed_slice(),
311 _len: PhantomData,
312 })
313 } else {
314 Err(PrimitiveError::InvalidSize(M::USIZE, data.len()))
315 }
316 }
317}
318
319impl<T: Sized> From<T> for HeapArray<T, U1> {
320 fn from(element: T) -> Self {
321 Self {
322 data: Box::new([element]),
323 _len: PhantomData,
324 }
325 }
326}
327
328impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
329 fn from(array: HeapArray<T, M>) -> Self {
330 array.data.into_vec()
331 }
332}
333
334impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
335 fn from(array: Array<T, M>) -> Self {
336 Self {
337 data: array.to_vec().into_boxed_slice(),
338 _len: PhantomData,
339 }
340 }
341}
342
343impl<T: Sized, M: Positive> HeapArray<T, M> {
344 pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
345 where
346 M: Sub<N, Output: Positive>,
347 {
348 let mut vec = self.data.into_vec();
349 let last_n = vec.split_off(M::USIZE - N::USIZE);
350
351 (
352 HeapArray {
353 data: vec.into_boxed_slice(),
354 _len: PhantomData,
355 },
356 HeapArray {
357 data: last_n.into_boxed_slice(),
358 _len: PhantomData,
359 },
360 )
361 }
362
363 pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
364 Self {
365 data: (0..M::USIZE).map(f).collect::<Box<_>>(),
366 _len: PhantomData,
367 }
368 }
369
370 pub fn from_constant(c: T) -> Self
371 where
372 T: Copy,
373 {
374 Self {
375 data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
376 _len: PhantomData,
377 }
378 }
379}
380
381impl<T: Sized, M: Positive> Display for HeapArray<T, M>
382where
383 T: Display,
384{
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 write!(f, "[")?;
387 for (i, item) in self.data.iter().enumerate() {
388 if i != 0 {
389 write!(f, ", ")?;
390 }
391 write!(f, "{item}")?;
392 }
393 write!(f, "]")
394 }
395}
396
397#[cfg(test)]
398pub mod tests {
399 use hybrid_array::sizes::{U2, U3, U6};
400 use typenum::{U1, U4};
401
402 use super::*;
403
404 #[test]
405 fn test_heap_array() {
406 let array = HeapArray::<_, U3>::from_fn(|i| i);
407 assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
408 }
409
410 #[test]
411 fn test_default() {
412 let array = HeapArray::<usize, U3>::default();
413 assert_eq!(array.len(), 3);
414 }
415
416 #[test]
417 fn test_heap_array_split_last() {
418 let array = HeapArray::<_, U6>::from_fn(|i| i);
419 let (first, last) = array.split_last::<U2>();
420 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
421 assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
422 }
423
424 #[test]
425 fn test_heap_array_from_array() {
426 let array = Array::<_, U3>::from_fn(|i| i);
427 let heap_array = HeapArray::<_, U3>::from(array);
428 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
429 }
430
431 #[test]
432 fn test_heap_array_from_vec() {
433 let vec = vec![0, 1, 2];
434 let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
435 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
436
437 let vec = vec![0, 1];
438 let heap_array = HeapArray::<_, U3>::try_from(vec);
439 assert!(heap_array.is_err());
440 }
441
442 #[test]
443 fn test_heap_array_from_iter() {
444 let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
445 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
446 }
447
448 #[test]
449 #[should_panic]
450 fn test_heap_array_from_iter_wrong_size() {
451 HeapArray::<_, U2>::from_iter(0..3);
452 }
453
454 #[test]
455 fn test_heap_array_deserialize() {
456 let array = HeapArray::<usize, U6>::from_fn(|i| i);
457 let serialized = bincode::serialize(&array).unwrap();
458 bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
459
460 use bincode::Options;
464 let config = bincode::DefaultOptions::new()
465 .with_fixint_encoding()
466 .reject_trailing_bytes();
467
468 let wrong_deserialize = config.deserialize::<HeapArray<usize, U3>>(&serialized);
469 assert!(wrong_deserialize.is_err());
470 }
471
472 #[test]
473 fn test_heap_array_split() {
474 let array = HeapArray::<_, U6>::from_fn(|i| i);
475 let (first, second) = array.split::<U4, U2>();
476 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
477 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
478
479 let (first, second) = array.split_halves::<U3>();
480 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
481 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
482
483 }
487
488 #[test]
489 fn test_heap_array_split3() {
490 let array = HeapArray::<_, U6>::from_fn(|i| i);
491 let (first, second, third) = array.split3::<U3, U2, U1>();
492 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
493 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
494 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
495
496 let (first, second, third) = array.split_thirds::<U2>();
497 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
498 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
499 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
500
501 }
505}