Skip to main content

spl_pod/
primitives.rs

1//! primitive types that can be used in `Pod`s
2#[cfg(feature = "borsh")]
3use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
4use bytemuck_derive::{Pod, Zeroable};
5#[cfg(feature = "serde-traits")]
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "wincode")]
8use wincode::{SchemaRead, SchemaWrite};
9
10/// The standard `bool` is not a `Pod`, define a replacement that is
11#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
12#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
13#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
14#[cfg_attr(feature = "serde-traits", serde(from = "bool", into = "bool"))]
15#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
16#[repr(transparent)]
17pub struct PodBool(pub u8);
18impl PodBool {
19    pub const fn from_bool(b: bool) -> Self {
20        Self(if b { 1 } else { 0 })
21    }
22}
23
24impl From<bool> for PodBool {
25    fn from(b: bool) -> Self {
26        Self::from_bool(b)
27    }
28}
29
30impl From<&bool> for PodBool {
31    fn from(b: &bool) -> Self {
32        Self(if *b { 1 } else { 0 })
33    }
34}
35
36impl From<&PodBool> for bool {
37    fn from(b: &PodBool) -> Self {
38        b.0 != 0
39    }
40}
41
42impl From<PodBool> for bool {
43    fn from(b: PodBool) -> Self {
44        b.0 != 0
45    }
46}
47
48/// Simple macro for implementing conversion functions between Pod* integers and
49/// standard integers.
50///
51/// The standard integer types can cause alignment issues when placed in a `Pod`,
52/// so these replacements are usable in all `Pod`s.
53#[macro_export]
54macro_rules! impl_int_conversion {
55    ($P:ty, $I:ty) => {
56        impl $P {
57            pub const fn from_primitive(n: $I) -> Self {
58                Self(n.to_le_bytes())
59            }
60        }
61        impl From<$I> for $P {
62            fn from(n: $I) -> Self {
63                Self::from_primitive(n)
64            }
65        }
66        impl From<$P> for $I {
67            fn from(pod: $P) -> Self {
68                Self::from_le_bytes(pod.0)
69            }
70        }
71    };
72}
73
74/// `u16` type that can be used in `Pod`s
75#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
76#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
77#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
78#[cfg_attr(feature = "serde-traits", serde(from = "u16", into = "u16"))]
79#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
80#[repr(transparent)]
81pub struct PodU16(pub [u8; 2]);
82impl_int_conversion!(PodU16, u16);
83
84/// `i16` type that can be used in Pods
85#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
86#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
87#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
88#[cfg_attr(feature = "serde-traits", serde(from = "i16", into = "i16"))]
89#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
90#[repr(transparent)]
91pub struct PodI16(pub [u8; 2]);
92impl_int_conversion!(PodI16, i16);
93
94/// `u32` type that can be used in `Pod`s
95#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
96#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
97#[cfg_attr(
98    feature = "borsh",
99    derive(BorshDeserialize, BorshSerialize, BorshSchema)
100)]
101#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
102#[cfg_attr(feature = "serde-traits", serde(from = "u32", into = "u32"))]
103#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
104#[repr(transparent)]
105pub struct PodU32(pub [u8; 4]);
106impl_int_conversion!(PodU32, u32);
107
108/// `u64` type that can be used in Pods
109#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
110#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
111#[cfg_attr(
112    feature = "borsh",
113    derive(BorshDeserialize, BorshSerialize, BorshSchema)
114)]
115#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
116#[cfg_attr(feature = "serde-traits", serde(from = "u64", into = "u64"))]
117#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
118#[repr(transparent)]
119pub struct PodU64(pub [u8; 8]);
120impl_int_conversion!(PodU64, u64);
121
122/// `i64` type that can be used in Pods
123#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
124#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
125#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
126#[cfg_attr(feature = "serde-traits", serde(from = "i64", into = "i64"))]
127#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
128#[repr(transparent)]
129pub struct PodI64([u8; 8]);
130impl_int_conversion!(PodI64, i64);
131
132/// `u128` type that can be used in Pods
133#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
134#[cfg_attr(feature = "wincode", wincode(assert_zero_copy))]
135#[cfg_attr(
136    feature = "borsh",
137    derive(BorshDeserialize, BorshSerialize, BorshSchema)
138)]
139#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
140#[cfg_attr(feature = "serde-traits", serde(from = "u128", into = "u128"))]
141#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
142#[repr(transparent)]
143pub struct PodU128(pub [u8; 16]);
144impl_int_conversion!(PodU128, u128);
145
146#[cfg(test)]
147mod tests {
148    use {super::*, crate::bytemuck::pod_from_bytes};
149
150    #[test]
151    fn test_pod_bool() {
152        assert!(pod_from_bytes::<PodBool>(&[]).is_err());
153        assert!(pod_from_bytes::<PodBool>(&[0, 0]).is_err());
154
155        for i in 0..=u8::MAX {
156            assert_eq!(i != 0, bool::from(pod_from_bytes::<PodBool>(&[i]).unwrap()));
157        }
158    }
159
160    #[cfg(feature = "serde-traits")]
161    #[test]
162    fn test_pod_bool_serde() {
163        let pod_false: PodBool = false.into();
164        let pod_true: PodBool = true.into();
165
166        let serialized_false = serde_json::to_string(&pod_false).unwrap();
167        let serialized_true = serde_json::to_string(&pod_true).unwrap();
168        assert_eq!(&serialized_false, "false");
169        assert_eq!(&serialized_true, "true");
170
171        let deserialized_false = serde_json::from_str::<PodBool>(&serialized_false).unwrap();
172        let deserialized_true = serde_json::from_str::<PodBool>(&serialized_true).unwrap();
173        assert_eq!(pod_false, deserialized_false);
174        assert_eq!(pod_true, deserialized_true);
175    }
176
177    #[test]
178    fn test_pod_u16() {
179        assert!(pod_from_bytes::<PodU16>(&[]).is_err());
180        assert_eq!(1u16, u16::from(*pod_from_bytes::<PodU16>(&[1, 0]).unwrap()));
181    }
182
183    #[cfg(feature = "serde-traits")]
184    #[test]
185    fn test_pod_u16_serde() {
186        let pod_u16: PodU16 = u16::MAX.into();
187
188        let serialized = serde_json::to_string(&pod_u16).unwrap();
189        assert_eq!(&serialized, "65535");
190
191        let deserialized = serde_json::from_str::<PodU16>(&serialized).unwrap();
192        assert_eq!(pod_u16, deserialized);
193    }
194
195    #[test]
196    fn test_pod_i16() {
197        assert!(pod_from_bytes::<PodI16>(&[]).is_err());
198        assert_eq!(
199            -1i16,
200            i16::from(*pod_from_bytes::<PodI16>(&[255, 255]).unwrap())
201        );
202    }
203
204    #[cfg(feature = "serde-traits")]
205    #[test]
206    fn test_pod_i16_serde() {
207        let pod_i16: PodI16 = i16::MAX.into();
208
209        println!("pod_i16 {:?}", pod_i16);
210
211        let serialized = serde_json::to_string(&pod_i16).unwrap();
212        assert_eq!(&serialized, "32767");
213
214        let deserialized = serde_json::from_str::<PodI16>(&serialized).unwrap();
215        assert_eq!(pod_i16, deserialized);
216    }
217
218    #[test]
219    fn test_pod_u64() {
220        assert!(pod_from_bytes::<PodU64>(&[]).is_err());
221        assert_eq!(
222            1u64,
223            u64::from(*pod_from_bytes::<PodU64>(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap())
224        );
225    }
226
227    #[cfg(feature = "serde-traits")]
228    #[test]
229    fn test_pod_u64_serde() {
230        let pod_u64: PodU64 = u64::MAX.into();
231
232        let serialized = serde_json::to_string(&pod_u64).unwrap();
233        assert_eq!(&serialized, "18446744073709551615");
234
235        let deserialized = serde_json::from_str::<PodU64>(&serialized).unwrap();
236        assert_eq!(pod_u64, deserialized);
237    }
238
239    #[test]
240    fn test_pod_i64() {
241        assert!(pod_from_bytes::<PodI64>(&[]).is_err());
242        assert_eq!(
243            -1i64,
244            i64::from(
245                *pod_from_bytes::<PodI64>(&[255, 255, 255, 255, 255, 255, 255, 255]).unwrap()
246            )
247        );
248    }
249
250    #[cfg(feature = "serde-traits")]
251    #[test]
252    fn test_pod_i64_serde() {
253        let pod_i64: PodI64 = i64::MAX.into();
254
255        let serialized = serde_json::to_string(&pod_i64).unwrap();
256        assert_eq!(&serialized, "9223372036854775807");
257
258        let deserialized = serde_json::from_str::<PodI64>(&serialized).unwrap();
259        assert_eq!(pod_i64, deserialized);
260    }
261
262    #[test]
263    fn test_pod_u128() {
264        assert!(pod_from_bytes::<PodU128>(&[]).is_err());
265        assert_eq!(
266            1u128,
267            u128::from(
268                *pod_from_bytes::<PodU128>(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
269                    .unwrap()
270            )
271        );
272    }
273
274    #[cfg(feature = "serde-traits")]
275    #[test]
276    fn test_pod_u128_serde() {
277        let pod_u128: PodU128 = u128::MAX.into();
278
279        let serialized = serde_json::to_string(&pod_u128).unwrap();
280        assert_eq!(&serialized, "340282366920938463463374607431768211455");
281
282        let deserialized = serde_json::from_str::<PodU128>(&serialized).unwrap();
283        assert_eq!(pod_u128, deserialized);
284    }
285
286    #[cfg(feature = "wincode")]
287    mod wincode_tests {
288        use {super::*, test_case::test_case};
289
290        #[test_case(PodBool::from_bool(true))]
291        #[test_case(PodBool::from_bool(false))]
292        #[test_case(PodU16::from_primitive(u16::MAX))]
293        #[test_case(PodI16::from_primitive(i16::MIN))]
294        #[test_case(PodU32::from_primitive(u32::MAX))]
295        #[test_case(PodU64::from_primitive(u64::MAX))]
296        #[test_case(PodI64::from_primitive(i64::MIN))]
297        #[test_case(PodU128::from_primitive(u128::MAX))]
298        fn wincode_roundtrip<
299            T: PartialEq
300                + std::fmt::Debug
301                + for<'de> wincode::SchemaRead<'de, wincode::config::DefaultConfig, Dst = T>
302                + wincode::SchemaWrite<wincode::config::DefaultConfig, Src = T>,
303        >(
304            pod: T,
305        ) {
306            let bytes = wincode::serialize(&pod).unwrap();
307            let deserialized: T = wincode::deserialize(&bytes).unwrap();
308            assert_eq!(pod, deserialized);
309        }
310    }
311}