1use std::fmt::Write;
2
3use inflector::Inflector;
4use regex::Regex;
5
6fn type_translation(input: &str) -> String {
7 if input.starts_with("futhark") {
8 auto_ctor(input)
9 } else {
10 let mut buffer = String::new();
11 if input.starts_with("int8") {
12 write!(&mut buffer, "i8");
13 } else if input.starts_with("int") {
14 write!(&mut buffer, "i{}", &input[3..5]);
15 } else if input.starts_with("uint8") {
16 write!(&mut buffer, "u8");
17 } else if input.starts_with("uint") {
18 write!(&mut buffer, "u{}", &input[4..6]);
19 } else if input.starts_with("float") {
20 write!(&mut buffer, "f32");
21 } else if input.starts_with("double") {
22 write!(&mut buffer, "f64");
23 }
24 buffer
25 }
26}
27
28fn ctor_array_type(t: &str, dim: usize) -> String {
29 format!("Array_{}_{}d", t, dim)
30}
31
32const RE_ARRAY_TYPE_STR: &str = r"futhark_(.+)_(\d)d\b";
33
34fn parse_array_type(t: &str) -> Option<(String, usize)> {
35 let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
36 if let Some(captures) = re_array_type.captures(t) {
37 let dim: usize = captures[2].parse().unwrap();
38 let ftype = &captures[1];
39 Some((ftype.to_string(), dim))
40 } else {
41 None
42 }
43}
44fn auto_ctor(t: &str) -> String {
45 let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
46 if let Some((ftype, dim)) = parse_array_type(t) {
47 ctor_array_type(&ftype, dim)
48 } else {
49 to_opaque_type_name(t)
50 }
51}
52
53pub(crate) fn gen_entry_point(input: &str) -> (String, String, Vec<String>) {
54 let re_name = Regex::new(r"futhark_entry_(.+)\(").unwrap();
55 let re_arg_pairs =
56 Regex::new(r"(?m)\s*(?:const\s*)?(?:struct\s*)?([a-z0-9_]+)\s\**([a-z0-9]+),?\s?").unwrap();
57
58 let arg_pairs: Vec<(String, String)> = re_arg_pairs
59 .captures_iter(input)
60 .skip(2)
61 .map(|c| (c[1].to_owned(), c[2].to_owned()))
62 .collect();
63 let name = re_name.captures(input).unwrap()[1].to_owned();
64 let mut buffer = format!("pub fn {name}", name = name);
65
66 write!(&mut buffer, "(&mut self, ");
67 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
68 if argname.starts_with("in") {
69 let argtype_string = type_translation(argtype);
70 write!(
71 &mut buffer,
72 "{}: {}{}, ",
73 argname,
74 if argtype_string.starts_with("FutharkOpaque") {
75 "&"
76 } else {
77 ""
78 },
79 argtype_string
80 );
81 }
82 }
83 write!(&mut buffer, ") -> ");
84 let mut output_buffer = String::from("Result<(");
85 let mut output_counter = 0;
86 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
87 if argname.starts_with("out") {
88 if output_counter > 0 {
89 write!(&mut output_buffer, ", ");
90 }
91 output_counter += 1;
92 write!(&mut output_buffer, "{}", type_translation(argtype));
93 }
94 }
95 write!(&mut output_buffer, ")>");
96 writeln!(&mut buffer, "{}", output_buffer);
97
98 write!(
99 &mut buffer,
100 "{{\nlet ctx = self.ptr();\nunsafe{{\n_{name}(ctx, ",
101 name = name
102 );
103 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
104 if argname.starts_with("in") {
105 if argtype.starts_with("futhark") {
106 write!(&mut buffer, "{}.as_raw_mut(), ", argname);
107 } else {
108 write!(&mut buffer, "{}, ", argname);
109 }
110 }
111 }
112 write!(&mut buffer, ")\n}}}}\n");
113
114 let mut buffer2 = String::new();
116 write!(
117 &mut buffer2,
118 "unsafe fn _{name}(ctx: *mut bindings::futhark_context, ",
119 name = name
120 );
121 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
122 if argname.starts_with("in") {
123 if argtype.starts_with("futhark") {
124 write!(&mut buffer2, "{}: *const bindings::{}, ", argname, argtype);
125 } else {
126 write!(&mut buffer2, "{}: {}, ", argname, type_translation(argtype));
127 }
128 }
129 }
130 writeln!(&mut buffer2, ") -> {} {{", output_buffer);
131 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
132 if argname.starts_with("out") {
133 if argtype.starts_with("futhark") {
134 writeln!(
135 &mut buffer2,
136 "let mut raw_{} = std::ptr::null_mut();",
137 argname
138 );
139 } else {
140 writeln!(
141 &mut buffer2,
142 "let mut raw_{} = {}::default();",
143 argname,
144 type_translation(argtype)
145 );
146 }
147 }
148 }
149
150 write!(
151 &mut buffer2,
152 "\nif bindings::futhark_entry_{name}(ctx, ",
153 name = name
154 );
155 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
156 if argname.starts_with("out") {
157 write!(&mut buffer2, "&mut raw_{}, ", argname);
158 }
159 }
160 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
161 if argname.starts_with("in") {
162 write!(&mut buffer2, "{}, ", argname);
163 }
164 }
165 writeln!(
166 &mut buffer2,
167 ") != 0 {{
168return Err(FutharkError::new(ctx).into());}}"
169 );
170
171 let mut opaque_types = Vec::new();
172 let mut result_counter = 0;
174 write!(&mut buffer2, "Ok((");
175 for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
176 if argname.starts_with("out") {
177 if parse_array_type(argtype).is_none() {
178 opaque_types.push(argtype.clone());
179 }
180 if result_counter > 0 {
181 write!(&mut buffer2, ", ");
182 }
183 result_counter += 1;
184 if argtype.starts_with("futhark") {
185 writeln!(
186 &mut buffer2,
187 "{}::from_ptr(ctx, raw_{})",
188 auto_ctor(argtype),
189 argname
190 );
191 } else {
192 writeln!(&mut buffer2, "raw_{}", argname);
193 }
194 }
195 }
196 write!(&mut buffer2, "))\n}}");
197
198 (buffer, buffer2, opaque_types)
199}
200fn to_opaque_type_name(s: &str) -> String {
201 let mut rust_opaque_type = s.to_camel_case();
202
203 if let Some(r) = rust_opaque_type.get_mut(0..1) {
204 r.make_ascii_uppercase();
205 }
206 rust_opaque_type
207}
208
209fn gen_opaque_type(opaque_type: &str) -> String {
210 let rust_opaque_type = to_opaque_type_name(opaque_type);
211 assert!(opaque_type.starts_with("futhark_"),);
212 let base_type = &opaque_type[8..];
213 let bindings = format!("bindings::{}", opaque_type);
214 format!(
215 include_str!("static/static_opaque_types.rs"),
216 opaque_type = rust_opaque_type,
217 futhark_type = bindings,
218 base_type = base_type
219 )
220}
221
222pub(crate) fn gen_entry_points(input: &Vec<String>) -> String {
223 let mut buffer = String::from(
224 r#"impl FutharkContext {
225"#,
226 );
227 let mut opaque_types = Vec::new();
228 let mut buffer2 = String::new();
229 for t in input {
230 let (a, b, otypes) = gen_entry_point(t);
231 opaque_types.extend(otypes);
232 writeln!(&mut buffer, "{}", a).expect("Write failed!");
233 writeln!(&mut buffer2, "{}", b).expect("Write failed!");
234 }
235
236 opaque_types.sort();
237 opaque_types.dedup();
238 for (i, opaque_type) in opaque_types.iter().enumerate() {
239 if opaque_type.starts_with("futhark_") {
240 writeln!(&mut buffer2, "{}", gen_opaque_type(opaque_type));
241 }
242 }
243
244 writeln!(&mut buffer, "}}").expect("Write failed!");
245 writeln!(&mut buffer, "{}", buffer2).expect("Write failed!");
246
247 buffer
248}