hpt-cudakernels 0.0.13

A library implements cuda kernels for hpt
Documentation
use std::collections::HashMap;
use std::io::Read;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;

use regex::Regex;

fn main() {
    if !cfg!(feature = "cuda") {
        return;
    }
    let caps = compute_cap();
    let cu_files = find_cu_files(Path::new("src")).expect("find cu files");
    for cu_file in &cu_files {
        println!("cargo:rerun-if-changed={}", cu_file.display());
    }
    // create OUT_DIR if not exists
    let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR not set");
    let generated_constants = PathBuf::from(std::env::var("OUT_DIR").expect("OUT_DIR not set"))
        .join("generated_constants.rs");
    let out_dir = Path::new(&out_dir);
    if !out_dir.exists() {
        std::fs::create_dir_all(out_dir).expect("create OUT_DIR");
    }
    let mut functions_map = HashMap::new();
    for cap in &caps {
        functions_map.insert(cap, HashMap::<String, (usize,)>::new());
    }

    let mut buffer = String::new();

    for cu_file in cu_files {
        let obj_files = compile_cu(&cu_file, out_dir, &caps).expect("compile failed");
        let mut cap_map = phf_codegen::Map::new();
        for (idx, name) in obj_files.into_iter().enumerate() {
            let name_upper_case = name.to_uppercase();
            cap_map.entry(
                caps[idx],
                &format!(
                    "({}, &{}, &{})",
                    &name_upper_case,
                    format!("{}_REG_INFO", &name_upper_case),
                    format!("{}_FUNC_LIST", &name_upper_case)
                ),
            );
            let content = std::fs::read_to_string(out_dir.join(format!("{}.ptx", name)))
                .expect("read ptx file");
            let mut map = phf_codegen::Map::new();
            let reg_info = count_registers(&content);
            let mut func_list = Vec::new();
            for (name, (pred, b16, b32, b64)) in reg_info {
                func_list.push(format!("\"{}\"", name));
                map.entry(
                    name,
                    &format!(
                        "RegisterInfo {{ pred: {}, b16: {}, b32: {}, b64: {} }}",
                        pred, b16, b32, b64
                    ),
                );
            }
            buffer.push_str(&format!(
                "pub const {}_REG_INFO: phf::Map<&'static str, RegisterInfo> = {};\n",
                &name_upper_case,
                map.build()
            ));
            buffer.push_str(&format!(
                "pub const {}: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/{}.ptx\"));\n",
                &name_upper_case, name
            ));
            buffer.push_str(&format!(
                "pub const {}_FUNC_LIST: [&str; {}] = [{}];\n",
                &name_upper_case,
                func_list.len(),
                func_list.join(", ")
            ));
        }
        let file_stem = cu_file
            .file_stem()
            .and_then(|s| s.to_str())
            .ok_or_else(|| format!("Invalid file stem for {:?}", cu_file))
            .expect(format!("Invalid file stem for {:?}", cu_file).as_str());
        buffer.push_str(&format!(
            "pub const {}: phf::Map<usize, (&'static str, &'static phf::Map<&'static str, RegisterInfo>, &'static [&str])> = {};",
            file_stem.to_uppercase(),
            cap_map.build()
        ));
        if let Ok(mut file) = std::fs::File::open(&generated_constants) {
            let mut content = String::new();
            file.read_to_string(&mut content).unwrap();
            let mut file =
                std::fs::File::create(&generated_constants).expect("create generated.rs");
            file.write_all(buffer.as_bytes()).unwrap();
        } else {
            let mut file =
                std::fs::File::create(&generated_constants).expect("create generated.rs");
            file.write_all(buffer.as_bytes()).unwrap();
        }
    }
}

