rtc-interceptor-derive 0.20.0-alpha.1

Derive macros for RTC Interceptor trait
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
//! Derive macros for RTC Interceptor trait.
//!
//! This crate provides two macros that work together:
//!
//! - `#[derive(Interceptor)]` - Marks a struct as an interceptor and identifies the next field
//! - `#[interceptor]` - Attribute macro for impl blocks to generate trait implementations
//!
//! # Design Pattern
//!
//! The design follows Rust's derive pattern (similar to `#[derive(Default)]` with `#[default]`):
//!
//! ```ignore
//! use rtc_interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
//! use std::collections::VecDeque;
//!
//! #[derive(Interceptor)]
//! pub struct MyInterceptor<P: Interceptor> {
//!     #[next]
//!     next: P,  // The next interceptor in the chain (can use any field name)
//!     buffer: VecDeque<TaggedPacket>,
//! }
//!
//! #[interceptor]
//! impl<P: Interceptor> MyInterceptor<P> {
//!     #[overrides]
//!     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
//!         // Custom logic here
//!         self.next.handle_read(msg)
//!     }
//! }
//! ```
//!
//! # Pure Delegation (No Custom Logic)
//!
//! For interceptors that just pass through without modification:
//!
//! ```ignore
//! #[derive(Interceptor)]
//! pub struct PassthroughInterceptor<P: Interceptor> {
//!     #[next]
//!     next: P,
//! }
//!
//! #[interceptor]
//! impl<P: Interceptor> PassthroughInterceptor<P> {}
//! // Empty impl block - all methods are auto-generated
//! ```
//!
//! # Required Imports
//!
//! The macros require certain types to be in scope:
//!
//! ```ignore
//! use rtc_interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
//! // Or through rtc umbrella crate:
//! use rtc::interceptor::{Interceptor, interceptor, TaggedPacket, Packet, StreamInfo};
//! use rtc::shared::error::Error;
//! use rtc::sansio;  // Required for macro-generated code
//! ```

use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Ident, ImplItem, ItemImpl, Type, parse_macro_input};

/// Derive macro that marks a struct as an interceptor.
///
/// This macro validates the struct has a `#[next]` field and generates
/// a hidden accessor method. It does NOT generate Protocol/Interceptor implementations -
/// those are generated by the `#[interceptor]` attribute on the impl block.
///
/// # Attributes
///
/// - `#[next]` - Mark the field that contains the next interceptor in the chain (required)
///
/// # Examples
///
/// Pure delegation (no custom logic):
/// ```ignore
/// #[derive(Interceptor)]
/// pub struct PassthroughInterceptor<P: Interceptor> {
///     #[next]
///     next: P,
/// }
///
/// #[interceptor]
/// impl<P: Interceptor> PassthroughInterceptor<P> {}
/// ```
///
/// With custom logic:
/// ```ignore
/// #[derive(Interceptor)]
/// pub struct MyInterceptor<P: Interceptor> {
///     #[next]
///     next: P,
///     buffer: VecDeque<TaggedPacket>,
/// }
///
/// #[interceptor]
/// impl<P: Interceptor> MyInterceptor<P> {
///     #[overrides]
///     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
///         // Custom logic
///         self.next.handle_read(msg)
///     }
/// }
/// ```
#[proc_macro_derive(Interceptor, attributes(next))]
pub fn derive_interceptor(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    // Find the next field marked with #[next] - validates it exists and gets its type
    let (next_name, next_type) = match find_next_field(&input) {
        Ok(field) => field,
        Err(err) => return err.into_compile_error().into(),
    };

    let name = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    // Generate hidden accessor method that #[interceptor] will use
    // This allows #[interceptor] to work without knowing the field name
    let expanded = quote! {
        impl #impl_generics #name #ty_generics #where_clause {
            /// Hidden accessor for the next interceptor (used by #[interceptor] macro)
            #[doc(hidden)]
            #[inline(always)]
            fn __interceptor_inner_mut(&mut self) -> &mut #next_type {
                &mut self.#next_name
            }
        }
    };

    TokenStream::from(expanded)
}

