use std::cmp::Ordering;
use crate::ast::*;
use proc_macro2::TokenStream;
use quote::quote;
pub fn gen_encode(ast: &SchemaType, lurk_ipld: &syn::Ident) -> TokenStream {
let (ident, generics, body) = match ast {
SchemaType::Struct(s) => (&s.name, s.generics.as_ref().unwrap(), gen_encode_struct(s)),
SchemaType::Union(u) => (&u.name, &u.generics, gen_encode_union(u)),
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let trait_name = quote!(#lurk_ipld::codec::Encode<#lurk_ipld::cbor::DagCborCodec>);
quote! {
impl #impl_generics #trait_name for #ident #ty_generics #where_clause {
fn encode<W: std::io::Write>(
&self,
c: #lurk_ipld::cbor::DagCborCodec,
w: &mut W,
) -> #lurk_ipld::Result<()> {
use #lurk_ipld::codec::Encode;
use #lurk_ipld::cbor::cbor::MajorKind;
use #lurk_ipld::cbor::encode::{write_null, write_u8, write_u64};
#body
}
}
}
}
pub fn gen_decode(ast: &SchemaType, lurk_ipld: &syn::Ident) -> TokenStream {
let (ident, generics, body) = match ast {
SchemaType::Struct(s) => (&s.name, s.generics.as_ref().unwrap(), gen_decode_struct(s)),
SchemaType::Union(u) => (&u.name, &u.generics, gen_decode_union(u)),
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let trait_name = quote!(#lurk_ipld::codec::Decode<#lurk_ipld::cbor::DagCborCodec>);
quote! {
impl #impl_generics #trait_name for #ident #ty_generics #where_clause {
fn decode<R: std::io::Read + std::io::Seek>(
c: #lurk_ipld::cbor::DagCborCodec,
r: &mut R,
) -> #lurk_ipld::Result<Self> {
use #lurk_ipld::cbor::cbor::{MajorKind, NULL};
use #lurk_ipld::cbor::decode::{read_uint, read_major};
use #lurk_ipld::cbor::error::{LengthOutOfRange, MissingKey, UnexpectedCode, UnexpectedKey};
use #lurk_ipld::codec::Decode;
use #lurk_ipld::error::Result;
use std::io::SeekFrom;
#body
}
}
}
}
fn rename(name: &syn::Member, rename: Option<&String>) -> TokenStream {
if let Some(rename) = rename {
quote!(#rename)
} else {
let name = match name {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(index) => index.index.to_string(),
};
quote!(#name)
}
}
fn default(binding: &syn::Ident, default: Option<&syn::Expr>, tokens: TokenStream) -> TokenStream {
if let Some(default) = default {
quote! {
if #binding != &#default {
#tokens
}
}
} else {
tokens
}
}
fn gen_encode_match(arms: impl Iterator<Item = TokenStream>) -> TokenStream {
quote! {
match *self {
#(#arms,)*
}
Ok(())
}
}
fn gen_encode_struct(s: &Struct) -> TokenStream {
let pat = &*s.pat;
let body = gen_encode_struct_body(s);
gen_encode_match(std::iter::once(quote!(#pat => { #body })))
}
fn gen_encode_struct_body(s: &Struct) -> TokenStream {
match s.repr {
StructRepr::Map => {
let len = s.fields.len() as u64;
let dfields = s.fields.iter().filter_map(|field| {
if let Some(default) = field.default.as_ref() {
let binding = &field.binding;
Some(quote! {
if #binding == &#default {
len -= 1;
}
})
} else {
None
}
});
let mut cbor_order = s
.fields
.iter()
.map(|field| {
let key = rename(&field.name, field.rename.as_ref());
let binding = &field.binding;
let field = default(
binding,
field.default.as_deref(),
quote! {
Encode::encode(#key, c, w)?;
Encode::encode(#binding, c, w)?;
},
);
(key.to_string(), field)
})
.collect::<Vec<(String, _)>>();
cbor_order.sort_unstable_by(|(key_a, _), (key_b, _)| {
match key_a.len().cmp(&key_b.len()) {
Ordering::Greater => Ordering::Greater,
Ordering::Less => Ordering::Less,
Ordering::Equal => key_a.cmp(key_b),
}
});
let fields = cbor_order.iter().map(|(_, field)| field);
quote! {
let mut len = #len;
#(#dfields)*
write_u64(w, MajorKind::Map, len)?;
#(#fields)*
}
}
StructRepr::Tuple => {
let len = s.fields.len() as u64;
let fields = s.fields.iter().map(|field| {
let binding = &field.binding;
quote! {
Encode::encode(#binding, c, w)?;
}
});
quote! {
write_u64(w, MajorKind::Array, #len)?;
#(#fields)*
}
}
StructRepr::Value => {
assert_eq!(s.fields.len(), 1);
let field = &s.fields[0];
let binding = &field.binding;
default(
binding,
field.default.as_deref(),
quote! {
Encode::encode(#binding, c, w)?;
},
)
}
StructRepr::Null => {
assert_eq!(s.fields.len(), 0);
quote!(write_null(w)?;)
}
}
}
#[allow(clippy::needless_collect)]
fn gen_encode_union(u: &Union) -> TokenStream {
let arms = u
.variants
.iter()
.enumerate()
.map(|(i, s)| {
let pat = &*s.pat;
let key = rename(&syn::Member::Named(s.name.clone()), s.rename.as_ref());
let value = gen_encode_struct_body(s);
match u.repr {
UnionRepr::Keyed => {
quote! {
#pat => {
write_u8(w, MajorKind::Map, 1)?;
Encode::encode(#key, c, w)?;
#value
}
}
}
UnionRepr::Kinded => {
quote!(#pat => { #value })
}
UnionRepr::String => {
assert_eq!(s.repr, StructRepr::Null);
quote!(#pat => Encode::encode(#key, c, w)?)
}
UnionRepr::Int => {
assert_eq!(s.repr, StructRepr::Null);
quote!()
}
UnionRepr::IntTuple => {
quote! {
#pat => {
write_u8(w, MajorKind::Array, 2)?;
write_u64(w, MajorKind::UnsignedInt, #i as u64)?;
#value
}
}
}
}
})
.collect::<Vec<_>>();
if u.repr == UnionRepr::Int {
quote!(Encode::encode(&(*self as u64), c, w))
} else {
gen_encode_match(arms.into_iter())
}
}
fn gen_decode_struct(s: &Struct) -> TokenStream {
let len = s.fields.len() as u64;
let construct = &*s.construct;
match s.repr {
StructRepr::Map => {
let binding: Vec<_> = s.fields.iter().map(|field| &field.binding).collect();
let key: Vec<_> = s
.fields
.iter()
.map(|field| rename(&field.name, field.rename.as_ref()))
.collect();
let fields: Vec<_> = s
.fields
.iter()
.map(|field| {
let binding = &field.binding;
let key = rename(&field.name, field.rename.as_ref());
if let Some(default) = field.default.as_ref() {
quote!(let #binding = #binding.unwrap_or(#default);)
} else {
quote!(let #binding = #binding.ok_or(MissingKey::new::<Self>(#key))?;)
}
})
.collect();
quote! {
let major = read_major(r)?;
match major.kind() {
MajorKind::Map => {
let len = read_uint(r, major)?;
if len > #len {
return Err(LengthOutOfRange::new::<Self>().into());
}
#(let mut #binding = None;)*
for _ in 0..len {
let mut key: String = Decode::decode(c, r)?;
match key.as_str() {
#(#key => { #binding = Some(Decode::decode(c, r)?); })*
_ => {
Decode::decode(c, r)?;
}
}
}
#(#fields)*
return Ok(#construct);
}
_ => {
return Err(UnexpectedCode::new::<Self>(major.into()).into());
}
}
}
}
StructRepr::Tuple => {
let fields = s.fields.iter().map(|field| {
let binding = &field.binding;
quote! {
let #binding = Decode::decode(c, r)?;
}
});
quote! {
let major = read_major(r)?;
match major.kind() {
MajorKind::Array => {
let len = read_uint(r, major)?;
if len != #len {
return Err(LengthOutOfRange::new::<Self>().into());
}
#(#fields)*
return Ok(#construct);
}
_ => {
return Err(UnexpectedCode::new::<Self>(major.into()).into());
}
}
}
}
StructRepr::Value => {
assert_eq!(s.fields.len(), 1);
let binding = &s.fields[0].binding;
quote! {
let #binding = Decode::decode(c, r)?;
return Ok(#construct);
}
}
StructRepr::Null => {
assert_eq!(s.fields.len(), 0);
quote! {
let major = read_major(r)?;
match major {
NULL => {
return Ok(#construct);
}
_ => {
return Err(UnexpectedCode::new::<Self>(major.into()).into());
}
}
}
}
}
}
fn gen_decode_union(u: &Union) -> TokenStream {
match u.repr {
UnionRepr::Keyed => {
let variants = u.variants.iter().map(|s| {
let key = rename(&syn::Member::Named(s.name.clone()), s.rename.as_ref());
let parse = gen_decode_struct(s);
quote! {
if key.as_str() == #key {
#parse
}
}
});
quote! {
let major = read_major(r)?;
if major.kind() != MajorKind::Map {
return Err(UnexpectedCode::new::<Self>(major.into()).into());
} else if read_uint(r, major)? != 1 {
return Err(LengthOutOfRange::new::<Self>().into());
}
let key: String = Decode::decode(c, r)?;
#(#variants;)*
Err(UnexpectedKey::new::<Self>(key).into())
}
}
UnionRepr::Kinded => {
let variants = u.variants.iter().map(|s| {
let parse = gen_decode_struct(s);
quote! {
let pos = r.seek(SeekFrom::Current(0))?;
let result: Result<Self> = (|| {
#parse
})();
match result {
Ok(res) => return Ok(res),
Err(err) => {
r.seek(SeekFrom::Start(pos))?;
}
};
}
});
quote! {
#(#variants;)*
Err(UnexpectedCode::new::<Self>(read_major(r)?.into()).into())
}
}
UnionRepr::String => {
let arms = u.variants.iter().map(|v| {
let pat = &*v.pat;
let value = rename(&syn::Member::Named(v.name.clone()), v.rename.as_ref());
quote!(#value => #pat)
});
quote! {
let key: String = Decode::decode(c, r)?;
let res = match key.as_str() {
#(#arms,)*
_ => return Err(UnexpectedKey::new::<Self>(key).into()),
};
Ok(res)
}
}
UnionRepr::Int => {
let arms = u.variants.iter().map(|v| {
let pat = &*v.pat;
quote!(x if x == #pat as u64 => #pat)
});
quote! {
let key: u64 = Decode::decode(c, r)?;
let res = match key {
#(#arms,)*
_ => return Err(UnexpectedKey::new::<Self>(key.to_string()).into()),
};
Ok(res)
}
}
UnionRepr::IntTuple => {
let variants = u.variants.iter().enumerate().map(|(i, s)| {
let i = i as u64;
let parse = gen_decode_struct(s);
quote!(#i => { #parse })
});
quote! {
let major = read_major(r)?;
if major.kind() != MajorKind::Array {
return Err(UnexpectedCode::new::<Self>(major.into()).into());
}
if read_uint(r, major)? != 2 {
return Err(LengthOutOfRange::new::<Self>().into());
}
let ty: u64 = Decode::decode(c, r)?;
match ty {
#(#variants,)*
_ => return Err(UnexpectedKey::new::<Self>(ty.to_string()).into()),
}
}
}
}
}