1use convert_case::{Case, Casing};
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{quote, ToTokens};
8use std::collections::{HashMap, HashSet};
9use syn::{
10 parse::Parse, parse_macro_input, Attribute, Error, ItemEnum, Lit, LitStr, Meta, MetaList,
11 MetaNameValue, NestedMeta, Path,
12};
13
14#[proc_macro_derive(Privsep, attributes(connect, main_path, username, disable_privdrop))]
22pub fn derive_privsep(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
23 let input = parse_macro_input!(item as ItemEnum);
24
25 derive_privsep_enum(input)
26 .unwrap_or_else(|err| err.to_compile_error())
27 .into()
28}
29
30fn parse_attribute_value(attrs: &[Attribute], name: &str) -> Result<Option<LitStr>, Error> {
31 if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident(name)) {
32 match attr.parse_meta()? {
33 Meta::NameValue(MetaNameValue {
34 lit: Lit::Str(lit_str),
35 ..
36 }) => Ok(Some(lit_str)),
37 meta => Err(Error::new_spanned(
38 meta,
39 &format!("invalid `{}` attribute", name),
40 )),
41 }
42 } else {
43 Ok(None)
44 }
45}
46
47fn parse_attribute_ident(attrs: &[Attribute], name: &str) -> Result<Vec<Ident>, Error> {
48 let mut result = vec![];
49
50 if let Some(attr) = attrs.iter().find(|attr| attr.path.is_ident(name)) {
52 match attr.parse_meta()? {
53 Meta::List(MetaList { nested, .. }) => {
54 for nested in nested.iter() {
55 if let NestedMeta::Meta(Meta::Path(path)) = nested {
56 if let Some(ident) = path.get_ident() {
57 result.push(ident.clone());
58 }
59 }
60 }
61 }
62 ref meta => {
63 return Err(Error::new_spanned(
64 meta,
65 &format!("invalid `{}` attribute", name),
66 ))
67 }
68 }
69 }
70
71 Ok(result)
72}
73
74fn parse_attribute_type<T: Parse + ToTokens>(
75 attrs: &[Attribute],
76 name: &str,
77 default: &str,
78) -> Result<T, Error> {
79 parse_attribute_value(attrs, name)?
80 .unwrap_or_else(|| LitStr::new(default, Span::call_site()))
81 .parse()
82}
83
84fn derive_privsep_enum(item: ItemEnum) -> Result<TokenStream, Error> {
85 let ident = item.ident.clone();
86 let attrs = &item.attrs;
87 let mut as_ref_str = vec![];
88 let mut child_main = vec![];
89 let mut child_peers = vec![];
90 let mut const_as_array = vec![];
91 let mut const_id = vec![];
92 let mut const_ids = vec![];
93 let mut const_names = vec![];
94 let mut child_names = vec![];
95 let mut from_id = vec![];
96 let mut children = vec![];
97 let mut connect_map = HashMap::new();
98 let not_connected = HashSet::new();
99 let array_len = item.variants.len();
100
101 let disable_privdrop = attrs.iter().any(|a| a.path.is_ident("disable_privdrop"));
103 let username = if let Some(username) = parse_attribute_value(attrs, "username")? {
104 username
105 } else if disable_privdrop {
106 LitStr::new("", Span::call_site())
107 } else {
108 return Err(Error::new_spanned(
109 item,
110 "`Privsep` requires `username` attribute",
111 ));
112 };
113 let doc = attrs
114 .iter()
115 .filter(|a| a.path.is_ident("doc"))
116 .collect::<Vec<_>>();
117
118 for variant in item.variants.iter() {
120 let child_ident = variant.ident.clone();
121 children.push(child_ident.clone());
122
123 let connect = parse_attribute_ident(&variant.attrs, "connect")?
124 .into_iter()
125 .collect::<HashSet<_>>();
126 connect_map.insert(child_ident, connect);
127 }
128
129 let temp_map = connect_map.clone();
130 for (key, value) in temp_map.into_iter() {
131 for entry in value.iter() {
132 if !children.contains(entry) {
133 return Err(Error::new_spanned(
134 item,
135 &format!("Connection to unknown process `{}`", entry),
136 ));
137 }
138 if let Some(other) = connect_map.get_mut(entry) {
139 other.insert(key.clone());
140 }
141 }
142 }
143
144 let mut main_path = quote! {
145 unimplemented!()
146 };
147 let mut options = quote! {
148 Options {
149 config,
150 ..Default::default()
151 }
152 };
153
154 for (id, variant) in item.variants.iter().enumerate() {
156 let child_doc = variant
157 .attrs
158 .iter()
159 .filter(|a| a.path.is_ident("doc"))
160 .collect::<Vec<_>>();
161 let child_ident = &variant.ident;
162 let name_ident = child_ident.to_string();
163 let name = name_ident.to_case(Case::Kebab);
164 let name_snake = name_ident.to_case(Case::Snake);
165 let name_upper = name_ident.to_case(Case::UpperSnake);
166 let id_name = Ident::new(&(name_upper + "_ID"), Span::call_site());
167 let child_main_path: Path =
168 parse_attribute_type(&variant.attrs, "main_path", &(name_snake + "::main"))?;
169
170 let child_username =
171 parse_attribute_value(&variant.attrs, "username")?.unwrap_or_else(|| username.clone());
172 let child_disable_privdrop =
173 disable_privdrop || attrs.iter().any(|a| a.path.is_ident("disable_privdrop"));
174 let child_options = quote! {
175 privsep::process::Options {
176 config: config.clone(),
177 disable_privdrop: #child_disable_privdrop,
178 username: #child_username.into(),
179 }
180 };
181 child_names.push(name.clone());
182
183 let connect = connect_map.get(child_ident).unwrap_or(¬_connected);
184
185 let child_connect = children
186 .iter()
187 .enumerate()
188 .map(|(id, child)| {
189 let is_connected = id == 0 || connect.contains(child);
190 quote! {
191 Process {
192 name: Self::as_static_str(&Self::#child),
193 connect: #is_connected
194 },
195 }
196 })
197 .collect::<Vec<_>>();
198
199 let is_child = id != 0;
200
201 const_as_array.push(quote! {
202 Process { name: #name, connect: #is_child },
203 });
204
205 const_id.push(quote! {
206 #(#child_doc)*
207 pub const #id_name: usize = #id;
208 });
209
210 const_ids.push(quote! {
211 #id,
212 });
213
214 const_names.push(quote! {
215 #name,
216 });
217
218 as_ref_str.push(quote! {
219 Self::#child_ident => #name,
220 });
221
222 from_id.push(quote! {
223 #id => Ok(Self::#child_ident),
224 });
225
226 child_peers.push(quote! {
227 [#(#child_connect)*],
228 });
229
230 if is_child {
231 let process = quote! {
232 Child::<#array_len>::new([#(#child_connect)*], #name, &#child_options).await?
233 };
234 child_main.push(quote! {
235 #name => {
236 let process = #process;
237 #child_main_path(process, config).await
238 }
239 });
240 } else {
241 options = child_options;
242 main_path = quote! {
243 #child_main_path
244 };
245 }
246 }
247 let child_main = child_main.into_iter().rev().collect::<Vec<_>>();
248
249 if child_names.first().map(AsRef::as_ref) != Some("parent") {
250 return Err(Error::new_spanned(
251 item.variants,
252 "Missing `Parent` variant",
253 ));
254 }
255
256 Ok(quote! {
257 #(#doc)*
258 impl #ident {
259 #(#const_id)*
260
261 #[doc = "IDs of all child processes."]
262 pub const PROCESS_IDS: [usize; #array_len] = [#(#const_ids)*];
263
264 #[doc = "Names of all child processes."]
265 pub const PROCESS_NAMES: [&'static str; #array_len] = [#(#const_names)*];
266
267 #[doc = "Return processes as const list."]
268 pub const fn as_array() -> [privsep::process::Process; #array_len] {
269 use privsep::process::Process;
270 [
271 #(#const_as_array)*
272 ]
273 }
274
275 #[doc = "Start parent or child process."]
276 pub async fn main(config: privsep::Config) -> Result<(), privsep::Error> {
277 use privsep::process::{Child, Parent, Process};
278 let name = std::env::args().next().unwrap_or_default();
279 match name.as_ref() {
280 #(#child_main)*
281 _ => {
282 let process = Parent::new(Self::as_array(), &#options).await?;
283 #main_path(process.connect([#(#child_peers)*]).await?, config).await
284 }
285 }
286 }
287
288 pub const fn as_static_str(&self) -> &'static str {
289 match self {
290 #(#as_ref_str)*
291 }
292 }
293 }
294
295 impl AsRef<str> for #ident {
296 fn as_ref(&self) -> &str {
297 self.as_static_str()
298 }
299 }
300
301 impl std::fmt::Display for #ident {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 write!(f, "{}", self.as_ref())
304 }
305 }
306
307 impl std::convert::TryFrom<usize> for #ident {
308 type Error = &'static str;
309
310 fn try_from(id: usize) -> Result<Self, Self::Error> {
311 match id {
312 #(#from_id)*
313 _ => Err("Invalid privsep process ID"),
314 }
315 }
316 }
317 })
318}