llam 0.1.0

Safe, Go-style Rust bindings for the LLAM runtime
use std::env;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::process::Command;

const DEFAULT_LLAM_VERSION: &str = "1.0.0";
const DEFAULT_BASE_URL: &str = "https://github.com/Feralthedogg/LLAM/releases/download";

fn main() {
    let manifest_dir =
        PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR missing"));
    let repo_root = manifest_dir
        .parent()
        .expect("llam must live under the LLAM-rs repository root")
        .to_path_buf();
    let out_dir = PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR missing"));

    print_rerun_directives();

    let lib_name = env::var("LLAM_SYS_LIB_NAME").unwrap_or_else(|_| "llam_runtime".to_string());
    let link_kind = env::var("LLAM_SYS_LINK_KIND").unwrap_or_else(|_| "static".to_string());
    let sdk = resolve_sdk(&repo_root, &out_dir);

    println!("cargo:root={}", sdk.prefix.display());
    println!("cargo:include={}", sdk.include_dir.display());
    println!("cargo:lib={}", sdk.lib_dir.display());
    println!("cargo:rustc-link-search=native={}", sdk.lib_dir.display());
    if let Some(bin_dir) = sdk.bin_dir.as_ref() {
        println!("cargo:rustc-link-search=native={}", bin_dir.display());
    }
    println!("cargo:rustc-link-lib={link_kind}={lib_name}");
    link_platform_libraries();
}

fn print_rerun_directives() {
    for key in [
        "LLAM_SYS_PREFIX",
        "LLAM_SYS_LIB_DIR",
        "LLAM_SYS_INCLUDE_DIR",
        "LLAM_SYS_LINK_KIND",
        "LLAM_SYS_LIB_NAME",
        "LLAM_SYS_NO_BUILD",
        "LLAM_SYS_NO_INSTALL",
        "LLAM_SYS_INSTALL_PREFIX",
        "LLAM_SYS_INSTALL_VERSION",
        "LLAM_SYS_INSTALL_TARGET",
        "LLAM_SYS_INSTALL_BASE_URL",
        "LLAM_SYS_INSTALL_SCRIPT",
        "LLAM_SYS_FORCE_INSTALL",
        "LLAM_SYS_ROOT",
        "LLAM_INSTALL_VERSION",
    ] {
        println!("cargo:rerun-if-env-changed={key}");
    }
}

struct Sdk {
    prefix: PathBuf,
    include_dir: PathBuf,
    lib_dir: PathBuf,
    bin_dir: Option<PathBuf>,
}

fn resolve_sdk(repo_root: &Path, out_dir: &Path) -> Sdk {
    if let Some(lib_dir) = env::var_os("LLAM_SYS_LIB_DIR").map(PathBuf::from) {
        let include_dir = env::var_os("LLAM_SYS_INCLUDE_DIR")
            .map(PathBuf::from)
            .or_else(|| {
                env::var_os("LLAM_SYS_PREFIX").map(|prefix| PathBuf::from(prefix).join("include"))
            })
            .unwrap_or_else(|| lib_dir.parent().unwrap_or(Path::new("")).join("include"));
        assert!(
            include_dir.join("llam/runtime.h").is_file(),
            "LLAM_SYS_INCLUDE_DIR or LLAM_SYS_PREFIX must point at an installed LLAM include directory"
        );
        return Sdk {
            prefix: lib_dir.parent().unwrap_or(Path::new("")).to_path_buf(),
            include_dir,
            lib_dir,
            bin_dir: None,
        };
    }

    if let Some(prefix) = env::var_os("LLAM_SYS_PREFIX").map(PathBuf::from) {
        assert_installed_sdk(&prefix);
        return installed_sdk(prefix);
    }

    assert!(
        env::var_os("LLAM_SYS_NO_INSTALL").is_none() && env::var_os("LLAM_SYS_NO_BUILD").is_none(),
        "LLAM_SYS_NO_INSTALL/LLAM_SYS_NO_BUILD was set, but no installed SDK was provided; set LLAM_SYS_PREFIX or LLAM_SYS_LIB_DIR"
    );

    let prefix = env::var_os("LLAM_SYS_INSTALL_PREFIX")
        .map(PathBuf::from)
        .unwrap_or_else(|| out_dir.join("llam-sdk"));

    if env::var_os("LLAM_SYS_FORCE_INSTALL").is_some() || !is_installed_sdk(&prefix) {
        install_sdk(repo_root, out_dir, &prefix);
    }
    assert_installed_sdk(&prefix);
    installed_sdk(prefix)
}

fn installed_sdk(prefix: PathBuf) -> Sdk {
    let include_dir = prefix.join("include");
    let lib_dir = prefix.join("lib");
    let bin_dir = prefix.join("bin");
    Sdk {
        prefix,
        include_dir,
        lib_dir,
        bin_dir: bin_dir.is_dir().then_some(bin_dir),
    }
}

fn is_installed_sdk(prefix: &Path) -> bool {
    prefix.join("include/llam/runtime.h").is_file() && prefix.join("lib").is_dir()
}

fn assert_installed_sdk(prefix: &Path) {
    assert!(
        is_installed_sdk(prefix),
        "LLAM SDK is not installed at {}; expected include/llam/runtime.h and lib/",
        prefix.display()
    );
}

