use serde::Deserialize;
use std::collections::BTreeMap;
use std::path::Path;
pub(crate) const COEFF_MAX: f64 = 1e9;
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub(crate) struct Coefficients {
#[serde(default = "default_value_coeff")]
pub(crate) value: f64,
#[serde(default = "default_risk_coeff")]
pub(crate) risk: f64,
}
impl Default for Coefficients {
fn default() -> Self {
Self {
value: 1.0,
risk: 2.0,
}
}
}
fn default_value_coeff() -> f64 {
1.0
}
fn default_risk_coeff() -> f64 {
2.0
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub(crate) struct ConsequenceCoeffs {
#[serde(default = "default_dep_coeff")]
pub(crate) dep_coeff: f64,
#[serde(default = "default_ref_coeff")]
pub(crate) ref_coeff: f64,
}
impl Default for ConsequenceCoeffs {
fn default() -> Self {
Self {
dep_coeff: 0.5,
ref_coeff: 1.0,
}
}
}
fn default_dep_coeff() -> f64 {
0.5
}
fn default_ref_coeff() -> f64 {
1.0
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub(crate) struct EstimateCost {
#[serde(default = "default_skew")]
pub(crate) skew: f64,
#[serde(default = "default_margin")]
pub(crate) margin: f64,
}
impl Default for EstimateCost {
fn default() -> Self {
Self {
skew: 0.65,
margin: 1.0,
}
}
}
fn default_skew() -> f64 {
0.65
}
fn default_margin() -> f64 {
1.0
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub(crate) struct PriorityConfig {
#[serde(default)]
pub(crate) coefficients: Coefficients,
#[serde(default)]
pub(crate) kind_weights: BTreeMap<String, f64>,
#[serde(default)]
pub(crate) tag_coefficients: BTreeMap<String, f64>,
#[serde(default)]
pub(crate) consequence: ConsequenceCoeffs,
#[serde(default)]
pub(crate) estimate: EstimateCost,
}
impl PriorityConfig {
pub(crate) fn kind_weight(&self, kind: &str) -> f64 {
self.kind_weights.get(kind).copied().unwrap_or(1.0)
}
pub(crate) fn tag_coeff(&self, tag: &str) -> f64 {
self.tag_coefficients.get(tag).copied().unwrap_or(1.0)
}
}
pub(crate) fn load(root: &Path) -> PriorityConfig {
let Some(table) = read_priority_table(root) else {
return PriorityConfig::default();
};
load_from_table(&table)
}
pub(crate) fn read_priority_table(root: &Path) -> Option<toml::Table> {
let text = std::fs::read_to_string(root.join(crate::dtoml::DOCTRINE_TOML)).ok()?;
let raw: toml::Value = text.parse().ok()?;
raw.get("priority")?.as_table().cloned()
}
pub(crate) fn load_from_table(table: &toml::value::Table) -> PriorityConfig {
let mut cfg = PriorityConfig::default();
if let Some(t) = table.get("coefficients").and_then(|v| v.as_table()) {
cfg.coefficients.value = f64_or(t, "value", 1.0);
cfg.coefficients.risk = f64_or(t, "risk", 2.0);
}
if let Some(t) = table.get("consequence").and_then(|v| v.as_table()) {
cfg.consequence.dep_coeff = f64_or(t, "dep_coeff", 0.5);
cfg.consequence.ref_coeff = f64_or(t, "ref_coeff", 1.0);
}
if let Some(t) = table.get("estimate").and_then(|v| v.as_table()) {
cfg.estimate.skew = f64_or(t, "skew", 0.65);
cfg.estimate.margin = f64_or(t, "margin", 1.0);
}
if let Some(t) = table.get("kind_weights").and_then(|v| v.as_table()) {
for (k, v) in t {
if let Some(f) = f64_val(v) {
cfg.kind_weights.insert(k.clone(), f);
}
}
}
if let Some(t) = table.get("tag_coefficients").and_then(|v| v.as_table()) {
for (k, v) in t {
if let Some(f) = f64_val(v) {
cfg.tag_coefficients.insert(k.clone(), f);
}
}
}
clamp(cfg)
}
#[expect(
clippy::as_conversions,
clippy::cast_precision_loss,
reason = "i64→f64 safe for TOML config coefficients (never near i64::MAX)"
)]
fn f64_val(v: &toml::Value) -> Option<f64> {
v.as_float().or_else(|| v.as_integer().map(|i| i as f64))
}
fn f64_or(table: &toml::value::Table, key: &str, default: f64) -> f64 {
table.get(key).and_then(f64_val).unwrap_or(default)
}
fn clamp(mut cfg: PriorityConfig) -> PriorityConfig {
cfg.coefficients.value = clamp_general(cfg.coefficients.value, 1.0);
cfg.coefficients.risk = clamp_general(cfg.coefficients.risk, 2.0);
cfg.consequence.ref_coeff = clamp_general(cfg.consequence.ref_coeff, 1.0);
cfg.consequence.dep_coeff = clamp_dep(cfg.consequence.dep_coeff);
cfg.estimate.skew = clamp_skew(cfg.estimate.skew);
cfg.estimate.margin = clamp_general(cfg.estimate.margin, 1.0);
for v in cfg.kind_weights.values_mut() {
*v = clamp_general(*v, 1.0);
}
for v in cfg.tag_coefficients.values_mut() {
*v = clamp_general(*v, 1.0);
}
cfg
}
pub(crate) fn clamp_general(value: f64, fallback: f64) -> f64 {
if !value.is_finite() {
return fallback;
}
if value < 0.0 {
return 0.0;
}
if value > COEFF_MAX {
return COEFF_MAX;
}
value
}
pub(crate) fn clamp_dep(value: f64) -> f64 {
if !value.is_finite() {
return 0.5;
}
if value <= 0.0 {
return 0.0;
}
if value > 1.0 {
return 1.0;
}
value
}
pub(crate) fn clamp_skew(value: f64) -> f64 {
if !value.is_finite() {
return 0.65;
}
if value < 0.0 {
return 0.0;
}
if value > 1.0 {
return 1.0;
}
value
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn load_from(body: &str) -> PriorityConfig {
let dir = tempfile::tempdir().unwrap();
let config_dir = dir.path().join(".doctrine");
fs::create_dir_all(&config_dir).unwrap();
fs::write(dir.path().join(crate::dtoml::DOCTRINE_TOML), body).unwrap();
load(dir.path())
}
#[test]
fn missing_priority_section_is_defaults() {
let cfg = load_from("[dispatch]\npreferred-subprocess-harness = \"pi\"\n");
assert_eq!(cfg.coefficients.value, 1.0);
assert_eq!(cfg.coefficients.risk, 2.0);
assert_eq!(cfg.consequence.dep_coeff, 0.5);
assert_eq!(cfg.consequence.ref_coeff, 1.0);
assert!(cfg.kind_weights.is_empty());
assert!(cfg.tag_coefficients.is_empty());
}
#[test]
fn no_doctrine_toml_is_defaults() {
let dir = tempfile::tempdir().unwrap();
let cfg = load(dir.path());
assert_eq!(cfg.coefficients.value, 1.0);
assert_eq!(cfg.coefficients.risk, 2.0);
}
#[test]
fn partial_section_fills_defaults() {
let cfg = load_from("[priority]\nkind_weights = { SL = 2.5 }\n");
assert_eq!(cfg.coefficients.value, 1.0); assert_eq!(cfg.coefficients.risk, 2.0); assert_eq!(cfg.consequence.dep_coeff, 0.5); assert_eq!(cfg.consequence.ref_coeff, 1.0); assert_eq!(cfg.kind_weight("SL"), 2.5);
assert_eq!(cfg.kind_weight("ADR"), 1.0); assert!(cfg.tag_coefficients.is_empty());
}
#[test]
fn unknown_key_ignored() {
let cfg = load_from("[priority]\ncoefficients = { value = 3.0, risk = 4.0, extra = 99 }\n");
assert_eq!(cfg.coefficients.value, 3.0);
assert_eq!(cfg.coefficients.risk, 4.0);
}
#[test]
fn nan_coefficient_clamps_to_default() {
let cfg = load_from("[priority]\ncoefficients = { value = nan, risk = nan }\n");
assert_eq!(cfg.coefficients.value, 1.0);
assert_eq!(cfg.coefficients.risk, 2.0);
}
#[test]
fn inf_coefficient_clamps_to_default() {
let cfg = load_from("[priority]\ncoefficients = { value = inf, risk = -inf }\n");
assert_eq!(cfg.coefficients.value, 1.0);
assert_eq!(cfg.coefficients.risk, 2.0);
}
#[test]
fn negative_coefficient_clamps_to_zero() {
let cfg = load_from("[priority]\ncoefficients = { value = -5.0, risk = -0.1 }\n");
assert_eq!(cfg.coefficients.value, 0.0);
assert_eq!(cfg.coefficients.risk, 0.0);
}
#[test]
fn over_max_coefficient_clamps_to_max() {
let body = format!(
"[priority]\ncoefficients = {{ value = {max}, risk = {max} }}\n",
max = COEFF_MAX + 1.0
);
let cfg = load_from(&body);
assert_eq!(cfg.coefficients.value, COEFF_MAX);
assert_eq!(cfg.coefficients.risk, COEFF_MAX);
}
#[test]
fn dep_coeff_over_one_clamps_to_one() {
let cfg = load_from("[priority]\nconsequence = { dep_coeff = 5.0 }\n");
assert_eq!(cfg.consequence.dep_coeff, 1.0);
}
#[test]
fn dep_coeff_zero_or_negative_clamps_to_zero() {
let cfg = load_from("[priority]\nconsequence = { dep_coeff = 0.0 }\n");
assert_eq!(cfg.consequence.dep_coeff, 0.0);
let cfg2 = load_from("[priority]\nconsequence = { dep_coeff = -0.5 }\n");
assert_eq!(cfg2.consequence.dep_coeff, 0.0);
}
#[test]
fn malformed_toml_in_priority_section_returns_defaults() {
let cfg = load_from("[priority]\ncoefficients = { value = 3.0\n");
assert_eq!(cfg.coefficients.value, 1.0); }
#[test]
fn non_numeric_value_clamps_returns_defaults() {
let cfg = load_from("[priority]\ncoefficients = { value = \"abc\", risk = 4.0 }\n");
assert_eq!(cfg.coefficients.value, 1.0); assert_eq!(cfg.coefficients.risk, 4.0); }
#[test]
fn kind_weight_absent_key_returns_default_one() {
let cfg = PriorityConfig::default();
assert_eq!(cfg.kind_weight("NONEXISTENT"), 1.0);
}
#[test]
fn tag_coeff_absent_key_returns_default_one() {
let cfg = PriorityConfig::default();
assert_eq!(cfg.tag_coeff("nonexistent"), 1.0);
}
#[test]
fn kind_weight_present_key_returns_stored() {
let cfg = load_from("[priority]\nkind_weights = { SL = 3.0, ADR = 1.5 }\n");
assert_eq!(cfg.kind_weight("SL"), 3.0);
assert_eq!(cfg.kind_weight("ADR"), 1.5);
}
#[test]
fn tag_coeff_present_key_returns_stored() {
let cfg = load_from("[priority]\ntag_coefficients = { \"area:risk\" = 2.0 }\n");
assert_eq!(cfg.tag_coeff("area:risk"), 2.0);
}
#[test]
fn estimate_absent_uses_defaults() {
let cfg = load_from("[priority]\ncoefficients = { value = 3.0 }\n");
assert_eq!(cfg.estimate.skew, 0.65);
assert_eq!(cfg.estimate.margin, 1.0);
let dir = tempfile::tempdir().unwrap();
let cfg2 = load(dir.path());
assert_eq!(cfg2.estimate.skew, 0.65);
assert_eq!(cfg2.estimate.margin, 1.0);
}
#[test]
fn estimate_clamps_values() {
let cfg = load_from("[priority]\nestimate = { skew = 1.5, margin = -3 }\n");
assert_eq!(cfg.estimate.skew, 1.0);
assert_eq!(cfg.estimate.margin, 0.0);
let cfg2 = load_from("[priority]\nestimate = { skew = -0.2 }\n");
assert_eq!(cfg2.estimate.skew, 0.0);
let cfg3 = load_from("[priority]\nestimate = { skew = nan, margin = inf }\n");
assert_eq!(cfg3.estimate.skew, 0.65);
assert_eq!(cfg3.estimate.margin, 1.0);
}
#[test]
fn estimate_roundtrip_valid_values() {
let cfg = load_from("[priority]\nestimate = { skew = 0.7, margin = 2 }\n");
assert_eq!(cfg.estimate.skew, 0.7);
assert_eq!(cfg.estimate.margin, 2.0);
}
}