/// Attribute macro for impl blocks to generate Protocol and Interceptor implementations.
///
/// This macro generates the trait implementations, delegating non-overridden
/// methods to the next interceptor field (identified by `#[next]` in the struct).
///
/// **Important:** The struct must have `#[derive(Interceptor)]` with a `#[next]` field.
///
/// # Attributes
///
/// - `#[overrides]` - Mark methods that provide custom implementations
///
/// # Examples
///
/// With custom logic:
/// ```ignore
/// #[derive(Interceptor)]
/// pub struct MyInterceptor<P: Interceptor> {
///     #[next]
///     next: P,  // Can use any field name
///     buffer: VecDeque<TaggedPacket>,
/// }
///
/// #[interceptor]
/// impl<P: Interceptor> MyInterceptor<P> {
///     #[overrides]
///     fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
///         // Custom logic
///         self.next.handle_read(msg)
///     }
/// }
/// ```
///
/// Pure delegation (no custom logic):
/// ```ignore
/// #[derive(Interceptor)]
/// pub struct PassthroughInterceptor<P: Interceptor> {
///     #[next]
///     wrapped: P,  // Can use any field name
/// }
///
/// #[interceptor]
/// impl<P: Interceptor> PassthroughInterceptor<P> {}
/// // Empty impl - all methods delegate to wrapped field
/// ```
#[proc_macro_attribute]
pub fn interceptor(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut input = parse_macro_input!(item as ItemImpl);

    // Note: Field name is no longer needed here - we use __interceptor_inner_mut() accessor
    // which is generated by #[derive(Interceptor)]

    // Collect method names marked with #[overrides]
    let mut override_methods: Vec<Ident> = Vec::new();

    for item in &mut input.items {
        if let ImplItem::Fn(method) = item {
            // Check if method has #[overrides] attribute
            let has_override = method
                .attrs
                .iter()
                .any(|attr| attr.path().is_ident("overrides"));

            if has_override {
                override_methods.push(method.sig.ident.clone());
                // Remove the #[overrides] attribute
                method
                    .attrs
                    .retain(|attr| !attr.path().is_ident("overrides"));
            }
        }
    }

    // Get type name and generics from the impl
    let self_ty = &input.self_ty;
    let generics = &input.generics;
    let where_clause = &generics.where_clause;
    let (impl_generics, _, _) = generics.split_for_impl();

    // Generate Protocol methods that are NOT overridden (using accessor method)
    let protocol_methods = generate_protocol_methods(&override_methods);
    let interceptor_methods = generate_interceptor_methods(&override_methods);

    // Protocol method names
    let protocol_method_names = [
        "handle_read",
        "poll_read",
        "handle_write",
        "poll_write",
        "handle_event",
        "poll_event",
        "handle_timeout",
        "poll_timeout",
        "close",
    ];

    // Interceptor method names
    let interceptor_method_names = [
        "bind_local_stream",
        "unbind_local_stream",
        "bind_remote_stream",
        "unbind_remote_stream",
    ];

    // Extract Protocol overridden methods
    let protocol_override_items: Vec<_> = input
        .items
        .iter()
        .filter(|item| {
            if let ImplItem::Fn(method) = item {
                let name = method.sig.ident.to_string();
                override_methods.contains(&method.sig.ident)
                    && protocol_method_names.contains(&name.as_str())
            } else {
                false
            }
        })
        .collect();

    // Extract Interceptor overridden methods
    let interceptor_override_items: Vec<_> = input
        .items
        .iter()
        .filter(|item| {
            if let ImplItem::Fn(method) = item {
                let name = method.sig.ident.to_string();
                override_methods.contains(&method.sig.ident)
                    && interceptor_method_names.contains(&name.as_str())
            } else {
                false
            }
        })
        .collect();

    let expanded = quote! {
        impl #impl_generics sansio::Protocol<
            TaggedPacket,
            TaggedPacket,
            ()
        > for #self_ty #where_clause {
            type Rout = TaggedPacket;
            type Wout = TaggedPacket;
            type Eout = ();
            type Error = Error;
            type Time = std::time::Instant;

            #protocol_methods
            #(#protocol_override_items)*
        }

        impl #impl_generics Interceptor for #self_ty #where_clause {
            #interceptor_methods
            #(#interceptor_override_items)*
        }
    };

    TokenStream::from(expanded)
}

