problemreductions_macros/
lib.rs1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
10
11#[proc_macro_attribute]
39pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
40 let attrs = parse_macro_input!(attr as ReductionAttrs);
41 let impl_block = parse_macro_input!(item as ItemImpl);
42
43 match generate_reduction_entry(&attrs, &impl_block) {
44 Ok(tokens) => tokens.into(),
45 Err(e) => e.to_compile_error().into(),
46 }
47}
48
49struct ReductionAttrs {
51 source_graph: Option<String>,
52 target_graph: Option<String>,
53 source_weighted: Option<bool>,
54 target_weighted: Option<bool>,
55 overhead: Option<TokenStream2>,
56}
57
58impl syn::parse::Parse for ReductionAttrs {
59 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
60 let mut attrs = ReductionAttrs {
61 source_graph: None,
62 target_graph: None,
63 source_weighted: None,
64 target_weighted: None,
65 overhead: None,
66 };
67
68 while !input.is_empty() {
69 let ident: syn::Ident = input.parse()?;
70 input.parse::<syn::Token![=]>()?;
71
72 match ident.to_string().as_str() {
73 "source_graph" => {
74 let lit: syn::LitStr = input.parse()?;
75 attrs.source_graph = Some(lit.value());
76 }
77 "target_graph" => {
78 let lit: syn::LitStr = input.parse()?;
79 attrs.target_graph = Some(lit.value());
80 }
81 "source_weighted" => {
82 let lit: syn::LitBool = input.parse()?;
83 attrs.source_weighted = Some(lit.value());
84 }
85 "target_weighted" => {
86 let lit: syn::LitBool = input.parse()?;
87 attrs.target_weighted = Some(lit.value());
88 }
89 "overhead" => {
90 let content;
91 syn::braced!(content in input);
92 attrs.overhead = Some(content.parse()?);
93 }
94 _ => {
95 return Err(syn::Error::new(
96 ident.span(),
97 format!("unknown attribute: {}", ident),
98 ));
99 }
100 }
101
102 if input.peek(syn::Token![,]) {
103 input.parse::<syn::Token![,]>()?;
104 }
105 }
106
107 Ok(attrs)
108 }
109}
110
111fn extract_type_name(ty: &Type) -> Option<String> {
113 match ty {
114 Type::Path(type_path) => {
115 let segment = type_path.path.segments.last()?;
116 Some(segment.ident.to_string())
117 }
118 _ => None,
119 }
120}
121
122fn extract_graph_type(ty: &Type) -> Option<String> {
124 match ty {
125 Type::Path(type_path) => {
126 let segment = type_path.path.segments.last()?;
127 if let PathArguments::AngleBracketed(args) = &segment.arguments {
128 for arg in args.args.iter() {
130 if let GenericArgument::Type(Type::Path(inner_path)) = arg {
131 let name = inner_path
132 .path
133 .segments
134 .last()
135 .map(|s| s.ident.to_string())?;
136 if name.len() == 1
138 && name
139 .chars()
140 .next()
141 .map(|c| c.is_ascii_uppercase())
142 .unwrap_or(false)
143 {
144 return None; }
146 if is_weight_type(&name) {
148 return None; }
150 return Some(name);
151 }
152 }
153 }
154 None
155 }
156 _ => None,
157 }
158}
159
160fn is_weight_type(name: &str) -> bool {
162 ["i32", "i64", "f32", "f64", "Unweighted"].contains(&name)
163}
164
165fn extract_weight_type(ty: &Type) -> Option<Type> {
169 match ty {
170 Type::Path(type_path) => {
171 let segment = type_path.path.segments.last()?;
172 if let PathArguments::AngleBracketed(args) = &segment.arguments {
173 let type_args: Vec<_> = args
174 .args
175 .iter()
176 .filter_map(|arg| {
177 if let GenericArgument::Type(t) = arg {
178 Some(t)
179 } else {
180 None
181 }
182 })
183 .collect();
184
185 match type_args.len() {
186 1 => {
187 let first = type_args[0];
189 if let Type::Path(inner_path) = first {
190 let name = inner_path.path.segments.last()?.ident.to_string();
191 if is_weight_type(&name) {
192 return Some(first.clone());
193 }
194 }
195 None
196 }
197 2 => {
198 Some(type_args[1].clone())
200 }
201 _ => None,
202 }
203 } else {
204 None
205 }
206 }
207 _ => None,
208 }
209}
210
211fn get_weight_name(ty: &Type) -> String {
215 match ty {
216 Type::Path(type_path) => {
217 let name = type_path
218 .path
219 .segments
220 .last()
221 .map(|s| s.ident.to_string())
222 .unwrap_or_else(|| "Unweighted".to_string());
223 if name.len() == 1
225 && name
226 .chars()
227 .next()
228 .map(|c| c.is_ascii_uppercase())
229 .unwrap_or(false)
230 {
231 "Unweighted".to_string()
232 } else {
233 name
234 }
235 }
236 _ => "Unweighted".to_string(),
237 }
238}
239
240fn generate_reduction_entry(
242 attrs: &ReductionAttrs,
243 impl_block: &ItemImpl,
244) -> syn::Result<TokenStream2> {
245 let trait_path = impl_block
247 .trait_
248 .as_ref()
249 .map(|(_, path, _)| path)
250 .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
251
252 let target_type = extract_target_from_trait(trait_path)?;
254
255 let source_type = &impl_block.self_ty;
257
258 let source_name = extract_type_name(source_type)
260 .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
261 let target_name = extract_type_name(&target_type)
262 .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
263
264 let source_weight_name = attrs
266 .source_weighted
267 .map(|w| {
268 if w {
269 "i32".to_string()
270 } else {
271 "Unweighted".to_string()
272 }
273 })
274 .unwrap_or_else(|| {
275 extract_weight_type(source_type)
276 .map(|t| get_weight_name(&t))
277 .unwrap_or_else(|| "Unweighted".to_string())
278 });
279 let target_weight_name = attrs
280 .target_weighted
281 .map(|w| {
282 if w {
283 "i32".to_string()
284 } else {
285 "Unweighted".to_string()
286 }
287 })
288 .unwrap_or_else(|| {
289 extract_weight_type(&target_type)
290 .map(|t| get_weight_name(&t))
291 .unwrap_or_else(|| "Unweighted".to_string())
292 });
293
294 let source_graph = attrs
296 .source_graph
297 .clone()
298 .or_else(|| extract_graph_type(source_type))
299 .unwrap_or_else(|| "SimpleGraph".to_string());
300 let target_graph = attrs
301 .target_graph
302 .clone()
303 .or_else(|| extract_graph_type(&target_type))
304 .unwrap_or_else(|| "SimpleGraph".to_string());
305
306 let overhead = attrs.overhead.clone().unwrap_or_else(|| {
308 quote! {
309 crate::rules::registry::ReductionOverhead::default()
310 }
311 });
312
313 let output = quote! {
315 #impl_block
316
317 inventory::submit! {
318 crate::rules::registry::ReductionEntry {
319 source_name: #source_name,
320 target_name: #target_name,
321 source_variant: &[("graph", #source_graph), ("weight", #source_weight_name)],
322 target_variant: &[("graph", #target_graph), ("weight", #target_weight_name)],
323 overhead_fn: || { #overhead },
324 module_path: module_path!(),
325 }
326 }
327 };
328
329 Ok(output)
330}
331
332fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
334 let segment = path
335 .segments
336 .last()
337 .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
338
339 if segment.ident != "ReduceTo" {
340 return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
341 }
342
343 if let PathArguments::AngleBracketed(args) = &segment.arguments {
344 if let Some(GenericArgument::Type(ty)) = args.args.first() {
345 return Ok(ty.clone());
346 }
347 }
348
349 Err(syn::Error::new_spanned(
350 segment,
351 "Expected ReduceTo<Target> with type parameter",
352 ))
353}