sigmd 0.1.0

Windows API signature metadata
Documentation
use std::{
    collections::HashMap,
    fs::File,
    io::{BufWriter, Write as _},
    path::{Path, PathBuf},
    sync::Arc,
    time::Instant,
};

use anyhow::{Context as _, Error, anyhow};
use clap::Args;
use glob::MatchOptions;
use indicatif::{ProgressBar, ProgressStyle};
use rayon::iter::{IntoParallelRefIterator as _, ParallelIterator as _};
use sigmd::{
    Architecture,
    model::{Database, Metadata},
};

use super::Arch;
use crate::cli::{
    compiler::{Compiler, ScoredFunction, ScoredInterface, ScoredMetadata},
    config::Config,
};

#[derive(Args, Debug)]
pub struct BuildArgs {
    /// Architectures to include in the bundle. Repeat to select more
    /// than one.
    ///
    /// Default: both x86 and x64.
    #[arg(short, long, default_values_t = [Arch::X86, Arch::X64])]
    arch: Vec<Arch>,

    /// Output path for the binary archive.
    #[arg(short, long, default_value = "metadata.bin")]
    output: PathBuf,

    /// Optional path for the human-readable JSON dump.
    #[arg(short, long)]
    json: Option<PathBuf>,
}

struct MetadataBuilder {
    functions: HashMap<String, ScoredFunction>,
    interfaces: HashMap<String, ScoredInterface>,
}

impl MetadataBuilder {
    fn new() -> Self {
        Self {
            functions: HashMap::new(),
            interfaces: HashMap::new(),
        }
    }

    fn merge(&mut self, scored: ScoredMetadata) {
        for scored_function in scored.functions {
            self.insert_function(scored_function);
        }

        for scored_interface in scored.interfaces {
            self.insert_interface(scored_interface);
        }
    }

    fn insert_function(&mut self, new: ScoredFunction) {
        if let Some(old) = self.functions.get(&new.function.name)
            && new.score <= old.score
        {
            return;
        }

        self.functions.insert(new.function.name.clone(), new);
    }

    fn insert_interface(&mut self, new: ScoredInterface) {
        if let Some(old) = self.interfaces.get(&new.interface.name)
            && new.score <= old.score
        {
            return;
        }

        self.interfaces.insert(new.interface.name.clone(), new);
    }

    fn apply_ignore(&mut self, db: &crate::cli::config::Database) {
        fn __ignore<V>(map: &mut HashMap<String, V>, pattern: &str) {
            if let Some(suffix) = pattern.strip_prefix('*') {
                map.retain(|key, _| !key.ends_with(suffix));
            }
            else if let Some(prefix) = pattern.strip_suffix('*') {
                map.retain(|key, _| !key.starts_with(prefix));
            }
            else {
                map.remove(pattern);
            }
        }

        for pattern in &db.ignore_functions {
            __ignore(&mut self.functions, pattern);
        }

        for pattern in &db.ignore_interfaces {
            __ignore(&mut self.interfaces, pattern);
        }
    }

    fn build(self) -> Metadata {
        Metadata::builder()
            .functions(self.functions.into_values().map(|item| item.function))
            .interfaces(self.interfaces.into_values().map(|item| item.interface))
            .build()
    }
}

struct DatabaseBuilder {
    x86: MetadataBuilder,
    x64: MetadataBuilder,
}

impl DatabaseBuilder {
    fn new() -> Self {
        Self {
            x86: MetadataBuilder::new(),
            x64: MetadataBuilder::new(),
        }
    }

    fn bucket(&mut self, arch: Architecture) -> &mut MetadataBuilder {
        match arch {
            Architecture::X86 => &mut self.x86,
            Architecture::X64 => &mut self.x64,
        }
    }

    fn build(self, db: &crate::cli::config::Database) -> Database {
        let mut x86 = self.x86;
        x86.apply_ignore(db);

        let mut x64 = self.x64;
        x64.apply_ignore(db);

        Database {
            x86: x86.build(),
            x64: x64.build(),
        }
    }
}

