use std::borrow::Cow;
use aws_smithy_types::{Blob, DateTime};
use minicbor::decode::Error;
use crate::data::Type;
#[derive(Debug, Clone)]
pub struct Decoder<'b> {
decoder: minicbor::Decoder<'b>,
}
#[derive(Debug)]
pub struct DeserializeError {
#[allow(dead_code)]
_inner: Error,
}
impl std::fmt::Display for DeserializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self._inner.fmt(f)
}
}
impl std::error::Error for DeserializeError {}
impl DeserializeError {
pub(crate) fn new(inner: Error) -> Self {
Self { _inner: inner }
}
pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
Self {
_inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
.with_message("encountered unexpected union variant; expected end of union")
.at(at),
}
}
pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
Self {
_inner: Error::message(format!("encountered unknown union variant {variant_name}"))
.at(at),
}
}
pub fn mixed_union_variants(at: usize) -> Self {
Self {
_inner: Error::message(
"encountered mixed variants in union; expected a single union variant to be set",
)
.at(at),
}
}
pub fn expected_end_of_stream(at: usize) -> Self {
Self {
_inner: Error::message("encountered additional data; expected end of stream").at(at),
}
}
pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
Self {
_inner: Error::message(message.into()).at(at),
}
}
pub fn is_type_mismatch(&self) -> bool {
self._inner.is_type_mismatch()
}
}
macro_rules! delegate_method {
($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
$(
pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
self.decoder.$encoder_name().map_err(DeserializeError::new)
}
)+
};
}
impl<'b> Decoder<'b> {
pub fn new(bytes: &'b [u8]) -> Self {
Self {
decoder: minicbor::Decoder::new(bytes),
}
}
pub fn datatype(&self) -> Result<Type, DeserializeError> {
self.decoder
.datatype()
.map(Type::new)
.map_err(DeserializeError::new)
}
delegate_method! {
skip => skip(());
boolean => bool(bool);
byte => i8(i8);
short => i16(i16);
integer => i32(i32);
long => i64(i64);
float => f32(f32);
double => f64(f64);
null => null(());
list => array(Option<u64>);
map => map(Option<u64>);
}
pub fn position(&self) -> usize {
self.decoder.position()
}
pub fn set_position(&mut self, pos: usize) {
self.decoder.set_position(pos)
}
pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
let bookmark = self.decoder.position();
match self.decoder.str() {
Ok(str_value) => Ok(Cow::Borrowed(str_value)),
Err(e) if e.is_type_mismatch() => {
self.decoder.set_position(bookmark);
Ok(Cow::Owned(self.string()?))
}
Err(e) => Err(DeserializeError::new(e)),
}
}
pub fn string(&mut self) -> Result<String, DeserializeError> {
let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
let head = iter.next();
let decoded_string = match head {
None => String::new(),
Some(head) => {
let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
for chunk in iter {
combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
}
combined_chunks
}
};
Ok(decoded_string)
}
pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
let parts: Vec<&[u8]> = iter
.collect::<Result<_, _>>()
.map_err(DeserializeError::new)?;
Ok(if parts.len() == 1 {
Blob::new(parts[0]) } else {
Blob::new(parts.concat()) })
}
pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
let tag = self.decoder.tag().map_err(DeserializeError::new)?;
let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
if tag != timestamp_tag {
Err(DeserializeError::new(Error::message(
"expected timestamp tag",
)))
} else {
let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
let mut result = DateTime::from_secs_f64(epoch_seconds);
let subsec_nanos = result.subsec_nanos();
result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
Ok(result)
}
}
}
#[allow(dead_code)] #[derive(Debug)]
pub struct ArrayIter<'a, 'b, T> {
inner: minicbor::decode::ArrayIter<'a, 'b, T>,
}
impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
type Item = Result<T, DeserializeError>;
fn next(&mut self) -> Option<Self::Item> {
self.inner
.next()
.map(|opt| opt.map_err(DeserializeError::new))
}
}
#[allow(dead_code)] #[derive(Debug)]
pub struct MapIter<'a, 'b, K, V> {
inner: minicbor::decode::MapIter<'a, 'b, K, V>,
}
impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
where
K: minicbor::Decode<'b, ()>,
V: minicbor::Decode<'b, ()>,
{
type Item = Result<(K, V), DeserializeError>;
fn next(&mut self) -> Option<Self::Item> {
self.inner
.next()
.map(|opt| opt.map_err(DeserializeError::new))
}
}
pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
where
F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
{
match decoder.datatype()? {
crate::data::Type::Null => {
decoder.null()?;
Ok(builder)
}
_ => f(builder, decoder),
}
}
#[cfg(test)]
mod tests {
use crate::Decoder;
use aws_smithy_types::date_time::Format;
#[test]
fn test_definite_str_is_cow_borrowed() {
let definite_bytes = [
0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
];
let mut decoder = Decoder::new(&definite_bytes);
let member = decoder.str().expect("could not decode str");
assert_eq!(member, "thisIsAKey");
assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
}
#[test]
fn test_indefinite_str_is_cow_owned() {
let indefinite_bytes = [
0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
0x79, 0xff,
];
let mut decoder = Decoder::new(&indefinite_bytes);
let member = decoder.str().expect("could not decode str");
assert_eq!(member, "thisIsAKey");
assert!(matches!(member, std::borrow::Cow::Owned(_)));
}
#[test]
fn test_empty_str_works() {
let bytes = [0x60];
let mut decoder = Decoder::new(&bytes);
let member = decoder.str().expect("could not decode empty str");
assert_eq!(member, "");
}
#[test]
fn test_empty_blob_works() {
let bytes = [0x40];
let mut decoder = Decoder::new(&bytes);
let member = decoder.blob().expect("could not decode an empty blob");
assert_eq!(member, aws_smithy_types::Blob::new([]));
}
#[test]
fn test_indefinite_length_blob() {
let indefinite_bytes = [
0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
0x61, 0xff,
];
let mut decoder = Decoder::new(&indefinite_bytes);
let member = decoder.blob().expect("could not decode blob");
assert_eq!(
member,
aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
);
}
#[test]
fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
let bytes = [
0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
];
let mut decoder = Decoder::new(&bytes);
let timestamp = decoder.timestamp().expect("should decode timestamp");
assert_eq!(
timestamp,
aws_smithy_types::date_time::DateTime::from_str(
"2000-01-02T20:34:56.123Z",
Format::DateTime
)
.unwrap()
);
}
}