futhark_bindgen/generate/
rust.rs

1use crate::generate::{convert_struct_name, first_uppercase};
2use crate::*;
3use std::io::Write;
4
5/// Rust codegen
6pub struct Rust {
7    typemap: BTreeMap<String, String>,
8}
9
10fn type_is_array(t: &str) -> bool {
11    t.contains("ArrayF") || t.contains("ArrayI") || t.contains("ArrayU") || t.contains("ArrayB")
12}
13
14fn type_is_opaque(a: &str) -> bool {
15    a.contains("futhark_opaque_")
16}
17
18// Rust `f16` codgen requires the `half` crate
19const RUST_TYPE_MAP: &[(&str, &str)] = &[("f16", "half::f16")];
20
21impl Default for Rust {
22    fn default() -> Self {
23        let typemap = RUST_TYPE_MAP
24            .iter()
25            .map(|(a, b)| (a.to_string(), b.to_string()))
26            .collect();
27        Rust { typemap }
28    }
29}
30
31struct ArrayInfo {
32    futhark_type: String,
33    rust_type: String,
34    #[allow(unused)]
35    elem: String,
36}
37
38impl Rust {
39    fn get_type(typemap: &BTreeMap<String, String>, t: &str) -> String {
40        let a = typemap.get(t);
41        let x = match a {
42            Some(t) => t.clone(),
43            None => t.to_string(),
44        };
45        if x.is_empty() {
46            panic!("Unsupported type: {t}");
47        }
48        x
49    }
50}
51
52impl Generate for Rust {
53    fn array_type(
54        &mut self,
55        _pkg: &Package,
56        config: &mut Config,
57        name: &str,
58        a: &manifest::ArrayType,
59    ) -> Result<(), Error> {
60        let elemtype = a.elemtype.to_str();
61        let rank = a.rank;
62
63        let futhark_type = convert_struct_name(&a.ctype).to_string();
64        let rust_type = format!("Array{}D{rank}", first_uppercase(elemtype));
65        let info = ArrayInfo {
66            futhark_type,
67            rust_type,
68            elem: elemtype.to_string(),
69        };
70
71        let mut dim_params = Vec::new();
72        let mut new_dim_args = Vec::new();
73
74        for i in 0..a.rank {
75            let dim = format!("dims[{i}]");
76            dim_params.push(dim);
77            new_dim_args.push(format!("dim{i}: i64"));
78        }
79
80        writeln!(
81            config.output_file,
82            include_str!("templates/rust/array.rs"),
83            futhark_type = info.futhark_type,
84            rust_type = info.rust_type,
85            rank = a.rank,
86            elemtype = info.elem,
87            new_fn = a.ops.new,
88            free_fn = a.ops.free,
89            values_fn = a.ops.values,
90            shape_fn = a.ops.shape,
91            dim_params = dim_params.join(", "),
92            new_dim_args = new_dim_args.join(", ")
93        )?;
94
95        self.typemap
96            .insert(name.to_string(), info.futhark_type.clone());
97        self.typemap.insert(info.futhark_type, info.rust_type);
98        Ok(())
99    }
100
101    fn opaque_type(
102        &mut self,
103        _pkg: &Package,
104        config: &mut Config,
105        name: &str,
106        ty: &manifest::OpaqueType,
107    ) -> Result<(), Error> {
108        let futhark_type = convert_struct_name(&ty.ctype).to_string();
109        let mut rust_type = first_uppercase(futhark_type.strip_prefix("futhark_opaque_").unwrap());
110        if rust_type.chars().next().unwrap().is_numeric() || name.contains(' ') {
111            rust_type = format!("Type{}", rust_type);
112        }
113
114        writeln!(
115            config.output_file,
116            include_str!("templates/rust/opaque.rs"),
117            futhark_type = futhark_type,
118            rust_type = rust_type,
119            free_fn = ty.ops.free,
120        )?;
121
122        let record = match &ty.record {
123            Some(r) => r,
124            None => {
125                self.typemap.insert(name.to_string(), futhark_type.clone());
126                self.typemap.insert(futhark_type, rust_type);
127                return Ok(());
128            }
129        };
130
131        let mut new_call_args = vec![];
132        let mut new_params = vec![];
133        let mut new_extern_params = vec![];
134        for field in record.fields.iter() {
135            // Build new function
136            let a = Self::get_type(&self.typemap, &field.r#type);
137            let t = Self::get_type(&self.typemap, &a);
138
139            let u = if t == field.r#type {
140                t.to_string()
141            } else {
142                format!("&{t}")
143            };
144
145            if type_is_opaque(&a) {
146                new_call_args.push(format!("field{}.data", field.name));
147                new_extern_params.push(format!("field{}: *const {a}", field.name));
148            } else if type_is_array(&t) {
149                new_call_args.push(format!("field{}.ptr", field.name));
150                new_extern_params.push(format!("field{}: *const {a}", field.name));
151            } else {
152                new_call_args.push(format!("field{}", field.name));
153                new_extern_params.push(format!("field{}: {a}", field.name));
154            }
155
156            new_params.push(format!("field{}: {u}", field.name));
157
158            // Implement get function
159
160            // If the output type is an array or opaque type then we need to wrap the return value
161            let (output, futhark_field_type) = if type_is_opaque(&a) || type_is_array(&t) {
162                (
163                    format!("Ok({t}::from_ptr(self.ctx, out))"),
164                    format!("*mut {a}"),
165                )
166            } else {
167                ("Ok(out)".to_string(), a)
168            };
169
170            writeln!(
171                config.output_file,
172                include_str!("templates/rust/record_project.rs"),
173                project_fn = field.project,
174                rust_type = rust_type,
175                futhark_type = futhark_type,
176                field_name = field.name,
177                futhark_field_type = futhark_field_type,
178                rust_field_type = t,
179                output = output
180            )?;
181        }
182
183        writeln!(
184            config.output_file,
185            include_str!("templates/rust/record.rs"),
186            rust_type = rust_type,
187            futhark_type = futhark_type,
188            new_fn = record.new,
189            new_params = new_params.join(", "),
190            new_call_args = new_call_args.join(", "),
191            new_extern_params = new_extern_params.join(", "),
192        )?;
193
194        self.typemap.insert(name.to_string(), futhark_type.clone());
195        self.typemap.insert(futhark_type, rust_type);
196
197        Ok(())
198    }
199
200    fn entry(
201        &mut self,
202        _pkg: &Package,
203        config: &mut Config,
204        name: &str,
205        entry: &manifest::Entry,
206    ) -> Result<(), Error> {
207        let mut call_args = Vec::new();
208        let mut entry_params = Vec::new();
209        let mut return_type = Vec::new();
210        let mut out_decl = Vec::new();
211        let mut futhark_entry_params = Vec::new();
212        let mut entry_return = Vec::new();
213
214        // Output arguments
215        for (i, arg) in entry.outputs.iter().enumerate() {
216            let a = Self::get_type(&self.typemap, &arg.r#type);
217
218            let name = format!("out{i}");
219
220            let t = Self::get_type(&self.typemap, &a);
221
222            if type_is_array(&t) || type_is_opaque(&a) {
223                futhark_entry_params.push(format!("{name}: *mut *mut {a}"));
224            } else {
225                futhark_entry_params.push(format!("{name}: *mut {a}"));
226            }
227
228            if type_is_array(&t) || type_is_opaque(&a) {
229                entry_return.push(format!("{t}::from_ptr(self, {name}.assume_init())",));
230            } else {
231                entry_return.push(format!("{name}.assume_init()"));
232            }
233
234            out_decl.push(format!("let mut {name} = std::mem::MaybeUninit::zeroed();"));
235            call_args.push(format!("{name}.as_mut_ptr()"));
236            return_type.push(t);
237        }
238
239        // Input arguments
240        for (i, arg) in entry.inputs.iter().enumerate() {
241            let a = Self::get_type(&self.typemap, &arg.r#type);
242            let name = format!("input{i}");
243
244            let t = Self::get_type(&self.typemap, &a);
245
246            if type_is_array(&t) {
247                futhark_entry_params.push(format!("{name}: *const {a}"));
248
249                entry_params.push(format!("{name}: &{t}"));
250                call_args.push(format!("{name}.ptr as *mut _"));
251            } else if type_is_opaque(&a) {
252                futhark_entry_params.push(format!("{name}: *const {a}"));
253
254                entry_params.push(format!("{name}: &{t}"));
255                call_args.push(format!("{name}.data as *mut _"));
256            } else {
257                futhark_entry_params.push(format!("{name}: {a}"));
258                entry_params.push(format!("{name}: {t}"));
259                call_args.push(name);
260            }
261        }
262
263        let (entry_return_type, entry_return) = match entry.outputs.len() {
264            0 => ("()".to_string(), "()".to_string()),
265            1 => (return_type.join(", "), entry_return.join(", ")),
266            _ => (
267                format!("({})", return_type.join(", ")),
268                format!("({})", entry_return.join(", ")),
269            ),
270        };
271
272        writeln!(
273            config.output_file,
274            include_str!("templates/rust/entry.rs"),
275            entry_fn = entry.cfun,
276            entry_name = name,
277            entry_params = entry_params.join(", "),
278            entry_return_type = entry_return_type,
279            out_decl = out_decl.join(";\n"),
280            call_args = call_args.join(", "),
281            entry_return = entry_return,
282            futhark_entry_params = futhark_entry_params.join(", "),
283        )?;
284
285        Ok(())
286    }
287
288    fn bindings(&mut self, pkg: &Package, config: &mut Config) -> Result<(), Error> {
289        writeln!(config.output_file, "// Generated by futhark-bindgen\n")?;
290        let backend_extern_functions = match &pkg.manifest.backend {
291            Backend::Multicore => {
292                "fn futhark_context_config_set_num_threads(_: *mut futhark_context_config, _: std::os::raw::c_int);"
293            }
294            Backend::OpenCL | Backend::CUDA => {
295                "fn futhark_context_config_set_device(_: *mut futhark_context_config, _: *const std::os::raw::c_char);"
296            }
297            _ => "",
298        };
299
300        let backend_options = match pkg.manifest.backend {
301            Backend::Multicore => {
302                "pub fn threads(mut self, n: u32) -> Options { self.num_threads = n as u32; self }"
303            }
304            Backend::CUDA | Backend::OpenCL => {
305                "pub fn device(mut self, s: impl AsRef<str>) -> Options { self.device = Some(std::ffi::CString::new(s.as_ref()).expect(\"Invalid device\")); self }"
306            }
307            _ => "",
308        };
309
310        let configure_num_threads = if pkg.manifest.backend == Backend::Multicore {
311            "futhark_context_config_set_num_threads(config, options.num_threads as std::os::raw::c_int);"
312        } else {
313            "let _ = &options.num_threads;"
314        };
315
316        let configure_set_device = if matches!(
317            pkg.manifest.backend,
318            Backend::CUDA | Backend::OpenCL
319        ) {
320            "if let Some(d) = &options.device { futhark_context_config_set_device(config, d.as_ptr()); }"
321        } else {
322            "let _ = &options.device;"
323        };
324
325        writeln!(
326            config.output_file,
327            include_str!("templates/rust/context.rs"),
328            backend_options = backend_options,
329            configure_num_threads = configure_num_threads,
330            configure_set_device = configure_set_device,
331            backend_extern_functions = backend_extern_functions,
332        )?;
333
334        Ok(())
335    }
336
337    fn format(&mut self, path: &std::path::Path) -> Result<(), Error> {
338        let _ = std::process::Command::new("rustfmt").arg(path).status();
339        Ok(())
340    }
341}