use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use crate::ConfigError;
pub const SCHEMA_VERSION: &str = "marque-corpus-override-1";
#[derive(Debug, Clone, Default, PartialEq)]
pub struct CorpusOverride {
pub token_overrides: BTreeMap<String, f32>,
pub template_overrides: BTreeMap<String, f32>,
pub strict_context_overrides: StrictContextOverrides,
}
impl CorpusOverride {
#[inline]
pub fn is_empty(&self) -> bool {
self.token_overrides.is_empty()
&& self.template_overrides.is_empty()
&& self.strict_context_overrides.is_empty()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct StrictContextOverrides {
pub confidential_floor: Option<f32>,
pub secret_floor: Option<f32>,
pub top_secret_floor: Option<f32>,
}
impl StrictContextOverrides {
#[inline]
fn is_empty(&self) -> bool {
self.confidential_floor.is_none()
&& self.secret_floor.is_none()
&& self.top_secret_floor.is_none()
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
struct OverrideFile {
schema_version: String,
#[serde(default)]
token_overrides: BTreeMap<String, TokenOverrideEntry>,
#[serde(default)]
template_overrides: BTreeMap<String, TemplateOverrideEntry>,
#[serde(default)]
strict_context_overrides: Option<StrictContextOverridesFile>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
struct TokenOverrideEntry {
log_prior: f64,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
struct TemplateOverrideEntry {
log_prior: f64,
}
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(deny_unknown_fields)]
struct StrictContextOverridesFile {
#[serde(default)]
confidential_floor: Option<f64>,
#[serde(default)]
secret_floor: Option<f64>,
#[serde(default)]
top_secret_floor: Option<f64>,
}
pub fn load_corpus_override(path: &Path) -> Result<CorpusOverride, ConfigError> {
let raw = std::fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
path: path.to_path_buf(),
source: e,
})?;
parse_corpus_override(&raw, path.to_path_buf())
}
pub fn parse_corpus_override(
raw: &str,
source_path: PathBuf,
) -> Result<CorpusOverride, ConfigError> {
let file: OverrideFile =
serde_json::from_str(raw).map_err(|e| ConfigError::CorpusOverrideParse {
path: source_path.clone(),
reason: e.to_string(),
})?;
if file.schema_version != SCHEMA_VERSION {
return Err(ConfigError::CorpusOverrideSchemaMismatch {
path: source_path,
file_version: file.schema_version,
expected: SCHEMA_VERSION,
});
}
let mut token_overrides = BTreeMap::new();
for (token, entry) in file.token_overrides {
let lp = entry.log_prior as f32;
validate_log_prior(&source_path, "token_overrides", &token, lp)?;
token_overrides.insert(token, lp);
}
let mut template_overrides = BTreeMap::new();
for (template, entry) in file.template_overrides {
let lp = entry.log_prior as f32;
validate_log_prior(&source_path, "template_overrides", &template, lp)?;
template_overrides.insert(template, lp);
}
let strict_context_overrides = match file.strict_context_overrides {
None => StrictContextOverrides::default(),
Some(s) => {
let mut out = StrictContextOverrides::default();
if let Some(v) = s.confidential_floor {
let v32 = v as f32;
validate_floor(&source_path, "confidential_floor", v32)?;
out.confidential_floor = Some(v32);
}
if let Some(v) = s.secret_floor {
let v32 = v as f32;
validate_floor(&source_path, "secret_floor", v32)?;
out.secret_floor = Some(v32);
}
if let Some(v) = s.top_secret_floor {
let v32 = v as f32;
validate_floor(&source_path, "top_secret_floor", v32)?;
out.top_secret_floor = Some(v32);
}
out
}
};
Ok(CorpusOverride {
token_overrides,
template_overrides,
strict_context_overrides,
})
}
fn validate_log_prior(
path: &Path,
section: &'static str,
key: &str,
value: f32,
) -> Result<(), ConfigError> {
if !value.is_finite() {
return Err(ConfigError::CorpusOverrideInvalidValue {
path: path.to_path_buf(),
section,
key: key.to_owned(),
reason: "log_prior must be finite — `-Inf` (`log(0)`) is rejected as a regenerator footgun; \
express 'very rare' with a finite very-negative number (e.g., -50.0) instead",
});
}
if value > 1e-3 {
return Err(ConfigError::CorpusOverrideInvalidValue {
path: path.to_path_buf(),
section,
key: key.to_owned(),
reason: "log_prior must be ≤ 0 (probabilities ≤ 1)",
});
}
Ok(())
}
fn validate_floor(path: &Path, key: &'static str, value: f32) -> Result<(), ConfigError> {
if !value.is_finite() || !(value > 0.0 && value <= 1.0) {
return Err(ConfigError::CorpusOverrideInvalidValue {
path: path.to_path_buf(),
section: "strict_context_overrides",
key: key.to_owned(),
reason: "floor must be in (0.0, 1.0] and finite — `0.0` is rejected because it silently \
makes the strict-context rule a no-op; write a finite small positive (e.g., 0.01) \
for a permissive floor",
});
}
Ok(())
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
fn p() -> PathBuf {
PathBuf::from("test.json")
}
#[test]
fn parses_minimal_audit_marker_only_override() {
let raw = r#"{"schema_version": "marque-corpus-override-1"}"#;
let parsed = parse_corpus_override(raw, p()).unwrap();
assert!(parsed.is_empty());
assert!(parsed.token_overrides.is_empty());
assert!(parsed.template_overrides.is_empty());
assert_eq!(
parsed.strict_context_overrides,
StrictContextOverrides::default()
);
}
#[test]
fn parses_full_override() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"token_overrides": {
"SECRET": { "log_prior": -2.5 },
"NOFORN": { "log_prior": -3.0 }
},
"template_overrides": {
"classification//dissem": { "log_prior": -1.8 }
},
"strict_context_overrides": {
"confidential_floor": 0.95,
"secret_floor": 0.98,
"top_secret_floor": 0.99
}
}"#;
let parsed = parse_corpus_override(raw, p()).unwrap();
assert!(!parsed.is_empty());
assert_eq!(parsed.token_overrides.len(), 2);
assert!((parsed.token_overrides["SECRET"] - (-2.5)).abs() < 1e-5);
assert!((parsed.token_overrides["NOFORN"] - (-3.0)).abs() < 1e-5);
assert_eq!(parsed.template_overrides.len(), 1);
assert!((parsed.template_overrides["classification//dissem"] - (-1.8)).abs() < 1e-5);
assert_eq!(
parsed.strict_context_overrides.confidential_floor,
Some(0.95)
);
assert_eq!(parsed.strict_context_overrides.secret_floor, Some(0.98));
assert_eq!(parsed.strict_context_overrides.top_secret_floor, Some(0.99));
}
#[test]
fn parses_partial_strict_context_overrides() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"strict_context_overrides": { "secret_floor": 0.97 }
}"#;
let parsed = parse_corpus_override(raw, p()).unwrap();
assert_eq!(parsed.strict_context_overrides.confidential_floor, None);
assert_eq!(parsed.strict_context_overrides.secret_floor, Some(0.97));
assert_eq!(parsed.strict_context_overrides.top_secret_floor, None);
}
#[test]
fn rejects_unknown_schema_version() {
let raw = r#"{"schema_version": "marque-corpus-override-99"}"#;
let err = parse_corpus_override(raw, p()).unwrap_err();
match err {
ConfigError::CorpusOverrideSchemaMismatch {
file_version,
expected,
..
} => {
assert_eq!(file_version, "marque-corpus-override-99");
assert_eq!(expected, SCHEMA_VERSION);
}
other => panic!("expected SchemaMismatch, got {other:?}"),
}
}
#[test]
fn rejects_missing_schema_version() {
let raw = r#"{}"#;
assert!(matches!(
parse_corpus_override(raw, p()),
Err(ConfigError::CorpusOverrideParse { .. })
));
}
#[test]
fn rejects_unknown_top_level_field() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"token_override": { "SECRET": { "log_prior": -2.5 } }
}"#;
assert!(matches!(
parse_corpus_override(raw, p()),
Err(ConfigError::CorpusOverrideParse { .. })
));
}
#[test]
fn rejects_unknown_token_entry_field() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"token_overrides": {
"SECRET": { "log_prior": -2.5, "weight": 0.5 }
}
}"#;
assert!(matches!(
parse_corpus_override(raw, p()),
Err(ConfigError::CorpusOverrideParse { .. })
));
}
#[test]
fn rejects_non_finite_log_prior() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"token_overrides": { "SECRET": { "log_prior": 5.0 } }
}"#;
match parse_corpus_override(raw, p()).unwrap_err() {
ConfigError::CorpusOverrideInvalidValue { section, key, .. } => {
assert_eq!(section, "token_overrides");
assert_eq!(key, "SECRET");
}
other => panic!("expected InvalidValue, got {other:?}"),
}
}
#[test]
fn rejects_floor_outside_unit_interval() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"strict_context_overrides": { "secret_floor": 1.5 }
}"#;
match parse_corpus_override(raw, p()).unwrap_err() {
ConfigError::CorpusOverrideInvalidValue { section, key, .. } => {
assert_eq!(section, "strict_context_overrides");
assert_eq!(key, "secret_floor");
}
other => panic!("expected InvalidValue, got {other:?}"),
}
}
#[test]
fn accepts_log_prior_zero_with_slop() {
let raw = r#"{
"schema_version": "marque-corpus-override-1",
"token_overrides": { "SECRET": { "log_prior": 0.0 } }
}"#;
let parsed = parse_corpus_override(raw, p()).unwrap();
assert_eq!(parsed.token_overrides["SECRET"], 0.0);
}
#[test]
fn load_corpus_override_returns_read_error_for_missing_file() {
let tmp = tempfile::tempdir().unwrap();
let bad = tmp.path().join("missing-override.json");
match load_corpus_override(&bad).unwrap_err() {
ConfigError::ReadError { path, .. } => assert_eq!(path, bad),
other => panic!("expected ReadError, got {other:?}"),
}
}
}