Skip to main content

midnight_base_crypto/fab/
serialize.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use super::{AlignedValue, Alignment, AlignmentAtom, AlignmentSegment, Value, ValueAtom};
15#[cfg(feature = "proptest")]
16use serialize::randomised_serialization_test;
17use serialize::{Deserializable, ReadExt, Serializable, VecExt};
18use std::io::{self, Read, Write};
19
20const ONE_BYTE_LIMIT: u32 = (1 << 5) - 1;
21const TWO_BYTE_START: u32 = ONE_BYTE_LIMIT + 1;
22const TWO_BYTE_LIMIT: u32 = (1 << 12) - 1;
23const THREE_BYTE_START: u32 = TWO_BYTE_LIMIT + 1;
24const THREE_BYTE_LIMIT: u32 = (1 << 19) - 1;
25
26pub(super) fn write_flagged_int<W: Write>(
27    writer: &mut W,
28    x: bool,
29    y: bool,
30    int: u32,
31) -> io::Result<()> {
32    let flag_u8 = ((x as u8) << 7) | ((y as u8) << 6);
33    match int {
34        0..=ONE_BYTE_LIMIT => writer.write_all(&[flag_u8 | int as u8][..]),
35        TWO_BYTE_START..=TWO_BYTE_LIMIT => {
36            writer.write_all(&[flag_u8 | 0x20 | (int % 0x20) as u8, (int >> 5) as u8])
37        }
38        THREE_BYTE_START..=THREE_BYTE_LIMIT => writer.write_all(&[
39            flag_u8 | 0x20 | (int % 0x20) as u8,
40            0x80 | ((int >> 5) % 0x80) as u8,
41            (int >> 12) as u8,
42        ]),
43        _ => Err(io::Error::new(
44            io::ErrorKind::InvalidInput,
45            format!("integer out of three-byte limit: {}", int),
46        )),
47    }
48}
49
50pub(super) fn flagged_int_size(int: u32) -> usize {
51    match int {
52        0..=ONE_BYTE_LIMIT => 1,
53        TWO_BYTE_START..=TWO_BYTE_LIMIT => 2,
54        THREE_BYTE_START..=THREE_BYTE_LIMIT => 3,
55        // Ideally we'd error, but that's not sensible for a size *hint*.
56        _ => 1000,
57    }
58}
59
60fn read_flagged_int<R: Read>(reader: &mut R) -> io::Result<(bool, bool, u32)> {
61    let mut byte_buf = [0u8];
62    reader.read_exact(&mut byte_buf[..])?;
63    let x = (byte_buf[0] & 0x80) != 0;
64    let y = (byte_buf[0] & 0x40) != 0;
65    let a = (byte_buf[0] % 0x20) as u32;
66    if (byte_buf[0] & 0x20) == 0 {
67        return Ok((x, y, a));
68    }
69    reader.read_exact(&mut byte_buf[..])?;
70    let b = (byte_buf[0] % 0x80) as u32;
71    if (byte_buf[0] & 0x80) == 0 {
72        return if b == 0 {
73            Err(io::Error::new(
74                io::ErrorKind::InvalidInput,
75                "use of longer encoding than necessary for flagged int",
76            ))
77        } else {
78            Ok((x, y, a | (b << 5)))
79        };
80    }
81    reader.read_exact(&mut byte_buf[..])?;
82    let c = (byte_buf[0] % 0x80) as u32;
83    if (byte_buf[0] & 0x80) == 0 {
84        if c == 0 {
85            Err(io::Error::new(
86                io::ErrorKind::InvalidInput,
87                "use of longer encoding than necessary for flagged int",
88            ))
89        } else {
90            Ok((x, y, a | (b << 5) | (c << 12)))
91        }
92    } else {
93        Err(io::Error::new(
94            io::ErrorKind::InvalidInput,
95            "use of reserved flag in three-byte flagged int encoding",
96        ))
97    }
98}
99
100impl Serializable for Value {
101    fn serialize(&self, writer: &mut impl Write) -> std::io::Result<()> {
102        if self.0.len() == 1 {
103            <ValueAtom as Serializable>::serialize(&self.0[0], writer)
104        } else {
105            write_flagged_int(writer, true, false, self.0.len() as u32)?;
106            for atom in self.0.iter() {
107                atom.serialize(writer)?;
108            }
109            Ok(())
110        }
111    }
112
113    fn serialized_size(&self) -> usize {
114        if self.0.len() == 1 {
115            self.0[0].serialized_size()
116        } else {
117            flagged_int_size(self.0.len() as u32)
118                + self
119                    .0
120                    .iter()
121                    .map(Serializable::serialized_size)
122                    .sum::<usize>()
123        }
124    }
125}
126
127impl Deserializable for Value {
128    fn deserialize(
129        reader: &mut impl std::io::Read,
130        mut recursion_depth: u32,
131    ) -> Result<Self, std::io::Error> {
132        Self::check_rec(&mut recursion_depth)?;
133        let (x, y, int) = read_flagged_int(reader)?;
134        match (x, y) {
135            (false, _) => Ok(Value(vec![ValueAtom::deserialize_with_flagged_int(
136                x, y, int, reader,
137            )?])),
138            (true, false) => {
139                let mut res = Vec::new();
140                for _ in 0..int {
141                    res.push(<ValueAtom as Deserializable>::deserialize(
142                        reader,
143                        recursion_depth,
144                    )?);
145                }
146                Ok(Value(res))
147            }
148            (true, true) => Err(io::Error::new(
149                io::ErrorKind::InvalidData,
150                "Attempted to decode Value with reserved flags '11'",
151            )),
152        }
153    }
154}
155
156/// Returns integer's size.
157pub fn int_size(int: usize) -> usize {
158    match int {
159        0x00..=0xff => 1,
160        0x100..=0xffff => 2,
161        0x10000..=0xffffff => 3,
162        0x1000000..=0xffffffff => 4,
163        _ => unreachable!("invalid fab length"),
164    }
165}
166
167impl Serializable for Alignment {
168    fn serialize(&self, writer: &mut impl Write) -> io::Result<()> {
169        if self.0.len() == 1 {
170            self.0[0].serialize(writer)
171        } else {
172            write_flagged_int(writer, true, true, self.0.len() as u32)?;
173            for segment in self.0.iter() {
174                segment.serialize(writer)?;
175            }
176            Ok(())
177        }
178    }
179
180    fn serialized_size(&self) -> usize {
181        if self.0.len() == 1 {
182            self.0[0].serialized_size()
183        } else {
184            flagged_int_size(self.0.len() as u32)
185                + self
186                    .0
187                    .iter()
188                    .map(Serializable::serialized_size)
189                    .sum::<usize>()
190        }
191    }
192}
193
194#[cfg(feature = "proptest")]
195randomised_serialization_test!(Alignment);
196
197impl Deserializable for Alignment {
198    fn deserialize(
199        reader: &mut impl Read,
200        mut recursion_depth: u32,
201    ) -> Result<Self, std::io::Error> {
202        Self::check_rec(&mut recursion_depth)?;
203        let (x, y, int) = read_flagged_int(reader)?;
204        Alignment::deserialize_with_flagged_int(x, y, int, reader, recursion_depth)
205    }
206}
207
208impl Serializable for AlignmentSegment {
209    fn serialize(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
210        match self {
211            AlignmentSegment::Atom(atom) => atom.serialize(writer),
212            AlignmentSegment::Option(branches) => {
213                write_flagged_int(writer, true, false, branches.len() as u32)?;
214                for branch in branches.iter() {
215                    branch.serialize(writer)?;
216                }
217                Ok(())
218            }
219        }
220    }
221
222    fn serialized_size(&self) -> usize {
223        match self {
224            AlignmentSegment::Atom(atom) => atom.serialized_size(),
225            AlignmentSegment::Option(branches) => {
226                flagged_int_size(branches.len() as u32)
227                    + branches
228                        .iter()
229                        .map(Serializable::serialized_size)
230                        .sum::<usize>()
231            }
232        }
233    }
234}
235
236#[cfg(feature = "proptest")]
237randomised_serialization_test!(AlignmentSegment);
238
239impl Deserializable for AlignmentSegment {
240    fn deserialize(
241        reader: &mut impl Read,
242        mut recursion_depth: u32,
243    ) -> Result<Self, std::io::Error> {
244        Self::check_rec(&mut recursion_depth)?;
245        let (x, y, int) = read_flagged_int(reader)?;
246        AlignmentSegment::deserialize_with_flagged_int(x, y, int, reader, recursion_depth)
247    }
248}
249
250impl AlignmentSegment {
251    fn deserialize_with_flagged_int<R: Read>(
252        x: bool,
253        y: bool,
254        int: u32,
255        reader: &mut R,
256        recursion_depth: u32,
257    ) -> io::Result<Self> {
258        match (x, y) {
259            (false, _) => AlignmentAtom::deserialize_with_flagged_int(x, y, int, reader)
260                .map(AlignmentSegment::Atom),
261            (true, false) => {
262                let mut branches = Vec::with_bounded_capacity(int as usize);
263                for _ in 0..int {
264                    branches.push(<Alignment as Deserializable>::deserialize(
265                        reader,
266                        recursion_depth,
267                    )?);
268                }
269                Ok(AlignmentSegment::Option(branches))
270            }
271            (true, true) => Err(io::Error::new(
272                io::ErrorKind::InvalidData,
273                "Use of reserved flag '11' in AlignmentSegment",
274            )),
275        }
276    }
277}
278
279impl Serializable for AlignedValue {
280    fn serialize(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
281        self.value.serialize(writer)?;
282        self.alignment.serialize(writer)
283    }
284
285    fn serialized_size(&self) -> usize {
286        self.value.serialized_size() + self.alignment.serialized_size()
287    }
288}
289
290#[cfg(feature = "proptest")]
291randomised_serialization_test!(AlignedValue);
292
293impl Deserializable for AlignedValue {
294    fn deserialize(
295        reader: &mut impl Read,
296        mut recursion_depth: u32,
297    ) -> Result<Self, std::io::Error> {
298        Self::check_rec(&mut recursion_depth)?;
299        let value: Value = Deserializable::deserialize(reader, recursion_depth)?;
300        let alignment: Alignment = Deserializable::deserialize(reader, recursion_depth)?;
301        Ok(AlignedValue { value, alignment })
302    }
303}
304
305impl Serializable for ValueAtom {
306    fn serialize(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
307        if self.is_in_normal_form() {
308            if self.0.len() == 1 && self.0[0] < 32 {
309                write_flagged_int(writer, false, false, self.0[0] as u32)?;
310            } else {
311                write_flagged_int(writer, false, true, self.0.len() as u32)?;
312                writer.write_all(&self.0[..])?;
313            }
314            Ok(())
315        } else {
316            self.clone().normalize().serialize(writer)
317        }
318    }
319
320    fn serialized_size(&self) -> usize {
321        if self.is_in_normal_form() {
322            if self.0.len() == 1 && self.0[0] < 32 {
323                flagged_int_size(self.0[0] as u32)
324            } else {
325                flagged_int_size(self.0.len() as u32) + self.0.len()
326            }
327        } else {
328            self.clone().normalize().serialized_size()
329        }
330    }
331}
332
333impl Deserializable for ValueAtom {
334    fn deserialize(reader: &mut impl Read, _recursion_depth: u32) -> Result<Self, std::io::Error> {
335        let (x, y, int) = read_flagged_int(reader)?;
336        Self::deserialize_with_flagged_int(x, y, int, reader)
337    }
338}
339
340impl Serializable for AlignmentAtom {
341    fn serialize(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
342        match self {
343            AlignmentAtom::Compress => write_flagged_int(writer, false, true, 0),
344            AlignmentAtom::Field => write_flagged_int(writer, false, true, 1),
345            AlignmentAtom::Bytes { length } => write_flagged_int(writer, false, false, *length),
346        }
347    }
348
349    fn serialized_size(&self) -> usize {
350        match self {
351            AlignmentAtom::Bytes { length } => flagged_int_size(*length),
352            AlignmentAtom::Compress | AlignmentAtom::Field => 1,
353        }
354    }
355}
356
357#[cfg(feature = "proptest")]
358randomised_serialization_test!(AlignmentAtom);
359
360impl Deserializable for AlignmentAtom {
361    fn deserialize(reader: &mut impl Read, _recursion_depth: u32) -> Result<Self, std::io::Error> {
362        let (x, y, int) = read_flagged_int(reader)?;
363        AlignmentAtom::deserialize_with_flagged_int(x, y, int, reader)
364    }
365}
366
367impl Alignment {
368    fn deserialize_with_flagged_int<R: Read>(
369        x: bool,
370        y: bool,
371        int: u32,
372        reader: &mut R,
373        recursion_depth: u32,
374    ) -> io::Result<Self> {
375        if x && y {
376            if int == 1 {
377                return Err(io::Error::new(
378                    io::ErrorKind::InvalidData,
379                    "singleton alignment encoded as multi-entry alignment",
380                ));
381            }
382            let mut res = Vec::with_bounded_capacity(int as usize);
383            for _ in 0..int {
384                res.push(Deserializable::deserialize(reader, recursion_depth)?);
385            }
386            Ok(Alignment(res))
387        } else {
388            Ok(Alignment(vec![
389                AlignmentSegment::deserialize_with_flagged_int(x, y, int, reader, recursion_depth)?,
390            ]))
391        }
392    }
393}
394
395impl ValueAtom {
396    fn deserialize_with_flagged_int<R: Read>(
397        x: bool,
398        y: bool,
399        int: u32,
400        reader: &mut R,
401    ) -> io::Result<Self> {
402        if x {
403            return Err(io::Error::new(
404                io::ErrorKind::InvalidData,
405                "x-flag may not be 1 for value atom",
406            ));
407        }
408        if y {
409            let res = reader.read_exact_to_vec(int as usize)?;
410            if int > 0 && res[int as usize - 1] == 0 {
411                Err(io::Error::new(
412                    io::ErrorKind::InvalidData,
413                    "ValueAtom ended with zero byte",
414                ))
415            } else {
416                Ok(ValueAtom(res))
417            }
418        } else if int < 32 && int > 0 {
419            Ok(ValueAtom(vec![int as u8]))
420        } else {
421            Err(io::Error::new(
422                io::ErrorKind::InvalidData,
423                format!("singleton ValueAtom out of range: {}", int),
424            ))
425        }
426    }
427}
428
429impl AlignmentAtom {
430    fn deserialize_with_flagged_int<R: Read>(
431        x: bool,
432        y: bool,
433        int: u32,
434        _reader: &mut R,
435    ) -> io::Result<Self> {
436        match (x, y, int) {
437            (false, false, length) => Ok(AlignmentAtom::Bytes { length }),
438            (false, true, 0) => Ok(AlignmentAtom::Compress),
439            (false, true, 1) => Ok(AlignmentAtom::Field),
440            _ => Err(io::Error::new(
441                io::ErrorKind::InvalidData,
442                "illegal value for alignment atom",
443            )),
444        }
445    }
446}