sigmd 0.1.0

Windows API signature metadata
Documentation
//! Configuration loader.

use std::{
    collections::{HashMap, HashSet},
    path::Path,
};

use anyhow::{Context as _, Error, anyhow};
use serde::Deserialize;

/// Top-level configuration.
#[derive(Debug, Deserialize)]
pub struct Config {
    /// Glob patterns expanding to SDK header paths.
    pub sdk: Vec<String>,

    /// Directories passed to clang via `-isystem`.
    pub include: Vec<String>,

    /// Header files prepended to every TU via `-include`.
    pub inject: Vec<String>,

    /// Database filter settings.
    #[serde(default)]
    pub database: Database,

    #[serde(default)]
    pub custom_types: Vec<Extended>,
}

/// Database ignore-list configuration.
#[derive(Debug, Default, Deserialize)]
pub struct Database {
    /// Names matched verbatim, prefix `Foo*`, or suffix `*Foo`.
    /// Both leading and trailing `*` is invalid.
    #[serde(default)]
    pub ignore_functions: Vec<String>,

    /// Same syntax as `ignore_functions`, applied to interface names.
    #[serde(default)]
    pub ignore_interfaces: Vec<String>,
}

#[derive(Debug, Deserialize)]
pub struct Extended {
    pub id: u8,
    pub name: String,
    pub matches: Vec<String>,
}

/// Loads and validates a config file at the given path.
pub fn load(path: &Path) -> Result<Config, Error> {
    let config = ::config::Config::builder()
        .add_source(::config::File::from(path))
        .build()
        .with_context(|| format!("failed to read config: {}", path.display()))?
        .try_deserialize::<Config>()
        .with_context(|| format!("failed to parse config: {}", path.display()))?;

    validate(&config)?;
    Ok(config)
}

/// Validates ignore-list patterns.
fn validate(config: &Config) -> Result<(), Error> {
    for name in &config.database.ignore_functions {
        if name.starts_with('*') && name.ends_with('*') && name.len() > 1 {
            return Err(anyhow!(
                "ignore_functions: pattern `{name}` cannot have both leading and trailing `*`"
            ));
        }
    }

    for name in &config.database.ignore_interfaces {
        if name.starts_with('*') && name.ends_with('*') && name.len() > 1 {
            return Err(anyhow!(
                "ignore_interfaces: pattern `{name}` cannot have both leading and trailing `*`"
            ));
        }
    }

    let mut seen_ids = HashSet::new();
    let mut seen_matches = HashMap::new();

    for entry in &config.custom_types {
        if entry.id == 0 {
            return Err(anyhow!("custom_types[{}]: id 0 is reserved", entry.name));
        }

        if !seen_ids.insert(entry.id) {
            return Err(anyhow!("custom_types: duplicate id {}", entry.id));
        }

        for item in &entry.matches {
            if let Some(prev) = seen_matches.insert(item.as_str(), entry.name.as_str()) {
                return Err(anyhow!(
                    "custom_types: `{}` listed under both `{}` and `{}`",
                    item,
                    prev,
                    entry.name
                ));
            }
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use std::io::Write as _;

    use super::*;

    fn write_temp(name: &str, contents: &str) -> tempfile::NamedTempFile {
        let mut f = tempfile::Builder::new()
            .prefix(name)
            .suffix(".yaml")
            .tempfile()
            .unwrap();
        write!(f, "{contents}").unwrap();
        f.flush().unwrap();
        f
    }

    #[test]
    fn loads_minimal_config() {
        let f = write_temp(
            "sigmd-new-test-",
            r#"
sdk:
  - "headers/**/*.h"
include:
  - "include"
inject:
  - "include/inject.h"
database:
  ignore_functions:
    - RtlFreeAnsiString
    - "*OpenKey"
"#,
        );
        let cfg = load(f.path()).unwrap();
        assert_eq!(cfg.sdk[0], "headers/**/*.h");
        assert_eq!(
            cfg.database.ignore_functions,
            vec!["RtlFreeAnsiString", "*OpenKey"]
        );
    }

    #[test]
    fn rejects_double_wildcard() {
        let f = write_temp(
            "sigmd-new-test-",
            r#"
sdk: []
include: []
inject: []
database:
  ignore_functions:
    - "*foo*"
"#,
        );
        let err = load(f.path()).unwrap_err();
        assert!(err.to_string().contains("leading and trailing `*`"));
    }
}