hipcheck_sdk_macros/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use std::ops::Not;
7use std::sync::{LazyLock, Mutex};
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, Data, DeriveInput, Error, Ident, ItemFn, Meta, PatType};
10
11static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![]));
12
13#[allow(unused)]
14#[derive(Debug, Clone)]
15struct NamedQuerySpec {
16	pub struct_name: String,
17	pub function: String,
18	pub default: bool,
19}
20
21struct QuerySpec {
22	pub function: Ident,
23	pub input_type: syn::Type,
24	pub output_type: syn::Type,
25	pub default: bool,
26}
27
28/// Parse Path to confirm that it represents a Result<T: Serialize> and return the type T
29fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> {
30	use syn::GenericArgument;
31	use syn::PathArguments;
32	// Assert it is a Result
33	// Panic: Safe to unwrap because there should be at least one element in the sequence
34	let last = p.segments.last().unwrap();
35	if last.ident != "Result" {
36		return Err(Error::new(
37			p.span(),
38			"Expected return type to be a Result<T: Serialize>",
39		));
40	}
41	match &last.arguments {
42		PathArguments::AngleBracketed(x) => {
43			let Some(GenericArgument::Type(ty)) = x.args.first() else {
44				return Err(Error::new(
45					p.span(),
46					"Expected return type to be a Result<T: Serialize>",
47				));
48			};
49			Ok(ty.clone())
50		}
51		_ => Err(Error::new(
52			p.span(),
53			"Expected return type to be a Result<T: Serialize>",
54		)),
55	}
56}
57
58/// Parse PatType to confirm that it contains a &mut PluginEngine
59fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> {
60	if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() {
61		if type_reference.mutability.is_some() {
62			if let syn::Type::Path(type_path) = type_reference.elem.as_ref() {
63				let last = type_path.path.segments.last().unwrap();
64				if last.ident == "PluginEngine" {
65					return Ok(());
66				}
67			}
68		}
69	}
70
71	Err(Error::new(
72		engine_arg.span(),
73		"The first argument of the query function must be a &mut PluginEngine",
74	))
75}
76
77fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> {
78	use syn::Meta::*;
79	use syn::ReturnType;
80	let sig = &item_fn.sig;
81
82	let function = sig.ident.clone();
83
84	let input_type: syn::Type = {
85		let inputs = &sig.inputs;
86		if inputs.len() != 2 {
87			return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
88		}
89		// Validate that the first arg is type &mut PluginEngine
90		if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) {
91			parse_plugin_engine(engine_arg)?;
92		}
93
94		if let Some(input_arg) = inputs.get(1) {
95			let syn::FnArg::Typed(input_arg_info) = input_arg else {
96				return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
97			};
98			input_arg_info.ty.as_ref().clone()
99		} else {
100			return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
101		}
102	};
103
104	let output_type = match &sig.output {
105		ReturnType::Default => {
106			return Err(Error::new(
107				item_fn.span(),
108				"Query function must return Result<T: Serialize>",
109			));
110		}
111		ReturnType::Type(_, b_type) => {
112			use syn::Type;
113			match b_type.as_ref() {
114				Type::Path(p) => parse_result_generic(&p.path)?,
115				_ => {
116					return Err(Error::new(
117						item_fn.span(),
118						"Query function must return Result<T: Serialize>",
119					))
120				}
121			}
122		}
123	};
124
125	let default = match opt_meta {
126		Some(NameValue(nv)) => {
127			// Panic: Safe to unwrap because there should be at least one element in the sequence
128			if nv.path.segments.first().unwrap().ident == "default" {
129				match nv.value {
130					syn::Expr::Lit(e) => match e.lit {
131						syn::Lit::Bool(s) => s.value,
132						_ => {
133							return Err(Error::new(
134								item_fn.span(),
135								"Default field on query function options must have a Boolean value",
136							));
137						}
138					},
139					_ => {
140						return Err(Error::new(
141							item_fn.span(),
142							"Default field on query function options must have a Boolean value",
143						));
144					}
145				}
146			} else {
147				return Err(Error::new(
148					item_fn.span(),
149					"Default field must be set if options are included for the query function",
150				));
151			}
152		}
153		Some(Path(p)) => {
154			let seg: &syn::PathSegment = p.segments.first().unwrap();
155			if seg.ident == "default" {
156				match seg.arguments {
157					syn::PathArguments::None => true,
158					_ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")),
159				}
160			} else {
161				return Err(Error::new(
162					item_fn.span(),
163					"Default field must be set if options are included for the query function",
164				));
165			}
166		}
167		None => false,
168		_ => {
169			return Err(Error::new(
170				item_fn.span(),
171				"Cannot parse query function options",
172			));
173		}
174	};
175
176	Ok(QuerySpec {
177		function,
178		default,
179		input_type,
180		output_type,
181	})
182}
183
184/// An attribute on a function that creates an associated struct that implements
185/// the Hipcheck Rust SDK's `Query` trait. The function must have the signature
186/// `fn(&mut PluginEngine, content: impl serde::Deserialize) ->
187/// hipcheck_sdk::Result<impl serde::Serialize>`. The generated struct's name is
188/// the pascal-case version of the function name (e.g. `do_something()` ->
189/// `DoSomething`).
190#[proc_macro_attribute]
191pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
192	let mut to_return = proc_macro2::TokenStream::from(item.clone());
193	let item_fn = parse_macro_input!(item as ItemFn);
194	let opt_meta: Option<Meta> = if attr.is_empty().not() {
195		Some(parse_macro_input!(attr as Meta))
196	} else {
197		None
198	};
199	let spec = match parse_named_query_spec(opt_meta, item_fn) {
200		Ok(span) => span,
201		Err(err) => return err.to_compile_error().into(),
202	};
203
204	let struct_name = Ident::new(
205		spec.function
206			.to_string()
207			.to_case(convert_case::Case::Pascal)
208			.as_str(),
209		Span::call_site(),
210	);
211	let ident = &spec.function;
212	let input_type = spec.input_type;
213	let output_type = spec.output_type;
214
215	let to_follow = quote::quote! {
216		struct #struct_name {}
217
218		#[hipcheck_sdk::prelude::async_trait]
219		impl hipcheck_sdk::prelude::Query for #struct_name {
220			fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
221				hipcheck_sdk::prelude::schema_for!(#input_type).schema
222			}
223
224			fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
225				hipcheck_sdk::prelude::schema_for!(#output_type).schema
226			}
227
228			async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> {
229				let input = hipcheck_sdk::prelude::from_value(input).map_err(|_|
230					hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?;
231				let output = #ident(engine, input).await?;
232				hipcheck_sdk::prelude::to_value(output).map_err(|_|
233					hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat)
234			}
235		}
236	};
237
238	QUERIES.lock().unwrap().push(NamedQuerySpec {
239		struct_name: struct_name.to_string(),
240		function: spec.function.to_string(),
241		default: spec.default,
242	});
243
244	to_return.extend(to_follow);
245	proc_macro::TokenStream::from(to_return)
246}
247
248/// Generates an implementation of the `Plugin::queries()` trait function using
249/// all previously-expanded `#[query]` attribute macros. Due to Rust's macro
250/// expansion ordering, all `#[query]` functions must come before this macro
251/// to ensure they are seen and added.
252#[proc_macro]
253pub fn queries(_item: TokenStream) -> TokenStream {
254	let mut agg = proc_macro2::TokenStream::new();
255	let q_lock = QUERIES.lock().unwrap();
256	// Create a NamedQuery for each #query func we've seen
257	for q in q_lock.iter() {
258		let name = match q.default {
259			true => "",
260			false => q.function.as_str(),
261		};
262		let inner = Ident::new(q.struct_name.as_str(), Span::call_site());
263		let out = quote::quote! {
264			NamedQuery {
265				name: #name,
266				inner: Box::new(#inner {})
267			},
268		};
269		agg.extend(out);
270	}
271	tracing::debug!(
272		"Auto-generating Plugin::queries() with {} detected queries",
273		q_lock.len()
274	);
275	// Impl `Plugin::queries` as a vec of generated NamedQuery instances
276	let out = quote::quote! {
277		fn queries(&self) -> impl Iterator<Item = NamedQuery> {
278			vec![#agg].into_iter()
279		}
280	};
281	proc_macro::TokenStream::from(out)
282}
283
284/// Generates a derived macro implementation of the `PluginConfig` trait to
285/// deserialize each plugin config field derived from the Policy File.
286/// Config-related errors are handled by the `ConfigError` crate to generate
287/// specific error messages that detail the plugin, field, and type expected from
288/// the Policy File.
289#[proc_macro_derive(PluginConfig)]
290pub fn derive_plugin_config(input: TokenStream) -> TokenStream {
291	// Parse the input struct
292	let input = parse_macro_input!(input as DeriveInput);
293
294	// Extract the RawConfig struct name
295	let struct_name = &input.ident;
296
297	let Data::Struct(syn::DataStruct { fields, .. }) = &input.data else {
298		// Return an error if the macro is used on something other than a struct
299		return syn::Error::new(input.span(), "PluginConfig can only be derived for structs")
300			.to_compile_error()
301			.into();
302	};
303
304	// Helper function to convert field names to dashed strings
305	fn to_dashed_field_name(field: &syn::Field) -> String {
306		field.ident.as_ref().unwrap().to_string().replace("_", "-")
307	}
308
309	// Generate deserialization logic for each field
310	let field_deserialization: Vec<_> = fields
311		.iter()
312		.map(|field| {
313			let field_name = field.ident.as_ref().unwrap();
314			let field_name_str = to_dashed_field_name(field);
315			let field_type = &field.ty;
316
317			quote::quote! {
318				let #field_name = if let Some(value) = config.remove(#field_name_str) {
319					// Map contained value, return an error if an invalid value is provided for the field
320					serde_json::from_value::<#field_type>(value.clone()).map_err(|_| {
321						ConfigError::InvalidConfigValue {
322							field_name: #field_name_str.to_owned(),
323							value: format!("{:?}", value),
324							reason: format!(
325								"Expected type: {}, but got: {:?}",
326								stringify!(#field_type),
327								value
328							),
329						}
330					})?
331				} else {
332					// Try deserializing from null. If it works, value's type indicates it was
333					// optional. If this fails, missing required config.
334					serde_json::from_value::<#field_type>(serde_json::Value::Null).map_err(|_| {
335						ConfigError::MissingRequiredConfig {
336							field_name: #field_name_str.to_owned(),
337							field_type: stringify!(#field_type).to_owned(),
338							possible_values: vec![],
339						}
340					})?
341				};
342			}
343		})
344		.collect();
345
346	// After the expected fields are extracted, there should be no remaining fields in the config map
347	let validate_fields = quote::quote! {
348		if let Some((unexpected_key, value)) = config.iter().next() {
349			// Return an error if any remaining key/value pair in map
350			return Err(ConfigError::UnrecognizedConfig {
351				field_name: unexpected_key.to_string(),
352				field_value: format!("{:?}", value),
353				possible_confusables: vec![],
354			});
355		}
356	};
357
358	// Generate code to initialize the struct fields
359	let initialize_struct: Vec<_> = fields
360		.iter()
361		.map(|field| {
362			let field_name = field.ident.as_ref().unwrap();
363			quote::quote! {
364				#field_name
365			}
366		})
367		.collect();
368
369	// Generate the implementation of the PluginConfig trait
370	let impl_block = quote::quote! {
371		impl<'de> PluginConfig<'de> for #struct_name {
372			fn deserialize(conf_ref: &serde_json::Value) -> StdResult<Self, ConfigError> {
373				let mut conf_owned = conf_ref.clone();
374				let mut dummy = serde_json::Map::new();
375				let config = conf_owned.as_object_mut().unwrap_or(&mut dummy);
376
377				#(#field_deserialization)* // Deserialize each field
378				#validate_fields
379				Ok(Self {
380					#(#initialize_struct),* // Initialize each field
381				})
382			}
383		}
384	};
385
386	// Return the generated TokenStream
387	proc_macro::TokenStream::from(impl_block)
388}