genfut/
entry.rs

1use std::fmt::Write;
2
3use inflector::Inflector;
4use regex::Regex;
5
6fn type_translation(input: &str) -> String {
7    if input.starts_with("futhark") {
8        auto_ctor(input)
9    } else {
10        let mut buffer = String::new();
11        if input.starts_with("int8") {
12            write!(&mut buffer, "i8");
13        } else if input.starts_with("int") {
14            write!(&mut buffer, "i{}", &input[3..5]);
15        } else if input.starts_with("uint8") {
16            write!(&mut buffer, "u8");
17        } else if input.starts_with("uint") {
18            write!(&mut buffer, "u{}", &input[4..6]);
19        } else if input.starts_with("float") {
20            write!(&mut buffer, "f32");
21        } else if input.starts_with("double") {
22            write!(&mut buffer, "f64");
23        }
24        buffer
25    }
26}
27
28fn ctor_array_type(t: &str, dim: usize) -> String {
29    format!("Array_{}_{}d", t, dim)
30}
31
32const RE_ARRAY_TYPE_STR: &str = r"futhark_(.+)_(\d)d\b";
33
34fn parse_array_type(t: &str) -> Option<(String, usize)> {
35    let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
36    if let Some(captures) = re_array_type.captures(t) {
37        let dim: usize = captures[2].parse().unwrap();
38        let ftype = &captures[1];
39        Some((ftype.to_string(), dim))
40    } else {
41        None
42    }
43}
44fn auto_ctor(t: &str) -> String {
45    let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
46    if let Some((ftype, dim)) = parse_array_type(t) {
47        ctor_array_type(&ftype, dim)
48    } else {
49        to_opaque_type_name(t)
50    }
51}
52
53pub(crate) fn gen_entry_point(input: &str) -> (String, String, Vec<String>) {
54    let re_name = Regex::new(r"futhark_entry_(.+)\(").unwrap();
55    let re_arg_pairs =
56        Regex::new(r"(?m)\s*(?:const\s*)?(?:struct\s*)?([a-z0-9_]+)\s\**([a-z0-9]+),?\s?").unwrap();
57
58    let arg_pairs: Vec<(String, String)> = re_arg_pairs
59        .captures_iter(input)
60        .skip(2)
61        .map(|c| (c[1].to_owned(), c[2].to_owned()))
62        .collect();
63    let name = re_name.captures(input).unwrap()[1].to_owned();
64    let mut buffer = format!("pub fn {name}", name = name);
65
66    write!(&mut buffer, "(&mut self, ");
67    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
68        if argname.starts_with("in") {
69            let argtype_string = type_translation(argtype);
70            write!(
71                &mut buffer,
72                "{}: {}{}, ",
73                argname,
74                if argtype_string.starts_with("FutharkOpaque") {
75                    "&"
76                } else {
77                    ""
78                },
79                argtype_string
80            );
81        }
82    }
83    write!(&mut buffer, ") -> ");
84    let mut output_buffer = String::from("Result<(");
85    let mut output_counter = 0;
86    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
87        if argname.starts_with("out") {
88            if output_counter > 0 {
89                write!(&mut output_buffer, ", ");
90            }
91            output_counter += 1;
92            write!(&mut output_buffer, "{}", type_translation(argtype));
93        }
94    }
95    write!(&mut output_buffer, ")>");
96    writeln!(&mut buffer, "{}", output_buffer);
97
98    write!(
99        &mut buffer,
100        "{{\nlet ctx = self.ptr();\nunsafe{{\n_{name}(ctx, ",
101        name = name
102    );
103    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
104        if argname.starts_with("in") {
105            if argtype.starts_with("futhark") {
106                write!(&mut buffer, "{}.as_raw_mut(), ", argname);
107            } else {
108                write!(&mut buffer, "{}, ", argname);
109            }
110        }
111    }
112    write!(&mut buffer, ")\n}}}}\n");
113
114    // END OF FIRST PART
115    let mut buffer2 = String::new();
116    write!(
117        &mut buffer2,
118        "unsafe fn _{name}(ctx: *mut bindings::futhark_context, ",
119        name = name
120    );
121    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
122        if argname.starts_with("in") {
123            if argtype.starts_with("futhark") {
124                write!(&mut buffer2, "{}: *const bindings::{}, ", argname, argtype);
125            } else {
126                write!(&mut buffer2, "{}: {}, ", argname, type_translation(argtype));
127            }
128        }
129    }
130    writeln!(&mut buffer2, ") -> {} {{", output_buffer);
131    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
132        if argname.starts_with("out") {
133            if argtype.starts_with("futhark") {
134                writeln!(
135                    &mut buffer2,
136                    "let mut raw_{} = std::ptr::null_mut();",
137                    argname
138                );
139            } else {
140                writeln!(
141                    &mut buffer2,
142                    "let mut raw_{} = {}::default();",
143                    argname,
144                    type_translation(argtype)
145                );
146            }
147        }
148    }
149
150    write!(
151        &mut buffer2,
152        "\nif bindings::futhark_entry_{name}(ctx, ",
153        name = name
154    );
155    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
156        if argname.starts_with("out") {
157            write!(&mut buffer2, "&mut raw_{}, ", argname);
158        }
159    }
160    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
161        if argname.starts_with("in") {
162            write!(&mut buffer2, "{}, ", argname);
163        }
164    }
165    writeln!(
166        &mut buffer2,
167        ") != 0 {{
168return Err(FutharkError::new(ctx).into());}}"
169    );
170
171    let mut opaque_types = Vec::new();
172    // OUTPUT
173    let mut result_counter = 0;
174    write!(&mut buffer2, "Ok((");
175    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
176        if argname.starts_with("out") {
177            if parse_array_type(argtype).is_none() {
178                opaque_types.push(argtype.clone());
179            }
180            if result_counter > 0 {
181                write!(&mut buffer2, ", ");
182            }
183            result_counter += 1;
184            if argtype.starts_with("futhark") {
185                writeln!(
186                    &mut buffer2,
187                    "{}::from_ptr(ctx, raw_{})",
188                    auto_ctor(argtype),
189                    argname
190                );
191            } else {
192                writeln!(&mut buffer2, "raw_{}", argname);
193            }
194        }
195    }
196    write!(&mut buffer2, "))\n}}");
197
198    (buffer, buffer2, opaque_types)
199}
200fn to_opaque_type_name(s: &str) -> String {
201    let mut rust_opaque_type = s.to_camel_case();
202
203    if let Some(r) = rust_opaque_type.get_mut(0..1) {
204        r.make_ascii_uppercase();
205    }
206    rust_opaque_type
207}
208
209fn gen_opaque_type(opaque_type: &str) -> String {
210    let rust_opaque_type = to_opaque_type_name(opaque_type);
211    assert!(opaque_type.starts_with("futhark_"),);
212    let base_type = &opaque_type[8..];
213    let bindings = format!("bindings::{}", opaque_type);
214    format!(
215        include_str!("static/static_opaque_types.rs"),
216        opaque_type = rust_opaque_type,
217        futhark_type = bindings,
218        base_type = base_type
219    )
220}
221
222pub(crate) fn gen_entry_points(input: &Vec<String>) -> String {
223    let mut buffer = String::from(
224        r#"impl FutharkContext {
225"#,
226    );
227    let mut opaque_types = Vec::new();
228    let mut buffer2 = String::new();
229    for t in input {
230        let (a, b, otypes) = gen_entry_point(t);
231        opaque_types.extend(otypes);
232        writeln!(&mut buffer, "{}", a).expect("Write failed!");
233        writeln!(&mut buffer2, "{}", b).expect("Write failed!");
234    }
235
236    opaque_types.sort();
237    opaque_types.dedup();
238    for (i, opaque_type) in opaque_types.iter().enumerate() {
239        if opaque_type.starts_with("futhark_") {
240            writeln!(&mut buffer2, "{}", gen_opaque_type(opaque_type));
241        }
242    }
243
244    writeln!(&mut buffer, "}}").expect("Write failed!");
245    writeln!(&mut buffer, "{}", buffer2).expect("Write failed!");
246
247    buffer
248}