1use std::hash::{BuildHasher, Hash};
16
17pub trait Key: Send + Sync + 'static + Hash + Eq {}
19pub trait Value: Send + Sync + 'static {}
21
22impl<T: Send + Sync + 'static + std::hash::Hash + Eq> Key for T {}
23impl<T: Send + Sync + 'static> Value for T {}
24
25pub trait HashBuilder: BuildHasher + Send + Sync + 'static {}
27impl<T> HashBuilder for T where T: BuildHasher + Send + Sync + 'static {}
28
29#[derive(Debug, thiserror::Error)]
31pub enum CodeError {
32    #[error("exceed size limit")]
34    SizeLimit,
35    #[error("io error: {0}")]
37    Io(std::io::Error),
38    #[cfg(feature = "serde")]
39    #[error("bincode error: {0}")]
41    Bincode(bincode::Error),
42    #[error("unrecognized data: {0:?}")]
44    Unrecognized(Vec<u8>),
45    #[error("other error: {0}")]
47    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
48}
49
50pub type CodeResult<T> = std::result::Result<T, CodeError>;
52
53impl From<std::io::Error> for CodeError {
54    fn from(err: std::io::Error) -> Self {
55        match err.kind() {
56            std::io::ErrorKind::WriteZero => Self::SizeLimit,
57            _ => Self::Io(err),
58        }
59    }
60}
61
62#[cfg(feature = "serde")]
63impl From<bincode::Error> for CodeError {
64    fn from(err: bincode::Error) -> Self {
65        match *err {
66            bincode::ErrorKind::SizeLimit => Self::SizeLimit,
67            bincode::ErrorKind::Io(e) => e.into(),
68            _ => Self::Bincode(err),
69        }
70    }
71}
72
73pub trait StorageKey: Key + Code {}
75impl<T> StorageKey for T where T: Key + Code {}
76
77pub trait StorageValue: Value + 'static + Code {}
79impl<T> StorageValue for T where T: Value + Code {}
80
81pub trait Code {
91    fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError>;
93
94    fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
96    where
97        Self: Sized;
98
99    fn estimated_size(&self) -> usize;
103}
104
105#[cfg(feature = "serde")]
106impl<T> Code for T
107where
108    T: serde::Serialize + serde::de::DeserializeOwned,
109{
110    fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
111        bincode::serialize_into(writer, self).map_err(CodeError::from)
112    }
113
114    fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError> {
115        bincode::deserialize_from(reader).map_err(CodeError::from)
116    }
117
118    fn estimated_size(&self) -> usize {
119        bincode::serialized_size(self).unwrap() as usize
120    }
121}
122
123macro_rules! impl_serde_for_numeric_types {
124    ($($t:ty),*) => {
125        $(
126            #[cfg(not(feature = "serde"))]
127            impl Code for $t {
128                fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
129                    writer.write_all(&self.to_le_bytes()).map_err(CodeError::from)
130                }
131
132                fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError> {
133                    let mut buf = [0u8; std::mem::size_of::<$t>()];
134                    reader.read_exact(&mut buf).map_err(CodeError::from)?;
135                    Ok(<$t>::from_le_bytes(buf))
136                }
137
138                fn estimated_size(&self) -> usize {
139                    std::mem::size_of::<$t>()
140                }
141            }
142        )*
143    };
144}
145
146macro_rules! for_all_numeric_types {
147    ($macro:ident) => {
148        $macro! { u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64}
149    };
150}
151
152for_all_numeric_types! { impl_serde_for_numeric_types }
153
154#[cfg(not(feature = "serde"))]
155impl Code for bool {
156    fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
157        writer
158            .write_all(if *self { &[1u8] } else { &[0u8] })
159            .map_err(CodeError::from)
160    }
161
162    fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
163    where
164        Self: Sized,
165    {
166        let mut buf = [0u8; 1];
167        reader.read_exact(&mut buf).map_err(CodeError::from)?;
168        match buf[0] {
169            0 => Ok(false),
170            1 => Ok(true),
171            _ => Err(CodeError::Unrecognized(buf.to_vec())),
172        }
173    }
174
175    fn estimated_size(&self) -> usize {
176        1
177    }
178}
179
180#[cfg(not(feature = "serde"))]
181impl Code for Vec<u8> {
182    fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
183        self.len().encode(writer)?;
184        writer.write_all(self).map_err(CodeError::from)
185    }
186
187    #[expect(clippy::uninit_vec)]
188    fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
189    where
190        Self: Sized,
191    {
192        let len = usize::decode(reader)?;
193        let mut v = Vec::with_capacity(len);
194        unsafe {
195            v.set_len(len);
196        }
197        reader.read_exact(&mut v).map_err(CodeError::from)?;
198        Ok(v)
199    }
200
201    fn estimated_size(&self) -> usize {
202        std::mem::size_of::<usize>() + self.len()
203    }
204}
205
206#[cfg(not(feature = "serde"))]
207impl Code for String {
208    fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
209        self.len().encode(writer)?;
210        writer.write_all(self.as_bytes()).map_err(CodeError::from)
211    }
212
213    #[expect(clippy::uninit_vec)]
214    fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
215    where
216        Self: Sized,
217    {
218        let len = usize::decode(reader)?;
219        let mut v = Vec::with_capacity(len);
220        unsafe {
221            v.set_len(len);
222        }
223        reader.read_exact(&mut v).map_err(CodeError::from)?;
224        String::from_utf8(v).map_err(|e| CodeError::Unrecognized(e.into_bytes()))
225    }
226
227    fn estimated_size(&self) -> usize {
228        std::mem::size_of::<usize>() + self.len()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[cfg(feature = "serde")]
237    mod serde {
238        use super::*;
239
240        #[test]
241        fn test_encode_overflow() {
242            let mut buf = [0u8; 4];
243            assert!(matches! {1u64.encode(&mut buf.as_mut()), Err(CodeError::SizeLimit)});
244        }
245    }
246
247    #[cfg(not(feature = "serde"))]
248    mod non_serde {
249        use super::*;
250
251        #[test]
252        fn test_encode_overflow() {
253            let mut buf = [0u8; 4];
254            assert!(matches! {1u64.encode(&mut buf.as_mut()), Err(CodeError::SizeLimit)});
255        }
256
257        macro_rules! impl_serde_test_for_numeric_types {
258            ($($t:ty),*) => {
259                paste::paste! {
260                    $(
261                        #[test]
262                        fn [<test_ $t _serde>]() {
263                            for a in [0 as $t, <$t>::MIN, <$t>::MAX] {
264                                let mut buf = vec![0xffu8; a.estimated_size()];
265                                a.encode(&mut buf.as_mut_slice()).unwrap();
266                                let b = <$t>::decode(&mut buf.as_slice()).unwrap();
267                                assert_eq!(a, b);
268                            }
269                        }
270                    )*
271                }
272            };
273        }
274
275        for_all_numeric_types! { impl_serde_test_for_numeric_types }
276
277        #[test]
278        fn test_bool_serde() {
279            let a = true;
280            let mut buf = vec![0xffu8; a.estimated_size()];
281            a.encode(&mut buf.as_mut_slice()).unwrap();
282            let b = bool::decode(&mut buf.as_slice()).unwrap();
283            assert_eq!(a, b);
284        }
285
286        #[test]
287        fn test_vec_u8_serde() {
288            let mut a = vec![0u8; 42];
289            rand::fill(&mut a[..]);
290            let mut buf = vec![0xffu8; a.estimated_size()];
291            a.encode(&mut buf.as_mut_slice()).unwrap();
292            let b = Vec::<u8>::decode(&mut buf.as_slice()).unwrap();
293            assert_eq!(a, b);
294        }
295
296        #[test]
297        fn test_string_serde() {
298            let a = "hello world".to_string();
299            let mut buf = vec![0xffu8; a.estimated_size()];
300            a.encode(&mut buf.as_mut_slice()).unwrap();
301            let b = String::decode(&mut buf.as_slice()).unwrap();
302            assert_eq!(a, b);
303        }
304    }
305}