Skip to main content

miden_node_proto/decode/
mod.rs

1use std::marker::PhantomData;
2
3use miden_protocol::utils::serde::Deserializable;
4
5mod utils;
6pub use utils::*;
7
8use crate::errors::ConversionError;
9// Re-export so callers can import from `conv`.
10pub use crate::errors::ConversionResultExt;
11
12// GRPC STRUCT DECODER
13// ================================================================================================
14
15/// Zero-cost struct decoder that captures the parent proto message type.
16///
17/// Created via [`GrpcDecodeExt::decoder`] which infers the parent type from the value:
18///
19/// ```rust,ignore
20/// // Before:
21/// let body = block.body.try_convert_field::<proto::SignedBlock>("body")?;
22/// let header = block.header.try_convert_field::<proto::SignedBlock>("header")?;
23///
24/// // After:
25/// let decoder = block.decoder();
26/// let body = decode!(decoder, block.body);
27/// let header = decode!(decoder, block.header);
28/// ```
29pub struct GrpcStructDecoder<M>(PhantomData<M>);
30
31impl<M: prost::Message> Default for GrpcStructDecoder<M> {
32    /// Create a decoder for the given parent message type directly.
33    ///
34    /// Prefer [`GrpcDecodeExt::decoder`] when a value of type `M` is available, as it infers
35    /// the type automatically.
36    fn default() -> Self {
37        Self(PhantomData)
38    }
39}
40
41impl<M: prost::Message> GrpcStructDecoder<M> {
42    /// Decode a required optional field: checks for `None`, converts via `TryInto`, and adds field
43    /// context on error.
44    pub fn decode_field<T, F>(
45        &self,
46        name: &'static str,
47        value: Option<T>,
48    ) -> Result<F, ConversionError>
49    where
50        T: TryInto<F>,
51        T::Error: Into<ConversionError>,
52    {
53        value
54            .ok_or_else(|| ConversionError::missing_field::<M>(name))?
55            .try_into()
56            .context(name)
57    }
58}
59
60/// Extension trait on [`prost::Message`] types to create a [`GrpcStructDecoder`] with the parent
61/// type inferred from the value.
62pub trait GrpcDecodeExt: prost::Message + Sized {
63    /// Create a decoder that uses `Self` as the parent message type for error reporting.
64    fn decoder(&self) -> GrpcStructDecoder<Self> {
65        GrpcStructDecoder(PhantomData)
66    }
67}
68
69impl<T: prost::Message> GrpcDecodeExt for T {}
70
71/// Decodes a required optional field from a protobuf message using the message's decoder.
72///
73/// Uses `stringify!` to automatically derive the field name for error reporting, avoiding
74/// the duplication between a string literal and the field access.
75///
76/// Has two forms:
77/// - `decode!(decoder, msg.field)` — expands to `decoder.decode_field("field", msg.field)`. Use
78///   when accessing a field directly on the message value.
79/// - `decode!(decoder, field)` — expands to `decoder.decode_field("field", field)`. Use after
80///   destructuring the message, when the field is a bare identifier.
81///
82/// # Usage
83///
84/// ```ignore
85/// let decoder = value.decoder();
86/// // With a field access:
87/// let sender = decode!(decoder, value.sender)?;
88///
89/// // With a bare identifier (after destructuring):
90/// let Proto { sender, .. } = value;
91/// let sender = decode!(decoder, sender)?;
92///
93/// // Without `?` to return the Result directly:
94/// decode!(decoder, value.id)
95/// ```
96#[macro_export]
97macro_rules! decode {
98    ($decoder:ident, $msg:ident . $field:ident) => {
99        $decoder.decode_field(stringify!($field), $msg.$field)
100    };
101    ($decoder:ident, $field:ident) => {
102        $decoder.decode_field(stringify!($field), $field)
103    };
104}
105
106// BYTE DESERIALIZATION EXTENSION TRAIT
107// ================================================================================================
108
109/// Extension trait on [`Deserializable`](miden_protocol::utils::Deserializable) types to
110/// deserialize from bytes and wrap errors as [`ConversionError`].
111///
112/// This removes the boilerplate of calling `T::read_from_bytes(&bytes)` followed by
113/// `.map_err(|source| ConversionError::deserialization("T", source))`:
114///
115/// ```rust,ignore
116/// // Before:
117/// BlockBody::read_from_bytes(&value.block_body)
118///     .map_err(|source| ConversionError::deserialization("BlockBody", source))
119///
120/// // After:
121/// BlockBody::decode_bytes(&value.block_body, "BlockBody")
122/// ```
123pub trait DecodeBytesExt: Deserializable {
124    /// Deserialize from bytes, wrapping any error as a [`ConversionError`].
125    fn decode_bytes(bytes: &[u8], entity: &'static str) -> Result<Self, ConversionError> {
126        Self::read_from_bytes(bytes)
127            .map_err(|source| ConversionError::deserialization(entity, source))
128    }
129}
130
131impl<T: Deserializable> DecodeBytesExt for T {}
132
133#[cfg(test)]
134mod tests {
135    use miden_protocol::Felt;
136
137    use super::*;
138    use crate::generated::primitives::Digest;
139
140    /// Simulates a deeply nested conversion where each layer adds its field context.
141    fn inner_conversion() -> Result<(), ConversionError> {
142        Err(ConversionError::message("value is not in range 0..MODULUS"))
143    }
144
145    fn outer_conversion() -> Result<(), ConversionError> {
146        inner_conversion().context("account_root").context("header")
147    }
148
149    #[test]
150    fn test_context_builds_dotted_field_path() {
151        let err = outer_conversion().unwrap_err();
152        assert_eq!(err.to_string(), "header.account_root: value is not in range 0..MODULUS");
153    }
154
155    #[test]
156    fn test_context_single_field() {
157        let err = inner_conversion().context("nullifier").unwrap_err();
158        assert_eq!(err.to_string(), "nullifier: value is not in range 0..MODULUS");
159    }
160
161    #[test]
162    fn test_context_deep_nesting() {
163        let err = outer_conversion().context("block").context("response").unwrap_err();
164        assert_eq!(
165            err.to_string(),
166            "response.block.header.account_root: value is not in range 0..MODULUS"
167        );
168    }
169
170    #[test]
171    fn test_no_context_shows_source_only() {
172        let err = inner_conversion().unwrap_err();
173        assert_eq!(err.to_string(), "value is not in range 0..MODULUS");
174    }
175
176    #[test]
177    fn test_context_on_external_error_type() {
178        let result: Result<u8, std::num::TryFromIntError> = u8::try_from(256u16);
179        let err = result.context("fee_amount").unwrap_err();
180        assert!(err.to_string().starts_with("fee_amount: "), "expected field prefix, got: {err}",);
181    }
182
183    #[test]
184    fn test_decode_field_missing() {
185        let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
186        let account_root: Option<Digest> = None;
187        let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
188        let err = result.unwrap_err();
189        assert!(
190            err.to_string().contains("account_root") && err.to_string().contains("missing"),
191            "expected missing field error, got: {err}",
192        );
193    }
194
195    #[test]
196    fn test_decode_field_conversion_error() {
197        let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
198        // Create a digest with an out-of-range value.
199        let account_root = Some(Digest { d0: u64::MAX, d1: 0, d2: 0, d3: 0 });
200        let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
201        let err = result.unwrap_err();
202        assert!(
203            err.to_string().starts_with("account_root: "),
204            "expected field prefix, got: {err}",
205        );
206    }
207}