rtc_interceptor_derive/
lib.rs

1//! Derive macros for RTC Interceptor trait.
2//!
3//! This crate provides two macros that work together:
4//!
5//! - `#[derive(Interceptor)]` - Marks a struct as an interceptor and identifies the next field
6//! - `#[interceptor]` - Attribute macro for impl blocks to generate trait implementations
7//!
8//! # Design Pattern
9//!
10//! The design follows Rust's derive pattern (similar to `#[derive(Default)]` with `#[default]`):
11//!
12//! ```ignore
13//! use rtc_interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
14//! use std::collections::VecDeque;
15//!
16//! #[derive(Interceptor)]
17//! pub struct MyInterceptor<P: Interceptor> {
18//!     #[next]
19//!     next: P,  // The next interceptor in the chain (can use any field name)
20//!     buffer: VecDeque<TaggedPacket>,
21//! }
22//!
23//! #[interceptor]
24//! impl<P: Interceptor> MyInterceptor<P> {
25//!     #[overrides]
26//!     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
27//!         // Custom logic here
28//!         self.next.handle_read(msg)
29//!     }
30//! }
31//! ```
32//!
33//! # Pure Delegation (No Custom Logic)
34//!
35//! For interceptors that just pass through without modification:
36//!
37//! ```ignore
38//! #[derive(Interceptor)]
39//! pub struct PassthroughInterceptor<P: Interceptor> {
40//!     #[next]
41//!     next: P,
42//! }
43//!
44//! #[interceptor]
45//! impl<P: Interceptor> PassthroughInterceptor<P> {}
46//! // Empty impl block - all methods are auto-generated
47//! ```
48//!
49//! # Required Imports
50//!
51//! The macros require certain types to be in scope:
52//!
53//! ```ignore
54//! use rtc_interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
55//! // Or through rtc umbrella crate:
56//! use rtc::interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
57//! use rtc::shared::error::Error;
58//! use rtc::sansio;  // Required for macro-generated code
59//! ```
60
61use proc_macro::TokenStream;
62use quote::quote;
63use syn::{Data, DeriveInput, Fields, Ident, ImplItem, ItemImpl, Type, parse_macro_input};
64
65/// Derive macro that marks a struct as an interceptor.
66///
67/// This macro validates the struct has a `#[next]` field and generates
68/// a hidden accessor method. It does NOT generate Protocol/Interceptor implementations -
69/// those are generated by the `#[interceptor]` attribute on the impl block.
70///
71/// # Attributes
72///
73/// - `#[next]` - Mark the field that contains the next interceptor in the chain (required)
74///
75/// # Examples
76///
77/// Pure delegation (no custom logic):
78/// ```ignore
79/// #[derive(Interceptor)]
80/// pub struct PassthroughInterceptor<P: Interceptor> {
81///     #[next]
82///     next: P,
83/// }
84///
85/// #[interceptor]
86/// impl<P: Interceptor> PassthroughInterceptor<P> {}
87/// ```
88///
89/// With custom logic:
90/// ```ignore
91/// #[derive(Interceptor)]
92/// pub struct MyInterceptor<P: Interceptor> {
93///     #[next]
94///     next: P,
95///     buffer: VecDeque<TaggedPacket>,
96/// }
97///
98/// #[interceptor]
99/// impl<P: Interceptor> MyInterceptor<P> {
100///     #[overrides]
101///     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
102///         // Custom logic
103///         self.next.handle_read(msg)
104///     }
105/// }
106/// ```
107#[proc_macro_derive(Interceptor, attributes(next))]
108pub fn derive_interceptor(input: TokenStream) -> TokenStream {
109    let input = parse_macro_input!(input as DeriveInput);
110
111    // Find the next field marked with #[next] - validates it exists and gets its type
112    let (next_name, next_type) = match find_next_field(&input) {
113        Ok(field) => field,
114        Err(err) => return err.into_compile_error().into(),
115    };
116
117    let name = &input.ident;
118    let generics = &input.generics;
119    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
120
121    // Generate hidden accessor method that #[interceptor] will use
122    // This allows #[interceptor] to work without knowing the field name
123    let expanded = quote! {
124        impl #impl_generics #name #ty_generics #where_clause {
125            /// Hidden accessor for the next interceptor (used by #[interceptor] macro)
126            #[doc(hidden)]
127            #[inline(always)]
128            fn __interceptor_inner_mut(&mut self) -> &mut #next_type {
129                &mut self.#next_name
130            }
131        }
132    };
133
134    TokenStream::from(expanded)
135}
136
137/// Attribute macro for impl blocks to generate Protocol and Interceptor implementations.
138///
139/// This macro generates the trait implementations, delegating non-overridden
140/// methods to the next interceptor field (identified by `#[next]` in the struct).
141///
142/// **Important:** The struct must have `#[derive(Interceptor)]` with a `#[next]` field.
143///
144/// # Attributes
145///
146/// - `#[overrides]` - Mark methods that provide custom implementations
147///
148/// # Examples
149///
150/// With custom logic:
151/// ```ignore
152/// #[derive(Interceptor)]
153/// pub struct MyInterceptor<P: Interceptor> {
154///     #[next]
155///     next: P,  // Can use any field name
156///     buffer: VecDeque<TaggedPacket>,
157/// }
158///
159/// #[interceptor]
160/// impl<P: Interceptor> MyInterceptor<P> {
161///     #[overrides]
162///     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
163///         // Custom logic
164///         self.next.handle_read(msg)
165///     }
166/// }
167/// ```
168///
169/// Pure delegation (no custom logic):
170/// ```ignore
171/// #[derive(Interceptor)]
172/// pub struct PassthroughInterceptor<P: Interceptor> {
173///     #[next]
174///     wrapped: P,  // Can use any field name
175/// }
176///
177/// #[interceptor]
178/// impl<P: Interceptor> PassthroughInterceptor<P> {}
179/// // Empty impl - all methods delegate to wrapped field
180/// ```
181#[proc_macro_attribute]
182pub fn interceptor(_attr: TokenStream, item: TokenStream) -> TokenStream {
183    let mut input = parse_macro_input!(item as ItemImpl);
184
185    // Note: Field name is no longer needed here - we use __interceptor_inner_mut() accessor
186    // which is generated by #[derive(Interceptor)]
187
188    // Collect method names marked with #[overrides]
189    let mut override_methods: Vec<Ident> = Vec::new();
190
191    for item in &mut input.items {
192        if let ImplItem::Fn(method) = item {
193            // Check if method has #[overrides] attribute
194            let has_override = method
195                .attrs
196                .iter()
197                .any(|attr| attr.path().is_ident("overrides"));
198
199            if has_override {
200                override_methods.push(method.sig.ident.clone());
201                // Remove the #[overrides] attribute
202                method
203                    .attrs
204                    .retain(|attr| !attr.path().is_ident("overrides"));
205            }
206        }
207    }
208
209    // Get type name and generics from the impl
210    let self_ty = &input.self_ty;
211    let generics = &input.generics;
212    let where_clause = &generics.where_clause;
213    let (impl_generics, _, _) = generics.split_for_impl();
214
215    // Generate Protocol methods that are NOT overridden (using accessor method)
216    let protocol_methods = generate_protocol_methods(&override_methods);
217    let interceptor_methods = generate_interceptor_methods(&override_methods);
218
219    // Protocol method names
220    let protocol_method_names = [
221        "handle_read",
222        "poll_read",
223        "handle_write",
224        "poll_write",
225        "handle_event",
226        "poll_event",
227        "handle_timeout",
228        "poll_timeout",
229        "close",
230    ];
231
232    // Interceptor method names
233    let interceptor_method_names = [
234        "bind_local_stream",
235        "unbind_local_stream",
236        "bind_remote_stream",
237        "unbind_remote_stream",
238    ];
239
240    // Extract Protocol overridden methods
241    let protocol_override_items: Vec<_> = input
242        .items
243        .iter()
244        .filter(|item| {
245            if let ImplItem::Fn(method) = item {
246                let name = method.sig.ident.to_string();
247                override_methods.contains(&method.sig.ident)
248                    && protocol_method_names.contains(&name.as_str())
249            } else {
250                false
251            }
252        })
253        .collect();
254
255    // Extract Interceptor overridden methods
256    let interceptor_override_items: Vec<_> = input
257        .items
258        .iter()
259        .filter(|item| {
260            if let ImplItem::Fn(method) = item {
261                let name = method.sig.ident.to_string();
262                override_methods.contains(&method.sig.ident)
263                    && interceptor_method_names.contains(&name.as_str())
264            } else {
265                false
266            }
267        })
268        .collect();
269
270    let expanded = quote! {
271        impl #impl_generics sansio::Protocol<
272            TaggedPacket,
273            TaggedPacket,
274            ()
275        > for #self_ty #where_clause {
276            type Rout = TaggedPacket;
277            type Wout = TaggedPacket;
278            type Eout = ();
279            type Error = Error;
280            type Time = std::time::Instant;
281
282            #protocol_methods
283            #(#protocol_override_items)*
284        }
285
286        impl #impl_generics Interceptor for #self_ty #where_clause {
287            #interceptor_methods
288            #(#interceptor_override_items)*
289        }
290    };
291
292    TokenStream::from(expanded)
293}
294
295/// Find the field marked with #[next] attribute, returning both name and type
296fn find_next_field(input: &DeriveInput) -> syn::Result<(Ident, Type)> {
297    let fields = match &input.data {
298        Data::Struct(data) => &data.fields,
299        _ => {
300            return Err(syn::Error::new_spanned(
301                input,
302                "Interceptor can only be derived for structs",
303            ));
304        }
305    };
306
307    let named_fields = match fields {
308        Fields::Named(fields) => &fields.named,
309        _ => {
310            return Err(syn::Error::new_spanned(
311                input,
312                "Interceptor can only be derived for structs with named fields",
313            ));
314        }
315    };
316
317    for field in named_fields {
318        let has_next_attr = field.attrs.iter().any(|attr| attr.path().is_ident("next"));
319        if has_next_attr {
320            let ident = field
321                .ident
322                .clone()
323                .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?;
324            let ty = field.ty.clone();
325            return Ok((ident, ty));
326        }
327    }
328
329    Err(syn::Error::new_spanned(
330        input,
331        "No field marked with #[next] attribute. Mark the next interceptor field with #[next].",
332    ))
333}
334
335/// Generate Protocol methods that delegate to inner, excluding overridden ones
336fn generate_protocol_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
337    let mut methods = proc_macro2::TokenStream::new();
338
339    if !override_methods.iter().any(|m| m == "handle_read") {
340        methods.extend(quote! {
341            fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
342                self.__interceptor_inner_mut().handle_read(msg)
343            }
344        });
345    }
346
347    if !override_methods.iter().any(|m| m == "poll_read") {
348        methods.extend(quote! {
349            fn poll_read(&mut self) -> Option<Self::Rout> {
350                self.__interceptor_inner_mut().poll_read()
351            }
352        });
353    }
354
355    if !override_methods.iter().any(|m| m == "handle_write") {
356        methods.extend(quote! {
357            fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
358                self.__interceptor_inner_mut().handle_write(msg)
359            }
360        });
361    }
362
363    if !override_methods.iter().any(|m| m == "poll_write") {
364        methods.extend(quote! {
365            fn poll_write(&mut self) -> Option<Self::Wout> {
366                self.__interceptor_inner_mut().poll_write()
367            }
368        });
369    }
370
371    if !override_methods.iter().any(|m| m == "handle_event") {
372        methods.extend(quote! {
373            fn handle_event(&mut self, evt: ()) -> Result<(), Self::Error> {
374                self.__interceptor_inner_mut().handle_event(evt)
375            }
376        });
377    }
378
379    if !override_methods.iter().any(|m| m == "poll_event") {
380        methods.extend(quote! {
381            fn poll_event(&mut self) -> Option<Self::Eout> {
382                self.__interceptor_inner_mut().poll_event()
383            }
384        });
385    }
386
387    if !override_methods.iter().any(|m| m == "handle_timeout") {
388        methods.extend(quote! {
389            fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
390                self.__interceptor_inner_mut().handle_timeout(now)
391            }
392        });
393    }
394
395    if !override_methods.iter().any(|m| m == "poll_timeout") {
396        methods.extend(quote! {
397            fn poll_timeout(&mut self) -> Option<Self::Time> {
398                self.__interceptor_inner_mut().poll_timeout()
399            }
400        });
401    }
402
403    if !override_methods.iter().any(|m| m == "close") {
404        methods.extend(quote! {
405            fn close(&mut self) -> Result<(), Self::Error> {
406                self.__interceptor_inner_mut().close()
407            }
408        });
409    }
410
411    methods
412}
413
414/// Generate Interceptor methods that delegate to inner, excluding overridden ones
415fn generate_interceptor_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
416    let mut methods = proc_macro2::TokenStream::new();
417
418    if !override_methods.iter().any(|m| m == "bind_local_stream") {
419        methods.extend(quote! {
420            fn bind_local_stream(&mut self, info: &StreamInfo) {
421                self.__interceptor_inner_mut().bind_local_stream(info);
422            }
423        });
424    }
425
426    if !override_methods.iter().any(|m| m == "unbind_local_stream") {
427        methods.extend(quote! {
428            fn unbind_local_stream(&mut self, info: &StreamInfo) {
429                self.__interceptor_inner_mut().unbind_local_stream(info);
430            }
431        });
432    }
433
434    if !override_methods.iter().any(|m| m == "bind_remote_stream") {
435        methods.extend(quote! {
436            fn bind_remote_stream(&mut self, info: &StreamInfo) {
437                self.__interceptor_inner_mut().bind_remote_stream(info);
438            }
439        });
440    }
441
442    if !override_methods.iter().any(|m| m == "unbind_remote_stream") {
443        methods.extend(quote! {
444            fn unbind_remote_stream(&mut self, info: &StreamInfo) {
445                self.__interceptor_inner_mut().unbind_remote_stream(info);
446            }
447        });
448    }
449
450    methods
451}