1use {
2 std::{
3 net::{Ipv4Addr, Ipv6Addr},
4 str::FromStr,
5 },
6 proc_macro::TokenStream,
7 syn::spanned::Spanned,
8 quote::quote_spanned,
9};
10
11#[proc_macro]
19pub fn ipv4_network(input: TokenStream) -> TokenStream {
20 let input = syn::parse_macro_input!(input as syn::LitStr);
21 let span = input.span();
22 let s = input.value();
23 let output = match parse_ipv4_network(&s) {
24 Ok((addr, subnet_mask_bits)) => {
25 let [b0, b1, b2, b3] = addr.octets();
26 quote_spanned!(span=> {
27 ::netsim::Ipv4Network::new(::std::net::Ipv4Addr::new(#b0, #b1, #b2, #b3), #subnet_mask_bits)
28 })
29 },
30 Err(err) => {
31 quote_spanned!(span=> {
32 compile_error!(#err)
33 })
34 },
35 };
36 output.into()
37}
38
39fn parse_ipv4_network(s: &str) -> Result<(Ipv4Addr, u8), String> {
40 let (addr, subnet_mask_bits) = match s.split_once('/') {
41 None => return Err(String::from("missing '/' character")),
42 Some((addr, subnet_mask_bits)) => (addr, subnet_mask_bits),
43 };
44 let addr = match Ipv4Addr::from_str(addr) {
45 Err(err) => return Err(err.to_string()),
46 Ok(addr) => addr,
47 };
48 let subnet_mask_bits = match u8::from_str(subnet_mask_bits) {
49 Err(err) => return Err(err.to_string()),
50 Ok(subnet_mask_bits) => subnet_mask_bits,
51 };
52 if subnet_mask_bits > 32 {
53 return Err(String::from("subnet mask bits cannot be greater than 32"));
54 }
55 Ok((addr, subnet_mask_bits))
56}
57
58#[proc_macro]
66pub fn ipv6_network(input: TokenStream) -> TokenStream {
67 let input = syn::parse_macro_input!(input as syn::LitStr);
68 let span = input.span();
69 let s = input.value();
70 let output = match parse_ipv6_network(&s) {
71 Ok((addr, subnet_mask_bits)) => {
72 let [b0, b1, b2, b3, b4, b5, b6, b7] = addr.segments();
73 quote_spanned!(span=> {
74 ::netsim::Ipv6Network::new(
75 ::std::net::Ipv6Addr::new(#b0, #b1, #b2, #b3, #b4, #b5, #b6, #b7),
76 #subnet_mask_bits,
77 )
78 })
79 },
80 Err(err) => {
81 quote_spanned!(span=> {
82 compile_error!(#err)
83 })
84 },
85 };
86 output.into()
87}
88
89fn parse_ipv6_network(s: &str) -> Result<(Ipv6Addr, u8), String> {
90 let (addr, subnet_mask_bits) = match s.split_once('/') {
91 None => return Err(String::from("missing '/' character")),
92 Some((addr, subnet_mask_bits)) => (addr, subnet_mask_bits),
93 };
94 let addr = match Ipv6Addr::from_str(addr) {
95 Err(err) => return Err(err.to_string()),
96 Ok(addr) => addr,
97 };
98 let subnet_mask_bits = match u8::from_str(subnet_mask_bits) {
99 Err(err) => return Err(err.to_string()),
100 Ok(subnet_mask_bits) => subnet_mask_bits,
101 };
102 if subnet_mask_bits > 128 {
103 return Err(String::from("subnet mask bits cannot be greater than 128"));
104 }
105 Ok((addr, subnet_mask_bits))
106}
107
108#[proc_macro_attribute]
110pub fn isolate(_attr: TokenStream, input: TokenStream) -> TokenStream {
111 let item_fn = syn::parse_macro_input!(input as syn::ItemFn);
112 let span = item_fn.span();
113 let syn::ItemFn { attrs, vis, sig, block } = item_fn;
114 let is_async = sig.asyncness.is_some();
115 let output = if is_async {
116 quote_spanned! {span=>
117 #(#attrs)*
118 #vis #sig {
119 let machine = netsim::Machine::new().expect("error creating machine");
120 let join_handle = machine.spawn(async move #block);
121 join_handle.await.unwrap().unwrap()
122 }
123 }
124 } else {
125 quote_spanned! {span=>
126 #(#attrs)*
127 #vis #sig {
128 let machine = netsim::Machine::new().expect("error creating machine");
129 let join_handle = machine.spawn(async move {
130 ::netsim::tokio::task::spawn_blocking(move || #block).await.unwrap()
131 });
132 join_handle.join_blocking().unwrap().unwrap()
133 }
134 }
135 };
136 output.into()
137}
138