/// Find the field marked with #[next] attribute, returning both name and type
fn find_next_field(input: &DeriveInput) -> syn::Result<(Ident, Type)> {
    let fields = match &input.data {
        Data::Struct(data) => &data.fields,
        _ => {
            return Err(syn::Error::new_spanned(
                input,
                "Interceptor can only be derived for structs",
            ));
        }
    };

    let named_fields = match fields {
        Fields::Named(fields) => &fields.named,
        _ => {
            return Err(syn::Error::new_spanned(
                input,
                "Interceptor can only be derived for structs with named fields",
            ));
        }
    };

    for field in named_fields {
        let has_next_attr = field.attrs.iter().any(|attr| attr.path().is_ident("next"));
        if has_next_attr {
            let ident = field
                .ident
                .clone()
                .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?;
            let ty = field.ty.clone();
            return Ok((ident, ty));
        }
    }

    Err(syn::Error::new_spanned(
        input,
        "No field marked with #[next] attribute. Mark the next interceptor field with #[next].",
    ))
}

/// Generate Protocol methods that delegate to inner, excluding overridden ones
fn generate_protocol_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
    let mut methods = proc_macro2::TokenStream::new();

    if !override_methods.iter().any(|m| m == "handle_read") {
        methods.extend(quote! {
            fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
                self.__interceptor_inner_mut().handle_read(msg)
            }
        });
    }

    if !override_methods.iter().any(|m| m == "poll_read") {
        methods.extend(quote! {
            fn poll_read(&mut self) -> Option<Self::Rout> {
                self.__interceptor_inner_mut().poll_read()
            }
        });
    }

    if !override_methods.iter().any(|m| m == "handle_write") {
        methods.extend(quote! {
            fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
                self.__interceptor_inner_mut().handle_write(msg)
            }
        });
    }

    if !override_methods.iter().any(|m| m == "poll_write") {
        methods.extend(quote! {
            fn poll_write(&mut self) -> Option<Self::Wout> {
                self.__interceptor_inner_mut().poll_write()
            }
        });
    }

    if !override_methods.iter().any(|m| m == "handle_event") {
        methods.extend(quote! {
            fn handle_event(&mut self, evt: ()) -> Result<(), Self::Error> {
                self.__interceptor_inner_mut().handle_event(evt)
            }
        });
    }

    if !override_methods.iter().any(|m| m == "poll_event") {
        methods.extend(quote! {
            fn poll_event(&mut self) -> Option<Self::Eout> {
                self.__interceptor_inner_mut().poll_event()
            }
        });
    }

    if !override_methods.iter().any(|m| m == "handle_timeout") {
        methods.extend(quote! {
            fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
                self.__interceptor_inner_mut().handle_timeout(now)
            }
        });
    }

    if !override_methods.iter().any(|m| m == "poll_timeout") {
        methods.extend(quote! {
            fn poll_timeout(&mut self) -> Option<Self::Time> {
                self.__interceptor_inner_mut().poll_timeout()
            }
        });
    }

    if !override_methods.iter().any(|m| m == "close") {
        methods.extend(quote! {
            fn close(&mut self) -> Result<(), Self::Error> {
                self.__interceptor_inner_mut().close()
            }
        });
    }

    methods
}

/// Generate Interceptor methods that delegate to inner, excluding overridden ones
fn generate_interceptor_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
    let mut methods = proc_macro2::TokenStream::new();

    if !override_methods.iter().any(|m| m == "bind_local_stream") {
        methods.extend(quote! {
            fn bind_local_stream(&mut self, info: &StreamInfo) {
                self.__interceptor_inner_mut().bind_local_stream(info);
            }
        });
    }

    if !override_methods.iter().any(|m| m == "unbind_local_stream") {
        methods.extend(quote! {
            fn unbind_local_stream(&mut self, info: &StreamInfo) {
                self.__interceptor_inner_mut().unbind_local_stream(info);
            }
        });
    }

    if !override_methods.iter().any(|m| m == "bind_remote_stream") {
        methods.extend(quote! {
            fn bind_remote_stream(&mut self, info: &StreamInfo) {
                self.__interceptor_inner_mut().bind_remote_stream(info);
            }
        });
    }

    if !override_methods.iter().any(|m| m == "unbind_remote_stream") {
        methods.extend(quote! {
            fn unbind_remote_stream(&mut self, info: &StreamInfo) {
                self.__interceptor_inner_mut().unbind_remote_stream(info);
            }
        });
    }

    methods
}