hipcheck_sdk_macros/
lib.rs1use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use std::ops::Not;
7use std::sync::{LazyLock, Mutex};
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, Error, Ident, ItemFn, Meta, PatType};
10
11static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![]));
12
13#[allow(unused)]
14#[derive(Debug, Clone)]
15struct NamedQuerySpec {
16 pub struct_name: String,
17 pub function: String,
18 pub default: bool,
19}
20
21struct QuerySpec {
22 pub function: Ident,
23 pub input_type: syn::Type,
24 pub output_type: syn::Type,
25 pub default: bool,
26}
27
28fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> {
30 use syn::GenericArgument;
31 use syn::PathArguments;
32 let last = p.segments.last().unwrap();
35 if last.ident != "Result" {
36 return Err(Error::new(
37 p.span(),
38 "Expected return type to be a Result<T: Serialize>",
39 ));
40 }
41 match &last.arguments {
42 PathArguments::AngleBracketed(x) => {
43 let Some(GenericArgument::Type(ty)) = x.args.first() else {
44 return Err(Error::new(
45 p.span(),
46 "Expected return type to be a Result<T: Serialize>",
47 ));
48 };
49 Ok(ty.clone())
50 }
51 _ => Err(Error::new(
52 p.span(),
53 "Expected return type to be a Result<T: Serialize>",
54 )),
55 }
56}
57
58fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> {
60 if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() {
61 if type_reference.mutability.is_some() {
62 if let syn::Type::Path(type_path) = type_reference.elem.as_ref() {
63 let last = type_path.path.segments.last().unwrap();
64 if last.ident == "PluginEngine" {
65 return Ok(());
66 }
67 }
68 }
69 }
70
71 Err(Error::new(
72 engine_arg.span(),
73 "The first argument of the query function must be a &mut PluginEngine",
74 ))
75}
76
77fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> {
78 use syn::Meta::*;
79 use syn::ReturnType;
80 let sig = &item_fn.sig;
81
82 let function = sig.ident.clone();
83
84 let input_type: syn::Type = {
85 let inputs = &sig.inputs;
86 if inputs.len() != 2 {
87 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
88 }
89 if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) {
91 parse_plugin_engine(engine_arg)?;
92 }
93
94 if let Some(input_arg) = inputs.get(1) {
95 let syn::FnArg::Typed(input_arg_info) = input_arg else {
96 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
97 };
98 input_arg_info.ty.as_ref().clone()
99 } else {
100 return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
101 }
102 };
103
104 let output_type = match &sig.output {
105 ReturnType::Default => {
106 return Err(Error::new(
107 item_fn.span(),
108 "Query function must return Result<T: Serialize>",
109 ));
110 }
111 ReturnType::Type(_, b_type) => {
112 use syn::Type;
113 match b_type.as_ref() {
114 Type::Path(p) => parse_result_generic(&p.path)?,
115 _ => {
116 return Err(Error::new(
117 item_fn.span(),
118 "Query function must return Result<T: Serialize>",
119 ))
120 }
121 }
122 }
123 };
124
125 let default = match opt_meta {
126 Some(NameValue(nv)) => {
127 if nv.path.segments.first().unwrap().ident == "default" {
129 match nv.value {
130 syn::Expr::Lit(e) => match e.lit {
131 syn::Lit::Bool(s) => s.value,
132 _ => {
133 return Err(Error::new(
134 item_fn.span(),
135 "Default field on query function options must have a Boolean value",
136 ));
137 }
138 },
139 _ => {
140 return Err(Error::new(
141 item_fn.span(),
142 "Default field on query function options must have a Boolean value",
143 ));
144 }
145 }
146 } else {
147 return Err(Error::new(
148 item_fn.span(),
149 "Default field must be set if options are included for the query function",
150 ));
151 }
152 }
153 Some(Path(p)) => {
154 let seg: &syn::PathSegment = p.segments.first().unwrap();
155 if seg.ident == "default" {
156 match seg.arguments {
157 syn::PathArguments::None => true,
158 _ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")),
159 }
160 } else {
161 return Err(Error::new(
162 item_fn.span(),
163 "Default field must be set if options are included for the query function",
164 ));
165 }
166 }
167 None => false,
168 _ => {
169 return Err(Error::new(
170 item_fn.span(),
171 "Cannot parse query function options",
172 ));
173 }
174 };
175
176 Ok(QuerySpec {
177 function,
178 default,
179 input_type,
180 output_type,
181 })
182}
183
184#[proc_macro_attribute]
191pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
192 let mut to_return = proc_macro2::TokenStream::from(item.clone());
193 let item_fn = parse_macro_input!(item as ItemFn);
194 let opt_meta: Option<Meta> = if attr.is_empty().not() {
195 Some(parse_macro_input!(attr as Meta))
196 } else {
197 None
198 };
199 let spec = match parse_named_query_spec(opt_meta, item_fn) {
200 Ok(span) => span,
201 Err(err) => return err.to_compile_error().into(),
202 };
203
204 let struct_name = Ident::new(
205 spec.function
206 .to_string()
207 .to_case(convert_case::Case::Pascal)
208 .as_str(),
209 Span::call_site(),
210 );
211 let ident = &spec.function;
212 let input_type = spec.input_type;
213 let output_type = spec.output_type;
214
215 let to_follow = quote::quote! {
216 struct #struct_name {}
217
218 #[hipcheck_sdk::prelude::async_trait]
219 impl hipcheck_sdk::prelude::Query for #struct_name {
220 fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
221 hipcheck_sdk::prelude::schema_for!(#input_type).schema
222 }
223
224 fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
225 hipcheck_sdk::prelude::schema_for!(#output_type).schema
226 }
227
228 async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> {
229 let input = hipcheck_sdk::prelude::from_value(input).map_err(|_|
230 hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?;
231 let output = #ident(engine, input).await?;
232 hipcheck_sdk::prelude::to_value(output).map_err(|_|
233 hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat)
234 }
235 }
236 };
237
238 QUERIES.lock().unwrap().push(NamedQuerySpec {
239 struct_name: struct_name.to_string(),
240 function: spec.function.to_string(),
241 default: spec.default,
242 });
243
244 to_return.extend(to_follow);
245 proc_macro::TokenStream::from(to_return)
246}
247
248#[proc_macro]
253pub fn queries(_item: TokenStream) -> TokenStream {
254 let mut agg = proc_macro2::TokenStream::new();
255 let q_lock = QUERIES.lock().unwrap();
256 for q in q_lock.iter() {
258 let name = match q.default {
259 true => "",
260 false => q.function.as_str(),
261 };
262 let inner = Ident::new(q.struct_name.as_str(), Span::call_site());
263 let out = quote::quote! {
264 NamedQuery {
265 name: #name,
266 inner: Box::new(#inner {})
267 },
268 };
269 agg.extend(out);
270 }
271 tracing::debug!(
272 "Auto-generating Plugin::queries() with {} detected queries",
273 q_lock.len()
274 );
275 let out = quote::quote! {
277 fn queries(&self) -> impl Iterator<Item = NamedQuery> {
278 vec![#agg].into_iter()
279 }
280 };
281 proc_macro::TokenStream::from(out)
282}