use quote::quote;
use syn::parse::Parser;
use syn::spanned::Spanned;
use crate::add_angle_bracketed;
fn add_io_generic(generics: &mut syn::Generics) -> syn::parse::Result<()> {
let io_predicate = generics.params.iter().position(|p| {
if let syn::GenericParam::Type(tp) = p {
tp.ident == "IO"
} else {
false
}
});
if io_predicate.is_none() {
let ident = syn::Ident::new("IO", proc_macro2::Span::call_site());
generics
.params
.push(syn::GenericParam::Type(syn::TypeParam::from(ident)));
}
generics.make_where_clause();
if let Some(ref mut where_clause) = generics.where_clause {
let io_where: syn::WherePredicate = syn::parse2(
quote! { IO: 'static + tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + std::marker::Unpin + std::marker::Send },
)?;
where_clause.predicates.push(io_where);
} else {
return Err(syn::Error::new(
generics.span(),
"wrong where clause on the struct",
));
}
Ok(())
}
fn generate_struct(mut item_struct: syn::ItemStruct) -> syn::parse::Result<syn::ItemStruct> {
if let syn::Fields::Named(ref mut fields) = item_struct.fields {
fields
.named
.push(syn::Field::parse_named.parse2(quote! { stream: IO })?);
fields.named.push(
syn::Field::parse_named
.parse2(quote! { addr: std::option::Option<prosa::io::SocketAddr> })?,
);
fields
.named
.push(syn::Field::parse_named.parse2(quote! { buffer: bytes::BytesMut })?);
fields
.named
.push(syn::Field::parse_named.parse2(quote! { socket_id: u32 })?);
}
add_io_generic(&mut item_struct.generics)?;
Ok(item_struct)
}
fn generate_struct_impl(
item_struct: &syn::ItemStruct,
) -> syn::parse::Result<proc_macro2::TokenStream> {
let item_ident = &item_struct.ident;
let item_generics = &item_struct.generics;
let item_other_fields =
if let syn::Fields::Named(syn::FieldsNamed { named, .. }) = &item_struct.fields {
let token_other_fields =
named.iter().filter_map(|f| {
if let Some(ident) = f.ident.as_ref().filter(|&i| {
i != "stream" && i != "addr" && i != "buffer" && i != "socket_id"
}) {
return Some(quote! { #ident: std::default::Default::default() });
}
None
});
quote! { #(#token_other_fields,)* }
} else {
proc_macro2::TokenStream::new()
};
Ok(quote! {
impl #item_generics std::convert::From<IO> for #item_ident #item_generics
where
IO: 'static + tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + std::os::fd::AsRawFd + std::marker::Unpin + std::marker::Send
{
fn from(stream: IO) -> Self {
let socket_id = stream.as_raw_fd() as u32;
#item_ident {
stream,
addr: None,
buffer: bytes::BytesMut::with_capacity(16384),
socket_id,
#item_other_fields
}
}
}
impl #item_generics std::convert::From<(IO, prosa::io::SocketAddr)> for #item_ident #item_generics
where
IO: 'static + tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + std::os::fd::AsRawFd + std::marker::Unpin + std::marker::Send
{
fn from(socket: (IO, prosa::io::SocketAddr)) -> Self {
let (stream, addr) = socket;
let socket_id = stream.as_raw_fd() as u32;
#item_ident {
stream,
addr: Some(addr),
buffer: bytes::BytesMut::with_capacity(16384),
socket_id,
#item_other_fields
}
}
}
})
}
fn add_struct_impl(mut item_impl: syn::ItemImpl) -> syn::parse::Result<syn::ItemImpl> {
add_io_generic(&mut item_impl.generics)?;
if let syn::Type::Path(syn::TypePath {
path: syn::Path {
ref mut segments, ..
},
..
}) = *item_impl.self_ty
{
if let Some(segment) = segments.first_mut() {
if let syn::PathArguments::AngleBracketed(ref mut ab) = segment.arguments {
add_angle_bracketed("IO", ab)?;
} else {
let io_arg: syn::AngleBracketedGenericArguments = syn::parse2(quote! { <IO> })?;
segment.arguments = syn::PathArguments::AngleBracketed(io_arg);
}
} else {
return Err(syn::Error::new(
segments.span(),
"no path segment for the impl",
));
}
} else {
return Err(syn::Error::new(
item_impl.self_ty.span(),
"wrong path for impl",
));
}
Ok(item_impl)
}
pub(crate) fn io_impl(item: syn::Item) -> syn::parse::Result<proc_macro2::TokenStream> {
match item {
syn::Item::Struct(item_struct) => {
let struct_output = generate_struct(item_struct)?;
let struct_impl = generate_struct_impl(&struct_output)?;
Ok(quote! {
#struct_output
#struct_impl
})
}
syn::Item::Impl(item_impl) => {
let impl_output = add_struct_impl(item_impl)?;
Ok(quote! {
#impl_output
})
}
_ => Err(syn::Error::new(
proc_macro2::Span::call_site(),
"expected struct or impl expression",
)),
}
}