rtc_interceptor_derive/
lib.rs1use proc_macro::TokenStream;
62use quote::quote;
63use syn::{Data, DeriveInput, Fields, Ident, ImplItem, ItemImpl, Type, parse_macro_input};
64
65#[proc_macro_derive(Interceptor, attributes(next))]
108pub fn derive_interceptor(input: TokenStream) -> TokenStream {
109 let input = parse_macro_input!(input as DeriveInput);
110
111 let (next_name, next_type) = match find_next_field(&input) {
113 Ok(field) => field,
114 Err(err) => return err.into_compile_error().into(),
115 };
116
117 let name = &input.ident;
118 let generics = &input.generics;
119 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
120
121 let expanded = quote! {
124 impl #impl_generics #name #ty_generics #where_clause {
125 #[doc(hidden)]
127 #[inline(always)]
128 fn __interceptor_inner_mut(&mut self) -> &mut #next_type {
129 &mut self.#next_name
130 }
131 }
132 };
133
134 TokenStream::from(expanded)
135}
136
137#[proc_macro_attribute]
182pub fn interceptor(_attr: TokenStream, item: TokenStream) -> TokenStream {
183 let mut input = parse_macro_input!(item as ItemImpl);
184
185 let mut override_methods: Vec<Ident> = Vec::new();
190
191 for item in &mut input.items {
192 if let ImplItem::Fn(method) = item {
193 let has_override = method
195 .attrs
196 .iter()
197 .any(|attr| attr.path().is_ident("overrides"));
198
199 if has_override {
200 override_methods.push(method.sig.ident.clone());
201 method
203 .attrs
204 .retain(|attr| !attr.path().is_ident("overrides"));
205 }
206 }
207 }
208
209 let self_ty = &input.self_ty;
211 let generics = &input.generics;
212 let where_clause = &generics.where_clause;
213 let (impl_generics, _, _) = generics.split_for_impl();
214
215 let protocol_methods = generate_protocol_methods(&override_methods);
217 let interceptor_methods = generate_interceptor_methods(&override_methods);
218
219 let protocol_method_names = [
221 "handle_read",
222 "poll_read",
223 "handle_write",
224 "poll_write",
225 "handle_event",
226 "poll_event",
227 "handle_timeout",
228 "poll_timeout",
229 "close",
230 ];
231
232 let interceptor_method_names = [
234 "bind_local_stream",
235 "unbind_local_stream",
236 "bind_remote_stream",
237 "unbind_remote_stream",
238 ];
239
240 let protocol_override_items: Vec<_> = input
242 .items
243 .iter()
244 .filter(|item| {
245 if let ImplItem::Fn(method) = item {
246 let name = method.sig.ident.to_string();
247 override_methods.contains(&method.sig.ident)
248 && protocol_method_names.contains(&name.as_str())
249 } else {
250 false
251 }
252 })
253 .collect();
254
255 let interceptor_override_items: Vec<_> = input
257 .items
258 .iter()
259 .filter(|item| {
260 if let ImplItem::Fn(method) = item {
261 let name = method.sig.ident.to_string();
262 override_methods.contains(&method.sig.ident)
263 && interceptor_method_names.contains(&name.as_str())
264 } else {
265 false
266 }
267 })
268 .collect();
269
270 let expanded = quote! {
271 impl #impl_generics sansio::Protocol<
272 TaggedPacket,
273 TaggedPacket,
274 ()
275 > for #self_ty #where_clause {
276 type Rout = TaggedPacket;
277 type Wout = TaggedPacket;
278 type Eout = ();
279 type Error = Error;
280 type Time = std::time::Instant;
281
282 #protocol_methods
283 #(#protocol_override_items)*
284 }
285
286 impl #impl_generics Interceptor for #self_ty #where_clause {
287 #interceptor_methods
288 #(#interceptor_override_items)*
289 }
290 };
291
292 TokenStream::from(expanded)
293}
294
295fn find_next_field(input: &DeriveInput) -> syn::Result<(Ident, Type)> {
297 let fields = match &input.data {
298 Data::Struct(data) => &data.fields,
299 _ => {
300 return Err(syn::Error::new_spanned(
301 input,
302 "Interceptor can only be derived for structs",
303 ));
304 }
305 };
306
307 let named_fields = match fields {
308 Fields::Named(fields) => &fields.named,
309 _ => {
310 return Err(syn::Error::new_spanned(
311 input,
312 "Interceptor can only be derived for structs with named fields",
313 ));
314 }
315 };
316
317 for field in named_fields {
318 let has_next_attr = field.attrs.iter().any(|attr| attr.path().is_ident("next"));
319 if has_next_attr {
320 let ident = field
321 .ident
322 .clone()
323 .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?;
324 let ty = field.ty.clone();
325 return Ok((ident, ty));
326 }
327 }
328
329 Err(syn::Error::new_spanned(
330 input,
331 "No field marked with #[next] attribute. Mark the next interceptor field with #[next].",
332 ))
333}
334
335fn generate_protocol_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
337 let mut methods = proc_macro2::TokenStream::new();
338
339 if !override_methods.iter().any(|m| m == "handle_read") {
340 methods.extend(quote! {
341 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
342 self.__interceptor_inner_mut().handle_read(msg)
343 }
344 });
345 }
346
347 if !override_methods.iter().any(|m| m == "poll_read") {
348 methods.extend(quote! {
349 fn poll_read(&mut self) -> Option<Self::Rout> {
350 self.__interceptor_inner_mut().poll_read()
351 }
352 });
353 }
354
355 if !override_methods.iter().any(|m| m == "handle_write") {
356 methods.extend(quote! {
357 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
358 self.__interceptor_inner_mut().handle_write(msg)
359 }
360 });
361 }
362
363 if !override_methods.iter().any(|m| m == "poll_write") {
364 methods.extend(quote! {
365 fn poll_write(&mut self) -> Option<Self::Wout> {
366 self.__interceptor_inner_mut().poll_write()
367 }
368 });
369 }
370
371 if !override_methods.iter().any(|m| m == "handle_event") {
372 methods.extend(quote! {
373 fn handle_event(&mut self, evt: ()) -> Result<(), Self::Error> {
374 self.__interceptor_inner_mut().handle_event(evt)
375 }
376 });
377 }
378
379 if !override_methods.iter().any(|m| m == "poll_event") {
380 methods.extend(quote! {
381 fn poll_event(&mut self) -> Option<Self::Eout> {
382 self.__interceptor_inner_mut().poll_event()
383 }
384 });
385 }
386
387 if !override_methods.iter().any(|m| m == "handle_timeout") {
388 methods.extend(quote! {
389 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
390 self.__interceptor_inner_mut().handle_timeout(now)
391 }
392 });
393 }
394
395 if !override_methods.iter().any(|m| m == "poll_timeout") {
396 methods.extend(quote! {
397 fn poll_timeout(&mut self) -> Option<Self::Time> {
398 self.__interceptor_inner_mut().poll_timeout()
399 }
400 });
401 }
402
403 if !override_methods.iter().any(|m| m == "close") {
404 methods.extend(quote! {
405 fn close(&mut self) -> Result<(), Self::Error> {
406 self.__interceptor_inner_mut().close()
407 }
408 });
409 }
410
411 methods
412}
413
414fn generate_interceptor_methods(override_methods: &[Ident]) -> proc_macro2::TokenStream {
416 let mut methods = proc_macro2::TokenStream::new();
417
418 if !override_methods.iter().any(|m| m == "bind_local_stream") {
419 methods.extend(quote! {
420 fn bind_local_stream(&mut self, info: &StreamInfo) {
421 self.__interceptor_inner_mut().bind_local_stream(info);
422 }
423 });
424 }
425
426 if !override_methods.iter().any(|m| m == "unbind_local_stream") {
427 methods.extend(quote! {
428 fn unbind_local_stream(&mut self, info: &StreamInfo) {
429 self.__interceptor_inner_mut().unbind_local_stream(info);
430 }
431 });
432 }
433
434 if !override_methods.iter().any(|m| m == "bind_remote_stream") {
435 methods.extend(quote! {
436 fn bind_remote_stream(&mut self, info: &StreamInfo) {
437 self.__interceptor_inner_mut().bind_remote_stream(info);
438 }
439 });
440 }
441
442 if !override_methods.iter().any(|m| m == "unbind_remote_stream") {
443 methods.extend(quote! {
444 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
445 self.__interceptor_inner_mut().unbind_remote_stream(info);
446 }
447 });
448 }
449
450 methods
451}