1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ops::{Add, Mul, Sub},
5 vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21 #[deref]
22 #[deref_mut]
23 #[as_ref(forward)]
24 #[as_mut(forward)]
25 pub(super) data: Box<[T]>,
26 #[into_iterator(ignore)]
27 pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Pod, M: Positive> HeapArray<T, M> {
34 pub fn into_box_bytes(self) -> BoxBytes {
35 box_bytes_of(self.data)
36 }
37
38 pub fn from_box_bytes(buf: BoxBytes) -> Self {
39 Self {
40 data: from_box_bytes(buf),
41 _len: PhantomData,
42 }
43 }
44}
45
46impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
47 pub fn from_single_value(val: T) -> Self {
48 Self {
49 data: vec![val].into_boxed_slice(),
50 _len: PhantomData,
51 }
52 }
53}
54
55impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
56 fn default() -> Self {
57 Self {
58 data: (0..M::USIZE).map(|_| T::default()).collect(),
59 _len: PhantomData,
60 }
61 }
62}
63
64impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
67 .field("data", &self.data)
68 .finish()
69 }
70}
71
72impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
73 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
74 self.data.serialize(serializer)
75 }
76}
77
78impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
79 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
80 let data = Box::<[T]>::deserialize(deserializer)?;
81
82 if data.len() != M::USIZE {
83 return Err(serde::de::Error::custom(format!(
84 "Expected array of length {}, got {}",
85 M::USIZE,
86 data.len()
87 )));
88 }
89
90 Ok(Self {
91 data,
92 _len: PhantomData,
93 })
94 }
95}
96
97impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
98 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
99 let data = iter.into_iter().collect::<Box<_>>();
100 assert_eq!(data.len(), M::USIZE,);
101 Self {
102 data,
103 _len: PhantomData,
104 }
105 }
106}
107
108impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
109 type Item = T;
110 type Iter = rayon::vec::IntoIter<T>;
111
112 fn into_par_iter(self) -> Self::Iter {
113 self.data.into_par_iter()
114 }
115}
116
117impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
122 pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
123 where
124 M1: Positive,
125 M2: Positive + Add<M1, Output = M>,
126 {
127 let (m1, m2) = self.split_at(M1::USIZE);
128 (
129 HeapArray::<T, M1> {
130 data: m1.into(),
131 _len: PhantomData,
132 },
133 HeapArray::<T, M2> {
134 data: m2.into(),
135 _len: PhantomData,
136 },
137 )
138 }
139
140 pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
141 where
142 MDiv2: Positive + Mul<U2, Output = M>,
143 {
144 let (m1, m2) = self.split_at(MDiv2::USIZE);
145 (
146 HeapArray::<T, MDiv2> {
147 data: m1.into(),
148 _len: PhantomData,
149 },
150 HeapArray::<T, MDiv2> {
151 data: m2.into(),
152 _len: PhantomData,
153 },
154 )
155 }
156
157 pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
158 where
159 M: Mul<U2, Output: Positive>,
160 {
161 let mut vec = this.data.into_vec();
162 vec.extend(other.data.into_vec());
163 HeapArray::<T, Prod<M, U2>> {
164 data: vec.into_boxed_slice(),
165 _len: PhantomData,
166 }
167 }
168
169 pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
170 where
171 M1: Positive,
172 M2: Positive + Add<M1>,
173 M3: Positive + Add<Sum<M2, M1>, Output = M>,
174 {
175 let (m1, m_rest) = self.split_at(M1::USIZE);
176 let (m2, m3) = m_rest.split_at(M2::USIZE);
177 (
178 HeapArray::<T, M1> {
179 data: m1.into(),
180 _len: PhantomData,
181 },
182 HeapArray::<T, M2> {
183 data: m2.into(),
184 _len: PhantomData,
185 },
186 HeapArray::<T, M3> {
187 data: m3.into(),
188 _len: PhantomData,
189 },
190 )
191 }
192
193 pub fn split_thirds<MDiv3>(
194 &self,
195 ) -> (
196 HeapArray<T, MDiv3>,
197 HeapArray<T, MDiv3>,
198 HeapArray<T, MDiv3>,
199 )
200 where
201 MDiv3: Positive + Mul<U3, Output = M>,
202 {
203 let (m1, m_rest) = self.split_at(MDiv3::USIZE);
204 let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
205 (
206 HeapArray::<T, MDiv3> {
207 data: m1.into(),
208 _len: PhantomData,
209 },
210 HeapArray::<T, MDiv3> {
211 data: m2.into(),
212 _len: PhantomData,
213 },
214 HeapArray::<T, MDiv3> {
215 data: m3.into(),
216 _len: PhantomData,
217 },
218 )
219 }
220
221 pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
222 where
223 M: Mul<U3, Output: Positive>,
224 {
225 let mut vec = first.data.into_vec();
226 vec.extend(second.data.into_vec());
227 vec.extend(third.data.into_vec());
228 HeapArray::<T, Prod<M, U3>> {
229 data: vec.into_boxed_slice(),
230 _len: PhantomData,
231 }
232 }
233}
234
235pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
236 pub HeapArray<T1, M>,
237 pub HeapArray<T2, M>,
238);
239
240impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
241 fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
242 let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
243
244 assert_eq!(data1.len(), M::USIZE);
245 assert_eq!(data2.len(), M::USIZE);
246 HeapArrayTuple(
247 HeapArray::<T1, M> {
248 data: data1.into_boxed_slice(),
249 _len: PhantomData,
250 },
251 HeapArray::<T2, M> {
252 data: data2.into_boxed_slice(),
253 _len: PhantomData,
254 },
255 )
256 }
257}
258
259impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
260 type Error = PrimitiveError;
261
262 fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
263 if data.len() == M::USIZE {
264 Ok(Self {
265 data: data.into_boxed_slice(),
266 _len: PhantomData,
267 })
268 } else {
269 Err(PrimitiveError::SizeError(M::USIZE, data.len()))
270 }
271 }
272}
273
274impl<T: Sized> From<T> for HeapArray<T, U1> {
275 fn from(element: T) -> Self {
276 Self {
277 data: Box::new([element]),
278 _len: PhantomData,
279 }
280 }
281}
282
283impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
284 fn from(array: HeapArray<T, M>) -> Self {
285 array.data.into_vec()
286 }
287}
288
289impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
290 fn from(array: Array<T, M>) -> Self {
291 Self {
292 data: array.to_vec().into_boxed_slice(),
293 _len: PhantomData,
294 }
295 }
296}
297
298impl<T: Sized, M: Positive> HeapArray<T, M> {
299 pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
300 where
301 M: Sub<N, Output: Positive>,
302 {
303 let mut vec = self.data.into_vec();
304 let last_n = vec.split_off(M::USIZE - N::USIZE);
305
306 (
307 HeapArray {
308 data: vec.into_boxed_slice(),
309 _len: PhantomData,
310 },
311 HeapArray {
312 data: last_n.into_boxed_slice(),
313 _len: PhantomData,
314 },
315 )
316 }
317
318 pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
319 Self {
320 data: (0..M::USIZE).map(f).collect::<Box<_>>(),
321 _len: PhantomData,
322 }
323 }
324
325 pub fn from_constant(c: T) -> Self
326 where
327 T: Copy,
328 {
329 Self {
330 data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
331 _len: PhantomData,
332 }
333 }
334}
335
336impl<T: Sized, M: Positive> Display for HeapArray<T, M>
337where
338 T: Display,
339{
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 write!(f, "[")?;
342 for (i, item) in self.data.iter().enumerate() {
343 if i != 0 {
344 write!(f, ", ")?;
345 }
346 write!(f, "{item}")?;
347 }
348 write!(f, "]")
349 }
350}
351
352#[cfg(test)]
353pub mod tests {
354 use hybrid_array::sizes::{U2, U3, U6};
355 use typenum::{U1, U4};
356
357 use super::*;
358
359 #[test]
360 fn test_heap_array() {
361 let array = HeapArray::<_, U3>::from_fn(|i| i);
362 assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
363 }
364
365 #[test]
366 fn test_default() {
367 let array = HeapArray::<usize, U3>::default();
368 assert_eq!(array.len(), 3);
369 }
370
371 #[test]
372 fn test_heap_array_split_last() {
373 let array = HeapArray::<_, U6>::from_fn(|i| i);
374 let (first, last) = array.split_last::<U2>();
375 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
376 assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
377 }
378
379 #[test]
380 fn test_heap_array_from_array() {
381 let array = Array::<_, U3>::from_fn(|i| i);
382 let heap_array = HeapArray::<_, U3>::from(array);
383 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
384 }
385
386 #[test]
387 fn test_heap_array_from_vec() {
388 let vec = vec![0, 1, 2];
389 let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
390 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
391
392 let vec = vec![0, 1];
393 let heap_array = HeapArray::<_, U3>::try_from(vec);
394 assert!(heap_array.is_err());
395 }
396
397 #[test]
398 fn test_heap_array_from_iter() {
399 let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
400 assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
401 }
402
403 #[test]
404 #[should_panic]
405 fn test_heap_array_from_iter_wrong_size() {
406 HeapArray::<_, U2>::from_iter(0..3);
407 }
408
409 #[test]
410 fn test_heap_array_deserialize() {
411 let array = HeapArray::<usize, U6>::from_fn(|i| i);
412 let serialized = bincode::serialize(&array).unwrap();
413 bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
414
415 let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
416 assert!(wrong_deserialize.is_err());
417 }
418
419 #[test]
420 fn test_heap_array_split() {
421 let array = HeapArray::<_, U6>::from_fn(|i| i);
422 let (first, second) = array.split::<U4, U2>();
423 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
424 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
425
426 let (first, second) = array.split_halves::<U3>();
427 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
428 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
429
430 }
434
435 #[test]
436 fn test_heap_array_split3() {
437 let array = HeapArray::<_, U6>::from_fn(|i| i);
438 let (first, second, third) = array.split3::<U3, U2, U1>();
439 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
440 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
441 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
442
443 let (first, second, third) = array.split_thirds::<U2>();
444 assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
445 assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
446 assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
447
448 }
452}