use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Ident, ImplItem, ItemImpl, Type, parse_macro_input};
#[proc_macro_derive(Interceptor, attributes(next))]
pub fn derive_interceptor(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let (next_name, next_type) = match find_next_field(&input) {
Ok(field) => field,
Err(err) => return err.into_compile_error().into(),
};
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let expanded = quote! {
impl #impl_generics #name #ty_generics #where_clause {
#[doc(hidden)]
#[inline(always)]
fn __interceptor_inner_mut(&mut self) -> &mut #next_type {
&mut self.#next_name
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn interceptor(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as ItemImpl);
let mut override_methods: Vec<Ident> = Vec::new();
for item in &mut input.items {
if let ImplItem::Fn(method) = item {
let has_override = method
.attrs
.iter()
.any(|attr| attr.path().is_ident("overrides"));
if has_override {
override_methods.push(method.sig.ident.clone());
method
.attrs
.retain(|attr| !attr.path().is_ident("overrides"));
}
}
}
let self_ty = &input.self_ty;
let generics = &input.generics;
let where_clause = &generics.where_clause;
let (impl_generics, _, _) = generics.split_for_impl();
let protocol_methods = generate_protocol_methods(&override_methods);
let interceptor_methods = generate_interceptor_methods(&override_methods);
let protocol_method_names = [
"handle_read",
"poll_read",
"handle_write",
"poll_write",
"handle_event",
"poll_event",
"handle_timeout",
"poll_timeout",
"close",
];
let interceptor_method_names = [
"bind_local_stream",
"unbind_local_stream",
"bind_remote_stream",
"unbind_remote_stream",
];
let protocol_override_items: Vec<_> = input
.items
.iter()
.filter(|item| {
if let ImplItem::Fn(method) = item {
let name = method.sig.ident.to_string();
override_methods.contains(&method.sig.ident)
&& protocol_method_names.contains(&name.as_str())
} else {
false
}
})
.collect();
let interceptor_override_items: Vec<_> = input
.items
.iter()
.filter(|item| {
if let ImplItem::Fn(method) = item {
let name = method.sig.ident.to_string();
override_methods.contains(&method.sig.ident)
&& interceptor_method_names.contains(&name.as_str())
} else {
false
}
})
.collect();
let expanded = quote! {
impl #impl_generics sansio::Protocol<
TaggedPacket,
TaggedPacket,
()
> for #self_ty #where_clause {
type Rout = TaggedPacket;
type Wout = TaggedPacket;
type Eout = ();
type Error = Error;
type Time = std::time::Instant;
#protocol_methods
#(#protocol_override_items)*
}
impl #impl_generics Interceptor for #self_ty #where_clause {
#interceptor_methods
#(#interceptor_override_items)*
}
};
TokenStream::from(expanded)
}
fn find_next_field(input: &DeriveInput) -> syn::Result<(Ident, Type)> {
let fields = match &input.data {
Data::Struct(data) => &data.fields,
_ => {
return Err(syn::Error::new_spanned(
input,
"Interceptor can only be derived for structs",
));
}
};
let named_fields = match fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
input,
"Interceptor can only be derived for structs with named fields",
));
}
};
for field in named_fields {
let has_next_attr = field.attrs.iter().any(|attr| attr.path().is_ident("next"));
if has_next_attr {
let ident = field
.ident
.clone()
.ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?;
let ty = field.ty.clone();
return Ok((ident, ty));
}
}
Err(syn::Error::new_spanned(
input,
"No field marked with #[next] attribute. Mark the next interceptor field with #[next].",
))
}
fn generate_protocol_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
let mut methods = proc_macro2::TokenStream::new();
if !override_methods.iter().any(|m| m == "handle_read") {
methods.extend(quote! {
fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
self.__interceptor_inner_mut().handle_read(msg)
}
});
}
if !override_methods.iter().any(|m| m == "poll_read") {
methods.extend(quote! {
fn poll_read(&mut self) -> Option<Self::Rout> {
self.__interceptor_inner_mut().poll_read()
}
});
}
if !override_methods.iter().any(|m| m == "handle_write") {
methods.extend(quote! {
fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
self.__interceptor_inner_mut().handle_write(msg)
}
});
}
if !override_methods.iter().any(|m| m == "poll_write") {
methods.extend(quote! {
fn poll_write(&mut self) -> Option<Self::Wout> {
self.__interceptor_inner_mut().poll_write()
}
});
}
if !override_methods.iter().any(|m| m == "handle_event") {
methods.extend(quote! {
fn handle_event(&mut self, evt: ()) -> Result<(), Self::Error> {
self.__interceptor_inner_mut().handle_event(evt)
}
});
}
if !override_methods.iter().any(|m| m == "poll_event") {
methods.extend(quote! {
fn poll_event(&mut self) -> Option<Self::Eout> {
self.__interceptor_inner_mut().poll_event()
}
});
}
if !override_methods.iter().any(|m| m == "handle_timeout") {
methods.extend(quote! {
fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
self.__interceptor_inner_mut().handle_timeout(now)
}
});
}
if !override_methods.iter().any(|m| m == "poll_timeout") {
methods.extend(quote! {
fn poll_timeout(&mut self) -> Option<Self::Time> {
self.__interceptor_inner_mut().poll_timeout()
}
});
}
if !override_methods.iter().any(|m| m == "close") {
methods.extend(quote! {
fn close(&mut self) -> Result<(), Self::Error> {
self.__interceptor_inner_mut().close()
}
});
}
methods
}
fn generate_interceptor_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
let mut methods = proc_macro2::TokenStream::new();
if !override_methods.iter().any(|m| m == "bind_local_stream") {
methods.extend(quote! {
fn bind_local_stream(&mut self, info: &StreamInfo) {
self.__interceptor_inner_mut().bind_local_stream(info);
}
});
}
if !override_methods.iter().any(|m| m == "unbind_local_stream") {
methods.extend(quote! {
fn unbind_local_stream(&mut self, info: &StreamInfo) {
self.__interceptor_inner_mut().unbind_local_stream(info);
}
});
}
if !override_methods.iter().any(|m| m == "bind_remote_stream") {
methods.extend(quote! {
fn bind_remote_stream(&mut self, info: &StreamInfo) {
self.__interceptor_inner_mut().bind_remote_stream(info);
}
});
}
if !override_methods.iter().any(|m| m == "unbind_remote_stream") {
methods.extend(quote! {
fn unbind_remote_stream(&mut self, info: &StreamInfo) {
self.__interceptor_inner_mut().unbind_remote_stream(info);
}
});
}
methods
}