1use std::fmt::Write;
2
3use proc_macro::TokenStream;
4
5#[proc_macro_attribute]
9pub fn generate_target_variant(attrs: TokenStream, input: TokenStream) -> TokenStream {
10 let attrs = attrs.to_string();
11
12 let feature = attrs
13 .trim()
14 .strip_prefix("\"")
15 .unwrap()
16 .strip_suffix("\"")
17 .unwrap()
18 .split(",")
19 .collect::<Vec<_>>();
20
21 let mut comments = Vec::new();
22
23 let input_str = input.to_string();
24 let mut input_str = input_str.lines();
25
26 let mut def_line = input_str.next().unwrap();
27
28 while def_line.trim().starts_with("//")
29 || def_line
30 .trim()
31 .strip_prefix("#")
32 .map_or(false, |s| s.trim().starts_with("["))
33 {
34 comments.push(def_line.trim().to_string());
35 let Some(def_line_next) = input_str.next() else {
36 return "".parse().unwrap();
37 };
38 def_line = def_line_next;
39 }
40
41 let rest = if def_line.trim().starts_with("pub") {
42 let trimmed = def_line.trim();
43 let split = trimmed
44 .char_indices()
45 .filter(|(_, c)| c.is_whitespace())
46 .next()
47 .unwrap()
48 .0;
49 &trimmed[split..]
50 } else {
51 def_line
52 };
53
54 let Some(rest) = rest.trim().strip_prefix("fn ") else {
55 panic!("Expected a function definition, got {}", rest);
56 };
57
58 let name = rest
59 .chars()
60 .take_while(|c| c.is_alphanumeric() || *c == '_')
61 .collect::<String>();
62
63 let def_line_mutated =
64 def_line.replace(&name, format!("{}_{}", name, feature.join("_")).as_str());
65
66 let rest_lines = input_str.collect::<Vec<_>>();
67
68 let mut output = String::new();
69 for comment in comments.iter() {
70 writeln!(&mut output, "{}", comment).unwrap();
71 }
72 writeln!(&mut output, "{}", def_line).unwrap();
73 for line in rest_lines.iter() {
74 writeln!(&mut output, "{}", line).unwrap();
75 }
76
77 for comment in comments.iter() {
78 writeln!(&mut output, "{}", comment).unwrap();
79 }
80 writeln!(
81 &mut output,
82 "// This is a generated function for CPU target {}",
83 feature.join(", ")
84 )
85 .unwrap();
86 for feature in feature.iter() {
87 writeln!(&mut output, "#[target_feature(enable = \"{}\")]", feature).unwrap();
88 }
89 writeln!(&mut output, "{}", def_line_mutated).unwrap();
90 output.extend(rest_lines);
91
92 output.parse().unwrap()
93}