1use std::collections::HashMap;
9
10use chrono::{Datelike, Weekday};
11
12use datasynth_core::models::journal_entry::JournalEntryLine;
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct CalibrationStep {
17 pub rule_id: String,
18 pub parameter: String,
19 pub delta: f64,
20 pub new_value: f64,
21}
22
23pub struct VelocityCalibrator {
25 pub n_lines_between_calibrations: usize,
26 pub target_trigger_rates: HashMap<String, f64>,
27 current_window_counts: HashMap<String, u64>,
28 current_window_total: u64,
29 last_direction: HashMap<String, i32>,
30 windows_since_direction_change: HashMap<String, u32>,
31 pub adjustments: Vec<CalibrationStep>,
32 bounds: HashMap<String, (f64, f64)>,
33 pub current_values: HashMap<String, f64>,
34 step_sizes: HashMap<String, f64>,
35}
36
37impl VelocityCalibrator {
38 pub fn new(target_trigger_rates: HashMap<String, f64>, n_lines: usize) -> Self {
39 let mut bounds: HashMap<String, (f64, f64)> = HashMap::new();
40 bounds.insert("R6".into(), (1.0, 4.0));
41 bounds.insert("R7".into(), (0.0, 0.5));
42 bounds.insert("R8".into(), (0.0, 0.3));
43 bounds.insert("R9".into(), (0.0, 0.5));
44 bounds.insert("R10".into(), (0.0, 0.3));
45 let mut step_sizes: HashMap<String, f64> = HashMap::new();
46 step_sizes.insert("R6".into(), 0.01);
47 step_sizes.insert("R7".into(), 0.005);
48 step_sizes.insert("R8".into(), 0.005);
49 step_sizes.insert("R9".into(), 0.005);
50 step_sizes.insert("R10".into(), 0.005);
51 Self {
52 n_lines_between_calibrations: n_lines.max(1),
53 target_trigger_rates,
54 current_window_counts: HashMap::new(),
55 current_window_total: 0,
56 last_direction: HashMap::new(),
57 windows_since_direction_change: HashMap::new(),
58 adjustments: Vec::new(),
59 bounds,
60 current_values: HashMap::new(),
61 step_sizes,
62 }
63 }
64
65 pub fn observe_line(&mut self, line: &JournalEntryLine) -> Option<CalibrationStep> {
68 self.current_window_total += 1;
69 if Self::is_off_hours(line) {
70 *self.current_window_counts.entry("R7".into()).or_insert(0) += 1;
71 }
72 if Self::is_round_dollar(line) {
73 *self.current_window_counts.entry("R9".into()).or_insert(0) += 1;
74 }
75 if self.current_window_total < self.n_lines_between_calibrations as u64 {
76 return None;
77 }
78 let step = self.propose_step();
79 self.current_window_counts.clear();
80 self.current_window_total = 0;
81 step
82 }
83
84 fn propose_step(&mut self) -> Option<CalibrationStep> {
85 let mut worst: Option<(String, f64)> = None;
86 for (rule_id, &target) in &self.target_trigger_rates {
87 let observed = *self.current_window_counts.get(rule_id).unwrap_or(&0) as f64
88 / self.current_window_total.max(1) as f64;
89 let signed = observed - target;
90 let abs = signed.abs();
91 if worst
92 .as_ref()
93 .is_none_or(|(_, prev_abs)| abs > prev_abs.abs())
94 {
95 worst = Some((rule_id.clone(), signed));
96 }
97 }
98 let (rule_id, signed_delta) = worst?;
99 if signed_delta.abs() < 1e-9 {
100 return None;
101 }
102 let step_size = *self.step_sizes.get(&rule_id)?;
103 let (lo, hi) = *self.bounds.get(&rule_id)?;
104 let parameter = match rule_id.as_str() {
105 "R6" => "amounts.lognormal_sigma",
106 "R7" => "posting.off_hours_share",
107 "R8" => "period_close.post_close_share",
108 "R9" => "amounts.round_dollar_share",
109 "R10" => "posting.backdating_share",
110 _ => return None,
111 };
112 let direction = if signed_delta > 0.0 { -1 } else { 1 };
113 let last_dir = *self.last_direction.get(&rule_id).unwrap_or(&0);
114 let windows_since = *self
115 .windows_since_direction_change
116 .get(&rule_id)
117 .unwrap_or(&999);
118 if last_dir != 0 && direction != last_dir && windows_since < 5 {
119 self.windows_since_direction_change
120 .insert(rule_id, windows_since + 1);
121 return None;
122 }
123 let current = *self
124 .current_values
125 .get(parameter)
126 .unwrap_or(&((lo + hi) / 2.0));
127 let new_value = (current + direction as f64 * step_size).clamp(lo, hi);
128 if (new_value - current).abs() < 1e-9 {
129 return None;
130 }
131 self.current_values.insert(parameter.to_string(), new_value);
132 self.last_direction.insert(rule_id.clone(), direction);
133 if last_dir != direction {
134 self.windows_since_direction_change
135 .insert(rule_id.clone(), 0);
136 }
137 let step = CalibrationStep {
138 rule_id,
139 parameter: parameter.to_string(),
140 delta: direction as f64 * step_size,
141 new_value,
142 };
143 self.adjustments.push(step.clone());
144 Some(step)
145 }
146
147 fn is_off_hours(line: &JournalEntryLine) -> bool {
148 if let Some(d) = line.value_date {
149 matches!(d.weekday(), Weekday::Sat | Weekday::Sun)
150 } else {
151 false
152 }
153 }
154
155 fn is_round_dollar(line: &JournalEntryLine) -> bool {
156 let amt = line.local_amount.abs();
157 let amt_f64 = decimal_to_f64(amt);
158 let amt_i = amt_f64.round() as i64;
159 amt_i > 0 && amt_i % 1000 == 0
160 }
161}
162
163fn decimal_to_f64(d: rust_decimal::Decimal) -> f64 {
164 use rust_decimal::prelude::ToPrimitive;
165 d.to_f64().unwrap_or(0.0)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use chrono::NaiveDate;
172 use rust_decimal::Decimal;
173
174 fn line_with_amount(amount: f64) -> JournalEntryLine {
175 use rust_decimal::prelude::FromPrimitive;
176 JournalEntryLine {
177 local_amount: Decimal::from_f64(amount).unwrap_or(Decimal::ZERO),
178 value_date: Some(NaiveDate::from_ymd_opt(2026, 1, 5).unwrap()), ..JournalEntryLine::default()
180 }
181 }
182
183 #[test]
184 fn calibrator_proposes_step_when_off_target() {
185 let mut targets = HashMap::new();
186 targets.insert("R9".into(), 0.10);
187 let mut cal = VelocityCalibrator::new(targets, 10);
188 cal.current_values
189 .insert("amounts.round_dollar_share".into(), 0.30);
190
191 let mut step: Option<CalibrationStep> = None;
192 for i in 0..10 {
194 let amt = if i % 2 == 0 { 1000.0 } else { 1500.0 };
195 let s = cal.observe_line(&line_with_amount(amt));
196 if s.is_some() {
197 step = s;
198 }
199 }
200 let s = step.expect("step proposed");
201 assert_eq!(s.rule_id, "R9");
202 assert!(s.delta < 0.0, "expected downward delta, got {}", s.delta);
203 assert!(s.new_value < 0.30);
204 }
205
206 #[test]
207 fn calibrator_no_step_when_at_target() {
208 let mut targets = HashMap::new();
209 targets.insert("R9".into(), 0.5); let mut cal = VelocityCalibrator::new(targets, 10);
211 let mut step: Option<CalibrationStep> = None;
212 for i in 0..10 {
213 let amt = if i % 2 == 0 { 1000.0 } else { 1500.0 };
214 let s = cal.observe_line(&line_with_amount(amt));
215 if s.is_some() {
216 step = s;
217 }
218 }
219 assert!(step.is_none() || step.unwrap().delta.abs() < 1e-9);
221 }
222}