netsim_macros/
lib.rs

1use {
2    std::{
3        net::{Ipv4Addr, Ipv6Addr},
4        str::FromStr,
5    },
6    proc_macro::TokenStream,
7    syn::spanned::Spanned,
8    quote::quote_spanned,
9};
10
11/// Creates a `Ipv4Network` given an address range in CIDR notation.
12///
13/// # Example
14///
15/// ```rust
16/// assert_eq!(Ipv4Network::new(ipv4!("192.168.0.0"), 16), ipv4_network!("192.168.0.0/16"));
17/// ```
18#[proc_macro]
19pub fn ipv4_network(input: TokenStream) -> TokenStream {
20    let input = syn::parse_macro_input!(input as syn::LitStr);
21    let span = input.span();
22    let s = input.value();
23    let output = match parse_ipv4_network(&s) {
24        Ok((addr, subnet_mask_bits)) => {
25            let [b0, b1, b2, b3] = addr.octets();
26            quote_spanned!(span=> {
27                ::netsim::Ipv4Network::new(::std::net::Ipv4Addr::new(#b0, #b1, #b2, #b3), #subnet_mask_bits)
28            })
29        },
30        Err(err) => {
31            quote_spanned!(span=> {
32                compile_error!(#err)
33            })
34        },
35    };
36    output.into()
37}
38
39fn parse_ipv4_network(s: &str) -> Result<(Ipv4Addr, u8), String> {
40    let (addr, subnet_mask_bits) = match s.split_once('/') {
41        None => return Err(String::from("missing '/' character")),
42        Some((addr, subnet_mask_bits)) => (addr, subnet_mask_bits),
43    };
44    let addr = match Ipv4Addr::from_str(addr) {
45        Err(err) => return Err(err.to_string()),
46        Ok(addr) => addr,
47    };
48    let subnet_mask_bits = match u8::from_str(subnet_mask_bits) {
49        Err(err) => return Err(err.to_string()),
50        Ok(subnet_mask_bits) => subnet_mask_bits,
51    };
52    if subnet_mask_bits > 32 {
53        return Err(String::from("subnet mask bits cannot be greater than 32"));
54    }
55    Ok((addr, subnet_mask_bits))
56}
57
58/// Creates a `Ipv6Network` given an address range in CIDR notation.
59///
60/// # Example
61///
62/// ```rust
63/// assert_eq!(Ipv6Network::new(ipv6!("ff00::"), 8), ipv6_network!("ff00::/8"));
64/// ```
65#[proc_macro]
66pub fn ipv6_network(input: TokenStream) -> TokenStream {
67    let input = syn::parse_macro_input!(input as syn::LitStr);
68    let span = input.span();
69    let s = input.value();
70    let output = match parse_ipv6_network(&s) {
71        Ok((addr, subnet_mask_bits)) => {
72            let [b0, b1, b2, b3, b4, b5, b6, b7] = addr.segments();
73            quote_spanned!(span=> {
74                ::netsim::Ipv6Network::new(
75                    ::std::net::Ipv6Addr::new(#b0, #b1, #b2, #b3, #b4, #b5, #b6, #b7),
76                    #subnet_mask_bits,
77                )
78            })
79        },
80        Err(err) => {
81            quote_spanned!(span=> {
82                compile_error!(#err)
83            })
84        },
85    };
86    output.into()
87}
88
89fn parse_ipv6_network(s: &str) -> Result<(Ipv6Addr, u8), String> {
90    let (addr, subnet_mask_bits) = match s.split_once('/') {
91        None => return Err(String::from("missing '/' character")),
92        Some((addr, subnet_mask_bits)) => (addr, subnet_mask_bits),
93    };
94    let addr = match Ipv6Addr::from_str(addr) {
95        Err(err) => return Err(err.to_string()),
96        Ok(addr) => addr,
97    };
98    let subnet_mask_bits = match u8::from_str(subnet_mask_bits) {
99        Err(err) => return Err(err.to_string()),
100        Ok(subnet_mask_bits) => subnet_mask_bits,
101    };
102    if subnet_mask_bits > 128 {
103        return Err(String::from("subnet mask bits cannot be greater than 128"));
104    }
105    Ok((addr, subnet_mask_bits))
106}
107
108/// Makes a function run in an isolated network environment.
109#[proc_macro_attribute]
110pub fn isolate(_attr: TokenStream, input: TokenStream) -> TokenStream {
111    let item_fn = syn::parse_macro_input!(input as syn::ItemFn);
112    let span = item_fn.span();
113    let syn::ItemFn { attrs, vis, sig, block } = item_fn;
114    let is_async = sig.asyncness.is_some();
115    let output = if is_async {
116        quote_spanned! {span=>
117            #(#attrs)*
118            #vis #sig {
119                let machine = netsim::Machine::new().expect("error creating machine");
120                let join_handle = machine.spawn(async move #block);
121                join_handle.await.unwrap().unwrap()
122            }
123        }
124    } else {
125        quote_spanned! {span=>
126            #(#attrs)*
127            #vis #sig {
128                let machine = netsim::Machine::new().expect("error creating machine");
129                let join_handle = machine.spawn(async move {
130                    ::netsim::tokio::task::spawn_blocking(move || #block).await.unwrap()
131                });
132                join_handle.join_blocking().unwrap().unwrap()
133            }
134        }
135    };
136    output.into()
137}
138