futhark_bindgen/generate/
ocaml.rs

1use std::io::Write;
2
3use crate::generate::{convert_struct_name, first_uppercase};
4use crate::*;
5
6/// OCaml codegen
7pub struct OCaml {
8    typemap: BTreeMap<String, String>,
9    ctypes_map: BTreeMap<String, String>,
10    ba_map: BTreeMap<String, (String, String)>,
11    mli_file: std::fs::File,
12}
13
14const OCAML_CTYPES_MAP: &[(&str, &str)] = &[
15    ("i8", "char"),
16    ("u8", "uint8_t"),
17    ("i16", "int16_t"),
18    ("u16", "uint16_t"),
19    ("i32", "int32_t"),
20    ("u32", "uint32_t"),
21    ("i64", "int64_t"),
22    ("u64", "uint64_t"),
23    ("f16", ""), // No half type in OCaml
24    ("f32", "float"),
25    ("f64", "double"),
26    ("bool", "bool"),
27];
28
29const OCAML_TYPE_MAP: &[(&str, &str)] = &[
30    ("i8", "char"),
31    ("u8", "UInt8.t"),
32    ("i16", "int"),
33    ("u16", "UInt16.t"),
34    ("i32", "int32"),
35    ("i64", "int64"),
36    ("u32", "UInt32.t"),
37    ("u64", "UInt64.t"),
38    ("f16", ""), // No half type in OCaml
39    ("f32", "float"),
40    ("f64", "float"),
41    ("bool", "bool"),
42];
43
44const OCAML_BA_TYPE_MAP: &[(&str, (&str, &str))] = &[
45    ("i8", ("int", "Bigarray.int8_signed_elt")),
46    ("u8", ("int", "Bigarray.int8_unsigned_elt")),
47    ("i16", ("int", "Bigarray.int16_signed_elt")),
48    ("u16", ("int", "Bigarray.int16_unsigned_elt")),
49    ("i32", ("int32", "Bigarray.int32_elt")),
50    ("i64", ("int64", "Bigarray.int64_elt")),
51    ("u32", ("int32", "Bigarray.int32_elt")),
52    ("u64", ("int64", "Bigarray.int64_elt")),
53    ("f16", ("", "")), // No half Bigarray kind
54    ("f32", ("float", "Bigarray.float32_elt")),
55    ("f64", ("float", "Bigarray.float64_elt")),
56    ("bool", ("int", "Bigarray.int8_unsigned_elt")),
57];
58
59fn type_is_array(t: &str) -> bool {
60    t.contains("array_f") || t.contains("array_i") || t.contains("array_u") || t.contains("array_b")
61}
62
63fn type_is_opaque(t: &str) -> bool {
64    t.contains(".t")
65}
66
67fn ba_kind(t: &str) -> String {
68    let mut s = t.strip_suffix("_elt").unwrap().to_string();
69
70    if let Some(r) = s.get_mut(8..9) {
71        r.make_ascii_uppercase();
72    }
73
74    s
75}
76
77impl OCaml {
78    /// Create new OCaml codegen instance
79    pub fn new(config: &Config) -> Result<Self, Error> {
80        let typemap = OCAML_TYPE_MAP
81            .iter()
82            .map(|(a, b)| (a.to_string(), b.to_string()))
83            .collect();
84
85        let ba_map = OCAML_BA_TYPE_MAP
86            .iter()
87            .map(|(a, (b, c))| (a.to_string(), (b.to_string(), c.to_string())))
88            .collect();
89
90        let ctypes_map = OCAML_CTYPES_MAP
91            .iter()
92            .map(|(a, b)| (a.to_string(), b.to_string()))
93            .collect();
94
95        let mli_path = config.output_path.with_extension("mli");
96        let mli_file = std::fs::File::create(mli_path)?;
97        Ok(OCaml {
98            typemap,
99            ba_map,
100            ctypes_map,
101            mli_file,
102        })
103    }
104
105    fn foreign_function(&mut self, name: &str, ret: &str, args: Vec<&str>) -> String {
106        format!(
107            "let {name} = fn \"{name}\" ({} @-> returning ({ret}))",
108            args.join(" @-> ")
109        )
110    }
111
112    fn get_ctype(&self, t: &str) -> String {
113        let x = self
114            .ctypes_map
115            .get(t)
116            .cloned()
117            .unwrap_or_else(|| t.to_string());
118        if x.is_empty() {
119            panic!("Unsupported type: {t}");
120        }
121        x
122    }
123
124    fn get_type(&self, t: &str) -> String {
125        let x = self
126            .typemap
127            .get(t)
128            .cloned()
129            .unwrap_or_else(|| t.to_string());
130        if x.is_empty() {
131            panic!("Unsupported type: {t}");
132        }
133        x
134    }
135
136    fn get_ba_type(&self, t: &str) -> (String, String) {
137        let x = self.ba_map.get(t).cloned().unwrap();
138        if x.0.is_empty() {
139            panic!("Unsupported type: {t}");
140        }
141        x
142    }
143}
144
145impl Generate for OCaml {
146    fn bindings(&mut self, pkg: &Package, config: &mut Config) -> Result<(), Error> {
147        writeln!(self.mli_file, "(* Generated by futhark-bindgen *)\n")?;
148        writeln!(config.output_file, "(* Generated by futhark-bindgen *)\n")?;
149
150        let mut generated_foreign_functions = Vec::new();
151        match pkg.manifest.backend {
152            Backend::Multicore => {
153                generated_foreign_functions.push(format!(
154                    "  {}",
155                    self.foreign_function(
156                        "futhark_context_config_set_num_threads",
157                        "void",
158                        vec!["context_config", "int"]
159                    )
160                ));
161            }
162            Backend::CUDA | Backend::OpenCL => {
163                generated_foreign_functions.push(format!(
164                    "  {}",
165                    self.foreign_function(
166                        "futhark_context_config_set_device",
167                        "void",
168                        vec!["context_config", "string"]
169                    )
170                ));
171            }
172            _ => (),
173        }
174
175        for (name, ty) in &pkg.manifest.types {
176            match ty {
177                manifest::Type::Array(a) => {
178                    let elemtype = a.elemtype.to_str().to_string();
179                    let ctypes_elemtype = self.get_ctype(&elemtype);
180                    let rank = a.rank;
181                    let ocaml_name = format!("array_{elemtype}_{rank}d");
182                    self.typemap.insert(name.clone(), ocaml_name.clone());
183                    self.ctypes_map.insert(name.clone(), ocaml_name.clone());
184                    let elem_ptr = format!("ptr {ctypes_elemtype}");
185                    generated_foreign_functions.push(format!(
186                        "  let {ocaml_name} = typedef (ptr void) \"{ocaml_name}\""
187                    ));
188                    let mut new_args = vec!["context", &elem_ptr];
189                    new_args.resize(rank as usize + 2, "int64_t");
190                    generated_foreign_functions.push(format!(
191                        "  {}",
192                        self.foreign_function(&a.ops.new, &ocaml_name, new_args)
193                    ));
194                    generated_foreign_functions.push(format!(
195                        "  {}",
196                        self.foreign_function(
197                            &a.ops.values,
198                            "int",
199                            vec!["context", &ocaml_name, &elem_ptr]
200                        )
201                    ));
202                    generated_foreign_functions.push(format!(
203                        "  {}",
204                        self.foreign_function(&a.ops.free, "int", vec!["context", &ocaml_name])
205                    ));
206                    generated_foreign_functions.push(format!(
207                        "  {}",
208                        self.foreign_function(
209                            &a.ops.shape,
210                            "ptr int64_t",
211                            vec!["context", &ocaml_name]
212                        )
213                    ));
214                }
215                manifest::Type::Opaque(ty) => {
216                    let futhark_name = convert_struct_name(&ty.ctype);
217                    let mut ocaml_name = futhark_name
218                        .strip_prefix("futhark_opaque_")
219                        .unwrap()
220                        .to_string();
221                    if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
222                        ocaml_name = format!("type_{ocaml_name}");
223                    }
224
225                    self.typemap
226                        .insert(name.clone(), format!("{}.t", first_uppercase(&ocaml_name)));
227                    self.ctypes_map.insert(name.to_string(), ocaml_name.clone());
228                    generated_foreign_functions.push(format!(
229                        "  let {ocaml_name} = typedef (ptr void) \"{futhark_name}\""
230                    ));
231
232                    let free_fn = &ty.ops.free;
233                    generated_foreign_functions.push(format!(
234                        "  {}",
235                        self.foreign_function(free_fn, "int", vec!["context", &ocaml_name])
236                    ));
237
238                    let record = match &ty.record {
239                        Some(r) => r,
240                        None => continue,
241                    };
242
243                    let new_fn = &record.new;
244                    let mut args = vec!["context".to_string(), format!("ptr {ocaml_name}")];
245                    for f in record.fields.iter() {
246                        let cty = self
247                            .ctypes_map
248                            .get(&f.r#type)
249                            .cloned()
250                            .unwrap_or_else(|| f.r#type.clone());
251
252                        // project function
253                        generated_foreign_functions.push(format!(
254                            "  {}",
255                            self.foreign_function(
256                                &f.project,
257                                "int",
258                                vec!["context", &format!("ptr {cty}"), &ocaml_name]
259                            )
260                        ));
261
262                        args.push(cty);
263                    }
264                    let args = args.iter().map(|x| x.as_str()).collect();
265                    generated_foreign_functions
266                        .push(format!("  {}", self.foreign_function(new_fn, "int", args)));
267                }
268            }
269        }
270
271        for entry in pkg.manifest.entry_points.values() {
272            let mut args = vec!["context".to_string()];
273
274            for out in &entry.outputs {
275                let t = self.get_ctype(&out.r#type);
276
277                args.push(format!("ptr {t}"));
278            }
279
280            for input in &entry.inputs {
281                let t = self.get_ctype(&input.r#type);
282                args.push(t);
283            }
284
285            let args = args.iter().map(|x| x.as_str()).collect();
286            generated_foreign_functions.push(format!(
287                "  {}",
288                self.foreign_function(&entry.cfun, "int", args)
289            ));
290        }
291
292        let generated_foreign_functions = generated_foreign_functions.join("\n");
293
294        writeln!(
295            config.output_file,
296            include_str!("templates/ocaml/bindings.ml"),
297            generated_foreign_functions = generated_foreign_functions
298        )?;
299
300        writeln!(self.mli_file, include_str!("templates/ocaml/bindings.mli"))?;
301
302        let (extra_param, extra_line, extra_mli) = match pkg.manifest.backend {
303            Backend::Multicore => (
304                "?(num_threads = 0)",
305                "    Bindings.futhark_context_config_set_num_threads config num_threads;",
306                "?num_threads:int ->",
307            ),
308
309            Backend::CUDA | Backend::OpenCL => (
310                "?device",
311                "    Option.iter (Bindings.futhark_context_config_set_device config) device;",
312                "?device:string ->",
313            ),
314            _ => ("", "", ""),
315        };
316
317        writeln!(
318            config.output_file,
319            include_str!("templates/ocaml/context.ml"),
320            extra_param = extra_param,
321            extra_line = extra_line
322        )?;
323        writeln!(
324            self.mli_file,
325            include_str!("templates/ocaml/context.mli"),
326            extra_mli = extra_mli
327        )?;
328
329        Ok(())
330    }
331
332    fn array_type(
333        &mut self,
334        _pkg: &Package,
335        config: &mut Config,
336        name: &str,
337        ty: &manifest::ArrayType,
338    ) -> Result<(), Error> {
339        let rank = ty.rank;
340        let elemtype = ty.elemtype.to_str().to_string();
341        let ocaml_name = self.typemap.get(name).unwrap();
342        let module_name = first_uppercase(ocaml_name);
343        let mut dim_args = Vec::new();
344        for i in 0..rank {
345            dim_args.push(format!("(Int64.of_int dims.({i}))"));
346        }
347
348        let (ocaml_elemtype, ba_elemtype) = self.get_ba_type(&elemtype);
349        let ocaml_ctype = self.get_ctype(&elemtype);
350
351        writeln!(
352            config.output_file,
353            include_str!("templates/ocaml/array.ml"),
354            module_name = module_name,
355            elemtype = elemtype,
356            rank = rank,
357            dim_args = dim_args.join(" "),
358            ocaml_elemtype = ocaml_elemtype,
359            ba_elemtype = ba_elemtype,
360            ba_kind = ba_kind(&ba_elemtype),
361            ocaml_ctype = ocaml_ctype,
362        )?;
363
364        writeln!(
365            self.mli_file,
366            include_str!("templates/ocaml/array.mli"),
367            module_name = module_name,
368            ocaml_elemtype = ocaml_elemtype,
369            ba_elemtype = ba_elemtype,
370        )?;
371
372        Ok(())
373    }
374
375    fn opaque_type(
376        &mut self,
377        _pkg: &Package,
378        config: &mut Config,
379        name: &str,
380        ty: &manifest::OpaqueType,
381    ) -> Result<(), Error> {
382        let futhark_name = convert_struct_name(&ty.ctype);
383        let mut ocaml_name = futhark_name
384            .strip_prefix("futhark_opaque_")
385            .unwrap()
386            .to_string();
387        if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
388            ocaml_name = format!("type_{ocaml_name}");
389        }
390        let module_name = first_uppercase(&ocaml_name);
391        self.typemap
392            .insert(ocaml_name.clone(), format!("{module_name}.t"));
393
394        let free_fn = &ty.ops.free;
395
396        writeln!(config.output_file, "module {module_name} = struct")?;
397        writeln!(self.mli_file, "module {module_name} : sig")?;
398
399        writeln!(
400            config.output_file,
401            include_str!("templates/ocaml/opaque.ml"),
402            free_fn = free_fn,
403            name = ocaml_name,
404        )?;
405        writeln!(self.mli_file, include_str!("templates/ocaml/opaque.mli"),)?;
406
407        let record = match &ty.record {
408            Some(r) => r,
409            None => {
410                writeln!(config.output_file, "end")?;
411                writeln!(self.mli_file, "end")?;
412                return Ok(());
413            }
414        };
415
416        let mut new_params = Vec::new();
417        let mut new_call_args = Vec::new();
418        let mut new_arg_types = Vec::new();
419        for f in record.fields.iter() {
420            let t = self.get_type(&f.r#type);
421
422            new_params.push(format!("field{}", f.name));
423
424            if type_is_array(&t) {
425                new_call_args.push(format!("(get_ptr field{})", f.name));
426                new_arg_types.push(format!("{}.t", first_uppercase(&t)));
427            } else if type_is_opaque(&t) {
428                new_call_args.push(format!("(get_opaque_ptr field{})", f.name));
429                new_arg_types.push(t.to_string());
430            } else {
431                new_call_args.push(format!("field{}", f.name));
432                new_arg_types.push(t.to_string());
433            }
434        }
435
436        writeln!(
437            config.output_file,
438            include_str!("templates/ocaml/record.ml"),
439            new_params = new_params.join(" "),
440            new_fn = record.new,
441            new_call_args = new_call_args.join(" "),
442        )?;
443
444        writeln!(
445            self.mli_file,
446            include_str!("templates/ocaml/record.mli"),
447            new_arg_types = new_arg_types.join(" -> ")
448        )?;
449
450        for f in record.fields.iter() {
451            let t = self.get_type(&f.r#type);
452            let name = &f.name;
453            let project = &f.project;
454
455            let (out, out_type) = if type_is_opaque(&t) {
456                let call = t.replace(".t", ".of_ptr");
457                (format!("{call} t.opaque_ctx !@out"), t.to_string())
458            } else if type_is_array(&t) {
459                let array = first_uppercase(&t);
460                (
461                    format!("{array}.of_ptr t.opaque_ctx !@out"),
462                    format!("{}.t", first_uppercase(&t)),
463                )
464            } else {
465                ("!@out".to_string(), t.to_string())
466            };
467
468            let alloc_type = if type_is_array(&t) {
469                format!("Bindings.{t}")
470            } else if type_is_opaque(&t) {
471                t
472            } else {
473                self.get_ctype(&f.r#type)
474            };
475
476            writeln!(
477                config.output_file,
478                include_str!("templates/ocaml/record_project.ml"),
479                name = name,
480                s = alloc_type,
481                project = project,
482                out = out
483            )?;
484            writeln!(
485                self.mli_file,
486                include_str!("templates/ocaml/record_project.mli"),
487                name = name,
488                out_type = out_type
489            )?;
490        }
491
492        writeln!(config.output_file, "end\n")?;
493        writeln!(self.mli_file, "end\n")?;
494
495        Ok(())
496    }
497
498    fn entry(
499        &mut self,
500        _pkg: &Package,
501        config: &mut Config,
502        name: &str,
503        entry: &manifest::Entry,
504    ) -> Result<(), Error> {
505        let mut arg_types = Vec::new();
506        let mut return_type = Vec::new();
507        let mut entry_params = Vec::new();
508        let mut call_args = Vec::new();
509        let mut out_return = Vec::new();
510        let mut out_decl = Vec::new();
511
512        for (i, out) in entry.outputs.iter().enumerate() {
513            let t = self.get_type(&out.r#type);
514            let ct = self.get_ctype(&out.r#type);
515
516            let mut ocaml_elemtype = t.clone();
517
518            // Transform into `Module.t`
519            if ocaml_elemtype.contains("array_") {
520                ocaml_elemtype = first_uppercase(&ocaml_elemtype) + ".t"
521            }
522
523            return_type.push(ocaml_elemtype);
524
525            let i = if entry.outputs.len() == 1 {
526                String::new()
527            } else {
528                i.to_string()
529            };
530
531            if type_is_array(&t) || type_is_opaque(&t) {
532                out_decl.push(format!("  let out{i}_ptr = allocate (ptr void) null in"));
533            } else {
534                out_decl.push(format!("  let out{i}_ptr = allocate_n {ct} ~count:1 in"));
535            }
536
537            call_args.push(format!("out{i}_ptr"));
538
539            if type_is_array(&t) {
540                let m = first_uppercase(&t);
541                out_return.push(format!("({m}.of_ptr ctx !@out{i}_ptr)"));
542            } else if type_is_opaque(&t) {
543                let m = first_uppercase(&t);
544                let m = m.strip_suffix(".t").unwrap_or(&m);
545                out_return.push(format!("({m}.of_ptr ctx !@out{i}_ptr)"));
546            } else {
547                out_return.push(format!("!@out{i}_ptr"));
548            }
549        }
550
551        for (i, input) in entry.inputs.iter().enumerate() {
552            entry_params.push(format!("input{i}"));
553
554            let mut ocaml_elemtype = self.get_type(&input.r#type);
555
556            // Transform into `Module.t`
557            if type_is_array(&ocaml_elemtype) {
558                ocaml_elemtype = first_uppercase(&ocaml_elemtype) + ".t"
559            }
560
561            arg_types.push(ocaml_elemtype);
562
563            let t = self.get_type(&input.r#type);
564            if type_is_array(&t) {
565                call_args.push(format!("(get_ptr input{i})"));
566            } else if type_is_opaque(&t) {
567                call_args.push(format!("(get_opaque_ptr input{i})"));
568            } else {
569                call_args.push(format!("input{i}"));
570            }
571        }
572
573        writeln!(
574            config.output_file,
575            include_str!("templates/ocaml/entry.ml"),
576            name = name,
577            entry_params = entry_params.join(" "),
578            out_decl = out_decl.join("\n"),
579            call_args = call_args.join(" "),
580            out_return = out_return.join(", ")
581        )?;
582
583        let return_type = if return_type.is_empty() {
584            "unit".to_string()
585        } else {
586            return_type.join(" * ")
587        };
588        writeln!(
589            self.mli_file,
590            include_str!("templates/ocaml/entry.mli"),
591            name = name,
592            arg_types = arg_types.join(" -> "),
593            return_type = return_type,
594        )?;
595
596        Ok(())
597    }
598}