1use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use std::ops::Not;
7use std::sync::{LazyLock, Mutex};
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, Data, DeriveInput, Error, Ident, ItemFn, Meta, PatType};
10
11static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![]));
12
13#[allow(unused)]
14#[derive(Debug, Clone)]
15struct NamedQuerySpec {
16 pub struct_name: String,
17 pub function: String,
18 pub default: bool,
19}
20
21struct QuerySpec {
22 pub function: Ident,
23 pub input_type: syn::Type,
24 pub output_type: syn::Type,
25 pub default: bool,
26}
27
28fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> {
30 use syn::GenericArgument;
31 use syn::PathArguments;
32 let last = p.segments.last().unwrap();
35 if last.ident != "Result" {
36 return Err(Error::new(
37 p.span(),
38 "Expected return type to be a Result<T: Serialize>",
39 ));
40 }
41 match &last.arguments {
42 PathArguments::AngleBracketed(x) => {
43 let Some(GenericArgument::Type(ty)) = x.args.first() else {
44 return Err(Error::new(
45 p.span(),
46 "Expected return type to be a Result<T: Serialize>",
47 ));
48 };
49 Ok(ty.clone())
50 }
51 _ => Err(Error::new(
52 p.span(),
53 "Expected return type to be a Result<T: Serialize>",
54 )),
55 }
56}
57
58fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> {
60 if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() {
61 if type_reference.mutability.is_some() {
62 if let syn::Type::Path(type_path) = type_reference.elem.as_ref() {
63 let last = type_path.path.segments.last().unwrap();
64 if last.ident == "PluginEngine" {
65 return Ok(());
66 }
67 }
68 }
69 }
70
71 Err(Error::new(
72 engine_arg.span(),
73 "The first argument of the query function must be a &mut PluginEngine",
74 ))
75}
76
77fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> {
78 use syn::Meta::*;
79 use syn::ReturnType;
80 let sig = &item_fn.sig;
81
82 let function = sig.ident.clone();
83
84 let input_type: syn::Type = {
85 let inputs = &sig.inputs;
86 if inputs.len() != 2 {
87 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
88 }
89 if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) {
91 parse_plugin_engine(engine_arg)?;
92 }
93
94 if let Some(input_arg) = inputs.get(1) {
95 let syn::FnArg::Typed(input_arg_info) = input_arg else {
96 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
97 };
98 input_arg_info.ty.as_ref().clone()
99 } else {
100 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
101 }
102 };
103
104 let output_type = match &sig.output {
105 ReturnType::Default => {
106 return Err(Error::new(
107 item_fn.span(),
108 "Query function must return Result<T: Serialize>",
109 ));
110 }
111 ReturnType::Type(_, b_type) => {
112 use syn::Type;
113 match b_type.as_ref() {
114 Type::Path(p) => parse_result_generic(&p.path)?,
115 _ => {
116 return Err(Error::new(
117 item_fn.span(),
118 "Query function must return Result<T: Serialize>",
119 ))
120 }
121 }
122 }
123 };
124
125 let default = match opt_meta {
126 Some(NameValue(nv)) => {
127 if nv.path.segments.first().unwrap().ident == "default" {
129 match nv.value {
130 syn::Expr::Lit(e) => match e.lit {
131 syn::Lit::Bool(s) => s.value,
132 _ => {
133 return Err(Error::new(
134 item_fn.span(),
135 "Default field on query function options must have a Boolean value",
136 ));
137 }
138 },
139 _ => {
140 return Err(Error::new(
141 item_fn.span(),
142 "Default field on query function options must have a Boolean value",
143 ));
144 }
145 }
146 } else {
147 return Err(Error::new(
148 item_fn.span(),
149 "Default field must be set if options are included for the query function",
150 ));
151 }
152 }
153 Some(Path(p)) => {
154 let seg: &syn::PathSegment = p.segments.first().unwrap();
155 if seg.ident == "default" {
156 match seg.arguments {
157 syn::PathArguments::None => true,
158 _ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")),
159 }
160 } else {
161 return Err(Error::new(
162 item_fn.span(),
163 "Default field must be set if options are included for the query function",
164 ));
165 }
166 }
167 None => false,
168 _ => {
169 return Err(Error::new(
170 item_fn.span(),
171 "Cannot parse query function options",
172 ));
173 }
174 };
175
176 Ok(QuerySpec {
177 function,
178 default,
179 input_type,
180 output_type,
181 })
182}
183
184#[proc_macro_attribute]
191pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
192 let mut to_return = proc_macro2::TokenStream::from(item.clone());
193 let item_fn = parse_macro_input!(item as ItemFn);
194 let opt_meta: Option<Meta> = if attr.is_empty().not() {
195 Some(parse_macro_input!(attr as Meta))
196 } else {
197 None
198 };
199 let spec = match parse_named_query_spec(opt_meta, item_fn) {
200 Ok(span) => span,
201 Err(err) => return err.to_compile_error().into(),
202 };
203
204 let struct_name = Ident::new(
205 spec.function
206 .to_string()
207 .to_case(convert_case::Case::Pascal)
208 .as_str(),
209 Span::call_site(),
210 );
211 let ident = &spec.function;
212 let input_type = spec.input_type;
213 let output_type = spec.output_type;
214
215 let to_follow = quote::quote! {
216 struct #struct_name {}
217
218 #[hipcheck_sdk::prelude::async_trait]
219 impl hipcheck_sdk::prelude::Query for #struct_name {
220 fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
221 hipcheck_sdk::prelude::schema_for!(#input_type).schema
222 }
223
224 fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
225 hipcheck_sdk::prelude::schema_for!(#output_type).schema
226 }
227
228 async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> {
229 let input = hipcheck_sdk::prelude::from_value(input).map_err(|_|
230 hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?;
231 let output = #ident(engine, input).await?;
232 hipcheck_sdk::prelude::to_value(output).map_err(|_|
233 hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat)
234 }
235 }
236 };
237
238 QUERIES.lock().unwrap().push(NamedQuerySpec {
239 struct_name: struct_name.to_string(),
240 function: spec.function.to_string(),
241 default: spec.default,
242 });
243
244 to_return.extend(to_follow);
245 proc_macro::TokenStream::from(to_return)
246}
247
248#[proc_macro]
253pub fn queries(_item: TokenStream) -> TokenStream {
254 let mut agg = proc_macro2::TokenStream::new();
255 let q_lock = QUERIES.lock().unwrap();
256 for q in q_lock.iter() {
258 let name = match q.default {
259 true => "",
260 false => q.function.as_str(),
261 };
262 let inner = Ident::new(q.struct_name.as_str(), Span::call_site());
263 let out = quote::quote! {
264 NamedQuery {
265 name: #name,
266 inner: Box::new(#inner {})
267 },
268 };
269 agg.extend(out);
270 }
271 tracing::debug!(
272 "Auto-generating Plugin::queries() with {} detected queries",
273 q_lock.len()
274 );
275 let out = quote::quote! {
277 fn queries(&self) -> impl Iterator<Item = NamedQuery> {
278 vec![#agg].into_iter()
279 }
280 };
281 proc_macro::TokenStream::from(out)
282}
283
284#[proc_macro_derive(PluginConfig)]
290pub fn derive_plugin_config(input: TokenStream) -> TokenStream {
291 let input = parse_macro_input!(input as DeriveInput);
293
294 let struct_name = &input.ident;
296
297 let Data::Struct(syn::DataStruct { fields, .. }) = &input.data else {
298 return syn::Error::new(input.span(), "PluginConfig can only be derived for structs")
300 .to_compile_error()
301 .into();
302 };
303
304 fn to_dashed_field_name(field: &syn::Field) -> String {
306 field.ident.as_ref().unwrap().to_string().replace("_", "-")
307 }
308
309 let field_deserialization: Vec<_> = fields
311 .iter()
312 .map(|field| {
313 let field_name = field.ident.as_ref().unwrap();
314 let field_name_str = to_dashed_field_name(field);
315 let field_type = &field.ty;
316
317 quote::quote! {
318 let #field_name = if let Some(value) = config.remove(#field_name_str) {
319 serde_json::from_value::<#field_type>(value.clone()).map_err(|_| {
321 ConfigError::InvalidConfigValue {
322 field_name: #field_name_str.to_owned(),
323 value: format!("{:?}", value),
324 reason: format!(
325 "Expected type: {}, but got: {:?}",
326 stringify!(#field_type),
327 value
328 ),
329 }
330 })?
331 } else {
332 serde_json::from_value::<#field_type>(serde_json::Value::Null).map_err(|_| {
335 ConfigError::MissingRequiredConfig {
336 field_name: #field_name_str.to_owned(),
337 field_type: stringify!(#field_type).to_owned(),
338 possible_values: vec![],
339 }
340 })?
341 };
342 }
343 })
344 .collect();
345
346 let validate_fields = quote::quote! {
348 if let Some((unexpected_key, value)) = config.iter().next() {
349 return Err(ConfigError::UnrecognizedConfig {
351 field_name: unexpected_key.to_string(),
352 field_value: format!("{:?}", value),
353 possible_confusables: vec![],
354 });
355 }
356 };
357
358 let initialize_struct: Vec<_> = fields
360 .iter()
361 .map(|field| {
362 let field_name = field.ident.as_ref().unwrap();
363 quote::quote! {
364 #field_name
365 }
366 })
367 .collect();
368
369 let impl_block = quote::quote! {
371 impl<'de> PluginConfig<'de> for #struct_name {
372 fn deserialize(conf_ref: &serde_json::Value) -> StdResult<Self, ConfigError> {
373 let mut conf_owned = conf_ref.clone();
374 let mut dummy = serde_json::Map::new();
375 let config = conf_owned.as_object_mut().unwrap_or(&mut dummy);
376
377 #(#field_deserialization)* #validate_fields
379 Ok(Self {
380 #(#initialize_struct),* })
382 }
383 }
384 };
385
386 proc_macro::TokenStream::from(impl_block)
388}