1#![warn(missing_docs)]
2#![forbid(unsafe_code)]
3#![doc = include_str!("../README.md")]
4
5use std::io::Write;
6
7use facet_core::{Def, Facet, NumericType, PrimitiveType, StructKind, TextualType, Type, UserType};
8use facet_reflect::{HasFields, HeapValue, Partial, Peek, ScalarType};
9
10#[derive(Debug)]
12pub enum XdrSerError {
13 Io(std::io::Error),
15 TooManyBytes,
17 TooManyVariants,
19 UnsupportedType,
21}
22
23impl core::fmt::Display for XdrSerError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 XdrSerError::Io(error) => write!(f, "IO error: {error}"),
27 XdrSerError::TooManyBytes => write!(f, "Too many bytes for field"),
28 XdrSerError::TooManyVariants => write!(f, "Enum variant discriminant too large"),
29 XdrSerError::UnsupportedType => write!(f, "Unsupported type"),
30 }
31 }
32}
33
34impl core::error::Error for XdrSerError {
35 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
36 match self {
37 XdrSerError::Io(error) => Some(error),
38 _ => None,
39 }
40 }
41}
42
43pub fn to_vec<'f, F: Facet<'f>>(value: &'f F) -> Result<Vec<u8>, XdrSerError> {
45 let mut buffer = Vec::new();
46 let peek = Peek::new(value);
47 serialize_value(peek, &mut buffer)?;
48 Ok(buffer)
49}
50
51fn serialize_value<W: Write>(peek: Peek<'_, '_>, writer: &mut W) -> Result<(), XdrSerError> {
52 match (peek.shape().def, peek.shape().ty) {
53 (Def::Scalar, _) => {
54 let peek = peek.innermost_peek();
55 serialize_scalar(peek, writer)
56 }
57 (Def::List(ld), _) => {
58 if ld.t().is_type::<u8>() && peek.shape().is_type::<Vec<u8>>() {
60 let bytes = peek.get::<Vec<u8>>().unwrap();
61 serialize_bytes(bytes, writer)
62 } else {
63 let list = peek.into_list_like().unwrap();
64 let items: Vec<_> = list.iter().collect();
65 serialize_array(items, writer)
66 }
67 }
68 (Def::Array(ad), _) => {
69 if ad.t().is_type::<u8>() {
70 let bytes: Vec<u8> = peek
72 .into_list_like()
73 .unwrap()
74 .iter()
75 .map(|p| *p.get::<u8>().unwrap())
76 .collect();
77 writer.write_all(&bytes).map_err(XdrSerError::Io)?;
79 let pad_len = bytes.len() % 4;
80 if pad_len != 0 {
81 let pad = vec![0u8; 4 - pad_len];
82 writer.write_all(&pad).map_err(XdrSerError::Io)?;
83 }
84 Ok(())
85 } else {
86 let list = peek.into_list_like().unwrap();
87 let items: Vec<_> = list.iter().collect();
88 for item in items {
90 serialize_value(item, writer)?;
91 }
92 Ok(())
93 }
94 }
95 (Def::Slice(sd), _) => {
96 if sd.t().is_type::<u8>() {
97 let bytes = peek.get::<[u8]>().unwrap();
98 serialize_bytes(bytes, writer)
99 } else {
100 let list = peek.into_list_like().unwrap();
101 let items: Vec<_> = list.iter().collect();
102 serialize_array(items, writer)
103 }
104 }
105 (Def::Option(_), _) => {
106 let opt = peek.into_option().unwrap();
107 if let Some(inner) = opt.value() {
108 writer
110 .write_all(&1u32.to_be_bytes())
111 .map_err(XdrSerError::Io)?;
112 serialize_value(inner, writer)
113 } else {
114 writer
116 .write_all(&0u32.to_be_bytes())
117 .map_err(XdrSerError::Io)?;
118 Ok(())
119 }
120 }
121 (Def::Pointer(_), _) => {
122 let ptr = peek.into_pointer().unwrap();
123 if let Some(inner) = ptr.borrow_inner() {
124 serialize_value(inner, writer)
125 } else {
126 Err(XdrSerError::UnsupportedType)
127 }
128 }
129 (_, Type::User(UserType::Struct(sd))) => match sd.kind {
130 StructKind::Unit => {
131 Ok(())
133 }
134 StructKind::Tuple | StructKind::TupleStruct => {
135 let ps = peek.into_struct().unwrap();
136 for (_, field_value) in ps.fields_for_serialize() {
137 serialize_value(field_value, writer)?;
138 }
139 Ok(())
140 }
141 StructKind::Struct => {
142 let ps = peek.into_struct().unwrap();
143 for (_, field_value) in ps.fields_for_serialize() {
144 serialize_value(field_value, writer)?;
145 }
146 Ok(())
147 }
148 },
149 (_, Type::User(UserType::Enum(et))) => {
150 let pe = peek.into_enum().unwrap();
151 let variant = pe.active_variant().expect("Failed to get active variant");
152
153 let variant_index = et
155 .variants
156 .iter()
157 .position(|v| v.name == variant.name)
158 .unwrap_or(0);
159 let discriminant = variant.discriminant.unwrap_or(variant_index as i64);
160 if discriminant < 0 || discriminant > u32::MAX as i64 {
161 return Err(XdrSerError::TooManyVariants);
162 }
163
164 writer
166 .write_all(&(discriminant as u32).to_be_bytes())
167 .map_err(XdrSerError::Io)?;
168
169 for (_, field_value) in pe.fields_for_serialize() {
171 serialize_value(field_value, writer)?;
172 }
173 Ok(())
174 }
175 (_, Type::Pointer(_)) => {
176 if let Some(s) = peek.as_str() {
178 serialize_str(s, writer)
179 } else if let Some(bytes) = peek.as_bytes() {
180 serialize_bytes(bytes, writer)
181 } else {
182 let innermost = peek.innermost_peek();
183 if innermost.shape() != peek.shape() {
184 serialize_value(innermost, writer)
185 } else {
186 Err(XdrSerError::UnsupportedType)
187 }
188 }
189 }
190 _ => Err(XdrSerError::UnsupportedType),
191 }
192}
193
194fn serialize_scalar<W: Write>(peek: Peek<'_, '_>, writer: &mut W) -> Result<(), XdrSerError> {
195 match peek.scalar_type() {
196 Some(ScalarType::Unit) => Ok(()),
197 Some(ScalarType::Bool) => {
198 let v = *peek.get::<bool>().unwrap();
199 let val: u32 = if v { 1 } else { 0 };
200 writer
201 .write_all(&val.to_be_bytes())
202 .map_err(XdrSerError::Io)
203 }
204 Some(ScalarType::Char) => {
205 let c = *peek.get::<char>().unwrap();
206 writer
207 .write_all(&(c as u32).to_be_bytes())
208 .map_err(XdrSerError::Io)
209 }
210 Some(ScalarType::Str) => {
211 let s = peek.get::<str>().unwrap();
212 serialize_str(s, writer)
213 }
214 Some(ScalarType::String) => {
215 let s = peek.get::<String>().unwrap();
216 serialize_str(s, writer)
217 }
218 Some(ScalarType::CowStr) => {
219 let s = peek.get::<std::borrow::Cow<'_, str>>().unwrap();
220 serialize_str(s, writer)
221 }
222 Some(ScalarType::F32) => {
223 let v = *peek.get::<f32>().unwrap();
224 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
225 }
226 Some(ScalarType::F64) => {
227 let v = *peek.get::<f64>().unwrap();
228 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
229 }
230 Some(ScalarType::U8) => {
231 let v = *peek.get::<u8>().unwrap();
232 writer
233 .write_all(&(v as u32).to_be_bytes())
234 .map_err(XdrSerError::Io)
235 }
236 Some(ScalarType::U16) => {
237 let v = *peek.get::<u16>().unwrap();
238 writer
239 .write_all(&(v as u32).to_be_bytes())
240 .map_err(XdrSerError::Io)
241 }
242 Some(ScalarType::U32) => {
243 let v = *peek.get::<u32>().unwrap();
244 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
245 }
246 Some(ScalarType::U64) => {
247 let v = *peek.get::<u64>().unwrap();
248 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
249 }
250 Some(ScalarType::U128) => Err(XdrSerError::UnsupportedType),
251 Some(ScalarType::USize) => {
252 let v = *peek.get::<usize>().unwrap();
253 writer
254 .write_all(&(v as u64).to_be_bytes())
255 .map_err(XdrSerError::Io)
256 }
257 Some(ScalarType::I8) => {
258 let v = *peek.get::<i8>().unwrap();
259 writer
260 .write_all(&(v as i32).to_be_bytes())
261 .map_err(XdrSerError::Io)
262 }
263 Some(ScalarType::I16) => {
264 let v = *peek.get::<i16>().unwrap();
265 writer
266 .write_all(&(v as i32).to_be_bytes())
267 .map_err(XdrSerError::Io)
268 }
269 Some(ScalarType::I32) => {
270 let v = *peek.get::<i32>().unwrap();
271 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
272 }
273 Some(ScalarType::I64) => {
274 let v = *peek.get::<i64>().unwrap();
275 writer.write_all(&v.to_be_bytes()).map_err(XdrSerError::Io)
276 }
277 Some(ScalarType::I128) => Err(XdrSerError::UnsupportedType),
278 Some(ScalarType::ISize) => {
279 let v = *peek.get::<isize>().unwrap();
280 writer
281 .write_all(&(v as i64).to_be_bytes())
282 .map_err(XdrSerError::Io)
283 }
284 Some(_) | None => Err(XdrSerError::UnsupportedType),
285 }
286}
287
288fn serialize_str<W: Write>(s: &str, writer: &mut W) -> Result<(), XdrSerError> {
289 serialize_bytes(s.as_bytes(), writer)
290}
291
292fn serialize_bytes<W: Write>(bytes: &[u8], writer: &mut W) -> Result<(), XdrSerError> {
293 if bytes.len() > u32::MAX as usize {
294 return Err(XdrSerError::TooManyBytes);
295 }
296 let len = bytes.len() as u32;
297 writer
298 .write_all(&len.to_be_bytes())
299 .map_err(XdrSerError::Io)?;
300 writer.write_all(bytes).map_err(XdrSerError::Io)?;
301 let pad_len = bytes.len() % 4;
302 if pad_len != 0 {
303 let pad = vec![0u8; 4 - pad_len];
304 writer.write_all(&pad).map_err(XdrSerError::Io)?;
305 }
306 Ok(())
307}
308
309fn serialize_array<W: Write>(items: Vec<Peek<'_, '_>>, writer: &mut W) -> Result<(), XdrSerError> {
310 if items.len() > u32::MAX as usize {
311 return Err(XdrSerError::TooManyBytes);
312 }
313 writer
314 .write_all(&(items.len() as u32).to_be_bytes())
315 .map_err(XdrSerError::Io)?;
316 for item in items {
317 serialize_value(item, writer)?;
318 }
319 Ok(())
320}
321
322#[derive(Debug)]
324pub enum XdrDeserError {
325 UnsupportedNumericType,
327 UnsupportedType,
329 UnexpectedEof,
331 InvalidBoolean {
333 position: usize,
335 },
336 InvalidOptional {
338 position: usize,
340 },
341 InvalidVariant {
343 position: usize,
345 },
346 InvalidString {
348 position: usize,
350 source: core::str::Utf8Error,
352 },
353}
354
355impl core::fmt::Display for XdrDeserError {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 match self {
358 XdrDeserError::UnsupportedNumericType => write!(f, "Unsupported numeric type"),
359 XdrDeserError::UnsupportedType => write!(f, "Unsupported type"),
360 XdrDeserError::UnexpectedEof => {
361 write!(f, "Unexpected end of input")
362 }
363 XdrDeserError::InvalidBoolean { position } => {
364 write!(f, "Invalid boolean at byte {position}")
365 }
366 XdrDeserError::InvalidOptional { position } => {
367 write!(f, "Invalid discriminant for optional at byte {position}")
368 }
369 XdrDeserError::InvalidVariant { position } => {
370 write!(f, "Invalid enum discriminant at byte {position}")
371 }
372 XdrDeserError::InvalidString { position, .. } => {
373 write!(f, "Invalid string at byte {position}")
374 }
375 }
376 }
377}
378
379impl core::error::Error for XdrDeserError {
380 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
381 match self {
382 XdrDeserError::InvalidString { source, .. } => Some(source),
383 _ => None,
384 }
385 }
386}
387
388#[derive(Debug, PartialEq)]
389enum PopReason {
390 TopLevel,
391 ObjectOrListVal,
392 Some,
393}
394
395#[derive(Debug)]
396enum DeserializeTask {
397 Value,
398 Field(usize),
399 ListItem,
400 Pop(PopReason),
401}
402
403struct XdrDeserializerStack<'input> {
404 input: &'input [u8],
405 pos: usize,
406 stack: Vec<DeserializeTask>,
407}
408
409impl<'input> XdrDeserializerStack<'input> {
410 fn next_u32(&mut self) -> Result<u32, XdrDeserError> {
411 assert_eq!(self.pos % 4, 0);
412 if self.input[self.pos..].len() < 4 {
413 return Err(XdrDeserError::UnexpectedEof);
414 }
415 let bytes = &self.input[self.pos..self.pos + 4];
416 self.pos += 4;
417 Ok(u32::from_be_bytes(bytes.try_into().unwrap()))
418 }
419
420 fn next_u64(&mut self) -> Result<u64, XdrDeserError> {
421 assert_eq!(self.pos % 4, 0);
422 if self.input[self.pos..].len() < 8 {
423 return Err(XdrDeserError::UnexpectedEof);
424 }
425 let bytes = &self.input[self.pos..self.pos + 8];
426 self.pos += 8;
427 Ok(u64::from_be_bytes(bytes.try_into().unwrap()))
428 }
429
430 fn next_data(&mut self, expected_len: Option<u32>) -> Result<&'input [u8], XdrDeserError> {
431 let len = self.next_u32()? as usize;
432 if let Some(expected_len) = expected_len {
433 assert_eq!(len, expected_len as usize);
434 }
435 self.pos += len;
436 let pad_len = len % 4;
437 let data = &self.input[self.pos - len..self.pos];
438 if pad_len != 0 {
439 self.pos += 4 - pad_len;
440 }
441 Ok(data)
442 }
443
444 fn next<'f>(&mut self, wip: Partial<'f>) -> Result<Partial<'f>, XdrDeserError> {
445 match (wip.shape().def, wip.shape().ty) {
446 (Def::Scalar, Type::Primitive(PrimitiveType::Numeric(numeric_type))) => {
447 let size = wip.shape().layout.sized_layout().unwrap().size();
448 match numeric_type {
449 NumericType::Integer { signed: false } => match size {
450 1 => {
451 let value = self.next_u32()? as u8;
452 let wip = wip.set(value).unwrap();
453 Ok(wip)
454 }
455 2 => {
456 let value = self.next_u32()? as u16;
457 let wip = wip.set(value).unwrap();
458 Ok(wip)
459 }
460 4 => {
461 let value = self.next_u32()?;
462 let wip = wip.set(value).unwrap();
463 Ok(wip)
464 }
465 8 => {
466 let value = self.next_u64()?;
467 let wip = wip.set(value).unwrap();
468 Ok(wip)
469 }
470 _ => {
471 let value = self.next_u64()? as usize;
473 let wip = wip.set(value).unwrap();
474 Ok(wip)
475 }
476 },
477 NumericType::Integer { signed: true } => match size {
478 1 => {
479 let value = self.next_u32()? as i8;
480 let wip = wip.set(value).unwrap();
481 Ok(wip)
482 }
483 2 => {
484 let value = self.next_u32()? as i16;
485 let wip = wip.set(value).unwrap();
486 Ok(wip)
487 }
488 4 => {
489 let value = self.next_u32()? as i32;
490 let wip = wip.set(value).unwrap();
491 Ok(wip)
492 }
493 8 => {
494 let value = self.next_u64()? as i64;
495 let wip = wip.set(value).unwrap();
496 Ok(wip)
497 }
498 _ => {
499 let value = self.next_u64()? as isize;
501 let wip = wip.set(value).unwrap();
502 Ok(wip)
503 }
504 },
505 NumericType::Float => match size {
506 4 => {
507 let bits = self.next_u32()?;
508 let float = f32::from_bits(bits);
509 let wip = wip.set(float).unwrap();
510 Ok(wip)
511 }
512 8 => {
513 let bits = self.next_u64()?;
514 let float = f64::from_bits(bits);
515 let wip = wip.set(float).unwrap();
516 Ok(wip)
517 }
518 _ => Err(XdrDeserError::UnsupportedNumericType),
519 },
520 }
521 }
522 (Def::Scalar, Type::Primitive(PrimitiveType::Textual(TextualType::Str))) => {
523 let string = core::str::from_utf8(self.next_data(None)?).map_err(|e| {
524 XdrDeserError::InvalidString {
525 position: self.pos - 1,
526 source: e,
527 }
528 })?;
529 let wip = wip.set(string.to_owned()).unwrap();
530 Ok(wip)
531 }
532 (Def::Scalar, Type::Primitive(PrimitiveType::Boolean)) => match self.next_u32()? {
533 0 => {
534 let wip = wip.set(false).unwrap();
535 Ok(wip)
536 }
537 1 => {
538 let wip = wip.set(true).unwrap();
539 Ok(wip)
540 }
541 _ => Err(XdrDeserError::InvalidBoolean {
542 position: self.pos - 4,
543 }),
544 },
545 (Def::Scalar, Type::Primitive(PrimitiveType::Textual(TextualType::Char))) => {
546 let value = self.next_u32()?;
547 let wip = wip.set(char::from_u32(value).unwrap()).unwrap();
548 Ok(wip)
549 }
550 (Def::Scalar, _) => {
551 let string = core::str::from_utf8(self.next_data(None)?).map_err(|e| {
553 XdrDeserError::InvalidString {
554 position: self.pos - 1,
555 source: e,
556 }
557 })?;
558 let wip = wip.set(string.to_owned()).unwrap();
559 Ok(wip)
560 }
561 (Def::List(ld), _) => {
562 if ld.t().is_type::<u8>() {
563 let data = self.next_data(None)?;
564 let wip = wip.set(data.to_vec()).unwrap();
565 Ok(wip)
566 } else {
567 let len = self.next_u32()?;
568 let wip = wip.begin_list().unwrap();
569 if len == 0 {
570 Ok(wip)
571 } else {
572 for _ in 0..len {
573 self.stack.push(DeserializeTask::ListItem);
574 }
575 Ok(wip)
576 }
577 }
578 }
579 (Def::Array(ad), _) => {
580 let len = ad.n;
581 if ad.t().is_type::<u8>() {
582 self.pos += len;
583 let pad_len = len % 4;
584 let mut wip = wip;
585 for byte in &self.input[self.pos - len..self.pos] {
586 wip = wip.begin_list_item().unwrap();
587 wip = wip.set(*byte).unwrap();
588 wip = wip.end().unwrap();
589 }
590 if pad_len != 0 {
591 self.pos += 4 - pad_len;
592 }
593 Ok(wip)
594 } else {
595 for _ in 0..len {
596 self.stack.push(DeserializeTask::ListItem);
597 }
598 Ok(wip)
599 }
600 }
601 (Def::Slice(sd), _) => {
602 if sd.t().is_type::<u8>() {
603 let data = self.next_data(None)?;
604 let wip = wip.set(data.to_vec()).unwrap();
605 Ok(wip)
606 } else {
607 let len = self.next_u32()?;
608 for _ in 0..len {
609 self.stack.push(DeserializeTask::ListItem);
610 }
611 Ok(wip)
612 }
613 }
614 (Def::Option(_), _) => match self.next_u32()? {
615 0 => {
616 let wip = wip.set_default().unwrap();
617 Ok(wip)
618 }
619 1 => {
620 self.stack.push(DeserializeTask::Pop(PopReason::Some));
621 self.stack.push(DeserializeTask::Value);
622 let wip = wip.select_variant(1).unwrap();
623 Ok(wip)
624 }
625 _ => Err(XdrDeserError::InvalidOptional {
626 position: self.pos - 4,
627 }),
628 },
629 (_, Type::User(ut)) => match ut {
630 UserType::Struct(st) => {
631 if st.kind == StructKind::Tuple {
632 for _field in st.fields.iter() {
634 self.stack.push(DeserializeTask::ListItem);
635 }
636 Ok(wip)
637 } else {
638 for (index, _field) in st.fields.iter().enumerate().rev() {
640 if !wip.is_field_set(index).unwrap() {
641 self.stack.push(DeserializeTask::Field(index));
642 }
643 }
644 Ok(wip)
645 }
646 }
647 UserType::Enum(et) => {
648 let discriminant = self.next_u32()?;
649 if let Some(variant) = et
650 .variants
651 .iter()
652 .find(|v| v.discriminant == Some(discriminant as i64))
653 .or(et.variants.get(discriminant as usize))
654 {
655 for (index, _field) in variant.data.fields.iter().enumerate().rev() {
656 self.stack.push(DeserializeTask::Field(index));
657 }
658 let wip = wip.select_variant(discriminant as i64).unwrap();
659 Ok(wip)
660 } else {
661 Err(XdrDeserError::InvalidVariant {
662 position: self.pos - 4,
663 })
664 }
665 }
666 _ => Err(XdrDeserError::UnsupportedType),
667 },
668 _ => Err(XdrDeserError::UnsupportedType),
669 }
670 }
671}
672
673pub fn deserialize_wip<'facet>(
675 input: &[u8],
676 mut wip: Partial<'facet>,
677) -> Result<HeapValue<'facet>, XdrDeserError> {
678 let mut runner = XdrDeserializerStack {
679 input,
680 pos: 0,
681 stack: vec![
682 DeserializeTask::Pop(PopReason::TopLevel),
683 DeserializeTask::Value,
684 ],
685 };
686
687 loop {
688 match runner.stack.pop() {
692 Some(DeserializeTask::Pop(reason)) => {
693 if reason == PopReason::TopLevel {
694 return Ok(wip.build().unwrap());
695 } else {
696 wip = wip.end().unwrap();
697 }
698 }
699 Some(DeserializeTask::Value) => {
700 wip = runner.next(wip)?;
701 }
702 Some(DeserializeTask::Field(index)) => {
703 runner
704 .stack
705 .push(DeserializeTask::Pop(PopReason::ObjectOrListVal));
706 runner.stack.push(DeserializeTask::Value);
707 wip = wip.begin_nth_field(index).unwrap();
708 }
709 Some(DeserializeTask::ListItem) => {
710 runner
711 .stack
712 .push(DeserializeTask::Pop(PopReason::ObjectOrListVal));
713 runner.stack.push(DeserializeTask::Value);
714 wip = wip.begin_list_item().unwrap();
715 }
716 None => unreachable!("Instruction stack is empty"),
717 }
718 }
719}
720
721pub fn deserialize<'f, F: facet_core::Facet<'f>>(input: &[u8]) -> Result<F, XdrDeserError> {
723 let v = deserialize_wip(input, Partial::alloc_shape(F::SHAPE).unwrap())?;
724 let f: F = v.materialize().unwrap();
725 Ok(f)
726}