futhark_bindgen/generate/
rust.rs1use crate::generate::{convert_struct_name, first_uppercase};
2use crate::*;
3use std::io::Write;
4
5pub struct Rust {
7 typemap: BTreeMap<String, String>,
8}
9
10fn type_is_array(t: &str) -> bool {
11 t.contains("ArrayF") || t.contains("ArrayI") || t.contains("ArrayU") || t.contains("ArrayB")
12}
13
14fn type_is_opaque(a: &str) -> bool {
15 a.contains("futhark_opaque_")
16}
17
18const RUST_TYPE_MAP: &[(&str, &str)] = &[("f16", "half::f16")];
20
21impl Default for Rust {
22 fn default() -> Self {
23 let typemap = RUST_TYPE_MAP
24 .iter()
25 .map(|(a, b)| (a.to_string(), b.to_string()))
26 .collect();
27 Rust { typemap }
28 }
29}
30
31struct ArrayInfo {
32 futhark_type: String,
33 rust_type: String,
34 #[allow(unused)]
35 elem: String,
36}
37
38impl Rust {
39 fn get_type(typemap: &BTreeMap<String, String>, t: &str) -> String {
40 let a = typemap.get(t);
41 let x = match a {
42 Some(t) => t.clone(),
43 None => t.to_string(),
44 };
45 if x.is_empty() {
46 panic!("Unsupported type: {t}");
47 }
48 x
49 }
50}
51
52impl Generate for Rust {
53 fn array_type(
54 &mut self,
55 _pkg: &Package,
56 config: &mut Config,
57 name: &str,
58 a: &manifest::ArrayType,
59 ) -> Result<(), Error> {
60 let elemtype = a.elemtype.to_str();
61 let rank = a.rank;
62
63 let futhark_type = convert_struct_name(&a.ctype).to_string();
64 let rust_type = format!("Array{}D{rank}", first_uppercase(elemtype));
65 let info = ArrayInfo {
66 futhark_type,
67 rust_type,
68 elem: elemtype.to_string(),
69 };
70
71 let mut dim_params = Vec::new();
72 let mut new_dim_args = Vec::new();
73
74 for i in 0..a.rank {
75 let dim = format!("dims[{i}]");
76 dim_params.push(dim);
77 new_dim_args.push(format!("dim{i}: i64"));
78 }
79
80 writeln!(
81 config.output_file,
82 include_str!("templates/rust/array.rs"),
83 futhark_type = info.futhark_type,
84 rust_type = info.rust_type,
85 rank = a.rank,
86 elemtype = info.elem,
87 new_fn = a.ops.new,
88 free_fn = a.ops.free,
89 values_fn = a.ops.values,
90 shape_fn = a.ops.shape,
91 dim_params = dim_params.join(", "),
92 new_dim_args = new_dim_args.join(", ")
93 )?;
94
95 self.typemap
96 .insert(name.to_string(), info.futhark_type.clone());
97 self.typemap.insert(info.futhark_type, info.rust_type);
98 Ok(())
99 }
100
101 fn opaque_type(
102 &mut self,
103 _pkg: &Package,
104 config: &mut Config,
105 name: &str,
106 ty: &manifest::OpaqueType,
107 ) -> Result<(), Error> {
108 let futhark_type = convert_struct_name(&ty.ctype).to_string();
109 let mut rust_type = first_uppercase(futhark_type.strip_prefix("futhark_opaque_").unwrap());
110 if rust_type.chars().next().unwrap().is_numeric() || name.contains(' ') {
111 rust_type = format!("Type{}", rust_type);
112 }
113
114 writeln!(
115 config.output_file,
116 include_str!("templates/rust/opaque.rs"),
117 futhark_type = futhark_type,
118 rust_type = rust_type,
119 free_fn = ty.ops.free,
120 )?;
121
122 let record = match &ty.record {
123 Some(r) => r,
124 None => {
125 self.typemap.insert(name.to_string(), futhark_type.clone());
126 self.typemap.insert(futhark_type, rust_type);
127 return Ok(());
128 }
129 };
130
131 let mut new_call_args = vec![];
132 let mut new_params = vec![];
133 let mut new_extern_params = vec![];
134 for field in record.fields.iter() {
135 let a = Self::get_type(&self.typemap, &field.r#type);
137 let t = Self::get_type(&self.typemap, &a);
138
139 let u = if t == field.r#type {
140 t.to_string()
141 } else {
142 format!("&{t}")
143 };
144
145 if type_is_opaque(&a) {
146 new_call_args.push(format!("field{}.data", field.name));
147 new_extern_params.push(format!("field{}: *const {a}", field.name));
148 } else if type_is_array(&t) {
149 new_call_args.push(format!("field{}.ptr", field.name));
150 new_extern_params.push(format!("field{}: *const {a}", field.name));
151 } else {
152 new_call_args.push(format!("field{}", field.name));
153 new_extern_params.push(format!("field{}: {a}", field.name));
154 }
155
156 new_params.push(format!("field{}: {u}", field.name));
157
158 let (output, futhark_field_type) = if type_is_opaque(&a) || type_is_array(&t) {
162 (
163 format!("Ok({t}::from_ptr(self.ctx, out))"),
164 format!("*mut {a}"),
165 )
166 } else {
167 ("Ok(out)".to_string(), a)
168 };
169
170 writeln!(
171 config.output_file,
172 include_str!("templates/rust/record_project.rs"),
173 project_fn = field.project,
174 rust_type = rust_type,
175 futhark_type = futhark_type,
176 field_name = field.name,
177 futhark_field_type = futhark_field_type,
178 rust_field_type = t,
179 output = output
180 )?;
181 }
182
183 writeln!(
184 config.output_file,
185 include_str!("templates/rust/record.rs"),
186 rust_type = rust_type,
187 futhark_type = futhark_type,
188 new_fn = record.new,
189 new_params = new_params.join(", "),
190 new_call_args = new_call_args.join(", "),
191 new_extern_params = new_extern_params.join(", "),
192 )?;
193
194 self.typemap.insert(name.to_string(), futhark_type.clone());
195 self.typemap.insert(futhark_type, rust_type);
196
197 Ok(())
198 }
199
200 fn entry(
201 &mut self,
202 _pkg: &Package,
203 config: &mut Config,
204 name: &str,
205 entry: &manifest::Entry,
206 ) -> Result<(), Error> {
207 let mut call_args = Vec::new();
208 let mut entry_params = Vec::new();
209 let mut return_type = Vec::new();
210 let mut out_decl = Vec::new();
211 let mut futhark_entry_params = Vec::new();
212 let mut entry_return = Vec::new();
213
214 for (i, arg) in entry.outputs.iter().enumerate() {
216 let a = Self::get_type(&self.typemap, &arg.r#type);
217
218 let name = format!("out{i}");
219
220 let t = Self::get_type(&self.typemap, &a);
221
222 if type_is_array(&t) || type_is_opaque(&a) {
223 futhark_entry_params.push(format!("{name}: *mut *mut {a}"));
224 } else {
225 futhark_entry_params.push(format!("{name}: *mut {a}"));
226 }
227
228 if type_is_array(&t) || type_is_opaque(&a) {
229 entry_return.push(format!("{t}::from_ptr(self, {name}.assume_init())",));
230 } else {
231 entry_return.push(format!("{name}.assume_init()"));
232 }
233
234 out_decl.push(format!("let mut {name} = std::mem::MaybeUninit::zeroed();"));
235 call_args.push(format!("{name}.as_mut_ptr()"));
236 return_type.push(t);
237 }
238
239 for (i, arg) in entry.inputs.iter().enumerate() {
241 let a = Self::get_type(&self.typemap, &arg.r#type);
242 let name = format!("input{i}");
243
244 let t = Self::get_type(&self.typemap, &a);
245
246 if type_is_array(&t) {
247 futhark_entry_params.push(format!("{name}: *const {a}"));
248
249 entry_params.push(format!("{name}: &{t}"));
250 call_args.push(format!("{name}.ptr as *mut _"));
251 } else if type_is_opaque(&a) {
252 futhark_entry_params.push(format!("{name}: *const {a}"));
253
254 entry_params.push(format!("{name}: &{t}"));
255 call_args.push(format!("{name}.data as *mut _"));
256 } else {
257 futhark_entry_params.push(format!("{name}: {a}"));
258 entry_params.push(format!("{name}: {t}"));
259 call_args.push(name);
260 }
261 }
262
263 let (entry_return_type, entry_return) = match entry.outputs.len() {
264 0 => ("()".to_string(), "()".to_string()),
265 1 => (return_type.join(", "), entry_return.join(", ")),
266 _ => (
267 format!("({})", return_type.join(", ")),
268 format!("({})", entry_return.join(", ")),
269 ),
270 };
271
272 writeln!(
273 config.output_file,
274 include_str!("templates/rust/entry.rs"),
275 entry_fn = entry.cfun,
276 entry_name = name,
277 entry_params = entry_params.join(", "),
278 entry_return_type = entry_return_type,
279 out_decl = out_decl.join(";\n"),
280 call_args = call_args.join(", "),
281 entry_return = entry_return,
282 futhark_entry_params = futhark_entry_params.join(", "),
283 )?;
284
285 Ok(())
286 }
287
288 fn bindings(&mut self, pkg: &Package, config: &mut Config) -> Result<(), Error> {
289 writeln!(config.output_file, "// Generated by futhark-bindgen\n")?;
290 let backend_extern_functions = match &pkg.manifest.backend {
291 Backend::Multicore => {
292 "fn futhark_context_config_set_num_threads(_: *mut futhark_context_config, _: std::os::raw::c_int);"
293 }
294 Backend::OpenCL | Backend::CUDA => {
295 "fn futhark_context_config_set_device(_: *mut futhark_context_config, _: *const std::os::raw::c_char);"
296 }
297 _ => "",
298 };
299
300 let backend_options = match pkg.manifest.backend {
301 Backend::Multicore => {
302 "pub fn threads(mut self, n: u32) -> Options { self.num_threads = n as u32; self }"
303 }
304 Backend::CUDA | Backend::OpenCL => {
305 "pub fn device(mut self, s: impl AsRef<str>) -> Options { self.device = Some(std::ffi::CString::new(s.as_ref()).expect(\"Invalid device\")); self }"
306 }
307 _ => "",
308 };
309
310 let configure_num_threads = if pkg.manifest.backend == Backend::Multicore {
311 "futhark_context_config_set_num_threads(config, options.num_threads as std::os::raw::c_int);"
312 } else {
313 "let _ = &options.num_threads;"
314 };
315
316 let configure_set_device = if matches!(
317 pkg.manifest.backend,
318 Backend::CUDA | Backend::OpenCL
319 ) {
320 "if let Some(d) = &options.device { futhark_context_config_set_device(config, d.as_ptr()); }"
321 } else {
322 "let _ = &options.device;"
323 };
324
325 writeln!(
326 config.output_file,
327 include_str!("templates/rust/context.rs"),
328 backend_options = backend_options,
329 configure_num_threads = configure_num_threads,
330 configure_set_device = configure_set_device,
331 backend_extern_functions = backend_extern_functions,
332 )?;
333
334 Ok(())
335 }
336
337 fn format(&mut self, path: &std::path::Path) -> Result<(), Error> {
338 let _ = std::process::Command::new("rustfmt").arg(path).status();
339 Ok(())
340 }
341}