1use std::io::Write;
2
3use crate::generate::{convert_struct_name, first_uppercase};
4use crate::*;
5
6pub struct OCaml {
8 typemap: BTreeMap<String, String>,
9 ctypes_map: BTreeMap<String, String>,
10 ba_map: BTreeMap<String, (String, String)>,
11 mli_file: std::fs::File,
12}
13
14const OCAML_CTYPES_MAP: &[(&str, &str)] = &[
15 ("i8", "char"),
16 ("u8", "uint8_t"),
17 ("i16", "int16_t"),
18 ("u16", "uint16_t"),
19 ("i32", "int32_t"),
20 ("u32", "uint32_t"),
21 ("i64", "int64_t"),
22 ("u64", "uint64_t"),
23 ("f16", ""), ("f32", "float"),
25 ("f64", "double"),
26 ("bool", "bool"),
27];
28
29const OCAML_TYPE_MAP: &[(&str, &str)] = &[
30 ("i8", "char"),
31 ("u8", "UInt8.t"),
32 ("i16", "int"),
33 ("u16", "UInt16.t"),
34 ("i32", "int32"),
35 ("i64", "int64"),
36 ("u32", "UInt32.t"),
37 ("u64", "UInt64.t"),
38 ("f16", ""), ("f32", "float"),
40 ("f64", "float"),
41 ("bool", "bool"),
42];
43
44const OCAML_BA_TYPE_MAP: &[(&str, (&str, &str))] = &[
45 ("i8", ("int", "Bigarray.int8_signed_elt")),
46 ("u8", ("int", "Bigarray.int8_unsigned_elt")),
47 ("i16", ("int", "Bigarray.int16_signed_elt")),
48 ("u16", ("int", "Bigarray.int16_unsigned_elt")),
49 ("i32", ("int32", "Bigarray.int32_elt")),
50 ("i64", ("int64", "Bigarray.int64_elt")),
51 ("u32", ("int32", "Bigarray.int32_elt")),
52 ("u64", ("int64", "Bigarray.int64_elt")),
53 ("f16", ("", "")), ("f32", ("float", "Bigarray.float32_elt")),
55 ("f64", ("float", "Bigarray.float64_elt")),
56 ("bool", ("int", "Bigarray.int8_unsigned_elt")),
57];
58
59fn type_is_array(t: &str) -> bool {
60 t.contains("array_f") || t.contains("array_i") || t.contains("array_u") || t.contains("array_b")
61}
62
63fn type_is_opaque(t: &str) -> bool {
64 t.contains(".t")
65}
66
67fn ba_kind(t: &str) -> String {
68 let mut s = t.strip_suffix("_elt").unwrap().to_string();
69
70 if let Some(r) = s.get_mut(8..9) {
71 r.make_ascii_uppercase();
72 }
73
74 s
75}
76
77impl OCaml {
78 pub fn new(config: &Config) -> Result<Self, Error> {
80 let typemap = OCAML_TYPE_MAP
81 .iter()
82 .map(|(a, b)| (a.to_string(), b.to_string()))
83 .collect();
84
85 let ba_map = OCAML_BA_TYPE_MAP
86 .iter()
87 .map(|(a, (b, c))| (a.to_string(), (b.to_string(), c.to_string())))
88 .collect();
89
90 let ctypes_map = OCAML_CTYPES_MAP
91 .iter()
92 .map(|(a, b)| (a.to_string(), b.to_string()))
93 .collect();
94
95 let mli_path = config.output_path.with_extension("mli");
96 let mli_file = std::fs::File::create(mli_path)?;
97 Ok(OCaml {
98 typemap,
99 ba_map,
100 ctypes_map,
101 mli_file,
102 })
103 }
104
105 fn foreign_function(&mut self, name: &str, ret: &str, args: Vec<&str>) -> String {
106 format!(
107 "let {name} = fn \"{name}\" ({} @-> returning ({ret}))",
108 args.join(" @-> ")
109 )
110 }
111
112 fn get_ctype(&self, t: &str) -> String {
113 let x = self
114 .ctypes_map
115 .get(t)
116 .cloned()
117 .unwrap_or_else(|| t.to_string());
118 if x.is_empty() {
119 panic!("Unsupported type: {t}");
120 }
121 x
122 }
123
124 fn get_type(&self, t: &str) -> String {
125 let x = self
126 .typemap
127 .get(t)
128 .cloned()
129 .unwrap_or_else(|| t.to_string());
130 if x.is_empty() {
131 panic!("Unsupported type: {t}");
132 }
133 x
134 }
135
136 fn get_ba_type(&self, t: &str) -> (String, String) {
137 let x = self.ba_map.get(t).cloned().unwrap();
138 if x.0.is_empty() {
139 panic!("Unsupported type: {t}");
140 }
141 x
142 }
143}
144
145impl Generate for OCaml {
146 fn bindings(&mut self, pkg: &Package, config: &mut Config) -> Result<(), Error> {
147 writeln!(self.mli_file, "(* Generated by futhark-bindgen *)\n")?;
148 writeln!(config.output_file, "(* Generated by futhark-bindgen *)\n")?;
149
150 let mut generated_foreign_functions = Vec::new();
151 match pkg.manifest.backend {
152 Backend::Multicore => {
153 generated_foreign_functions.push(format!(
154 " {}",
155 self.foreign_function(
156 "futhark_context_config_set_num_threads",
157 "void",
158 vec!["context_config", "int"]
159 )
160 ));
161 }
162 Backend::CUDA | Backend::OpenCL => {
163 generated_foreign_functions.push(format!(
164 " {}",
165 self.foreign_function(
166 "futhark_context_config_set_device",
167 "void",
168 vec!["context_config", "string"]
169 )
170 ));
171 }
172 _ => (),
173 }
174
175 for (name, ty) in &pkg.manifest.types {
176 match ty {
177 manifest::Type::Array(a) => {
178 let elemtype = a.elemtype.to_str().to_string();
179 let ctypes_elemtype = self.get_ctype(&elemtype);
180 let rank = a.rank;
181 let ocaml_name = format!("array_{elemtype}_{rank}d");
182 self.typemap.insert(name.clone(), ocaml_name.clone());
183 self.ctypes_map.insert(name.clone(), ocaml_name.clone());
184 let elem_ptr = format!("ptr {ctypes_elemtype}");
185 generated_foreign_functions.push(format!(
186 " let {ocaml_name} = typedef (ptr void) \"{ocaml_name}\""
187 ));
188 let mut new_args = vec!["context", &elem_ptr];
189 new_args.resize(rank as usize + 2, "int64_t");
190 generated_foreign_functions.push(format!(
191 " {}",
192 self.foreign_function(&a.ops.new, &ocaml_name, new_args)
193 ));
194 generated_foreign_functions.push(format!(
195 " {}",
196 self.foreign_function(
197 &a.ops.values,
198 "int",
199 vec!["context", &ocaml_name, &elem_ptr]
200 )
201 ));
202 generated_foreign_functions.push(format!(
203 " {}",
204 self.foreign_function(&a.ops.free, "int", vec!["context", &ocaml_name])
205 ));
206 generated_foreign_functions.push(format!(
207 " {}",
208 self.foreign_function(
209 &a.ops.shape,
210 "ptr int64_t",
211 vec!["context", &ocaml_name]
212 )
213 ));
214 }
215 manifest::Type::Opaque(ty) => {
216 let futhark_name = convert_struct_name(&ty.ctype);
217 let mut ocaml_name = futhark_name
218 .strip_prefix("futhark_opaque_")
219 .unwrap()
220 .to_string();
221 if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
222 ocaml_name = format!("type_{ocaml_name}");
223 }
224
225 self.typemap
226 .insert(name.clone(), format!("{}.t", first_uppercase(&ocaml_name)));
227 self.ctypes_map.insert(name.to_string(), ocaml_name.clone());
228 generated_foreign_functions.push(format!(
229 " let {ocaml_name} = typedef (ptr void) \"{futhark_name}\""
230 ));
231
232 let free_fn = &ty.ops.free;
233 generated_foreign_functions.push(format!(
234 " {}",
235 self.foreign_function(free_fn, "int", vec!["context", &ocaml_name])
236 ));
237
238 let record = match &ty.record {
239 Some(r) => r,
240 None => continue,
241 };
242
243 let new_fn = &record.new;
244 let mut args = vec!["context".to_string(), format!("ptr {ocaml_name}")];
245 for f in record.fields.iter() {
246 let cty = self
247 .ctypes_map
248 .get(&f.r#type)
249 .cloned()
250 .unwrap_or_else(|| f.r#type.clone());
251
252 generated_foreign_functions.push(format!(
254 " {}",
255 self.foreign_function(
256 &f.project,
257 "int",
258 vec!["context", &format!("ptr {cty}"), &ocaml_name]
259 )
260 ));
261
262 args.push(cty);
263 }
264 let args = args.iter().map(|x| x.as_str()).collect();
265 generated_foreign_functions
266 .push(format!(" {}", self.foreign_function(new_fn, "int", args)));
267 }
268 }
269 }
270
271 for entry in pkg.manifest.entry_points.values() {
272 let mut args = vec!["context".to_string()];
273
274 for out in &entry.outputs {
275 let t = self.get_ctype(&out.r#type);
276
277 args.push(format!("ptr {t}"));
278 }
279
280 for input in &entry.inputs {
281 let t = self.get_ctype(&input.r#type);
282 args.push(t);
283 }
284
285 let args = args.iter().map(|x| x.as_str()).collect();
286 generated_foreign_functions.push(format!(
287 " {}",
288 self.foreign_function(&entry.cfun, "int", args)
289 ));
290 }
291
292 let generated_foreign_functions = generated_foreign_functions.join("\n");
293
294 writeln!(
295 config.output_file,
296 include_str!("templates/ocaml/bindings.ml"),
297 generated_foreign_functions = generated_foreign_functions
298 )?;
299
300 writeln!(self.mli_file, include_str!("templates/ocaml/bindings.mli"))?;
301
302 let (extra_param, extra_line, extra_mli) = match pkg.manifest.backend {
303 Backend::Multicore => (
304 "?(num_threads = 0)",
305 " Bindings.futhark_context_config_set_num_threads config num_threads;",
306 "?num_threads:int ->",
307 ),
308
309 Backend::CUDA | Backend::OpenCL => (
310 "?device",
311 " Option.iter (Bindings.futhark_context_config_set_device config) device;",
312 "?device:string ->",
313 ),
314 _ => ("", "", ""),
315 };
316
317 writeln!(
318 config.output_file,
319 include_str!("templates/ocaml/context.ml"),
320 extra_param = extra_param,
321 extra_line = extra_line
322 )?;
323 writeln!(
324 self.mli_file,
325 include_str!("templates/ocaml/context.mli"),
326 extra_mli = extra_mli
327 )?;
328
329 Ok(())
330 }
331
332 fn array_type(
333 &mut self,
334 _pkg: &Package,
335 config: &mut Config,
336 name: &str,
337 ty: &manifest::ArrayType,
338 ) -> Result<(), Error> {
339 let rank = ty.rank;
340 let elemtype = ty.elemtype.to_str().to_string();
341 let ocaml_name = self.typemap.get(name).unwrap();
342 let module_name = first_uppercase(ocaml_name);
343 let mut dim_args = Vec::new();
344 for i in 0..rank {
345 dim_args.push(format!("(Int64.of_int dims.({i}))"));
346 }
347
348 let (ocaml_elemtype, ba_elemtype) = self.get_ba_type(&elemtype);
349 let ocaml_ctype = self.get_ctype(&elemtype);
350
351 writeln!(
352 config.output_file,
353 include_str!("templates/ocaml/array.ml"),
354 module_name = module_name,
355 elemtype = elemtype,
356 rank = rank,
357 dim_args = dim_args.join(" "),
358 ocaml_elemtype = ocaml_elemtype,
359 ba_elemtype = ba_elemtype,
360 ba_kind = ba_kind(&ba_elemtype),
361 ocaml_ctype = ocaml_ctype,
362 )?;
363
364 writeln!(
365 self.mli_file,
366 include_str!("templates/ocaml/array.mli"),
367 module_name = module_name,
368 ocaml_elemtype = ocaml_elemtype,
369 ba_elemtype = ba_elemtype,
370 )?;
371
372 Ok(())
373 }
374
375 fn opaque_type(
376 &mut self,
377 _pkg: &Package,
378 config: &mut Config,
379 name: &str,
380 ty: &manifest::OpaqueType,
381 ) -> Result<(), Error> {
382 let futhark_name = convert_struct_name(&ty.ctype);
383 let mut ocaml_name = futhark_name
384 .strip_prefix("futhark_opaque_")
385 .unwrap()
386 .to_string();
387 if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
388 ocaml_name = format!("type_{ocaml_name}");
389 }
390 let module_name = first_uppercase(&ocaml_name);
391 self.typemap
392 .insert(ocaml_name.clone(), format!("{module_name}.t"));
393
394 let free_fn = &ty.ops.free;
395
396 writeln!(config.output_file, "module {module_name} = struct")?;
397 writeln!(self.mli_file, "module {module_name} : sig")?;
398
399 writeln!(
400 config.output_file,
401 include_str!("templates/ocaml/opaque.ml"),
402 free_fn = free_fn,
403 name = ocaml_name,
404 )?;
405 writeln!(self.mli_file, include_str!("templates/ocaml/opaque.mli"),)?;
406
407 let record = match &ty.record {
408 Some(r) => r,
409 None => {
410 writeln!(config.output_file, "end")?;
411 writeln!(self.mli_file, "end")?;
412 return Ok(());
413 }
414 };
415
416 let mut new_params = Vec::new();
417 let mut new_call_args = Vec::new();
418 let mut new_arg_types = Vec::new();
419 for f in record.fields.iter() {
420 let t = self.get_type(&f.r#type);
421
422 new_params.push(format!("field{}", f.name));
423
424 if type_is_array(&t) {
425 new_call_args.push(format!("(get_ptr field{})", f.name));
426 new_arg_types.push(format!("{}.t", first_uppercase(&t)));
427 } else if type_is_opaque(&t) {
428 new_call_args.push(format!("(get_opaque_ptr field{})", f.name));
429 new_arg_types.push(t.to_string());
430 } else {
431 new_call_args.push(format!("field{}", f.name));
432 new_arg_types.push(t.to_string());
433 }
434 }
435
436 writeln!(
437 config.output_file,
438 include_str!("templates/ocaml/record.ml"),
439 new_params = new_params.join(" "),
440 new_fn = record.new,
441 new_call_args = new_call_args.join(" "),
442 )?;
443
444 writeln!(
445 self.mli_file,
446 include_str!("templates/ocaml/record.mli"),
447 new_arg_types = new_arg_types.join(" -> ")
448 )?;
449
450 for f in record.fields.iter() {
451 let t = self.get_type(&f.r#type);
452 let name = &f.name;
453 let project = &f.project;
454
455 let (out, out_type) = if type_is_opaque(&t) {
456 let call = t.replace(".t", ".of_ptr");
457 (format!("{call} t.opaque_ctx !@out"), t.to_string())
458 } else if type_is_array(&t) {
459 let array = first_uppercase(&t);
460 (
461 format!("{array}.of_ptr t.opaque_ctx !@out"),
462 format!("{}.t", first_uppercase(&t)),
463 )
464 } else {
465 ("!@out".to_string(), t.to_string())
466 };
467
468 let alloc_type = if type_is_array(&t) {
469 format!("Bindings.{t}")
470 } else if type_is_opaque(&t) {
471 t
472 } else {
473 self.get_ctype(&f.r#type)
474 };
475
476 writeln!(
477 config.output_file,
478 include_str!("templates/ocaml/record_project.ml"),
479 name = name,
480 s = alloc_type,
481 project = project,
482 out = out
483 )?;
484 writeln!(
485 self.mli_file,
486 include_str!("templates/ocaml/record_project.mli"),
487 name = name,
488 out_type = out_type
489 )?;
490 }
491
492 writeln!(config.output_file, "end\n")?;
493 writeln!(self.mli_file, "end\n")?;
494
495 Ok(())
496 }
497
498 fn entry(
499 &mut self,
500 _pkg: &Package,
501 config: &mut Config,
502 name: &str,
503 entry: &manifest::Entry,
504 ) -> Result<(), Error> {
505 let mut arg_types = Vec::new();
506 let mut return_type = Vec::new();
507 let mut entry_params = Vec::new();
508 let mut call_args = Vec::new();
509 let mut out_return = Vec::new();
510 let mut out_decl = Vec::new();
511
512 for (i, out) in entry.outputs.iter().enumerate() {
513 let t = self.get_type(&out.r#type);
514 let ct = self.get_ctype(&out.r#type);
515
516 let mut ocaml_elemtype = t.clone();
517
518 if ocaml_elemtype.contains("array_") {
520 ocaml_elemtype = first_uppercase(&ocaml_elemtype) + ".t"
521 }
522
523 return_type.push(ocaml_elemtype);
524
525 let i = if entry.outputs.len() == 1 {
526 String::new()
527 } else {
528 i.to_string()
529 };
530
531 if type_is_array(&t) || type_is_opaque(&t) {
532 out_decl.push(format!(" let out{i}_ptr = allocate (ptr void) null in"));
533 } else {
534 out_decl.push(format!(" let out{i}_ptr = allocate_n {ct} ~count:1 in"));
535 }
536
537 call_args.push(format!("out{i}_ptr"));
538
539 if type_is_array(&t) {
540 let m = first_uppercase(&t);
541 out_return.push(format!("({m}.of_ptr ctx !@out{i}_ptr)"));
542 } else if type_is_opaque(&t) {
543 let m = first_uppercase(&t);
544 let m = m.strip_suffix(".t").unwrap_or(&m);
545 out_return.push(format!("({m}.of_ptr ctx !@out{i}_ptr)"));
546 } else {
547 out_return.push(format!("!@out{i}_ptr"));
548 }
549 }
550
551 for (i, input) in entry.inputs.iter().enumerate() {
552 entry_params.push(format!("input{i}"));
553
554 let mut ocaml_elemtype = self.get_type(&input.r#type);
555
556 if type_is_array(&ocaml_elemtype) {
558 ocaml_elemtype = first_uppercase(&ocaml_elemtype) + ".t"
559 }
560
561 arg_types.push(ocaml_elemtype);
562
563 let t = self.get_type(&input.r#type);
564 if type_is_array(&t) {
565 call_args.push(format!("(get_ptr input{i})"));
566 } else if type_is_opaque(&t) {
567 call_args.push(format!("(get_opaque_ptr input{i})"));
568 } else {
569 call_args.push(format!("input{i}"));
570 }
571 }
572
573 writeln!(
574 config.output_file,
575 include_str!("templates/ocaml/entry.ml"),
576 name = name,
577 entry_params = entry_params.join(" "),
578 out_decl = out_decl.join("\n"),
579 call_args = call_args.join(" "),
580 out_return = out_return.join(", ")
581 )?;
582
583 let return_type = if return_type.is_empty() {
584 "unit".to_string()
585 } else {
586 return_type.join(" * ")
587 };
588 writeln!(
589 self.mli_file,
590 include_str!("templates/ocaml/entry.mli"),
591 name = name,
592 arg_types = arg_types.join(" -> "),
593 return_type = return_type,
594 )?;
595
596 Ok(())
597 }
598}