Skip to main content

midnight_serialize/
deserializable.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::VecExt;
15use crate::serializable::GLOBAL_TAG;
16use crate::tagged::Tagged;
17use std::borrow::Cow;
18use std::io::Read;
19use std::marker::PhantomData;
20use std::sync::Arc;
21use std::{collections::HashMap, collections::HashSet, hash::Hash};
22
23#[cfg(debug_assertions)]
24pub const RECURSION_LIMIT: u32 = 50;
25#[cfg(not(debug_assertions))]
26pub const RECURSION_LIMIT: u32 = 250;
27
28// Top-level deserialization function
29pub fn tagged_deserialize<T: Deserializable + Tagged>(mut reader: impl Read) -> std::io::Result<T> {
30    let tag_expected = format!("{GLOBAL_TAG}{}:", T::tag());
31    let mut read_tag = vec![0u8; tag_expected.len()];
32    let mut remaining_tag_buf = &mut read_tag[..];
33    while !remaining_tag_buf.is_empty() {
34        let read = reader.read(remaining_tag_buf)?;
35        if read == 0 {
36            let rem = remaining_tag_buf.len();
37            let len = read_tag.len() - rem;
38            read_tag.truncate(len);
39            break;
40        }
41        remaining_tag_buf = &mut remaining_tag_buf[read..];
42    }
43    if read_tag != tag_expected.as_bytes() {
44        let sanitised = String::from_utf8_lossy(&read_tag).replace(
45            |c: char| -> bool { !c.is_ascii_alphanumeric() && !":_-()[],".contains(c) },
46            "�",
47        );
48        return Err(std::io::Error::new(
49            std::io::ErrorKind::InvalidData,
50            format!("expected header tag '{tag_expected}', got '{sanitised}'"),
51        ));
52    }
53    let value = <T as Deserializable>::deserialize(&mut reader, 0)?;
54
55    #[allow(clippy::unbuffered_bytes)] // we can permit a potentally inefficient count here, as in
56    let count = reader.bytes().count(); // the happy path it should be 0
57
58    if count == 0 {
59        return Ok(value);
60    }
61
62    Err(std::io::Error::new(
63        std::io::ErrorKind::InvalidData,
64        format!(
65            "Not all bytes read deserializing '{}'; {} bytes remaining",
66            tag_expected, count
67        ),
68    ))
69}
70
71pub trait Deserializable
72where
73    Self: Sized,
74{
75    const LIMIT_RECURSION: bool = true;
76
77    fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self>;
78
79    fn check_rec(depth: &mut u32) -> std::io::Result<()> {
80        if Self::LIMIT_RECURSION {
81            *depth += 1;
82            if *depth > RECURSION_LIMIT {
83                return Err(std::io::Error::new(
84                    std::io::ErrorKind::InvalidData,
85                    "exceeded recursion depth deserializing",
86                ));
87            }
88        }
89        Ok(())
90    }
91}
92
93impl<T: Deserializable> Deserializable for Vec<T> {
94    fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
95        Self::check_rec(&mut recursion_depth)?;
96        let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
97        let mut result = Vec::with_bounded_capacity(len as usize);
98        for _ in 0..len {
99            result.push(<T as Deserializable>::deserialize(reader, recursion_depth)?);
100        }
101        Ok(result)
102    }
103}
104
105impl<K: Deserializable + PartialOrd + Hash + Eq, V: Deserializable> Deserializable
106    for HashMap<K, V>
107{
108    fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
109        Self::check_rec(&mut recursion_depth)?;
110        let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
111        let mut result = HashMap::new();
112        for _ in 0..len {
113            let k = <K as Deserializable>::deserialize(reader, recursion_depth)?;
114            let v = <V as Deserializable>::deserialize(reader, recursion_depth)?;
115            result.insert(k, v);
116        }
117        Ok(result)
118    }
119}
120
121impl<T: Deserializable + Hash + Eq> Deserializable for HashSet<T> {
122    fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
123        Self::check_rec(&mut recursion_depth)?;
124        let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
125        let mut result = HashSet::new();
126        for _ in 0..len {
127            result.insert(<T as Deserializable>::deserialize(reader, recursion_depth)?);
128        }
129        Ok(result)
130    }
131}
132
133impl<T: Deserializable> Deserializable for Option<T> {
134    fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
135        Self::check_rec(&mut recursion_depth)?;
136        let some = <u8 as Deserializable>::deserialize(reader, recursion_depth)?;
137        match some {
138            0 => Ok(None),
139            1 => Ok(Some(<T as Deserializable>::deserialize(
140                reader,
141                recursion_depth,
142            )?)),
143            _ => Err(std::io::Error::new(
144                std::io::ErrorKind::InvalidData,
145                format!("Invalid discriminant: {}.", some),
146            )),
147        }
148    }
149}
150
151impl<T: Deserializable> Deserializable for Arc<T> {
152    fn deserialize(
153        reader: &mut impl Read,
154        mut recursion_depth: u32,
155    ) -> Result<Self, std::io::Error> {
156        Self::check_rec(&mut recursion_depth)?;
157        Ok(Arc::new(T::deserialize(reader, recursion_depth)?))
158    }
159}
160
161impl<const N: usize> Deserializable for [u8; N] {
162    fn deserialize(reader: &mut impl Read, _recursion_depth: u32) -> std::io::Result<Self> {
163        let mut res = [0u8; N];
164        reader.read_exact(&mut res[..])?;
165        Ok(res)
166    }
167}
168
169impl<T: ?Sized> Deserializable for PhantomData<T> {
170    fn deserialize(_reader: &mut impl Read, _recursion_depth: u32) -> std::io::Result<Self> {
171        Ok(PhantomData)
172    }
173}
174
175impl<T: Deserializable> Deserializable for Box<T> {
176    fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self> {
177        T::deserialize(reader, recursion_depth).map(Box::new)
178    }
179}
180
181impl<'a, T: ToOwned + ?Sized> Deserializable for Cow<'a, T>
182where
183    T::Owned: Deserializable,
184{
185    fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self> {
186        <T::Owned>::deserialize(reader, recursion_depth).map(Cow::Owned)
187    }
188}