use std::marker::PhantomData;
use miden_protocol::utils::serde::Deserializable;
use crate::errors::ConversionError;
pub use crate::errors::ConversionResultExt;
pub struct GrpcStructDecoder<M>(PhantomData<M>);
impl<M: prost::Message> Default for GrpcStructDecoder<M> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<M: prost::Message> GrpcStructDecoder<M> {
pub fn decode_field<T, F>(
&self,
name: &'static str,
value: Option<T>,
) -> Result<F, ConversionError>
where
T: TryInto<F>,
T::Error: Into<ConversionError>,
{
value
.ok_or_else(|| ConversionError::missing_field::<M>(name))?
.try_into()
.context(name)
}
}
pub trait GrpcDecodeExt: prost::Message + Sized {
fn decoder(&self) -> GrpcStructDecoder<Self> {
GrpcStructDecoder(PhantomData)
}
}
impl<T: prost::Message> GrpcDecodeExt for T {}
#[macro_export]
macro_rules! decode {
($decoder:ident, $msg:ident . $field:ident) => {
$decoder.decode_field(stringify!($field), $msg.$field)
};
($decoder:ident, $field:ident) => {
$decoder.decode_field(stringify!($field), $field)
};
}
pub trait DecodeBytesExt: Deserializable {
fn decode_bytes(bytes: &[u8], entity: &'static str) -> Result<Self, ConversionError> {
Self::read_from_bytes(bytes)
.map_err(|source| ConversionError::deserialization(entity, source))
}
}
impl<T: Deserializable> DecodeBytesExt for T {}
#[cfg(test)]
mod tests {
use miden_protocol::Felt;
use super::*;
use crate::generated::primitives::Digest;
fn inner_conversion() -> Result<(), ConversionError> {
Err(ConversionError::message("value is not in range 0..MODULUS"))
}
fn outer_conversion() -> Result<(), ConversionError> {
inner_conversion().context("account_root").context("header")
}
#[test]
fn test_context_builds_dotted_field_path() {
let err = outer_conversion().unwrap_err();
assert_eq!(err.to_string(), "header.account_root: value is not in range 0..MODULUS");
}
#[test]
fn test_context_single_field() {
let err = inner_conversion().context("nullifier").unwrap_err();
assert_eq!(err.to_string(), "nullifier: value is not in range 0..MODULUS");
}
#[test]
fn test_context_deep_nesting() {
let err = outer_conversion().context("block").context("response").unwrap_err();
assert_eq!(
err.to_string(),
"response.block.header.account_root: value is not in range 0..MODULUS"
);
}
#[test]
fn test_no_context_shows_source_only() {
let err = inner_conversion().unwrap_err();
assert_eq!(err.to_string(), "value is not in range 0..MODULUS");
}
#[test]
fn test_context_on_external_error_type() {
let result: Result<u8, std::num::TryFromIntError> = u8::try_from(256u16);
let err = result.context("fee_amount").unwrap_err();
assert!(err.to_string().starts_with("fee_amount: "), "expected field prefix, got: {err}",);
}
#[test]
fn test_decode_field_missing() {
let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
let account_root: Option<Digest> = None;
let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
let err = result.unwrap_err();
assert!(
err.to_string().contains("account_root") && err.to_string().contains("missing"),
"expected missing field error, got: {err}",
);
}
#[test]
fn test_decode_field_conversion_error() {
let decoder = GrpcStructDecoder::<crate::generated::blockchain::BlockHeader>::default();
let account_root = Some(Digest { d0: u64::MAX, d1: 0, d2: 0, d3: 0 });
let result: Result<[Felt; 4], _> = decode!(decoder, account_root);
let err = result.unwrap_err();
assert!(
err.to_string().starts_with("account_root: "),
"expected field prefix, got: {err}",
);
}
}