1use crate::Nullable;
10#[cfg(feature = "bytemuck")]
11use bytemuck::{Pod, Zeroable};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14#[cfg(feature = "wincode")]
15use wincode::{SchemaRead, SchemaWrite};
16#[cfg(feature = "borsh")]
17use {
18 alloc::format,
19 borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
20};
21
22#[repr(transparent)]
28#[cfg_attr(
29 feature = "borsh",
30 derive(BorshDeserialize, BorshSerialize, BorshSchema)
31)]
32#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
33#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
34pub struct MaybeNull<T: Nullable>(T);
35
36#[cfg(feature = "wincode")]
40unsafe impl<T, C> wincode::config::ZeroCopy<C> for MaybeNull<T>
41where
42 C: wincode::config::ConfigCore,
43 T: Nullable + wincode::config::ZeroCopy<C>,
44{
45}
46
47impl<T: Nullable> Default for MaybeNull<T> {
48 fn default() -> Self {
49 Self(T::NONE)
50 }
51}
52
53impl<T: Nullable> MaybeNull<T> {
54 #[inline]
56 pub fn get(self) -> Option<T> {
57 if self.0.is_none() {
58 None
59 } else {
60 Some(self.0)
61 }
62 }
63
64 #[inline]
66 pub fn as_ref(&self) -> Option<&T> {
67 if self.0.is_none() {
68 None
69 } else {
70 Some(&self.0)
71 }
72 }
73
74 #[inline]
76 pub fn as_mut(&mut self) -> Option<&mut T> {
77 if self.0.is_none() {
78 None
79 } else {
80 Some(&mut self.0)
81 }
82 }
83
84 #[inline]
86 pub fn copied(&self) -> Option<T>
87 where
88 T: Copy,
89 {
90 self.as_ref().copied()
91 }
92
93 #[inline]
95 pub fn cloned(&self) -> Option<T>
96 where
97 T: Clone,
98 {
99 self.as_ref().cloned()
100 }
101}
102
103impl<T: Nullable> From<T> for MaybeNull<T> {
104 fn from(value: T) -> Self {
105 MaybeNull(value)
106 }
107}
108
109impl<T: Nullable> From<MaybeNull<T>> for Option<T> {
110 fn from(value: MaybeNull<T>) -> Self {
111 value.get()
112 }
113}
114
115impl<T: Nullable> TryFrom<Option<T>> for MaybeNull<T> {
116 type Error = MaybeNullError;
117
118 fn try_from(value: Option<T>) -> Result<Self, Self::Error> {
119 match value {
120 Some(value) if value.is_none() => Err(MaybeNullError::NoneValueInSome),
121 Some(value) => Ok(MaybeNull(value)),
122 None => Ok(MaybeNull(T::NONE)),
123 }
124 }
125}
126
127#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum MaybeNullError {
130 NoneValueInSome,
132}
133
134impl core::fmt::Display for MaybeNullError {
135 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
136 match self {
137 Self::NoneValueInSome => {
138 write!(f, "cannot wrap None-equivalent value in Some")
139 }
140 }
141 }
142}
143
144#[cfg(feature = "bytemuck")]
149unsafe impl<T: Nullable + Pod> Pod for MaybeNull<T> {}
150
151#[cfg(feature = "bytemuck")]
156unsafe impl<T: Nullable + Zeroable> Zeroable for MaybeNull<T> {}
157
158#[cfg(feature = "serde")]
159impl<T> Serialize for MaybeNull<T>
160where
161 T: Nullable + Serialize,
162{
163 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
164 where
165 S: Serializer,
166 {
167 if self.0.is_none() {
168 serializer.serialize_none()
169 } else {
170 serializer.serialize_some(&self.0)
171 }
172 }
173}
174
175#[cfg(feature = "serde")]
176impl<'de, T> Deserialize<'de> for MaybeNull<T>
177where
178 T: Nullable + Deserialize<'de>,
179{
180 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181 where
182 D: Deserializer<'de>,
183 {
184 let option = Option::<T>::deserialize(deserializer)?;
185 match option {
186 Some(value) if value.is_none() => Err(serde::de::Error::custom(
187 "Invalid MaybeNull encoding: Some(value) cannot equal the None marker.",
188 )),
189 Some(value) => Ok(MaybeNull(value)),
190 None => Ok(MaybeNull(T::NONE)),
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 impl Nullable for u64 {
200 const NONE: Self = 0;
201 }
202
203 #[test]
204 fn test_try_from_option() {
205 let some = Some(42u64);
206 assert_eq!(MaybeNull::try_from(some).unwrap(), MaybeNull(42u64));
207
208 let none: Option<u64> = None;
209 assert_eq!(MaybeNull::try_from(none).unwrap(), MaybeNull::from(0u64));
210
211 let invalid = Some(0u64);
212 assert_eq!(
213 MaybeNull::try_from(invalid).unwrap_err(),
214 MaybeNullError::NoneValueInSome,
215 );
216 }
217
218 #[test]
219 fn test_from_maybe_null() {
220 let some = MaybeNull::from(42u64);
221 let none = MaybeNull::from(0u64);
222
223 assert_eq!(Option::<u64>::from(some), Some(42));
224 assert_eq!(Option::<u64>::from(none), None);
225 }
226
227 #[test]
228 fn test_default() {
229 let def = MaybeNull::<u64>::default();
230 assert_eq!(def, MaybeNull(0u64));
231 assert_eq!(def.get(), None);
232 }
233
234 #[test]
235 fn test_copied() {
236 let some = MaybeNull::from(42u64);
237 assert_eq!(some.copied(), Some(42));
238
239 let none = MaybeNull::from(0u64);
240 assert_eq!(none.copied(), None);
241 }
242
243 #[test]
244 fn test_nullable_predicates() {
245 assert!(u64::NONE.is_none());
246 assert!(!u64::NONE.is_some());
247 assert!(8u64.is_some());
248 assert!(!8u64.is_none());
249 }
250
251 #[test]
252 fn test_as_ref() {
253 let some = MaybeNull::from(8u64);
254 assert_eq!(some.as_ref(), Some(&8u64));
255
256 let none = MaybeNull::from(u64::NONE);
257 assert_eq!(none.as_ref(), None);
258 }
259
260 #[test]
261 fn test_as_mut() {
262 let mut some = MaybeNull::from(3u64);
263 assert!(some.as_mut().is_some());
264 *some.as_mut().unwrap() = 4;
265 assert_eq!(some.get(), Some(4));
266
267 let mut none = MaybeNull::from(0u64);
268 assert!(none.as_mut().is_none());
269 }
270
271 #[derive(Clone, Debug, PartialEq)]
272 struct TestNonCopyNullable([u8; 4]);
273
274 impl Nullable for TestNonCopyNullable {
275 const NONE: Self = Self([0u8; 4]);
276 }
277
278 #[test]
279 fn test_cloned_with_non_copy_nullable() {
280 let some = MaybeNull::from(TestNonCopyNullable([1, 2, 3, 4]));
281 assert_eq!(some.cloned(), Some(TestNonCopyNullable([1, 2, 3, 4])));
282
283 let none = MaybeNull::from(TestNonCopyNullable::NONE);
284 assert_eq!(none.cloned(), None);
285 }
286
287 #[cfg(feature = "borsh")]
288 mod borsh_tests {
289 use {super::*, alloc::vec};
290
291 #[test]
292 fn test_borsh_roundtrip_u64() {
293 let some = MaybeNull::from(42u64);
294 let none = MaybeNull::from(0u64);
295
296 let some_bytes = borsh::to_vec(&some).unwrap();
297 let none_bytes = borsh::to_vec(&none).unwrap();
298
299 assert_eq!(some_bytes, 42u64.to_le_bytes().to_vec());
300 assert_eq!(none_bytes, vec![0; 8]);
301 assert_eq!(
302 borsh::from_slice::<MaybeNull<u64>>(&some_bytes).unwrap(),
303 some
304 );
305 assert_eq!(
306 borsh::from_slice::<MaybeNull<u64>>(&none_bytes).unwrap(),
307 none
308 );
309 assert!(borsh::from_slice::<MaybeNull<u64>>(&[]).is_err());
310 }
311 }
312
313 #[cfg(feature = "wincode")]
314 mod wincode_tests {
315 use {super::*, wincode::ZeroCopy};
316
317 #[test]
318 fn test_wincode_maybe_null_roundtrip_and_size() {
319 let some = MaybeNull::from(9u64);
320 let none = MaybeNull::from(0u64);
321
322 let some_bytes = wincode::serialize(&some).unwrap();
323 let none_bytes = wincode::serialize(&none).unwrap();
324
325 assert_eq!(some_bytes.len(), core::mem::size_of::<u64>());
326 assert_eq!(none_bytes.len(), core::mem::size_of::<u64>());
327 assert_eq!(some_bytes.as_slice(), &9u64.to_le_bytes());
328 assert_eq!(none_bytes.as_slice(), &0u64.to_le_bytes());
329
330 let some_roundtrip: MaybeNull<u64> = wincode::deserialize(&some_bytes).unwrap();
331 let none_roundtrip: MaybeNull<u64> = wincode::deserialize(&none_bytes).unwrap();
332 assert_eq!(some_roundtrip, some);
333 assert_eq!(none_roundtrip, none);
334
335 let some_zero_copy = MaybeNull::<u64>::from_bytes(&some_bytes).unwrap();
336 let none_zero_copy = MaybeNull::<u64>::from_bytes(&none_bytes).unwrap();
337 assert_eq!(some_zero_copy, &some);
338 assert_eq!(none_zero_copy, &none);
339 }
340
341 #[test]
342 fn test_wincode_maybe_null_rejects_truncated_input() {
343 assert!(wincode::deserialize::<MaybeNull<u64>>(&[]).is_err());
344 assert!(wincode::deserialize::<MaybeNull<u64>>(&[0; 7]).is_err());
345 }
346 }
347
348 #[cfg(feature = "serde")]
349 mod serde_tests {
350 use {super::*, alloc::string::ToString};
351
352 #[test]
353 fn test_serde_u64_some() {
354 let some = MaybeNull::from(7u64);
355 let serialized = serde_json::to_string(&some).unwrap();
356 assert_eq!(serialized, "7");
357 let deserialized = serde_json::from_str::<MaybeNull<u64>>(&serialized).unwrap();
358 assert_eq!(deserialized, some);
359 }
360
361 #[test]
362 fn test_serde_u64_none() {
363 let deserialized = serde_json::from_str::<MaybeNull<u64>>("null").unwrap();
364 assert_eq!(deserialized, MaybeNull::from(0));
365 }
366
367 #[test]
368 fn test_serde_u64_none_marker_error_message() {
369 let err = serde_json::from_str::<MaybeNull<u64>>("0").unwrap_err();
370 let message = err.to_string();
371 assert!(message.contains("MaybeNull encoding"));
372 assert!(message.contains("None marker"));
373 }
374
375 #[test]
376 fn test_serde_u64_reject_invalid_input() {
377 assert!(serde_json::from_str::<MaybeNull<u64>>("\"abc\"").is_err());
378 assert!(serde_json::from_str::<MaybeNull<u64>>("{}").is_err());
379 }
380 }
381
382 #[cfg(feature = "bytemuck")]
383 mod bytemuck_tests {
384 use super::*;
385
386 #[test]
387 fn test_maybe_null_u64() {
388 let some = MaybeNull::from(42u64);
389 assert_eq!(some.get(), Some(42));
390
391 let none = MaybeNull::from(0u64);
392 assert_eq!(none.get(), None);
393
394 let bytes = 42u64.to_le_bytes();
395 let value: &MaybeNull<u64> = bytemuck::from_bytes(&bytes);
396 assert_eq!(*value, MaybeNull::from(42u64));
397
398 let zero_bytes = 0u64.to_le_bytes();
399 let value: &MaybeNull<u64> = bytemuck::from_bytes(&zero_bytes);
400 assert_eq!(*value, MaybeNull::from(0u64));
401 assert_eq!(value.get(), None);
402 }
403
404 #[test]
405 fn test_maybe_null_from_bytes_errors() {
406 assert!(bytemuck::try_from_bytes::<MaybeNull<u64>>(&[]).is_err());
407 assert!(bytemuck::try_from_bytes::<MaybeNull<u64>>(&[0; 1]).is_err());
408 }
409 }
410}