fn install_sdk(repo_root: &Path, out_dir: &Path, prefix: &Path) {
    let version = env::var("LLAM_SYS_INSTALL_VERSION")
        .or_else(|_| env::var("LLAM_INSTALL_VERSION"))
        .unwrap_or_else(|_| DEFAULT_LLAM_VERSION.to_string());
    let base_url = env::var("LLAM_SYS_INSTALL_BASE_URL")
        .unwrap_or_else(|_| format!("{DEFAULT_BASE_URL}/{version}"));
    let target = env::var("LLAM_SYS_INSTALL_TARGET").unwrap_or_else(|_| default_install_target());
    let script = installer_script(repo_root, out_dir, &version);

    if target.starts_with("windows-") && !cfg!(windows) {
        panic!(
            "automatic Windows SDK installation is only supported on Windows hosts; set LLAM_SYS_PREFIX or LLAM_SYS_LIB_DIR for cross builds"
        );
    }

    if cfg!(windows) {
        let status = Command::new("powershell")
            .arg("-NoProfile")
            .arg("-ExecutionPolicy")
            .arg("Bypass")
            .arg("-File")
            .arg(&script)
            .arg("-Prefix")
            .arg(prefix)
            .arg("-Version")
            .arg(&version)
            .arg("-Target")
            .arg(&target)
            .arg("-BaseUrl")
            .arg(&base_url)
            .arg("-Force")
            .status()
            .expect("failed to run LLAM PowerShell installer");
        assert!(status.success(), "LLAM PowerShell installer failed");
    } else {
        let status = Command::new("sh")
            .arg(&script)
            .arg("--prefix")
            .arg(prefix)
            .arg("--version")
            .arg(&version)
            .arg("--target")
            .arg(&target)
            .arg("--base-url")
            .arg(&base_url)
            .arg("--force")
            .status()
            .expect("failed to run LLAM installer");
        assert!(status.success(), "LLAM installer failed");
    }
}

fn installer_script(repo_root: &Path, out_dir: &Path, version: &str) -> PathBuf {
    if let Some(path) = env::var_os("LLAM_SYS_INSTALL_SCRIPT").map(PathBuf::from) {
        if is_url(path.as_os_str()) {
            return download_installer(path.to_string_lossy().as_ref(), out_dir);
        }
        return path;
    }

    if let Some(root) = find_llam_root(repo_root) {
        let local = if cfg!(windows) {
            root.join("scripts/install.ps1")
        } else {
            root.join("scripts/install.sh")
        };
        if local.is_file() {
            println!("cargo:rerun-if-changed={}", local.display());
            return local;
        }
    }

    let url = if cfg!(windows) {
        format!("{DEFAULT_BASE_URL}/{version}/install.ps1")
    } else {
        format!("{DEFAULT_BASE_URL}/{version}/install.sh")
    };
    download_installer(&url, out_dir)
}

fn is_url(path: &OsStr) -> bool {
    let text = path.to_string_lossy();
    text.starts_with("https://") || text.starts_with("http://")
}

fn download_installer(url: &str, out_dir: &Path) -> PathBuf {
    let name = if cfg!(windows) {
        "install.ps1"
    } else {
        "install.sh"
    };
    let path = out_dir.join(name);
    download_file(url, &path);
    path
}

fn download_file(url: &str, out: &Path) {
    let status = if cfg!(windows) {
        Command::new("powershell")
            .arg("-NoProfile")
            .arg("-Command")
            .arg(format!(
                "Invoke-WebRequest -Uri '{}' -OutFile '{}'",
                escape_powershell(url),
                escape_powershell(&out.display().to_string())
            ))
            .status()
            .expect("failed to run PowerShell download")
    } else if command_exists("curl") {
        Command::new("curl")
            .arg("-fsSL")
            .arg(url)
            .arg("-o")
            .arg(out)
            .status()
            .expect("failed to run curl")
    } else {
        Command::new("wget")
            .arg("-O")
            .arg(out)
            .arg(url)
            .status()
            .expect("failed to run wget")
    };
    assert!(status.success(), "failed to download {url}");
}

fn command_exists(command: &str) -> bool {
    Command::new(command)
        .arg("--version")
        .stdout(std::process::Stdio::null())
        .stderr(std::process::Stdio::null())
        .status()
        .is_ok()
}

fn escape_powershell(value: &str) -> String {
    value.replace('\'', "''")
}

fn default_install_target() -> String {
    let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_else(|_| env::consts::OS.to_string());
    let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_else(|_| env::consts::ARCH.to_string());
    let os_name = match os.as_str() {
        "macos" => "macos",
        "linux" => "linux",
        "windows" => "windows",
        other => panic!("unsupported LLAM target OS for installer: {other}"),
    };
    let arch_name = match arch.as_str() {
        "x86_64" => "x86_64",
        "aarch64" => "aarch64",
        other => panic!("unsupported LLAM target arch for installer: {other}"),
    };
    format!("{os_name}-{arch_name}")
}

fn find_llam_root(repo_root: &Path) -> Option<PathBuf> {
    let parent = repo_root.parent();
    let candidates = [
        parent.map(|path| path.join("LLAM")),
        parent.map(|path| path.join("llam")),
    ];

    candidates
        .into_iter()
        .flatten()
        .find(|candidate| is_llam_root(candidate))
}

fn is_llam_root(path: &Path) -> bool {
    path.join("include/llam/runtime.h").is_file()
        && path.join("scripts/install.sh").is_file()
        && path.join("scripts/install.ps1").is_file()
}

fn link_platform_libraries() {
    match env::var("CARGO_CFG_TARGET_OS").as_deref() {
        Ok("linux") => {
            println!("cargo:rustc-link-lib=uring");
            println!("cargo:rustc-link-lib=m");
            println!("cargo:rustc-link-lib=pthread");
        }
        Ok("windows") => {
            println!("cargo:rustc-link-lib=ws2_32");
            println!("cargo:rustc-link-lib=mswsock");
            println!("cargo:rustc-link-lib=synchronization");
        }
        _ => {}
    }
}