1#![allow(clippy::cast_precision_loss)] use crate::{MapletError, MapletResult};
9
10#[derive(Debug, Clone)]
12pub struct ErrorRateController {
13 target_error_rate: f64,
15 current_error_rate: f64,
17 query_count: u64,
19 false_positive_count: u64,
21 collision_tracker: CollisionTracker,
23}
24
25impl ErrorRateController {
26 pub fn new(target_error_rate: f64) -> MapletResult<Self> {
28 if target_error_rate <= 0.0 || target_error_rate >= 1.0 {
29 return Err(MapletError::InvalidErrorRate(target_error_rate));
30 }
31
32 Ok(Self {
33 target_error_rate,
34 current_error_rate: 0.0,
35 query_count: 0,
36 false_positive_count: 0,
37 collision_tracker: CollisionTracker::new(),
38 })
39 }
40
41 pub fn record_query(&mut self, was_false_positive: bool) {
43 self.query_count += 1;
44 if was_false_positive {
45 self.false_positive_count += 1;
46 }
47
48 #[allow(clippy::cast_precision_loss)] {
51 self.current_error_rate = self.false_positive_count as f64 / self.query_count as f64;
52 }
53 }
54
55 pub fn record_collision(&mut self, fingerprint: u64, slot: usize) {
57 self.collision_tracker.record_collision(fingerprint, slot);
58 }
59
60 #[must_use]
62 pub const fn current_error_rate(&self) -> f64 {
63 self.current_error_rate
64 }
65
66 #[must_use]
68 pub const fn target_error_rate(&self) -> f64 {
69 self.target_error_rate
70 }
71
72 #[must_use]
74 pub fn is_error_rate_acceptable(&self) -> bool {
75 self.current_error_rate <= self.target_error_rate * 1.5 }
77
78 #[must_use]
80 pub fn stats(&self) -> ErrorRateStats {
81 ErrorRateStats {
82 target_error_rate: self.target_error_rate,
83 current_error_rate: self.current_error_rate,
84 query_count: self.query_count,
85 false_positive_count: self.false_positive_count,
86 collision_count: self.collision_tracker.collision_count(),
87 max_chain_length: self.collision_tracker.max_chain_length(),
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct CollisionTracker {
95 fingerprint_to_slots: std::collections::HashMap<u64, Vec<usize>>,
97 slot_to_fingerprints: std::collections::HashMap<usize, Vec<u64>>,
99}
100
101impl Default for CollisionTracker {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl CollisionTracker {
108 #[must_use]
110 pub fn new() -> Self {
111 Self {
112 fingerprint_to_slots: std::collections::HashMap::new(),
113 slot_to_fingerprints: std::collections::HashMap::new(),
114 }
115 }
116
117 pub fn record_collision(&mut self, fingerprint: u64, slot: usize) {
119 self.fingerprint_to_slots
120 .entry(fingerprint)
121 .or_default()
122 .push(slot);
123 self.slot_to_fingerprints
124 .entry(slot)
125 .or_default()
126 .push(fingerprint);
127 }
128
129 #[must_use]
131 pub fn collision_count(&self) -> usize {
132 self.slot_to_fingerprints
133 .values()
134 .filter(|fingerprints| fingerprints.len() > 1)
135 .count()
136 }
137
138 #[must_use]
140 pub fn max_chain_length(&self) -> usize {
141 self.slot_to_fingerprints
142 .values()
143 .map(std::vec::Vec::len)
144 .max()
145 .unwrap_or(0)
146 }
147
148 #[must_use]
150 pub fn get_colliding_fingerprints(&self, fingerprint: u64) -> Vec<u64> {
151 self.fingerprint_to_slots
152 .get(&fingerprint)
153 .map_or_else(Vec::new, |slots| {
154 let mut colliding = Vec::new();
155 for &slot in slots {
156 if let Some(fingerprints) = self.slot_to_fingerprints.get(&slot) {
157 for &fp in fingerprints {
158 if fp != fingerprint {
159 colliding.push(fp);
160 }
161 }
162 }
163 }
164 colliding.sort_unstable();
165 colliding.dedup();
166 colliding
167 })
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct StrongMapletValidator {
174 max_chain_length: usize,
176 error_threshold: f64,
178}
179
180impl StrongMapletValidator {
181 #[must_use]
183 pub const fn new(max_chain_length: usize, error_threshold: f64) -> Self {
184 Self {
185 max_chain_length,
186 error_threshold,
187 }
188 }
189
190 pub fn validate_strong_property<V, Op>(
196 &self,
197 collision_tracker: &CollisionTracker,
198 error_rate: f64,
199 ) -> MapletResult<StrongMapletValidation>
200 where
201 V: Clone,
202 Op: crate::operators::MergeOperator<V>,
203 {
204 let max_chain = collision_tracker.max_chain_length();
205 let collision_count = collision_tracker.collision_count();
206
207 let chain_length_ok = max_chain <= self.max_chain_length;
209
210 let error_rate_ok = error_rate <= self.error_threshold;
212
213 let prob_exceed_chain = if max_chain > 0 {
215 error_rate.powi(max_chain as i32)
216 } else {
217 0.0
218 };
219
220 let validation = StrongMapletValidation {
221 chain_length_ok,
222 error_rate_ok,
223 max_chain_length: max_chain,
224 collision_count,
225 error_rate,
226 prob_exceed_chain,
227 is_valid: chain_length_ok && error_rate_ok,
228 };
229
230 if !validation.is_valid {
231 tracing::warn!(
232 "Strong maplet property validation failed: chain_length_ok={}, error_rate_ok={}",
233 chain_length_ok,
234 error_rate_ok
235 );
236 }
237
238 Ok(validation)
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct StrongMapletValidation {
245 pub chain_length_ok: bool,
247 pub error_rate_ok: bool,
249 pub max_chain_length: usize,
251 pub collision_count: usize,
253 pub error_rate: f64,
255 pub prob_exceed_chain: f64,
257 pub is_valid: bool,
259}
260
261#[derive(Debug, Clone)]
263pub struct ErrorRateStats {
264 pub target_error_rate: f64,
265 pub current_error_rate: f64,
266 pub query_count: u64,
267 pub false_positive_count: u64,
268 pub collision_count: usize,
269 pub max_chain_length: usize,
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_error_rate_controller() {
278 let mut controller = ErrorRateController::new(0.01).unwrap();
279
280 for _ in 0..100 {
282 controller.record_query(false);
283 }
284
285 for _ in 0..1 {
287 controller.record_query(true);
288 }
289
290 assert_eq!(controller.query_count, 101);
291 assert_eq!(controller.false_positive_count, 1);
292 assert!((controller.current_error_rate() - 1.0 / 101.0).abs() < 1e-10);
293 assert!(controller.is_error_rate_acceptable());
294 }
295
296 #[test]
297 fn test_collision_tracker() {
298 let mut tracker = CollisionTracker::new();
299
300 tracker.record_collision(0x1234, 0);
302 tracker.record_collision(0x5678, 0); tracker.record_collision(0x9ABC, 1);
304
305 assert_eq!(tracker.collision_count(), 1);
306 assert_eq!(tracker.max_chain_length(), 2);
307
308 let colliding = tracker.get_colliding_fingerprints(0x1234);
309 assert_eq!(colliding, vec![0x5678]);
310 }
311
312 #[test]
313 fn test_strong_maplet_validator() {
314 let validator = StrongMapletValidator::new(5, 0.01);
315 let mut tracker = CollisionTracker::new();
316
317 tracker.record_collision(0x1234, 0);
319 tracker.record_collision(0x5678, 0);
320
321 let validation = validator
322 .validate_strong_property::<u64, crate::operators::CounterOperator>(&tracker, 0.005)
323 .unwrap();
324
325 assert!(validation.chain_length_ok);
326 assert!(validation.error_rate_ok);
327 assert!(validation.is_valid);
328 assert_eq!(validation.max_chain_length, 2);
329 assert_eq!(validation.collision_count, 1);
330 }
331}