privsep_derive/
lib.rs

1//! Helper macros for the [`privsep`] create.
2//!
3//! [`privsep`]: http://docs.rs/privsep/
4
5use convert_case::{Case, Casing};
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{quote, ToTokens};
8use std::collections::{HashMap, HashSet};
9use syn::{
10    parse::Parse, parse_macro_input, Attribute, Error, ItemEnum, Lit, LitStr, Meta, MetaList,
11    MetaNameValue, NestedMeta, Path,
12};
13
14/// Derive privsep processes from an enum.
15///
16/// Attributes:
17/// - `connect`: Connect child with the specified peer.
18/// - `main_path`: Set the path of the parent or process `main` function.
19/// - `username`: Set the default or the per-process privdrop user.
20/// - `disable_privdrop`: disable privdrop for the program or process.
21#[proc_macro_derive(Privsep, attributes(connect, main_path, username, disable_privdrop))]
22pub fn derive_privsep(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
23    let input = parse_macro_input!(item as ItemEnum);
24
25    derive_privsep_enum(input)
26        .unwrap_or_else(|err| err.to_compile_error())
27        .into()
28}
29
30fn parse_attribute_value(attrs: &[Attribute], name: &str) -> Result<Option<LitStr>, Error> {
31    if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident(name)) {
32        match attr.parse_meta()? {
33            Meta::NameValue(MetaNameValue {
34                lit: Lit::Str(lit_str),
35                ..
36            }) => Ok(Some(lit_str)),
37            meta => Err(Error::new_spanned(
38                meta,
39                &format!("invalid `{}` attribute", name),
40            )),
41        }
42    } else {
43        Ok(None)
44    }
45}
46
47fn parse_attribute_ident(attrs: &[Attribute], name: &str) -> Result<Vec<Ident>, Error> {
48    let mut result = vec![];
49
50    // TODO: could we use `darling` here?
51    if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident(name)) {
52        match attr.parse_meta()? {
53            Meta::List(MetaList { nested, .. }) => {
54                for nested in nested.iter() {
55                    if let NestedMeta::Meta(Meta::Path(path)) = nested {
56                        if let Some(ident) = path.get_ident() {
57                            result.push(ident.clone());
58                        }
59                    }
60                }
61            }
62            ref meta => {
63                return Err(Error::new_spanned(
64                    meta,
65                    &format!("invalid `{}` attribute", name),
66                ))
67            }
68        }
69    }
70
71    Ok(result)
72}
73
74fn parse_attribute_type<T: Parse + ToTokens>(
75    attrs: &[Attribute],
76    name: &str,
77    default: &str,
78) -> Result<T, Error> {
79    parse_attribute_value(attrs, name)?
80        .unwrap_or_else(|| LitStr::new(default, Span::call_site()))
81        .parse()
82}
83
84fn derive_privsep_enum(item: ItemEnum) -> Result<TokenStream, Error> {
85    let ident = item.ident.clone();
86    let attrs = &item.attrs;
87    let mut as_ref_str = vec![];
88    let mut child_main = vec![];
89    let mut child_peers = vec![];
90    let mut const_as_array = vec![];
91    let mut const_id = vec![];
92    let mut const_ids = vec![];
93    let mut const_names = vec![];
94    let mut child_names = vec![];
95    let mut from_id = vec![];
96    let mut children = vec![];
97    let mut connect_map = HashMap::new();
98    let not_connected = HashSet::new();
99    let array_len = item.variants.len();
100
101    // Get the global attributes.
102    let disable_privdrop = attrs.iter().any(|a| a.path.is_ident("disable_privdrop"));
103    let username = if let Some(username) = parse_attribute_value(attrs, "username")? {
104        username
105    } else if disable_privdrop {
106        LitStr::new("", Span::call_site())
107    } else {
108        return Err(Error::new_spanned(
109            item,
110            "`Privsep` requires `username` attribute",
111        ));
112    };
113    let doc = attrs
114        .iter()
115        .filter(|a| a.path.is_ident("doc"))
116        .collect::<Vec<_>>();
117
118    // Resolve bi-directional connections between processes.
119    for variant in item.variants.iter() {
120        let child_ident = variant.ident.clone();
121        children.push(child_ident.clone());
122
123        let connect = parse_attribute_ident(&variant.attrs, "connect")?
124            .into_iter()
125            .collect::<HashSet<_>>();
126        connect_map.insert(child_ident, connect);
127    }
128
129    let temp_map = connect_map.clone();
130    for (key, value) in temp_map.into_iter() {
131        for entry in value.iter() {
132            if !children.contains(entry) {
133                return Err(Error::new_spanned(
134                    item,
135                    &format!("Connection to unknown process `{}`", entry),
136                ));
137            }
138            if let Some(other) = connect_map.get_mut(entry) {
139                other.insert(key.clone());
140            }
141        }
142    }
143
144    let mut main_path = quote! {
145        unimplemented!()
146    };
147    let mut options = quote! {
148        Options {
149            config,
150            ..Default::default()
151        }
152    };
153
154    // Configure processes.
155    for (id, variant) in item.variants.iter().enumerate() {
156        let child_doc = variant
157            .attrs
158            .iter()
159            .filter(|a| a.path.is_ident("doc"))
160            .collect::<Vec<_>>();
161        let child_ident = &variant.ident;
162        let name_ident = child_ident.to_string();
163        let name = name_ident.to_case(Case::Kebab);
164        let name_snake = name_ident.to_case(Case::Snake);
165        let name_upper = name_ident.to_case(Case::UpperSnake);
166        let id_name = Ident::new(&(name_upper + "_ID"), Span::call_site());
167        let child_main_path: Path =
168            parse_attribute_type(&variant.attrs, "main_path", &(name_snake + "::main"))?;
169
170        let child_username =
171            parse_attribute_value(&variant.attrs, "username")?.unwrap_or_else(|| username.clone());
172        let child_disable_privdrop =
173            disable_privdrop || attrs.iter().any(|a| a.path.is_ident("disable_privdrop"));
174        let child_options = quote! {
175            privsep::process::Options {
176                config: config.clone(),
177                disable_privdrop: #child_disable_privdrop,
178                username: #child_username.into(),
179            }
180        };
181        child_names.push(name.clone());
182
183        let connect = connect_map.get(child_ident).unwrap_or(&not_connected);
184
185        let child_connect = children
186            .iter()
187            .enumerate()
188            .map(|(id, child)| {
189                let is_connected = id == 0 || connect.contains(child);
190                quote! {
191                    Process {
192                        name: Self::as_static_str(&Self::#child),
193                        connect: #is_connected
194                    },
195                }
196            })
197            .collect::<Vec<_>>();
198
199        let is_child = id != 0;
200
201        const_as_array.push(quote! {
202            Process { name: #name, connect: #is_child },
203        });
204
205        const_id.push(quote! {
206            #(#child_doc)*
207            pub const #id_name: usize = #id;
208        });
209
210        const_ids.push(quote! {
211            #id,
212        });
213
214        const_names.push(quote! {
215            #name,
216        });
217
218        as_ref_str.push(quote! {
219            Self::#child_ident => #name,
220        });
221
222        from_id.push(quote! {
223            #id => Ok(Self::#child_ident),
224        });
225
226        child_peers.push(quote! {
227            [#(#child_connect)*],
228        });
229
230        if is_child {
231            let process = quote! {
232                Child::<#array_len>::new([#(#child_connect)*], #name, &#child_options).await?
233            };
234            child_main.push(quote! {
235                #name => {
236                    let process = #process;
237                    #child_main_path(process, config).await
238                }
239            });
240        } else {
241            options = child_options;
242            main_path = quote! {
243                #child_main_path
244            };
245        }
246    }
247    let child_main = child_main.into_iter().rev().collect::<Vec<_>>();
248
249    if child_names.first().map(AsRef::as_ref) != Some("parent") {
250        return Err(Error::new_spanned(
251            item.variants,
252            "Missing `Parent` variant",
253        ));
254    }
255
256    Ok(quote! {
257        #(#doc)*
258        impl #ident {
259            #(#const_id)*
260
261            #[doc = "IDs of all child processes."]
262            pub const PROCESS_IDS: [usize; #array_len] = [#(#const_ids)*];
263
264            #[doc = "Names of all child processes."]
265            pub const PROCESS_NAMES: [&'static str; #array_len] = [#(#const_names)*];
266
267            #[doc = "Return processes as const list."]
268            pub const fn as_array() -> [privsep::process::Process; #array_len] {
269                use privsep::process::Process;
270                [
271                    #(#const_as_array)*
272                ]
273            }
274
275            #[doc = "Start parent or child process."]
276            pub async fn main(config: privsep::Config) -> Result<(), privsep::Error> {
277                use privsep::process::{Child, Parent, Process};
278                let name = std::env::args().next().unwrap_or_default();
279                match name.as_ref() {
280                    #(#child_main)*
281                    _ => {
282                        let process = Parent::new(Self::as_array(), &#options).await?;
283                        #main_path(process.connect([#(#child_peers)*]).await?, config).await
284                    }
285                }
286            }
287
288            pub const fn as_static_str(&self) -> &'static str {
289                match self {
290                    #(#as_ref_str)*
291                }
292            }
293        }
294
295        impl AsRef<str> for #ident {
296            fn as_ref(&self) -> &str {
297                self.as_static_str()
298            }
299        }
300
301        impl std::fmt::Display for #ident {
302            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303                write!(f, "{}", self.as_ref())
304            }
305        }
306
307        impl std::convert::TryFrom<usize> for #ident {
308            type Error = &'static str;
309
310            fn try_from(id: usize) -> Result<Self, Self::Error> {
311                match id {
312                    #(#from_id)*
313                    _ => Err("Invalid privsep process ID"),
314                }
315            }
316        }
317    })
318}