Skip to main content

datasynth_generators/
velocity_calibrator.rs

1//! SP3.4 — online velocity-rule calibrator.
2//!
3//! Periodically (every N=10,000 emitted lines), compares current rule-trigger
4//! rates against the target rates derived from a `BehavioralPriors`. For the
5//! most-off rule, proposes a bounded parameter nudge. Hysteresis prevents
6//! direction-flip oscillation.
7
8use std::collections::HashMap;
9
10use chrono::{Datelike, Weekday};
11
12use datasynth_core::models::journal_entry::JournalEntryLine;
13
14/// One adjustment proposed by the calibrator.
15#[derive(Debug, Clone, PartialEq)]
16pub struct CalibrationStep {
17    pub rule_id: String,
18    pub parameter: String,
19    pub delta: f64,
20    pub new_value: f64,
21}
22
23/// Online calibrator. Lives on the generator state.
24pub struct VelocityCalibrator {
25    pub n_lines_between_calibrations: usize,
26    pub target_trigger_rates: HashMap<String, f64>,
27    current_window_counts: HashMap<String, u64>,
28    current_window_total: u64,
29    last_direction: HashMap<String, i32>,
30    windows_since_direction_change: HashMap<String, u32>,
31    pub adjustments: Vec<CalibrationStep>,
32    bounds: HashMap<String, (f64, f64)>,
33    pub current_values: HashMap<String, f64>,
34    step_sizes: HashMap<String, f64>,
35}
36
37impl VelocityCalibrator {
38    pub fn new(target_trigger_rates: HashMap<String, f64>, n_lines: usize) -> Self {
39        let mut bounds: HashMap<String, (f64, f64)> = HashMap::new();
40        bounds.insert("R6".into(), (1.0, 4.0));
41        bounds.insert("R7".into(), (0.0, 0.5));
42        bounds.insert("R8".into(), (0.0, 0.3));
43        bounds.insert("R9".into(), (0.0, 0.5));
44        bounds.insert("R10".into(), (0.0, 0.3));
45        let mut step_sizes: HashMap<String, f64> = HashMap::new();
46        step_sizes.insert("R6".into(), 0.01);
47        step_sizes.insert("R7".into(), 0.005);
48        step_sizes.insert("R8".into(), 0.005);
49        step_sizes.insert("R9".into(), 0.005);
50        step_sizes.insert("R10".into(), 0.005);
51        Self {
52            n_lines_between_calibrations: n_lines.max(1),
53            target_trigger_rates,
54            current_window_counts: HashMap::new(),
55            current_window_total: 0,
56            last_direction: HashMap::new(),
57            windows_since_direction_change: HashMap::new(),
58            adjustments: Vec::new(),
59            bounds,
60            current_values: HashMap::new(),
61            step_sizes,
62        }
63    }
64
65    /// Called per emitted line. Counts trigger events; when window full, may
66    /// return a `CalibrationStep` for the most-off rule.
67    pub fn observe_line(&mut self, line: &JournalEntryLine) -> Option<CalibrationStep> {
68        self.current_window_total += 1;
69        if Self::is_off_hours(line) {
70            *self.current_window_counts.entry("R7".into()).or_insert(0) += 1;
71        }
72        if Self::is_round_dollar(line) {
73            *self.current_window_counts.entry("R9".into()).or_insert(0) += 1;
74        }
75        if self.current_window_total < self.n_lines_between_calibrations as u64 {
76            return None;
77        }
78        let step = self.propose_step();
79        self.current_window_counts.clear();
80        self.current_window_total = 0;
81        step
82    }
83
84    fn propose_step(&mut self) -> Option<CalibrationStep> {
85        let mut worst: Option<(String, f64)> = None;
86        for (rule_id, &target) in &self.target_trigger_rates {
87            let observed = *self.current_window_counts.get(rule_id).unwrap_or(&0) as f64
88                / self.current_window_total.max(1) as f64;
89            let signed = observed - target;
90            let abs = signed.abs();
91            if worst
92                .as_ref()
93                .is_none_or(|(_, prev_abs)| abs > prev_abs.abs())
94            {
95                worst = Some((rule_id.clone(), signed));
96            }
97        }
98        let (rule_id, signed_delta) = worst?;
99        if signed_delta.abs() < 1e-9 {
100            return None;
101        }
102        let step_size = *self.step_sizes.get(&rule_id)?;
103        let (lo, hi) = *self.bounds.get(&rule_id)?;
104        let parameter = match rule_id.as_str() {
105            "R6" => "amounts.lognormal_sigma",
106            "R7" => "posting.off_hours_share",
107            "R8" => "period_close.post_close_share",
108            "R9" => "amounts.round_dollar_share",
109            "R10" => "posting.backdating_share",
110            _ => return None,
111        };
112        let direction = if signed_delta > 0.0 { -1 } else { 1 };
113        let last_dir = *self.last_direction.get(&rule_id).unwrap_or(&0);
114        let windows_since = *self
115            .windows_since_direction_change
116            .get(&rule_id)
117            .unwrap_or(&999);
118        if last_dir != 0 && direction != last_dir && windows_since < 5 {
119            self.windows_since_direction_change
120                .insert(rule_id, windows_since + 1);
121            return None;
122        }
123        let current = *self
124            .current_values
125            .get(parameter)
126            .unwrap_or(&((lo + hi) / 2.0));
127        let new_value = (current + direction as f64 * step_size).clamp(lo, hi);
128        if (new_value - current).abs() < 1e-9 {
129            return None;
130        }
131        self.current_values.insert(parameter.to_string(), new_value);
132        self.last_direction.insert(rule_id.clone(), direction);
133        if last_dir != direction {
134            self.windows_since_direction_change
135                .insert(rule_id.clone(), 0);
136        }
137        let step = CalibrationStep {
138            rule_id,
139            parameter: parameter.to_string(),
140            delta: direction as f64 * step_size,
141            new_value,
142        };
143        self.adjustments.push(step.clone());
144        Some(step)
145    }
146
147    fn is_off_hours(line: &JournalEntryLine) -> bool {
148        if let Some(d) = line.value_date {
149            matches!(d.weekday(), Weekday::Sat | Weekday::Sun)
150        } else {
151            false
152        }
153    }
154
155    fn is_round_dollar(line: &JournalEntryLine) -> bool {
156        let amt = line.local_amount.abs();
157        let amt_f64 = decimal_to_f64(amt);
158        let amt_i = amt_f64.round() as i64;
159        amt_i > 0 && amt_i % 1000 == 0
160    }
161}
162
163fn decimal_to_f64(d: rust_decimal::Decimal) -> f64 {
164    use rust_decimal::prelude::ToPrimitive;
165    d.to_f64().unwrap_or(0.0)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use chrono::NaiveDate;
172    use rust_decimal::Decimal;
173
174    fn line_with_amount(amount: f64) -> JournalEntryLine {
175        use rust_decimal::prelude::FromPrimitive;
176        JournalEntryLine {
177            local_amount: Decimal::from_f64(amount).unwrap_or(Decimal::ZERO),
178            value_date: Some(NaiveDate::from_ymd_opt(2026, 1, 5).unwrap()), // Mon
179            ..JournalEntryLine::default()
180        }
181    }
182
183    #[test]
184    fn calibrator_proposes_step_when_off_target() {
185        let mut targets = HashMap::new();
186        targets.insert("R9".into(), 0.10);
187        let mut cal = VelocityCalibrator::new(targets, 10);
188        cal.current_values
189            .insert("amounts.round_dollar_share".into(), 0.30);
190
191        let mut step: Option<CalibrationStep> = None;
192        // 10 lines, 5 round-dollar (1000) → observed 0.5 vs target 0.10 → propose down.
193        for i in 0..10 {
194            let amt = if i % 2 == 0 { 1000.0 } else { 1500.0 };
195            let s = cal.observe_line(&line_with_amount(amt));
196            if s.is_some() {
197                step = s;
198            }
199        }
200        let s = step.expect("step proposed");
201        assert_eq!(s.rule_id, "R9");
202        assert!(s.delta < 0.0, "expected downward delta, got {}", s.delta);
203        assert!(s.new_value < 0.30);
204    }
205
206    #[test]
207    fn calibrator_no_step_when_at_target() {
208        let mut targets = HashMap::new();
209        targets.insert("R9".into(), 0.5); // We'll feed 5/10 = 0.5 round-dollar.
210        let mut cal = VelocityCalibrator::new(targets, 10);
211        let mut step: Option<CalibrationStep> = None;
212        for i in 0..10 {
213            let amt = if i % 2 == 0 { 1000.0 } else { 1500.0 };
214            let s = cal.observe_line(&line_with_amount(amt));
215            if s.is_some() {
216                step = s;
217            }
218        }
219        // At target → step delta should be tiny / None.
220        assert!(step.is_none() || step.unwrap().delta.abs() < 1e-9);
221    }
222}