genfut 0.5.0

Generate bindings for Futhark
Documentation
use std::fmt::Write;

use inflector::Inflector;
use regex::Regex;

fn type_translation(input: &str) -> String {
    if input.starts_with("futhark") {
        auto_ctor(input)
    } else {
        let mut buffer = String::new();
        if input.starts_with("int8") {
            write!(&mut buffer, "i8");
        } else if input.starts_with("int") {
            write!(&mut buffer, "i{}", &input[3..5]);
        } else if input.starts_with("uint8") {
            write!(&mut buffer, "u8");
        } else if input.starts_with("uint") {
            write!(&mut buffer, "u{}", &input[4..6]);
        } else if input.starts_with("float") {
            write!(&mut buffer, "f32");
        } else if input.starts_with("double") {
            write!(&mut buffer, "f64");
        }
        buffer
    }
}

fn ctor_array_type(t: &str, dim: usize) -> String {
    format!("Array_{}_{}d", t, dim)
}

const RE_ARRAY_TYPE_STR: &str = r"futhark_(.+)_(\d)d\b";

fn parse_array_type(t: &str) -> Option<(String, usize)> {
    let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
    if let Some(captures) = re_array_type.captures(t) {
        let dim: usize = captures[2].parse().unwrap();
        let ftype = &captures[1];
        Some((ftype.to_string(), dim))
    } else {
        None
    }
}
fn auto_ctor(t: &str) -> String {
    let re_array_type = Regex::new(RE_ARRAY_TYPE_STR).unwrap();
    if let Some((ftype, dim)) = parse_array_type(t) {
        ctor_array_type(&ftype, dim)
    } else {
        to_opaque_type_name(t)
    }
}

pub(crate) fn gen_entry_point(input: &str) -> (String, String, Vec<String>) {
    let re_name = Regex::new(r"futhark_entry_(.+)\(").unwrap();
    let re_arg_pairs =
        Regex::new(r"(?m)\s*(?:const\s*)?(?:struct\s*)?([a-z0-9_]+)\s\**([a-z0-9]+),?\s?").unwrap();

    let arg_pairs: Vec<(String, String)> = re_arg_pairs
        .captures_iter(input)
        .skip(2)
        .map(|c| (c[1].to_owned(), c[2].to_owned()))
        .collect();
    let name = re_name.captures(input).unwrap()[1].to_owned();
    let mut buffer = format!("pub fn {name}", name = name);

    write!(&mut buffer, "(&mut self, ");
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("in") {
            let argtype_string = type_translation(argtype);
            write!(
                &mut buffer,
                "{}: {}{}, ",
                argname,
                if argtype_string.starts_with("FutharkOpaque") {
                    "&"
                } else {
                    ""
                },
                argtype_string
            );
        }
    }
    write!(&mut buffer, ") -> ");
    let mut output_buffer = String::from("Result<(");
    let mut output_counter = 0;
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("out") {
            if output_counter > 0 {
                write!(&mut output_buffer, ", ");
            }
            output_counter += 1;
            write!(&mut output_buffer, "{}", type_translation(argtype));
        }
    }
    write!(&mut output_buffer, ")>");
    writeln!(&mut buffer, "{}", output_buffer);

    write!(
        &mut buffer,
        "{{\nlet ctx = self.ptr();\nunsafe{{\n_{name}(ctx, ",
        name = name
    );
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("in") {
            if argtype.starts_with("futhark") {
                write!(&mut buffer, "{}.as_raw_mut(), ", argname);
            } else {
                write!(&mut buffer, "{}, ", argname);
            }
        }
    }
    write!(&mut buffer, ")\n}}}}\n");

    // END OF FIRST PART
    let mut buffer2 = String::new();
    write!(
        &mut buffer2,
        "unsafe fn _{name}(ctx: *mut bindings::futhark_context, ",
        name = name
    );
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("in") {
            if argtype.starts_with("futhark") {
                write!(&mut buffer2, "{}: *const bindings::{}, ", argname, argtype);
            } else {
                write!(&mut buffer2, "{}: {}, ", argname, type_translation(argtype));
            }
        }
    }
    writeln!(&mut buffer2, ") -> {} {{", output_buffer);
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("out") {
            if argtype.starts_with("futhark") {
                writeln!(
                    &mut buffer2,
                    "let mut raw_{} = std::ptr::null_mut();",
                    argname
                );
            } else {
                writeln!(
                    &mut buffer2,
                    "let mut raw_{} = {}::default();",
                    argname,
                    type_translation(argtype)
                );
            }
        }
    }

    write!(
        &mut buffer2,
        "\nif bindings::futhark_entry_{name}(ctx, ",
        name = name
    );
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("out") {
            write!(&mut buffer2, "&mut raw_{}, ", argname);
        }
    }
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("in") {
            write!(&mut buffer2, "{}, ", argname);
        }
    }
    writeln!(
        &mut buffer2,
        ") != 0 {{
return Err(FutharkError::new(ctx).into());}}"
    );

    let mut opaque_types = Vec::new();
    // OUTPUT
    let mut result_counter = 0;
    write!(&mut buffer2, "Ok((");
    for (i, (argtype, argname)) in arg_pairs.iter().enumerate() {
        if argname.starts_with("out") {
            if parse_array_type(argtype).is_none() {
                opaque_types.push(argtype.clone());
            }
            if result_counter > 0 {
                write!(&mut buffer2, ", ");
            }
            result_counter += 1;
            if argtype.starts_with("futhark") {
                writeln!(
                    &mut buffer2,
                    "{}::from_ptr(ctx, raw_{})",
                    auto_ctor(argtype),
                    argname
                );
            } else {
                writeln!(&mut buffer2, "raw_{}", argname);
            }
        }
    }
    write!(&mut buffer2, "))\n}}");

    (buffer, buffer2, opaque_types)
}
fn to_opaque_type_name(s: &str) -> String {
    let mut rust_opaque_type = s.to_camel_case();

    if let Some(r) = rust_opaque_type.get_mut(0..1) {
        r.make_ascii_uppercase();
    }
    rust_opaque_type
}

fn gen_opaque_type(opaque_type: &str) -> String {
    let rust_opaque_type = to_opaque_type_name(opaque_type);
    assert!(opaque_type.starts_with("futhark_"),);
    let base_type = &opaque_type[8..];
    let bindings = format!("bindings::{}", opaque_type);
    format!(
        include_str!("static/static_opaque_types.rs"),
        opaque_type = rust_opaque_type,
        futhark_type = bindings,
        base_type = base_type
    )
}

pub(crate) fn gen_entry_points(input: &Vec<String>) -> String {
    let mut buffer = String::from(
        r#"impl FutharkContext {
"#,
    );
    let mut opaque_types = Vec::new();
    let mut buffer2 = String::new();
    for t in input {
        let (a, b, otypes) = gen_entry_point(t);
        opaque_types.extend(otypes);
        writeln!(&mut buffer, "{}", a).expect("Write failed!");
        writeln!(&mut buffer2, "{}", b).expect("Write failed!");
    }

    opaque_types.sort();
    opaque_types.dedup();
    for (i, opaque_type) in opaque_types.iter().enumerate() {
        if opaque_type.starts_with("futhark_") {
            writeln!(&mut buffer2, "{}", gen_opaque_type(opaque_type));
        }
    }

    writeln!(&mut buffer, "}}").expect("Write failed!");
    writeln!(&mut buffer, "{}", buffer2).expect("Write failed!");

    buffer
}