#![allow(unused)]
mod buildtime_downloader;
use buildtime_downloader::download_model;
struct KernelDirectories {
kernel_glob: &'static str,
rust_target: &'static str,
}
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs",
}];
fn main() {
println!("cargo::rerun-if-changed=build.rs");
#[cfg(feature = "cuda")]
{
use std::env;
use std::path::{Path, PathBuf};
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
for kdir in KERNEL_DIRS.iter() {
let safe_target = out_dir.join(
Path::new(kdir.rust_target)
.file_name()
.expect("Failed to get filename from rust_target"),
);
let bindings = cudaforge::KernelBuilder::new()
.source_glob(kdir.kernel_glob)
.build_ptx()
.expect("Failed to build ptx");
bindings
.write(safe_target)
.expect("Failed to write ptx bindings");
}
}
if let Some(model_rev) = core::option_env!("CANDLE_BUILDTIME_MODEL_REVISION") {
buildtime_downloader::download_model(model_rev).expect("Model download failed!");
}
}