facet_xdr/
lib.rs

1#![warn(missing_docs)]
2#![forbid(unsafe_code)]
3#![doc = include_str!("../README.md")]
4
5use std::io::Write;
6
7use facet_core::{
8    Def, Facet, IntegerSize, NumberBits, ScalarAffinity, Signedness, StructKind, Type, UserType,
9};
10use facet_reflect::{HeapValue, Partial, Peek};
11use facet_serialize::{Serializer, serialize_iterative};
12
13/// Errors when serializing to XDR bytes
14#[derive(Debug)]
15pub enum XdrSerError {
16    /// IO error
17    Io(std::io::Error),
18    /// Too many bytes for field
19    TooManyBytes,
20    /// Enum variant discriminant too large
21    TooManyVariants,
22    /// Unsupported type
23    UnsupportedType,
24}
25
26impl core::fmt::Display for XdrSerError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            XdrSerError::Io(error) => write!(f, "IO error: {}", error),
30            XdrSerError::TooManyBytes => write!(f, "Too many bytes for field"),
31            XdrSerError::TooManyVariants => write!(f, "Enum variant discriminant too large"),
32            XdrSerError::UnsupportedType => write!(f, "Unsupported type"),
33        }
34    }
35}
36
37impl core::error::Error for XdrSerError {
38    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
39        match self {
40            XdrSerError::Io(error) => Some(error),
41            _ => None,
42        }
43    }
44}
45
46/// Serialize any Facet type to XDR bytes
47pub fn to_vec<'f, F: Facet<'f>>(value: &'f F) -> Result<Vec<u8>, XdrSerError> {
48    let mut buffer = Vec::new();
49    let peek = Peek::new(value);
50    let mut serializer = XdrSerializer {
51        writer: &mut buffer,
52    };
53    serialize_iterative(peek, &mut serializer)?;
54    Ok(buffer)
55}
56
57struct XdrSerializer<'w, W: Write> {
58    writer: &'w mut W,
59}
60
61impl<'shape, W: Write> Serializer<'shape> for XdrSerializer<'_, W> {
62    type Error = XdrSerError;
63
64    fn serialize_u32(&mut self, value: u32) -> Result<(), Self::Error> {
65        self.writer
66            .write_all(&value.to_be_bytes())
67            .map_err(Self::Error::Io)
68    }
69
70    fn serialize_u64(&mut self, value: u64) -> Result<(), Self::Error> {
71        self.writer
72            .write_all(&value.to_be_bytes())
73            .map_err(Self::Error::Io)
74    }
75
76    fn serialize_u128(&mut self, _value: u128) -> Result<(), Self::Error> {
77        Err(Self::Error::UnsupportedType)
78    }
79
80    fn serialize_i32(&mut self, value: i32) -> Result<(), Self::Error> {
81        self.writer
82            .write_all(&value.to_be_bytes())
83            .map_err(Self::Error::Io)
84    }
85
86    fn serialize_i64(&mut self, value: i64) -> Result<(), Self::Error> {
87        self.writer
88            .write_all(&value.to_be_bytes())
89            .map_err(Self::Error::Io)
90    }
91
92    fn serialize_i128(&mut self, _value: i128) -> Result<(), Self::Error> {
93        Err(Self::Error::UnsupportedType)
94    }
95
96    fn serialize_f32(&mut self, value: f32) -> Result<(), Self::Error> {
97        self.writer
98            .write_all(&value.to_be_bytes())
99            .map_err(Self::Error::Io)
100    }
101
102    fn serialize_f64(&mut self, value: f64) -> Result<(), Self::Error> {
103        self.writer
104            .write_all(&value.to_be_bytes())
105            .map_err(Self::Error::Io)
106    }
107
108    fn serialize_bool(&mut self, value: bool) -> Result<(), Self::Error> {
109        if value {
110            self.writer.write_all(&1u32.to_be_bytes())
111        } else {
112            self.writer.write_all(&0u32.to_be_bytes())
113        }
114        .map_err(Self::Error::Io)
115    }
116
117    fn serialize_char(&mut self, value: char) -> Result<(), Self::Error> {
118        self.serialize_u32(value as u32)
119    }
120
121    fn serialize_str(&mut self, value: &str) -> Result<(), Self::Error> {
122        let bytes = value.as_bytes();
123        self.serialize_bytes(bytes)
124    }
125
126    fn serialize_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> {
127        if value.len() > u32::MAX as usize {
128            return Err(Self::Error::TooManyBytes);
129        }
130        let len = value.len() as u32;
131        self.writer
132            .write_all(&len.to_be_bytes())
133            .map_err(Self::Error::Io)?;
134        let pad_len = value.len() % 4;
135        self.writer.write_all(value).map_err(Self::Error::Io)?;
136        if pad_len != 0 {
137            let pad = vec![0u8; 4 - pad_len];
138            self.writer.write_all(&pad).map_err(Self::Error::Io)?;
139        }
140        Ok(())
141    }
142
143    fn serialize_none(&mut self) -> Result<(), Self::Error> {
144        Ok(())
145    }
146
147    fn serialize_unit(&mut self) -> Result<(), Self::Error> {
148        Ok(())
149    }
150
151    fn serialize_unit_variant(
152        &mut self,
153        _variant_index: usize,
154        _variant_name: &'shape str,
155    ) -> Result<(), Self::Error> {
156        Ok(())
157    }
158
159    fn start_object(&mut self, _len: Option<usize>) -> Result<(), Self::Error> {
160        Ok(())
161    }
162
163    fn serialize_field_name(&mut self, _name: &'shape str) -> Result<(), Self::Error> {
164        Ok(())
165    }
166
167    fn start_array(&mut self, len: Option<usize>) -> Result<(), Self::Error> {
168        if let Some(len) = len {
169            if len > u32::MAX as usize {
170                return Err(Self::Error::TooManyBytes);
171            }
172            self.writer
173                .write_all(&(len as u32).to_be_bytes())
174                .map_err(Self::Error::Io)
175        } else {
176            panic!("array length missing");
177        }
178    }
179
180    fn start_map(&mut self, _len: Option<usize>) -> Result<(), Self::Error> {
181        Ok(())
182    }
183
184    fn start_enum_variant(&mut self, discriminant: u64) -> Result<(), Self::Error> {
185        if discriminant > u32::MAX as u64 {
186            return Err(Self::Error::TooManyVariants);
187        }
188        self.writer
189            .write_all(&(discriminant as u32).to_be_bytes())
190            .map_err(Self::Error::Io)
191    }
192}
193
194/// Errors when deserializing from XDR bytes
195#[derive(Debug)]
196pub enum XdrDeserError {
197    /// Unsupported numeric type
198    UnsupportedNumericType,
199    /// Unsupported type
200    UnsupportedType,
201    /// Unexpected end of input
202    UnexpectedEof,
203    /// Invalid boolean
204    InvalidBoolean {
205        /// Position of this error in bytes
206        position: usize,
207    },
208    /// Invalid discriminant for optional
209    InvalidOptional {
210        /// Position of this error in bytes
211        position: usize,
212    },
213    /// Invalid enum discriminant
214    InvalidVariant {
215        /// Position of this error in bytes
216        position: usize,
217    },
218    /// Invalid string
219    InvalidString {
220        /// Position of this error in bytes
221        position: usize,
222        /// Underlying UTF-8 error
223        source: core::str::Utf8Error,
224    },
225}
226
227impl core::fmt::Display for XdrDeserError {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        match self {
230            XdrDeserError::UnsupportedNumericType => write!(f, "Unsupported numeric type"),
231            XdrDeserError::UnsupportedType => write!(f, "Unsupported type"),
232            XdrDeserError::UnexpectedEof => {
233                write!(f, "Unexpected end of input")
234            }
235            XdrDeserError::InvalidBoolean { position } => {
236                write!(f, "Invalid boolean at byte {}", position)
237            }
238            XdrDeserError::InvalidOptional { position } => {
239                write!(f, "Invalid discriminant for optional at byte {}", position)
240            }
241            XdrDeserError::InvalidVariant { position } => {
242                write!(f, "Invalid enum discriminant at byte {}", position)
243            }
244            XdrDeserError::InvalidString { position, .. } => {
245                write!(f, "Invalid string at byte {}", position)
246            }
247        }
248    }
249}
250
251impl core::error::Error for XdrDeserError {
252    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
253        match self {
254            XdrDeserError::InvalidString { source, .. } => Some(source),
255            _ => None,
256        }
257    }
258}
259
260#[derive(Debug, PartialEq)]
261enum PopReason {
262    TopLevel,
263    ObjectOrListVal,
264    Some,
265}
266
267#[derive(Debug)]
268enum DeserializeTask {
269    Value,
270    Field(usize),
271    ListItem,
272    Pop(PopReason),
273}
274
275struct XdrDeserializerStack<'input> {
276    input: &'input [u8],
277    pos: usize,
278    stack: Vec<DeserializeTask>,
279}
280
281impl<'shape, 'input> XdrDeserializerStack<'input> {
282    fn next_u32(&mut self) -> Result<u32, XdrDeserError> {
283        assert_eq!(self.pos % 4, 0);
284        if self.input[self.pos..].len() < 4 {
285            return Err(XdrDeserError::UnexpectedEof);
286        }
287        let bytes = &self.input[self.pos..self.pos + 4];
288        self.pos += 4;
289        Ok(u32::from_be_bytes(bytes.try_into().unwrap()))
290    }
291
292    fn next_u64(&mut self) -> Result<u64, XdrDeserError> {
293        assert_eq!(self.pos % 4, 0);
294        if self.input[self.pos..].len() < 8 {
295            return Err(XdrDeserError::UnexpectedEof);
296        }
297        let bytes = &self.input[self.pos..self.pos + 8];
298        self.pos += 8;
299        Ok(u64::from_be_bytes(bytes.try_into().unwrap()))
300    }
301
302    fn next_data(&mut self, expected_len: Option<u32>) -> Result<&'input [u8], XdrDeserError> {
303        let len = self.next_u32()? as usize;
304        if let Some(expected_len) = expected_len {
305            assert_eq!(len, expected_len as usize);
306        }
307        self.pos += len;
308        let pad_len = len % 4;
309        let data = &self.input[self.pos - len..self.pos];
310        if pad_len != 0 {
311            self.pos += 4 - pad_len;
312        }
313        Ok(data)
314    }
315
316    fn next<'f>(
317        &mut self,
318        mut wip: Partial<'f, 'shape>,
319    ) -> Result<Partial<'f, 'shape>, XdrDeserError> {
320        match (wip.shape().def, wip.shape().ty) {
321            (Def::Scalar(sd), _) => match sd.affinity {
322                ScalarAffinity::Number(na) => match na.bits {
323                    NumberBits::Integer { size, sign } => match (size, sign) {
324                        (IntegerSize::Fixed(8), Signedness::Unsigned) => {
325                            let value = self.next_u32()? as u8;
326                            wip.set(value).unwrap();
327                            Ok(wip)
328                        }
329                        (IntegerSize::Fixed(16), Signedness::Unsigned) => {
330                            let value = self.next_u32()? as u16;
331                            wip.set(value).unwrap();
332                            Ok(wip)
333                        }
334                        (IntegerSize::Fixed(32), Signedness::Unsigned) => {
335                            let value = self.next_u32()?;
336                            wip.set(value).unwrap();
337                            Ok(wip)
338                        }
339                        (IntegerSize::Fixed(64), Signedness::Unsigned) => {
340                            let value = self.next_u64()?;
341                            wip.set(value).unwrap();
342                            Ok(wip)
343                        }
344                        (IntegerSize::Fixed(8), Signedness::Signed) => {
345                            let value = self.next_u32()? as i8;
346                            wip.set(value).unwrap();
347                            Ok(wip)
348                        }
349                        (IntegerSize::Fixed(16), Signedness::Signed) => {
350                            let value = self.next_u32()? as i16;
351                            wip.set(value).unwrap();
352                            Ok(wip)
353                        }
354                        (IntegerSize::Fixed(32), Signedness::Signed) => {
355                            let value = self.next_u32()? as i32;
356                            wip.set(value).unwrap();
357                            Ok(wip)
358                        }
359                        (IntegerSize::Fixed(64), Signedness::Signed) => {
360                            let value = self.next_u64()? as i64;
361                            wip.set(value).unwrap();
362                            Ok(wip)
363                        }
364                        (IntegerSize::PointerSized, Signedness::Unsigned) => {
365                            // Handle usize - use 64-bit on most platforms
366                            let value = self.next_u64()? as usize;
367                            wip.set(value).unwrap();
368                            Ok(wip)
369                        }
370                        (IntegerSize::PointerSized, Signedness::Signed) => {
371                            // Handle isize - use 64-bit on most platforms
372                            let value = self.next_u64()? as isize;
373                            wip.set(value).unwrap();
374                            Ok(wip)
375                        }
376                        _ => Err(XdrDeserError::UnsupportedNumericType),
377                    },
378                    NumberBits::Float {
379                        sign_bits,
380                        exponent_bits,
381                        mantissa_bits,
382                        ..
383                    } => {
384                        let bits = sign_bits + exponent_bits + mantissa_bits;
385                        if bits == 32 {
386                            let bits = self.next_u32()?;
387                            let float = f32::from_bits(bits);
388                            wip.set(float).unwrap();
389                            Ok(wip)
390                        } else if bits == 64 {
391                            let bits = self.next_u64()?;
392                            let float = f64::from_bits(bits);
393                            wip.set(float).unwrap();
394                            Ok(wip)
395                        } else {
396                            Err(XdrDeserError::UnsupportedNumericType)
397                        }
398                    }
399                    _ => Err(XdrDeserError::UnsupportedNumericType),
400                },
401                ScalarAffinity::String(_) => {
402                    let string = core::str::from_utf8(self.next_data(None)?).map_err(|e| {
403                        XdrDeserError::InvalidString {
404                            position: self.pos - 1,
405                            source: e,
406                        }
407                    })?;
408                    wip.set(string.to_owned()).unwrap();
409                    Ok(wip)
410                }
411                ScalarAffinity::Boolean(_) => match self.next_u32()? {
412                    0 => {
413                        wip.set(false).unwrap();
414                        Ok(wip)
415                    }
416                    1 => {
417                        wip.set(true).unwrap();
418                        Ok(wip)
419                    }
420                    _ => Err(XdrDeserError::InvalidBoolean {
421                        position: self.pos - 4,
422                    }),
423                },
424                ScalarAffinity::Char(_) => {
425                    let value = self.next_u32()?;
426                    wip.set(char::from_u32(value).unwrap()).unwrap();
427                    Ok(wip)
428                }
429                _ => Err(XdrDeserError::UnsupportedType),
430            },
431            (Def::List(ld), _) => {
432                if ld.t().is_type::<u8>() {
433                    let data = self.next_data(None)?;
434                    wip.set(data.to_vec()).unwrap();
435                    Ok(wip)
436                } else {
437                    let len = self.next_u32()?;
438                    wip.begin_list().unwrap();
439                    if len == 0 {
440                        Ok(wip)
441                    } else {
442                        for _ in 0..len {
443                            self.stack.push(DeserializeTask::ListItem);
444                        }
445                        Ok(wip)
446                    }
447                }
448            }
449            (Def::Array(ad), _) => {
450                let len = ad.n;
451                if ad.t().is_type::<u8>() {
452                    self.pos += len;
453                    let pad_len = len % 4;
454                    for byte in &self.input[self.pos - len..self.pos] {
455                        wip.begin_list_item().unwrap();
456                        wip.set(*byte).unwrap();
457                        wip.end().unwrap();
458                    }
459                    if pad_len != 0 {
460                        self.pos += 4 - pad_len;
461                    }
462                    Ok(wip)
463                } else {
464                    for _ in 0..len {
465                        self.stack.push(DeserializeTask::ListItem);
466                    }
467                    Ok(wip)
468                }
469            }
470            (Def::Slice(sd), _) => {
471                if sd.t().is_type::<u8>() {
472                    let data = self.next_data(None)?;
473                    wip.set(data.to_vec()).unwrap();
474                    Ok(wip)
475                } else {
476                    let len = self.next_u32()?;
477                    for _ in 0..len {
478                        self.stack.push(DeserializeTask::ListItem);
479                    }
480                    Ok(wip)
481                }
482            }
483            (Def::Option(_), _) => match self.next_u32()? {
484                0 => {
485                    wip.set_default().unwrap();
486                    Ok(wip)
487                }
488                1 => {
489                    self.stack.push(DeserializeTask::Pop(PopReason::Some));
490                    self.stack.push(DeserializeTask::Value);
491                    wip.select_variant(1).unwrap();
492                    Ok(wip)
493                }
494                _ => Err(XdrDeserError::InvalidOptional {
495                    position: self.pos - 4,
496                }),
497            },
498            (_, Type::User(ut)) => match ut {
499                UserType::Struct(st) => {
500                    if st.kind == StructKind::Tuple {
501                        // Handle tuple structs
502                        for _field in st.fields.iter() {
503                            self.stack.push(DeserializeTask::ListItem);
504                        }
505                        Ok(wip)
506                    } else {
507                        // Handle regular structs
508                        for (index, _field) in st.fields.iter().enumerate().rev() {
509                            if !wip.is_field_set(index).unwrap() {
510                                self.stack.push(DeserializeTask::Field(index));
511                            }
512                        }
513                        Ok(wip)
514                    }
515                }
516                UserType::Enum(et) => {
517                    let discriminant = self.next_u32()?;
518                    if let Some(variant) = et
519                        .variants
520                        .iter()
521                        .find(|v| v.discriminant == Some(discriminant as i64))
522                        .or(et.variants.get(discriminant as usize))
523                    {
524                        for (index, _field) in variant.data.fields.iter().enumerate().rev() {
525                            self.stack.push(DeserializeTask::Field(index));
526                        }
527                        wip.select_variant(discriminant as i64).unwrap();
528                        Ok(wip)
529                    } else {
530                        Err(XdrDeserError::InvalidVariant {
531                            position: self.pos - 4,
532                        })
533                    }
534                }
535                _ => Err(XdrDeserError::UnsupportedType),
536            },
537            _ => Err(XdrDeserError::UnsupportedType),
538        }
539    }
540}
541
542/// Deserialize an XDR slice given some some [`Partial`] into a [`HeapValue`]
543pub fn deserialize_wip<'facet, 'shape>(
544    input: &[u8],
545    mut wip: Partial<'facet, 'shape>,
546) -> Result<HeapValue<'facet, 'shape>, XdrDeserError> {
547    let mut runner = XdrDeserializerStack {
548        input,
549        pos: 0,
550        stack: vec![
551            DeserializeTask::Pop(PopReason::TopLevel),
552            DeserializeTask::Value,
553        ],
554    };
555
556    loop {
557        // We no longer have access to frames_count
558        // The frame count assertion has been removed as it's an internal implementation detail
559
560        match runner.stack.pop() {
561            Some(DeserializeTask::Pop(reason)) => {
562                if reason == PopReason::TopLevel {
563                    return Ok(wip.build().unwrap());
564                } else {
565                    wip.end().unwrap();
566                }
567            }
568            Some(DeserializeTask::Value) => {
569                wip = runner.next(wip)?;
570            }
571            Some(DeserializeTask::Field(index)) => {
572                runner
573                    .stack
574                    .push(DeserializeTask::Pop(PopReason::ObjectOrListVal));
575                runner.stack.push(DeserializeTask::Value);
576                wip.begin_nth_field(index).unwrap();
577            }
578            Some(DeserializeTask::ListItem) => {
579                runner
580                    .stack
581                    .push(DeserializeTask::Pop(PopReason::ObjectOrListVal));
582                runner.stack.push(DeserializeTask::Value);
583                wip.begin_list_item().unwrap();
584            }
585            None => unreachable!("Instruction stack is empty"),
586        }
587    }
588}
589
590/// Deserialize a slice of XDR bytes into any Facet type
591pub fn deserialize<'f, F: facet_core::Facet<'f>>(input: &[u8]) -> Result<F, XdrDeserError> {
592    let v = deserialize_wip(input, Partial::alloc_shape(F::SHAPE).unwrap())?;
593    let f: F = v.materialize().unwrap();
594    Ok(f)
595}