essential_types/predicate/
encode.rs

1//! # Encode and Decode Predicates
2//!
3//! # Encoding
4//! ## Predicate
5//! | Field | Size (bytes) | Description |
6//! | --- | --- | --- |
7//! | number_of_nodes | 2 | The number of nodes in the predicate. |
8//! | nodes | 34 * number_of_nodes | The nodes in the predicate. |
9//! | number_of_edges | 2 | The number of edges in the predicate. |
10//! | edges | 2 * number_of_edges | The edges in the predicate. |
11//!
12//! ## Node
13//! | Field | Size (bytes) | Description |
14//! | --- | --- | --- |
15//! | edge_start | 2 | The index of the first edge in the edge list. |
16//! | program_address | 32 | The address of the program. |
17//!
18//! ## Edge
19//! | Field | Size (bytes) | Description |
20//! | --- | --- | --- |
21//! | edge | 2 | The index of the node that this edge points to. |
22
23use super::*;
24
25#[cfg(test)]
26mod tests;
27
28const NODE_SIZE_BYTES: usize = 34;
29const EDGE_SIZE_BYTES: usize = core::mem::size_of::<u16>();
30const LEN_SIZE_BYTES: usize = core::mem::size_of::<u16>();
31
32/// Errors that can occur when decoding a predicate.
33#[derive(Debug, PartialEq)]
34pub enum PredicateDecodeError {
35    /// The bytes are too short to contain the number of nodes.
36    BytesTooShort,
37}
38
39/// Errors that can occur when encoding a predicate.
40#[derive(Debug, PartialEq)]
41pub enum PredicateEncodeError {
42    /// The predicate contains too many nodes.
43    TooManyNodes,
44    /// The predicate contains too many edges.
45    TooManyEdges,
46}
47
48impl std::error::Error for PredicateDecodeError {}
49
50impl std::error::Error for PredicateEncodeError {}
51
52/// Encode a predicate into bytes.
53pub fn encode_predicate(
54    predicate: &Predicate,
55) -> Result<impl Iterator<Item = u8> + '_, PredicateEncodeError> {
56    let num_nodes = if predicate.nodes.len() <= Predicate::MAX_NODES as usize {
57        predicate.nodes.len() as u16
58    } else {
59        return Err(PredicateEncodeError::TooManyNodes);
60    };
61    let num_edges = if predicate.edges.len() <= Predicate::MAX_EDGES as usize {
62        predicate.edges.len() as u16
63    } else {
64        return Err(PredicateEncodeError::TooManyEdges);
65    };
66    let iter = num_nodes
67        .to_be_bytes()
68        .into_iter()
69        .chain(predicate.nodes.iter().flat_map(|node| {
70            node.edge_start
71                .to_be_bytes()
72                .into_iter()
73                .chain(node.program_address.0.iter().copied())
74        }))
75        .chain(num_edges.to_be_bytes())
76        .chain(predicate.edges.iter().flat_map(|edge| edge.to_be_bytes()));
77    Ok(iter)
78}
79
80/// The size of the encoded predicate.
81pub fn predicate_encoded_size(predicate: &Predicate) -> usize {
82    predicate.nodes.len() * NODE_SIZE_BYTES + predicate.edges.len() * EDGE_SIZE_BYTES + 2
83}
84
85/// Decode a predicate from bytes.
86pub fn decode_predicate(bytes: &[u8]) -> Result<Predicate, PredicateDecodeError> {
87    let Some(num_nodes) = bytes.get(..LEN_SIZE_BYTES).map(|x| {
88        let mut arr = [0; LEN_SIZE_BYTES];
89        arr.copy_from_slice(x);
90        u16::from_be_bytes(arr)
91    }) else {
92        return Err(PredicateDecodeError::BytesTooShort);
93    };
94
95    let nodes: Vec<_> =
96        match bytes.get(LEN_SIZE_BYTES..(LEN_SIZE_BYTES + num_nodes as usize * NODE_SIZE_BYTES)) {
97            Some(bytes) => bytes
98                .chunks_exact(NODE_SIZE_BYTES)
99                .take(num_nodes as usize)
100                .map(|node| Node {
101                    edge_start: u16::from_be_bytes(
102                        node[..2].try_into().expect("safe due to chunks exact"),
103                    ),
104                    program_address: ContentAddress(
105                        node[2..].try_into().expect("safe due to chunks exact"),
106                    ),
107                })
108                .collect(),
109            None => return Err(PredicateDecodeError::BytesTooShort),
110        };
111
112    let num_edges_pos = num_nodes as usize * NODE_SIZE_BYTES + LEN_SIZE_BYTES;
113    let Some(num_edges) = bytes.get(num_edges_pos..(num_edges_pos + 2)).map(|x| {
114        let mut arr = [0; 2];
115        arr.copy_from_slice(x);
116        u16::from_be_bytes(arr)
117    }) else {
118        return Err(PredicateDecodeError::BytesTooShort);
119    };
120
121    let edges_start = num_edges_pos + LEN_SIZE_BYTES;
122    let edges: Vec<_> =
123        match bytes.get(edges_start..(edges_start + num_edges as usize * EDGE_SIZE_BYTES)) {
124            Some(bytes) => bytes
125                .chunks_exact(EDGE_SIZE_BYTES)
126                .map(|edge| {
127                    let mut arr = [0; 2];
128                    arr.copy_from_slice(edge);
129                    u16::from_be_bytes(arr)
130                })
131                .collect(),
132            None => return Err(PredicateDecodeError::BytesTooShort),
133        };
134    Ok(Predicate { nodes, edges })
135}