Skip to main content

facet_cbor/
deserialize.rs

1use facet::Facet;
2use facet_core::{Def, ScalarType, StructKind, Type, UserType};
3use facet_reflect::Partial;
4
5use crate::decode;
6use crate::error::CborError;
7
8/// Deserialize a CBOR byte slice into a value of type `T`.
9pub fn from_slice<T: Facet<'static>>(bytes: &[u8]) -> Result<T, CborError> {
10    let partial =
11        Partial::alloc_owned::<T>().map_err(|e| CborError::ReflectError(e.to_string()))?;
12    let mut offset = 0;
13    let partial = deserialize_into(partial, bytes, &mut offset)?;
14    let heap_value = partial
15        .build()
16        .map_err(|e| CborError::ReflectError(e.to_string()))?;
17    heap_value
18        .materialize()
19        .map_err(|e| CborError::ReflectError(e.to_string()))
20}
21
22/// Recursively deserialize CBOR data into a Partial, dispatching on the shape.
23fn deserialize_into<'facet>(
24    partial: Partial<'facet, false>,
25    input: &[u8],
26    offset: &mut usize,
27) -> Result<Partial<'facet, false>, CborError> {
28    let shape = partial.shape();
29
30    // Unwrap transparent wrappers (newtypes, NonZero, etc.) to match serialization
31    if shape.is_transparent() {
32        let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
33        let partial = partial.begin_inner().map_err(re)?;
34        let partial = deserialize_into(partial, input, offset)?;
35        return partial.end().map_err(re);
36    }
37
38    // Check for scalar types first
39    if let Some(scalar_type) = shape.scalar_type() {
40        return deserialize_scalar(partial, scalar_type, input, offset);
41    }
42
43    // Check def-based types (Option, List, Map, etc.) before user types,
44    // mirroring the serialization order where Def::Option is checked before UserType::Enum.
45    match shape.def {
46        Def::Option(_) => {
47            return deserialize_option(partial, input, offset);
48        }
49        Def::List(list_def) => {
50            // Special case: Vec<u8> → byte string
51            if list_def.t().is_type::<u8>() {
52                return deserialize_byte_list(partial, input, offset);
53            }
54            return deserialize_list(partial, input, offset);
55        }
56        Def::Array(array_def) => {
57            return deserialize_array(partial, array_def.n, input, offset);
58        }
59        Def::Map(_) => {
60            return deserialize_map(partial, input, offset);
61        }
62        Def::Pointer(_) => {
63            return deserialize_pointer(partial, input, offset);
64        }
65        _ => {}
66    }
67
68    // Try struct/enum (user types)
69    match shape.ty {
70        Type::User(UserType::Struct(struct_type)) => match struct_type.kind {
71            StructKind::Struct => deserialize_struct(partial, input, offset),
72            StructKind::TupleStruct | StructKind::Tuple => {
73                deserialize_tuple(partial, struct_type.fields.len(), input, offset)
74            }
75            StructKind::Unit => {
76                decode::read_null(input, offset)?;
77                Ok(partial)
78            }
79        },
80        Type::User(UserType::Enum(_)) => {
81            if shape.tag.is_some() {
82                deserialize_enum_internally_tagged(partial, input, offset)
83            } else {
84                deserialize_enum(partial, input, offset)
85            }
86        }
87        _ => Err(CborError::UnsupportedType(format!("{}", shape))),
88    }
89}
90
91fn deserialize_scalar<'facet>(
92    partial: Partial<'facet, false>,
93    scalar_type: ScalarType,
94    input: &[u8],
95    offset: &mut usize,
96) -> Result<Partial<'facet, false>, CborError> {
97    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
98    match scalar_type {
99        ScalarType::Unit => {
100            decode::read_null(input, offset)?;
101            partial.set(()).map_err(re)
102        }
103        ScalarType::Bool => {
104            let v = decode::read_bool(input, offset)?;
105            partial.set(v).map_err(re)
106        }
107        ScalarType::Char => {
108            let s = decode::read_text(input, offset)?;
109            let c = s
110                .chars()
111                .next()
112                .ok_or_else(|| CborError::InvalidCbor("empty text string for char".into()))?;
113            partial.set(c).map_err(re)
114        }
115        ScalarType::U8 => {
116            let v = decode::read_int_as_u64(input, offset)?;
117            partial.set(v as u8).map_err(re)
118        }
119        ScalarType::U16 => {
120            let v = decode::read_int_as_u64(input, offset)?;
121            partial.set(v as u16).map_err(re)
122        }
123        ScalarType::U32 => {
124            let v = decode::read_int_as_u64(input, offset)?;
125            partial.set(v as u32).map_err(re)
126        }
127        ScalarType::U64 => {
128            let v = decode::read_int_as_u64(input, offset)?;
129            partial.set(v).map_err(re)
130        }
131        ScalarType::USize => {
132            let v = decode::read_int_as_u64(input, offset)?;
133            partial.set(v as usize).map_err(re)
134        }
135        ScalarType::I8 => {
136            let v = decode::read_int_as_i64(input, offset)?;
137            partial.set(v as i8).map_err(re)
138        }
139        ScalarType::I16 => {
140            let v = decode::read_int_as_i64(input, offset)?;
141            partial.set(v as i16).map_err(re)
142        }
143        ScalarType::I32 => {
144            let v = decode::read_int_as_i64(input, offset)?;
145            partial.set(v as i32).map_err(re)
146        }
147        ScalarType::I64 => {
148            let v = decode::read_int_as_i64(input, offset)?;
149            partial.set(v).map_err(re)
150        }
151        ScalarType::ISize => {
152            let v = decode::read_int_as_i64(input, offset)?;
153            partial.set(v as isize).map_err(re)
154        }
155        ScalarType::F32 => {
156            let v = decode::read_f32(input, offset)?;
157            partial.set(v).map_err(re)
158        }
159        ScalarType::F64 => {
160            let v = decode::read_f64(input, offset)?;
161            partial.set(v).map_err(re)
162        }
163        ScalarType::String => {
164            let s = decode::read_text(input, offset)?;
165            partial.set(s.to_owned()).map_err(re)
166        }
167        ScalarType::Str => {
168            // &str can't be deserialized into owned Partial (would need borrowed)
169            let s = decode::read_text(input, offset)?;
170            partial.set(s.to_owned()).map_err(re)
171        }
172        ScalarType::CowStr => {
173            let s = decode::read_text(input, offset)?;
174            partial
175                .set(std::borrow::Cow::<'static, str>::Owned(s.to_owned()))
176                .map_err(re)
177        }
178        _ => Err(CborError::UnsupportedType(format!(
179            "scalar type {scalar_type:?}"
180        ))),
181    }
182}
183
184fn deserialize_option<'facet>(
185    partial: Partial<'facet, false>,
186    input: &[u8],
187    offset: &mut usize,
188) -> Result<Partial<'facet, false>, CborError> {
189    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
190    if decode::is_null(input, *offset) {
191        // Consume the null byte, leave Option as None
192        *offset += 1;
193        Ok(partial)
194    } else {
195        // begin_some, deserialize inner, end
196        let partial = partial.begin_some().map_err(re)?;
197        let partial = deserialize_into(partial, input, offset)?;
198        partial.end().map_err(re)
199    }
200}
201
202fn deserialize_list<'facet>(
203    partial: Partial<'facet, false>,
204    input: &[u8],
205    offset: &mut usize,
206) -> Result<Partial<'facet, false>, CborError> {
207    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
208    let len = decode::read_array_header(input, offset)? as usize;
209    let mut partial = partial.init_list_with_capacity(len).map_err(re)?;
210    for _ in 0..len {
211        partial = partial.begin_list_item().map_err(re)?;
212        partial = deserialize_into(partial, input, offset)?;
213        partial = partial.end().map_err(re)?;
214    }
215    Ok(partial)
216}
217
218fn deserialize_byte_list<'facet>(
219    partial: Partial<'facet, false>,
220    input: &[u8],
221    offset: &mut usize,
222) -> Result<Partial<'facet, false>, CborError> {
223    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
224    let bytes = decode::read_bytes(input, offset)?;
225    // Build Vec<u8> from the byte string and set it directly
226    let vec: Vec<u8> = bytes.to_vec();
227    partial.set(vec).map_err(re)
228}
229
230fn deserialize_array<'facet>(
231    partial: Partial<'facet, false>,
232    expected_len: usize,
233    input: &[u8],
234    offset: &mut usize,
235) -> Result<Partial<'facet, false>, CborError> {
236    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
237    let len = decode::read_array_header(input, offset)? as usize;
238    if len != expected_len {
239        return Err(CborError::TypeMismatch {
240            expected: format!("array of length {expected_len}"),
241            got: format!("array of length {len}"),
242        });
243    }
244    // Fixed-size arrays use begin_nth_field like tuples
245    let mut partial = partial;
246    for i in 0..len {
247        partial = partial.begin_nth_field(i).map_err(re)?;
248        partial = deserialize_into(partial, input, offset)?;
249        partial = partial.end().map_err(re)?;
250    }
251    Ok(partial)
252}
253
254fn deserialize_map<'facet>(
255    partial: Partial<'facet, false>,
256    input: &[u8],
257    offset: &mut usize,
258) -> Result<Partial<'facet, false>, CborError> {
259    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
260    let len = decode::read_map_header(input, offset)? as usize;
261    let mut partial = partial.init_map().map_err(re)?;
262    for _ in 0..len {
263        // key
264        partial = partial.begin_key().map_err(re)?;
265        partial = deserialize_into(partial, input, offset)?;
266        partial = partial.end().map_err(re)?;
267        // value
268        partial = partial.begin_value().map_err(re)?;
269        partial = deserialize_into(partial, input, offset)?;
270        partial = partial.end().map_err(re)?;
271    }
272    Ok(partial)
273}
274
275fn deserialize_struct<'facet>(
276    partial: Partial<'facet, false>,
277    input: &[u8],
278    offset: &mut usize,
279) -> Result<Partial<'facet, false>, CborError> {
280    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
281    let len = decode::read_map_header(input, offset)? as usize;
282    let mut partial = partial;
283    for _ in 0..len {
284        let key = decode::read_text(input, offset)?;
285        // Try to find the field; if unknown, skip the value
286        if partial.field_index(key).is_some() {
287            partial = partial.begin_field(key).map_err(re)?;
288            partial = deserialize_into(partial, input, offset)?;
289            partial = partial.end().map_err(re)?;
290        } else {
291            decode::skip_value(input, offset)?;
292        }
293    }
294    Ok(partial)
295}
296
297fn deserialize_tuple<'facet>(
298    partial: Partial<'facet, false>,
299    field_count: usize,
300    input: &[u8],
301    offset: &mut usize,
302) -> Result<Partial<'facet, false>, CborError> {
303    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
304    let len = decode::read_array_header(input, offset)? as usize;
305    if len != field_count {
306        return Err(CborError::TypeMismatch {
307            expected: format!("array of length {field_count}"),
308            got: format!("array of length {len}"),
309        });
310    }
311    let mut partial = partial;
312    for i in 0..len {
313        partial = partial.begin_nth_field(i).map_err(re)?;
314        partial = deserialize_into(partial, input, offset)?;
315        partial = partial.end().map_err(re)?;
316    }
317    Ok(partial)
318}
319
320fn deserialize_enum<'facet>(
321    partial: Partial<'facet, false>,
322    input: &[u8],
323    offset: &mut usize,
324) -> Result<Partial<'facet, false>, CborError> {
325    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
326
327    // Encoded as a map with 1 entry: variant_name → payload
328    let map_len = decode::read_map_header(input, offset)?;
329    if map_len != 1 {
330        return Err(CborError::InvalidCbor(format!(
331            "expected map with 1 entry for enum, got {map_len}"
332        )));
333    }
334
335    let variant_name = decode::read_text(input, offset)?;
336
337    // Find the variant by name and get its info before selecting
338    let (_, variant) = partial
339        .find_variant(variant_name)
340        .ok_or_else(|| CborError::InvalidCbor(format!("unknown enum variant: {variant_name}")))?;
341    let kind = variant.data.kind;
342    let field_count = variant.data.fields.len();
343
344    let mut partial = partial.select_variant_named(variant_name).map_err(re)?;
345
346    match kind {
347        StructKind::Unit => {
348            // Unit variant: payload is null
349            decode::read_null(input, offset)?;
350        }
351        StructKind::TupleStruct => {
352            if field_count == 1 {
353                // Newtype variant: payload is the single value directly
354                partial = partial.begin_nth_field(0).map_err(re)?;
355                partial = deserialize_into(partial, input, offset)?;
356                partial = partial.end().map_err(re)?;
357            } else {
358                // Tuple variant: payload is an array
359                let arr_len = decode::read_array_header(input, offset)? as usize;
360                if arr_len != field_count {
361                    return Err(CborError::TypeMismatch {
362                        expected: format!("array of length {field_count}"),
363                        got: format!("array of length {arr_len}"),
364                    });
365                }
366                for i in 0..field_count {
367                    partial = partial.begin_nth_field(i).map_err(re)?;
368                    partial = deserialize_into(partial, input, offset)?;
369                    partial = partial.end().map_err(re)?;
370                }
371            }
372        }
373        StructKind::Tuple => {
374            // Tuple variant: payload is an array
375            let arr_len = decode::read_array_header(input, offset)? as usize;
376            if arr_len != field_count {
377                return Err(CborError::TypeMismatch {
378                    expected: format!("array of length {field_count}"),
379                    got: format!("array of length {arr_len}"),
380                });
381            }
382            for i in 0..field_count {
383                partial = partial.begin_nth_field(i).map_err(re)?;
384                partial = deserialize_into(partial, input, offset)?;
385                partial = partial.end().map_err(re)?;
386            }
387        }
388        StructKind::Struct => {
389            // Struct variant: payload is a map
390            let map_len = decode::read_map_header(input, offset)? as usize;
391            for _ in 0..map_len {
392                let field_name = decode::read_text(input, offset)?;
393                if partial.field_index(field_name).is_some() {
394                    partial = partial.begin_field(field_name).map_err(re)?;
395                    partial = deserialize_into(partial, input, offset)?;
396                    partial = partial.end().map_err(re)?;
397                } else {
398                    decode::skip_value(input, offset)?;
399                }
400            }
401        }
402    }
403
404    Ok(partial)
405}
406
407/// Deserialize an internally-tagged enum.
408///
409/// When `#[facet(tag = "...")]` is set:
410/// - If the CBOR value is a text string → unit variant (the string is the variant name)
411/// - If the CBOR value is a map → read the tag field to find variant name, rest are struct fields
412fn deserialize_enum_internally_tagged<'facet>(
413    partial: Partial<'facet, false>,
414    input: &[u8],
415    offset: &mut usize,
416) -> Result<Partial<'facet, false>, CborError> {
417    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
418    let tag_key = partial
419        .shape()
420        .tag
421        .expect("internally-tagged enum must have tag");
422
423    if *offset >= input.len() {
424        return Err(CborError::InvalidCbor("unexpected end of input".into()));
425    }
426
427    let major = input[*offset] >> 5;
428    if major == 3 {
429        // Text string → unit variant
430        let variant_name = decode::read_text(input, offset)?;
431        let partial = partial.select_variant_named(variant_name).map_err(re)?;
432        Ok(partial)
433    } else if major == 5 {
434        // Map → struct variant with tag field
435        let map_len = decode::read_map_header(input, offset)? as usize;
436
437        // First, find the tag field to determine which variant we're deserializing.
438        // We need to scan through map entries to find the tag key.
439        // For efficiency, we expect the tag to be the first entry.
440        let mut variant_name: Option<&str> = None;
441        let mut saved_offset = *offset;
442
443        // Read the first key — it should be the tag
444        let first_key = decode::read_text(input, offset)?;
445        if first_key == tag_key {
446            variant_name = Some(decode::read_text(input, offset)?);
447        } else {
448            // Tag wasn't first; scan the whole map from the start
449            *offset = saved_offset;
450            let scan_offset = &mut saved_offset;
451            *scan_offset = *offset;
452            for _ in 0..map_len {
453                let key = decode::read_text(input, scan_offset)?;
454                if key == tag_key {
455                    variant_name = Some(decode::read_text(input, scan_offset)?);
456                    break;
457                }
458                decode::skip_value(input, scan_offset)?;
459            }
460            // Reset to after the first key we already read
461        }
462
463        let variant_name = variant_name.ok_or_else(|| {
464            CborError::InvalidCbor(format!(
465                "internally-tagged enum map missing '{}' field",
466                tag_key
467            ))
468        })?;
469
470        // Find variant info before selecting
471        let (_, variant) = partial.find_variant(variant_name).ok_or_else(|| {
472            CborError::InvalidCbor(format!("unknown enum variant: {variant_name}"))
473        })?;
474        let kind = variant.data.kind;
475
476        if kind != StructKind::Struct {
477            return Err(CborError::InvalidCbor(format!(
478                "internally-tagged enum variant '{}' must be a struct variant",
479                variant_name
480            )));
481        }
482
483        let mut partial = partial.select_variant_named(variant_name).map_err(re)?;
484
485        // Now re-read the map from the beginning, skipping the tag field,
486        // and deserializing all other fields as struct fields.
487        // We need to re-parse from after the map header.
488        // Actually, we've already consumed the first key. Let's handle this properly.
489        // Reset offset to after map header and re-read all entries.
490        // But we already consumed some bytes... Let me restructure.
491
492        // We consumed: map_header + first_key("tag") + first_value(variant_name)
493        // Now read remaining map_len - 1 entries as struct fields
494        for _ in 1..map_len {
495            let field_name = decode::read_text(input, offset)?;
496            if field_name == tag_key {
497                // Skip duplicate tag field
498                decode::skip_value(input, offset)?;
499            } else if partial.field_index(field_name).is_some() {
500                partial = partial.begin_field(field_name).map_err(re)?;
501                partial = deserialize_into(partial, input, offset)?;
502                partial = partial.end().map_err(re)?;
503            } else {
504                decode::skip_value(input, offset)?;
505            }
506        }
507
508        Ok(partial)
509    } else {
510        Err(CborError::InvalidCbor(format!(
511            "internally-tagged enum expected text string or map, got major type {}",
512            major
513        )))
514    }
515}
516
517fn deserialize_pointer<'facet>(
518    partial: Partial<'facet, false>,
519    input: &[u8],
520    offset: &mut usize,
521) -> Result<Partial<'facet, false>, CborError> {
522    let re = |e: facet_reflect::ReflectError| CborError::ReflectError(e.to_string());
523    if decode::is_null(input, *offset) {
524        *offset += 1;
525        Ok(partial)
526    } else {
527        let partial = partial.begin_smart_ptr().map_err(re)?;
528        let partial = deserialize_into(partial, input, offset)?;
529        partial.end().map_err(re)
530    }
531}