pub(crate) fn count_registers(ptx: &str) -> HashMap<String, (usize, usize, usize, usize)> {
    let mut reg_counts = HashMap::new();

    // 匹配函数名和寄存器声明
    let func_re = Regex::new(r"\.visible \.entry (\w+)\(").unwrap();
    let pred_re = Regex::new(r"\.reg \.pred\s+%p<(\d+)>").unwrap();
    let b16_re = Regex::new(r"\.reg \.b16\s+%rs<(\d+)>").unwrap();
    let b32_re = Regex::new(r"\.reg \.b32\s+%r<(\d+)>").unwrap();
    let b64_re = Regex::new(r"\.reg \.b64\s+%rd<(\d+)>").unwrap();
    let mut current_func = String::new();

    for line in ptx.lines() {
        if let Some(cap) = func_re.captures(line) {
            current_func = cap[1].to_string();
        } else if !current_func.is_empty() {
            if let Some(cap) = pred_re.captures(line) {
                let pred_count = cap[1].parse::<usize>().unwrap();
                reg_counts
                    .entry(current_func.clone())
                    .or_insert((0, 0, 0, 0))
                    .0 = pred_count;
            }
            if let Some(cap) = b16_re.captures(line) {
                let b16_count = cap[1].parse::<usize>().unwrap();
                reg_counts
                    .entry(current_func.clone())
                    .or_insert((0, 0, 0, 0))
                    .1 = b16_count;
            }
            if let Some(cap) = b32_re.captures(line) {
                let b32_count = cap[1].parse::<usize>().unwrap();
                reg_counts
                    .entry(current_func.clone())
                    .or_insert((0, 0, 0, 0))
                    .2 = b32_count;
            }
            if let Some(cap) = b64_re.captures(line) {
                let b64_count = cap[1].parse::<usize>().unwrap();
                reg_counts
                    .entry(current_func.clone())
                    .or_insert((0, 0, 0, 0))
                    .3 = b64_count;
            }
        }
    }

    reg_counts
}

fn find_cu_files(dir: &Path) -> Result<Vec<PathBuf>, std::io::Error> {
    let mut cu_files = Vec::new();
    for entry in std::fs::read_dir(dir)? {
        let entry = entry?;
        let path = entry.path();
        if path.extension().and_then(|s| s.to_str()) == Some("cu") {
            cu_files.push(path);
        }
    }
    Ok(cu_files)
}

fn compile_cu(cu_file: &Path, out_dir: &Path, caps: &[u32]) -> Result<Vec<String>, String> {
    let mut obj_files = Vec::new();
    for cap in caps {
        let file_stem = cu_file
            .file_stem()
            .and_then(|s| s.to_str())
            .ok_or_else(|| format!("Invalid file stem for {:?}", cu_file))?;

        let name = format!("{}_{}", file_stem, cap);
        let obj_file = out_dir.join(format!("{}.ptx", name));

        let mut cmd = Command::new("nvcc");
        cmd.arg("-ptx")
            .arg("-O3")
            .arg("-allow-unsupported-compiler")
            .arg(cu_file.to_str().unwrap())
            .arg("-o")
            .arg(obj_file.to_str().unwrap())
            .arg(format!("-arch=sm_{}", cap));

        let status = cmd
            .status()
            .map_err(|e| format!("Failed to execute nvcc: {}", e))?;

        if !status.success() {
            return Err(format!(
                "nvcc failed to compile {:?} with status: {}",
                cu_file, status
            ));
        }
        obj_files.push(name);
    }

    Ok(obj_files)
}

fn compute_cap() -> Vec<u32> {
    let out = std::process::Command::new("nvidia-smi")
        .arg("--query-gpu=compute_cap")
        .arg("--format=csv")
        .output()
        .expect("cannot find `nvidia-smi` in PATH.");
    let out = std::str::from_utf8(&out.stdout).expect("stdout is not a utf8 string");
    let mut lines = out.lines();
    assert_eq!(lines.next().expect("missing line in stdout"), "compute_cap");
    let caps = lines
        .into_iter()
        .map(|s| s.replace('.', "").parse::<u32>().unwrap())
        .collect::<Vec<u32>>();
    caps
}