essential_types/predicate/
encode.rs

1//! # Encode and Decode Predicates
2//!
3//! # Encoding
4//! ## Predicate
5//! | Field | Size (bytes) | Description |
6//! | --- | --- | --- |
7//! | number_of_nodes | 2 | The number of nodes in the predicate. |
8//! | nodes | 35 * number_of_nodes | The nodes in the predicate. |
9//! | number_of_edges | 2 | The number of edges in the predicate. |
10//! | edges | 2 * number_of_edges | The edges in the predicate. |
11//!
12//! ## Node
13//! | Field | Size (bytes) | Description |
14//! | --- | --- | --- |
15//! | edge_start | 2 | The index of the first edge in the edge list. |
16//! | program_address | 32 | The address of the program. |
17//! | reads | 1 | The type of state this program has access to. |
18//!
19//! ## Edge
20//! | Field | Size (bytes) | Description |
21//! | --- | --- | --- |
22//! | edge | 2 | The index of the node that this edge points to. |
23
24use super::*;
25
26#[cfg(test)]
27mod tests;
28
29const NODE_SIZE_BYTES: usize = 35;
30const EDGE_SIZE_BYTES: usize = core::mem::size_of::<u16>();
31const LEN_SIZE_BYTES: usize = core::mem::size_of::<u16>();
32
33/// Errors that can occur when decoding a predicate.
34#[derive(Debug, PartialEq)]
35pub enum PredicateDecodeError {
36    /// The bytes are too short to contain the number of nodes.
37    BytesTooShort,
38}
39
40/// Errors that can occur when encoding a predicate.
41#[derive(Debug, PartialEq)]
42pub enum PredicateEncodeError {
43    /// The predicate contains too many nodes.
44    TooManyNodes,
45    /// The predicate contains too many edges.
46    TooManyEdges,
47}
48
49impl std::error::Error for PredicateDecodeError {}
50
51impl std::error::Error for PredicateEncodeError {}
52
53/// Encode a predicate into bytes.
54pub fn encode_predicate(
55    predicate: &Predicate,
56) -> Result<impl Iterator<Item = u8> + '_, PredicateEncodeError> {
57    let num_nodes = if predicate.nodes.len() <= Predicate::MAX_NODES as usize {
58        predicate.nodes.len() as u16
59    } else {
60        return Err(PredicateEncodeError::TooManyNodes);
61    };
62    let num_edges = if predicate.edges.len() <= Predicate::MAX_EDGES as usize {
63        predicate.edges.len() as u16
64    } else {
65        return Err(PredicateEncodeError::TooManyEdges);
66    };
67    let iter = num_nodes
68        .to_be_bytes()
69        .into_iter()
70        .chain(predicate.nodes.iter().flat_map(|node| {
71            node.edge_start
72                .to_be_bytes()
73                .into_iter()
74                .chain(node.program_address.0.iter().copied())
75                .chain(Some(node.reads as u8))
76        }))
77        .chain(num_edges.to_be_bytes())
78        .chain(predicate.edges.iter().flat_map(|edge| edge.to_be_bytes()));
79    Ok(iter)
80}
81
82/// The size of the encoded predicate.
83pub fn predicate_encoded_size(predicate: &Predicate) -> usize {
84    predicate.nodes.len() * NODE_SIZE_BYTES + predicate.edges.len() * EDGE_SIZE_BYTES + 2
85}
86
87/// Decode a predicate from bytes.
88pub fn decode_predicate(bytes: &[u8]) -> Result<Predicate, PredicateDecodeError> {
89    let Some(num_nodes) = bytes.get(..LEN_SIZE_BYTES).map(|x| {
90        let mut arr = [0; LEN_SIZE_BYTES];
91        arr.copy_from_slice(x);
92        u16::from_be_bytes(arr)
93    }) else {
94        return Err(PredicateDecodeError::BytesTooShort);
95    };
96
97    let nodes: Vec<_> =
98        match bytes.get(LEN_SIZE_BYTES..(LEN_SIZE_BYTES + num_nodes as usize * NODE_SIZE_BYTES)) {
99            Some(bytes) => bytes
100                .chunks_exact(NODE_SIZE_BYTES)
101                .take(num_nodes as usize)
102                .map(|node| Node {
103                    edge_start: u16::from_be_bytes(
104                        node[..2].try_into().expect("safe due to chunks exact"),
105                    ),
106                    program_address: ContentAddress(
107                        node[2..34].try_into().expect("safe due to chunks exact"),
108                    ),
109                    reads: Reads::from(node[34]),
110                })
111                .collect(),
112            None => return Err(PredicateDecodeError::BytesTooShort),
113        };
114
115    let num_edges_pos = num_nodes as usize * NODE_SIZE_BYTES + LEN_SIZE_BYTES;
116    let Some(num_edges) = bytes.get(num_edges_pos..(num_edges_pos + 2)).map(|x| {
117        let mut arr = [0; 2];
118        arr.copy_from_slice(x);
119        u16::from_be_bytes(arr)
120    }) else {
121        return Err(PredicateDecodeError::BytesTooShort);
122    };
123
124    let edges_start = num_edges_pos + LEN_SIZE_BYTES;
125    let edges: Vec<_> =
126        match bytes.get(edges_start..(edges_start + num_edges as usize * EDGE_SIZE_BYTES)) {
127            Some(bytes) => bytes
128                .chunks_exact(EDGE_SIZE_BYTES)
129                .map(|edge| {
130                    let mut arr = [0; 2];
131                    arr.copy_from_slice(edge);
132                    u16::from_be_bytes(arr)
133                })
134                .collect(),
135            None => return Err(PredicateDecodeError::BytesTooShort),
136        };
137    Ok(Predicate { nodes, edges })
138}
139
140impl Reads {
141    fn from(byte: u8) -> Self {
142        match byte % (Self::Post as u8 + 1) {
143            0 => Self::Pre,
144            1 => Self::Post,
145            _ => unreachable!(),
146        }
147    }
148}