lufa-rs 0.1.1

Rust bindings and utility macros for the LUFA library.
#![feature(trim_prefix_suffix, string_remove_matches)]

use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::{env, fs::OpenOptions};

use bindgen::callbacks::ParseCallbacks;
use bindgen::RustTarget;
use heck::{ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
use quote::{format_ident, quote};

const BLOCKLIST_FN: &[&str] = &["CALLBACK.*"];

struct Config {
    config_path: Option<String>,
    mmcu: String,
    cpu_frequency: String,
    usb_frequency: String,
    arch: String,
    board: String,
}

impl Config {
    fn get_config() -> Self {
        println!("cargo::rerun-if-env-changed=LUFA_CONFIG_PATH");
        println!("cargo::rerun-if-env-changed=MMCU");
        println!("cargo::rerun-if-env-changed=F_CPU");
        println!("cargo::rerun-if-env-changed=F_USB");
        println!("cargo::rerun-if-env-changed=ARCH");
        println!("cargo::rerun-if-env-changed=BOARD");
        let config_path = std::env::var("LUFA_CONFIG_PATH").ok();
        let mmcu = std::env::var("MMCU").expect("MMCU environment variable must be set !");

        let cpu_frequency =
            std::env::var("F_CPU").expect("F_CPU environment variable must be set !");
        // Default to F_CPU
        let binding = std::env::var("F_USB");
        let usb_frequency = binding.unwrap_or("F_CPU".to_string());

        let arch = std::env::var("ARCH").expect("ARCH environment variable must be set !");
        let board = std::env::var("BOARD").expect("BOARD environment variable must be set !");
        Self {
            config_path,
            mmcu,
            cpu_frequency,
            usb_frequency,
            arch,
            board,
        }
    }
}

pub fn main() {
    // Build LUFA
    let mut cc_builder = cc::Build::new();
    let mut bindings_builder = bindgen::Builder::default();

    let config = Config::get_config();

    // User variables
    cc_builder
        .flag(format!("-mmcu={}", config.mmcu))
        .define("F_CPU", config.cpu_frequency.as_str())
        .define("F_USB", config.usb_frequency.as_str())
        .define("ARCH", config.arch.as_str())
        .define("BOARD", config.board.as_str());

    bindings_builder = bindings_builder
        .clang_arg(format!("-mmcu={}", config.mmcu))
        .clang_arg(format!("-DF_CPU={}", config.cpu_frequency))
        .clang_arg(format!("-DF_USB={}", config.usb_frequency))
        .clang_arg(format!("-DARCH={}", config.arch))
        .clang_arg(format!("-DBOARD={}", config.board));

    if let Some(config_path) = config.config_path {
        if !Path::new(&config_path).exists() {
            panic!("LUFA_CONFIG_PATH \"{config_path}\" is not a valid path !");
        }
        if !Path::new(&config_path).join("LUFAConfig.h").exists() {
            panic!("LUFA_CONFIG_PATH \"{config_path}\" doesn't contains a LUFAConfig.h file !");
        }
        cc_builder
            .include(&config_path)
            .define("USE_LUFA_CONFIG_HEADER", None);
        bindings_builder = bindings_builder
            .clang_arg(format!("-I{config_path}"))
            .header("LUFAConfig.h");
    }

    cc_builder.include("lufa");
    cc_builder.include(".");

    cc_builder
        .compiler("avr-gcc")
        .target("avr-none")
        .flag("-Wno-discarded-qualifiers") // Avoid warning in the extern.c file generated by bindgen to wrap static functions
        .flag("-Os")
        .flag("-fshort-enums")
        .flag("-fno-inline-small-functions")
        .flag("-fno-strict-aliasing")
        .flag("-funsigned-char")
        .flag("-funsigned-bitfields")
        .flag("-ffunction-sections")
        .flag("-mrelax")
        .flag("-fno-jump-tables");

    // More files needed in some use cases ?
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/Device_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/Endpoint_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/EndpointStream_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/Host_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/PipeStream_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/Pipe_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/USBController_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/AVR8/USBInterrupt_AVR8.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/ConfigDescriptors.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/DeviceStandardReq.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/Events.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/HostStandardReq.c");
    cc_builder.file("lufa/LUFA/Drivers/USB/Core/USBTask.c");

    // cc_builder.object(PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("extern.o"));
    cc_builder.file(std::env::temp_dir().join("bindgen").join("extern.c"));

    // Generate bindings
    bindings_builder = bindings_builder
        .header("lufa/LUFA/Drivers/USB/USB.h")
        .clang_arg("--target=avr-none")
        .blocklist_type("size_t")
        .allowlist_file(".*lufa.*")
        .derive_default(true)
        .derive_eq(true)
        .derive_ord(true)
        .wrap_static_fns(true)
        .formatter(bindgen::Formatter::Rustfmt)
        .enable_function_attribute_detection()
        .disable_name_namespacing()
        .layout_tests(false)
        .default_enum_style(bindgen::EnumVariation::Rust {
            non_exhaustive: true,
        })
        .rust_target(RustTarget::nightly())
        .parse_callbacks(Box::new(RenameToRust {}))
        .use_core();

    for func in BLOCKLIST_FN {
        bindings_builder = bindings_builder.blocklist_function(func)
    }

    let bindings = bindings_builder
        .generate()
        .expect("Unable to generate bindings");

    let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());

    // Parse bindings with syn to apply some renaming
    let to_syn = bindings.to_string();
    let mut ast = syn::parse_str::<syn::File>(&to_syn).unwrap();

    convert_names(&mut ast);

    // Format the resulting bindings
    let not_formatted = quote! {#ast}.to_string();
    let formatted = rustfmt_generated_string(not_formatted).unwrap();

    // Output the result in a file included by the library
    let mut file = OpenOptions::new()
        .write(true)
        .truncate(true)
        .create(true)
        .open(out_path.join("bindings.rs"))
        .unwrap();

    file.write_all(formatted.as_bytes()).unwrap();

    cc_builder.compile("lufa");

    // Link to the library
    println!("cargo::rustc-link-lib=static=lufa");
}

/// Struct used to define rust bindgen callbacks to rename types and enums
#[derive(Debug, Clone)]
struct RenameToRust {}

impl ParseCallbacks for RenameToRust {
    fn item_name(&self, item_info: bindgen::callbacks::ItemInfo) -> Option<String> {
        match item_info.kind {
            bindgen::callbacks::ItemKind::Module => None,
            bindgen::callbacks::ItemKind::Type => {
                let name = item_info.name.trim_suffix("_t");
                Some(name.to_upper_camel_case())
            }
            bindgen::callbacks::ItemKind::Function => None,
            bindgen::callbacks::ItemKind::Var => Some(item_info.name.to_shouty_snake_case()),
            _ => unreachable!(),
        }
    }
    fn enum_variant_name(
        &self,
        _enum_name: Option<&str>,
        original_variant_name: &str,
        _variant_value: bindgen::callbacks::EnumVariantValue,
    ) -> Option<String> {
        let mut new_name = original_variant_name.to_upper_camel_case();
        new_name.remove_matches("Dtype");
        Some(new_name)
    }
    fn int_macro(&self, name: &str, value: i64) -> Option<bindgen::callbacks::IntKind> {
        if name.starts_with("HID_KEYBOARD") {
            assert!(u8::try_from(value).is_ok());
            Some(bindgen::callbacks::IntKind::U8)
        } else {
            None
        }
    }
}

/// Rename ident which can't be edited with bindgen callbacks, like structures fields.
pub fn convert_names(ast: &mut syn::File) {
    for item in &mut ast.items {
        if let syn::Item::Struct(structure) = item {
            for field in &mut structure.fields {
                if let Some(ident) = field.ident.as_mut() {
                    if ident.to_string().contains("_bitfield") {
                        continue;
                    }
                    let mut new_name = ident.to_string().to_snake_case();
                    if new_name == "type" {
                        new_name = "r#type".to_string();
                    }
                    *ident = format_ident!("{new_name}");
                };
            }
        }
    }
}

/// Gets the rustfmt path to rustfmt the generated bindings.
fn rustfmt_path() -> io::Result<PathBuf> {
    if let Ok(rustfmt) = env::var("RUSTFMT") {
        return Ok(rustfmt.into());
    }
    match which::which("rustfmt") {
        Ok(p) => Ok(p),
        Err(e) => Err(io::Error::other(format!("{e}"))),
    }
}

/// Checks if rustfmt_bindings is set and runs rustfmt on the string
fn rustfmt_generated_string(source: String) -> io::Result<String> {
    let rustfmt = rustfmt_path()?;
    let mut cmd = Command::new(&*rustfmt);

    cmd.stdin(Stdio::piped()).stdout(Stdio::piped());

    let mut child = cmd.spawn()?;
    let mut child_stdin = child.stdin.take().unwrap();
    let mut child_stdout = child.stdout.take().unwrap();

    // Write to stdin in a new thread, so that we can read from stdout on this
    // thread. This keeps the child from blocking on writing to its stdout which
    // might block us from writing to its stdin.
    let stdin_handle = ::std::thread::spawn(move || {
        let _ = child_stdin.write_all(source.as_bytes());
        source
    });

    let mut output = vec![];
    io::copy(&mut child_stdout, &mut output)?;

    let status = child.wait()?;
    let source = stdin_handle.join().expect(
        "The thread writing to rustfmt's stdin doesn't do \
             anything that could panic",
    );

    match String::from_utf8(output) {
        Ok(bindings) => match status.code() {
            Some(0) => Ok(bindings),
            Some(2) => Err(io::Error::other("Rustfmt parsing errors.".to_string())),
            Some(3) => {
                println!("Rustfmt could not format some lines.");
                Ok(bindings)
            }
            _ => Err(io::Error::other("Internal rustfmt error".to_string())),
        },
        _ => Ok(source),
    }
}