Skip to main content

miden_node_proto/decode/
mod.rs

1use std::marker::PhantomData;
2
3use miden_protocol::utils::serde::Deserializable;
4
5use crate::errors::ConversionError;
6// Re-export so callers can import from `conv`.
7pub use crate::errors::ConversionResultExt;
8
9// GRPC STRUCT DECODER
10// ================================================================================================
11
12/// Zero-cost struct decoder that captures the parent proto message type.
13///
14/// Created via [`GrpcDecodeExt::decoder`] which infers the parent type from the value:
15///
16/// ```rust,ignore
17/// // Before:
18/// let body = block.body.try_convert_field::<proto::SignedBlock>("body")?;
19/// let header = block.header.try_convert_field::<proto::SignedBlock>("header")?;
20///
21/// // After:
22/// let decoder = block.decoder();
23/// let body = decode!(decoder, block.body);
24/// let header = decode!(decoder, block.header);
25/// ```
26pub struct GrpcStructDecoder<M>(PhantomData<M>);
27
28impl<M: prost::Message> Default for GrpcStructDecoder<M> {
29    /// Create a decoder for the given parent message type directly.
30    ///
31    /// Prefer [`GrpcDecodeExt::decoder`] when a value of type `M` is available, as it infers
32    /// the type automatically.
33    fn default() -> Self {
34        Self(PhantomData)
35    }
36}
37
38impl<M: prost::Message> GrpcStructDecoder<M> {
39    /// Decode a required optional field: checks for `None`, converts via `TryInto`, and adds
40    /// field context on error.
41    pub fn decode_field<T, F>(
42        &self,
43        name: &'static str,
44        value: Option<T>,
45    ) -> Result<F, ConversionError>
46    where
47        T: TryInto<F>,
48        T::Error: Into<ConversionError>,
49    {
50        value
51            .ok_or_else(|| ConversionError::missing_field::<M>(name))?
52            .try_into()
53            .context(name)
54    }
55}
56
57/// Extension trait on [`prost::Message`] types to create a [`GrpcStructDecoder`] with the parent
58/// type inferred from the value.
59pub trait GrpcDecodeExt: prost::Message + Sized {
60    /// Create a decoder that uses `Self` as the parent message type for error reporting.
61    fn decoder(&self) -> GrpcStructDecoder<Self> {
62        GrpcStructDecoder(PhantomData)
63    }
64}
65
66impl<T: prost::Message> GrpcDecodeExt for T {}
67
68/// Decodes a required optional field from a protobuf message using the message's decoder.
69///
70/// Uses `stringify!` to automatically derive the field name for error reporting, avoiding
71/// the duplication between a string literal and the field access.
72///
73/// Has two forms:
74/// - `decode!(decoder, msg.field)` — expands to `decoder.decode_field("field", msg.field)`. Use
75///   when accessing a field directly on the message value.
76/// - `decode!(decoder, field)` — expands to `decoder.decode_field("field", field)`. Use after
77///   destructuring the message, when the field is a bare identifier.
78///
79/// # Usage
80///
81/// ```ignore
82/// let decoder = value.decoder();
83/// // With a field access:
84/// let sender = decode!(decoder, value.sender)?;
85///
86/// // With a bare identifier (after destructuring):
87/// let Proto { sender, .. } = value;
88/// let sender = decode!(decoder, sender)?;
89///
90/// // Without `?` to return the Result directly:
91/// decode!(decoder, value.id)
92/// ```
93#[macro_export]
94macro_rules! decode {
95    ($decoder:ident, $msg:ident . $field:ident) => {
96        $decoder.decode_field(stringify!($field), $msg.$field)
97    };
98    ($decoder:ident, $field:ident) => {
99        $decoder.decode_field(stringify!($field), $field)
100    };
101}
102
103// BYTE DESERIALIZATION EXTENSION TRAIT
104// ================================================================================================
105
106/// Extension trait on [`Deserializable`](miden_protocol::utils::Deserializable) types to
107/// deserialize from bytes and wrap errors as [`ConversionError`].
108///
109/// This removes the boilerplate of calling `T::read_from_bytes(&bytes)` followed by
110/// `.map_err(|source| ConversionError::deserialization("T", source))`:
111///
112/// ```rust,ignore
113/// // Before:
114/// BlockBody::read_from_bytes(&value.block_body)
115///     .map_err(|source| ConversionError::deserialization("BlockBody", source))
116///
117/// // After:
118/// BlockBody::decode_bytes(&value.block_body, "BlockBody")
119/// ```
120pub trait DecodeBytesExt: Deserializable {
121    /// Deserialize from bytes, wrapping any error as a [`ConversionError`].
122    fn decode_bytes(bytes: &[u8], entity: &'static str) -> Result<Self, ConversionError> {
123        Self::read_from_bytes(bytes)
124            .map_err(|source| ConversionError::deserialization(entity, source))
125    }
126}
127
128impl<T: Deserializable> DecodeBytesExt for T {}
129
130#[cfg(test)]
131mod tests {
132    use miden_protocol::Felt;
133
134    use super::*;
135    use crate::generated::primitives::Digest;
136
137    /// Simulates a deeply nested conversion where each layer adds its field context.
138    fn inner_conversion() -> Result<(), ConversionError> {
139        Err(ConversionError::message("value is not in range 0..MODULUS"))
140    }
141
142    fn outer_conversion() -> Result<(), ConversionError> {
143        inner_conversion().context("account_root").context("header")
144    }
145
146    #[test]
147    fn test_context_builds_dotted_field_path() {
148        let err = outer_conversion().unwrap_err();
149        assert_eq!(err.to_string(), "header.account_root: value is not in range 0..MODULUS");
150    }
151
152    #[test]
153    fn test_context_single_field() {
154        let err = inner_conversion().context("nullifier").unwrap_err();
155        assert_eq!(err.to_string(), "nullifier: value is not in range 0..MODULUS");
156    }
157
158    #[test]
159    fn test_context_deep_nesting() {
160        let err = outer_conversion().context("block").context("response").unwrap_err();
161        assert_eq!(
162            err.to_string(),
163            "response.block.header.account_root: value is not in range 0..MODULUS"
164        );
165    }
166
167    #[test]
168    fn test_no_context_shows_source_only() {
169        let err = inner_conversion().unwrap_err();
170        assert_eq!(err.to_string(), "value is not in range 0..MODULUS");
171    }
172
173    #[test]
174    fn test_context_on_external_error_type() {
175        let result: Result<u8, std::num::TryFromIntError> = u8::try_from(256u16);
176        let err = result.context("fee_amount").unwrap_err();
177        assert!(err.to_string().starts_with("fee_amount: "), "expected field prefix, got: {err}",);
178    }
179
180    #[test]
181    fn test_decode_field_missing() {
182        let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
183        let account_root: Option<Digest> = None;
184        let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
185        let err = result.unwrap_err();
186        assert!(
187            err.to_string().contains("account_root") && err.to_string().contains("missing"),
188            "expected missing field error, got: {err}",
189        );
190    }
191
192    #[test]
193    fn test_decode_field_conversion_error() {
194        let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
195        // Create a digest with an out-of-range value.
196        let account_root = Some(Digest { d0: u64::MAX, d1: 0, d2: 0, d3: 0 });
197        let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
198        let err = result.unwrap_err();
199        assert!(
200            err.to_string().starts_with("account_root: "),
201            "expected field prefix, got: {err}",
202        );
203    }
204}