fn expand_glob(pattern: impl AsRef<str>) -> Vec<PathBuf> {
    let pattern = pattern.as_ref();
    let opts = MatchOptions {
        case_sensitive: false,
        require_literal_separator: false,
        require_literal_leading_dot: false,
    };

    match glob::glob_with(pattern, opts) {
        Ok(iter) => iter.filter_map(Result::ok).collect(),
        Err(err) => {
            tracing::error!(%err, pattern, "glob expansion failed");
            Vec::new()
        }
    }
}

fn write_binary(database: &Database, path: &Path) -> Result<(), Error> {
    let bytes =
        rkyv::to_bytes::<rkyv::rancor::Error>(database).context("rkyv serialization failed")?;

    let file = File::create(path).with_context(|| format!("create {}", path.display()))?;
    let mut file = BufWriter::new(file);

    file.write_all(&bytes)
        .with_context(|| format!("write {}", path.display()))?;
    file.flush()
        .with_context(|| format!("flush {}", path.display()))?;

    Ok(())
}

fn write_json(database: &Database, path: &Path) -> Result<(), Error> {
    let file = File::create(path).with_context(|| format!("create {}", path.display()))?;
    let mut file = BufWriter::new(file);

    serde_json::to_writer_pretty(&mut file, database)
        .with_context(|| format!("json serialize to {}", path.display()))?;
    file.flush()
        .with_context(|| format!("flush {}", path.display()))?;

    Ok(())
}

pub fn run(config: Config, args: BuildArgs) -> Result<(), Error> {
    let header_paths = config.sdk.iter().flat_map(expand_glob).collect::<Vec<_>>();

    if header_paths.is_empty() {
        return Err(anyhow!("config.sdk did not match any files"));
    }

    tracing::info!(count = header_paths.len(), "expanded SDK headers");

    let custom_types = Arc::new(
        config
            .custom_types
            .iter()
            .flat_map(|entry| {
                entry
                    .matches
                    .iter()
                    .map(move |item| (item.clone(), entry.id))
            })
            .collect::<HashMap<String, u8>>(),
    );

    tracing::info!(count = custom_types.len(), "loaded custom types");

    let mut builder = DatabaseBuilder::new();

    let total = (header_paths.len() * args.arch.len()) as u64;
    let pb = ProgressBar::new(total);
    pb.set_style(
        ProgressStyle::with_template(
            "{spinner} [{elapsed_precise}] [{bar:45.cyan/blue}] {msg} {pos}/{len} ({eta})",
        )
        .unwrap(),
    );

    let started = Instant::now();

    for arch in args.arch {
        tracing::info!(%arch, "starting group");
        pb.set_message(format!("[{arch:>3}]"));

        let scoreds = header_paths
            .par_iter()
            .map_init(
                || Compiler::new().expect("clang compiler init"),
                |compiler, path| {
                    pb.println(format!("[{arch:>3}] compiling {}", path.display()));

                    let custom_types = custom_types.clone();

                    let result = compiler
                        .translation_unit()
                        .path(path)
                        .architecture(arch.into())
                        .include(config.include.clone())
                        .inject(config.inject.clone())
                        .custom_type_fn(move |name| custom_types.get(name).copied())
                        .compile();

                    pb.inc(1);

                    match result {
                        Ok(scored) => Some(scored),
                        Err(err) => {
                            tracing::error!(%err, ?path, "parse failed");
                            None
                        }
                    }
                },
            )
            .flatten()
            .collect::<Vec<ScoredMetadata>>();

        for scored in scoreds {
            builder.bucket(arch.into()).merge(scored);
        }
    }

    pb.finish_and_clear();

    let database = builder.build(&config.database);

    tracing::info!(
        "x86: {} functions, {} interfaces",
        database.x86.functions.len(),
        database.x86.interfaces.len()
    );

    tracing::info!(
        "x64: {} functions, {} interfaces",
        database.x64.functions.len(),
        database.x64.interfaces.len()
    );

    write_binary(&database, &args.output).context("write binary metadata")?;
    tracing::info!("wrote {}", args.output.display());

    if let Some(json) = args.json.as_ref() {
        write_json(&database, json).context("write JSON metadata")?;
        tracing::info!("wrote {}", json.display());
    }

    tracing::info!("done in {:?}", started.elapsed());
    Ok(())
}