Skip to main content

aeon_tk/
array.rs

1//! This module contains several "hacky" utilities to work around the current lack of generic_const_exprs in
2//! Rust.
3//!
4//! In several parts of the codebase  we have to creates arrays whose
5//! length is determined by an associated constant of a trait. Seeing as this is impossible in stable Rust, we instead
6//! use the following pattern:
7//!
8//! ```rust
9//! # use aeon::array::ArrayLike;
10//! trait MyTrait {
11//!     type Weights: ArrayLike<usize, Elem = f64>;
12//! }
13//!
14//! struct MyStruct;
15//!
16//! impl MyTrait for MyStruct {
17//!     type Weights = [f64; 10];
18//! }
19//! ```
20//!
21//! This module also implements various utilities for serializing arbitrary length arrays.
22//! The current version of `serde` hasn't been able to do this as it breaks backwards
23//! compatibility for zero length arrays.
24
25use serde::de::{Deserialize, Deserializer, Error as _};
26use serde::ser::{Serialize, SerializeTuple as _, Serializer};
27use std::marker::PhantomData;
28use std::mem::MaybeUninit;
29use std::{
30    fmt::Debug,
31    ops::{Index, IndexMut},
32};
33
34/// A helper trait for array types which can be indexed and iterated, with
35/// compile time known length. Use of this trait can be removed if `generic_const_exprs`
36/// is ever stabilized.
37pub trait ArrayLike<Idx>:
38    Index<Idx, Output = Self::Elem> + IndexMut<Idx, Output = Self::Elem>
39{
40    /// Length of array, known at compile time.
41    const LEN: usize;
42
43    /// Type of elements in the array
44    type Elem;
45
46    /// Creates an array of the given length and type by repeatly calling the given function.
47    fn from_fn<F: FnMut(Idx) -> Self::Elem>(cb: F) -> Self;
48}
49
50impl<T, const N: usize> ArrayLike<usize> for [T; N] {
51    const LEN: usize = N;
52
53    type Elem = T;
54
55    fn from_fn<F: FnMut(usize) -> T>(cb: F) -> Self {
56        core::array::from_fn::<T, N, F>(cb)
57    }
58}
59
60/// A wrapper around an `ArrayLike` which implements several common traits depending on the elements of `I`.
61/// This includes `Default`, `Clone`, `From`, and serialization assuming the element type also satisfies those traits.
62#[repr(transparent)]
63#[derive(Clone, Debug)]
64pub struct ArrayWrap<T, const N: usize>(pub [T; N]);
65
66impl<T: Serialize, const N: usize> Serialize for ArrayWrap<T, N> {
67    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
68    where
69        S: serde::Serializer,
70    {
71        let mut seq = serializer.serialize_tuple(N)?;
72
73        for i in 0..N {
74            seq.serialize_element(&self.0[i])?;
75        }
76
77        seq.end()
78    }
79}
80
81impl<'de, T, const N: usize> Deserialize<'de> for ArrayWrap<T, N>
82where
83    T: Deserialize<'de>,
84{
85    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86    where
87        D: serde::Deserializer<'de>,
88    {
89        /// Visitor for deserializing ArrayWrap<T, N>
90        struct Visitor<T, const N: usize>(PhantomData<[T; N]>);
91
92        impl<'de, T: Deserialize<'de>, const N: usize> serde::de::Visitor<'de> for Visitor<T, N> {
93            type Value = ArrayWrap<T, N>;
94
95            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
96                formatter.write_fmt(format_args!("Array of Length {}", N))
97            }
98
99            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
100            where
101                A: serde::de::SeqAccess<'de>,
102            {
103                let mut arr = [const { MaybeUninit::<T>::uninit() }; N];
104
105                let mut i = 0;
106
107                let err = loop {
108                    if i >= N {
109                        break None;
110                    }
111
112                    let elem = seq.next_element::<T>();
113
114                    match elem {
115                        Ok(Some(val)) => arr[i] = MaybeUninit::new(val),
116                        Ok(None) => {
117                            break Some(A::Error::custom::<String>(String::from(
118                                "Sequence length does not match array length",
119                            )));
120                        }
121                        Err(e) => break Some(e),
122                    }
123
124                    i += 1;
125                };
126
127                if let Some(e) = err {
128                    for item in arr.iter_mut().take(i) {
129                        unsafe {
130                            item.assume_init_drop();
131                        }
132                    }
133
134                    return Err(e);
135                }
136
137                Ok(ArrayWrap(unsafe {
138                    std::mem::transmute_copy::<_, [T; N]>(&arr)
139                }))
140            }
141        }
142
143        deserializer.deserialize_tuple(N, Visitor::<T, N>(PhantomData))
144    }
145}
146
147pub fn serialize<S, T, const N: usize>(data: &[T; N], ser: S) -> Result<S::Ok, S::Error>
148where
149    S: Serializer,
150    T: Serialize,
151{
152    let arr: &ArrayWrap<T, N> = unsafe { std::mem::transmute(data) };
153    arr.serialize(ser)
154}
155
156/// Deserialize const generic or arbitrarily-large arrays
157///
158/// For any array up to length `usize::MAX`, this function will allow Serde to properly deserialize
159/// it, provided the type `T` itself is deserializable.
160///
161/// This implementation is adapted from the [Serde documentation][deserialize_map].
162///
163/// [deserialize_map]: https://serde.rs/deserialize-map.html
164pub fn deserialize<'de, D, T, const N: usize>(deserialize: D) -> Result<[T; N], D::Error>
165where
166    D: Deserializer<'de>,
167    T: Deserialize<'de>,
168{
169    ArrayWrap::<T, N>::deserialize(deserialize).map(|val| val.0)
170}
171
172/// Contains methods for working with vecs of arrays.
173pub mod vec {
174    use super::ArrayWrap;
175    use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser::SerializeSeq};
176    use std::{fmt, marker::PhantomData};
177
178    /// Serialize vectors of const generic arrays.
179    pub fn serialize<S, T, const N: usize>(data: &Vec<[T; N]>, ser: S) -> Result<S::Ok, S::Error>
180    where
181        S: Serializer,
182        T: Serialize,
183    {
184        // See: https://serde.rs/impl-serialize.html#serializing-a-tuple
185        let mut s = ser.serialize_seq(Some(data.len()))?;
186        for array in data {
187            let array = unsafe { std::mem::transmute::<&[T; N], &ArrayWrap<T, N>>(array) };
188            s.serialize_element(array)?;
189        }
190        s.end()
191    }
192
193    /// Deserialize vectors of const generic arrays.
194    ///
195    /// For any array up to length `usize::MAX`, this function will allow Serde to properly deserialize
196    /// it, provided the type `T` itself is deserializable.
197    ///
198    /// This implementation is adapted from the [Serde documentation][deserialize_map].
199    ///
200    /// [deserialize_map]: https://serde.rs/deserialize-map.html
201    pub fn deserialize<'de, D, T, const N: usize>(deserialize: D) -> Result<Vec<[T; N]>, D::Error>
202    where
203        D: Deserializer<'de>,
204        T: Deserialize<'de>,
205    {
206        /// A Serde Deserializer `Visitor` for Vec<[T; N]> arrays
207        struct Visitor<T, const N: usize> {
208            _marker: PhantomData<T>,
209        }
210
211        impl<'de, T, const N: usize> de::Visitor<'de> for Visitor<T, N>
212        where
213            T: Deserialize<'de>,
214        {
215            type Value = Vec<[T; N]>;
216
217            /// Format a message stating we expect an array of size `N`
218            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219                write!(formatter, "a vector of arrays of size {}", N)
220            }
221
222            /// Process a sequence into an array
223            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
224            where
225                A: de::SeqAccess<'de>,
226            {
227                let mut arr: Vec<[T; N]> = Vec::new();
228
229                if let Some(size) = seq.size_hint() {
230                    arr.reserve(size);
231                }
232
233                loop {
234                    match seq.next_element() {
235                        Ok(Some(ArrayWrap(val))) => arr.push(val),
236                        Ok(None) => break,
237                        Err(e) => return Err(e),
238                    }
239                }
240
241                Ok(arr)
242            }
243        }
244
245        deserialize.deserialize_seq(Visitor::<T, N> {
246            _marker: PhantomData,
247        })
248    }
249}