use std::{
collections::{HashMap, HashSet},
path::Path,
};
use anyhow::{Context as _, Error, anyhow};
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct Config {
pub sdk: Vec<String>,
pub include: Vec<String>,
pub inject: Vec<String>,
#[serde(default)]
pub database: Database,
#[serde(default)]
pub custom_types: Vec<Extended>,
}
#[derive(Debug, Default, Deserialize)]
pub struct Database {
#[serde(default)]
pub ignore_functions: Vec<String>,
#[serde(default)]
pub ignore_interfaces: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct Extended {
pub id: u8,
pub name: String,
pub matches: Vec<String>,
}
pub fn load(path: &Path) -> Result<Config, Error> {
let config = ::config::Config::builder()
.add_source(::config::File::from(path))
.build()
.with_context(|| format!("failed to read config: {}", path.display()))?
.try_deserialize::<Config>()
.with_context(|| format!("failed to parse config: {}", path.display()))?;
validate(&config)?;
Ok(config)
}
fn validate(config: &Config) -> Result<(), Error> {
for name in &config.database.ignore_functions {
if name.starts_with('*') && name.ends_with('*') && name.len() > 1 {
return Err(anyhow!(
"ignore_functions: pattern `{name}` cannot have both leading and trailing `*`"
));
}
}
for name in &config.database.ignore_interfaces {
if name.starts_with('*') && name.ends_with('*') && name.len() > 1 {
return Err(anyhow!(
"ignore_interfaces: pattern `{name}` cannot have both leading and trailing `*`"
));
}
}
let mut seen_ids = HashSet::new();
let mut seen_matches = HashMap::new();
for entry in &config.custom_types {
if entry.id == 0 {
return Err(anyhow!("custom_types[{}]: id 0 is reserved", entry.name));
}
if !seen_ids.insert(entry.id) {
return Err(anyhow!("custom_types: duplicate id {}", entry.id));
}
for item in &entry.matches {
if let Some(prev) = seen_matches.insert(item.as_str(), entry.name.as_str()) {
return Err(anyhow!(
"custom_types: `{}` listed under both `{}` and `{}`",
item,
prev,
entry.name
));
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::io::Write as _;
use super::*;
fn write_temp(name: &str, contents: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::Builder::new()
.prefix(name)
.suffix(".yaml")
.tempfile()
.unwrap();
write!(f, "{contents}").unwrap();
f.flush().unwrap();
f
}
#[test]
fn loads_minimal_config() {
let f = write_temp(
"sigmd-new-test-",
r#"
sdk:
- "headers/**/*.h"
include:
- "include"
inject:
- "include/inject.h"
database:
ignore_functions:
- RtlFreeAnsiString
- "*OpenKey"
"#,
);
let cfg = load(f.path()).unwrap();
assert_eq!(cfg.sdk[0], "headers/**/*.h");
assert_eq!(
cfg.database.ignore_functions,
vec!["RtlFreeAnsiString", "*OpenKey"]
);
}
#[test]
fn rejects_double_wildcard() {
let f = write_temp(
"sigmd-new-test-",
r#"
sdk: []
include: []
inject: []
database:
ignore_functions:
- "*foo*"
"#,
);
let err = load(f.path()).unwrap_err();
assert!(err.to_string().contains("leading and trailing `*`"));
}
}