1use crate::{MapletError, MapletResult};
7
8#[derive(Debug, Clone)]
10pub struct ErrorRateController {
11 target_error_rate: f64,
13 current_error_rate: f64,
15 query_count: u64,
17 false_positive_count: u64,
19 collision_tracker: CollisionTracker,
21}
22
23impl ErrorRateController {
24 pub fn new(target_error_rate: f64) -> MapletResult<Self> {
26 if target_error_rate <= 0.0 || target_error_rate >= 1.0 {
27 return Err(MapletError::InvalidErrorRate(target_error_rate));
28 }
29
30 Ok(Self {
31 target_error_rate,
32 current_error_rate: 0.0,
33 query_count: 0,
34 false_positive_count: 0,
35 collision_tracker: CollisionTracker::new(),
36 })
37 }
38
39 pub fn record_query(&mut self, was_false_positive: bool) {
41 self.query_count += 1;
42 if was_false_positive {
43 self.false_positive_count += 1;
44 }
45
46 self.current_error_rate = self.false_positive_count as f64 / self.query_count as f64;
48 }
49
50 pub fn record_collision(&mut self, fingerprint: u64, slot: usize) {
52 self.collision_tracker.record_collision(fingerprint, slot);
53 }
54
55 pub fn current_error_rate(&self) -> f64 {
57 self.current_error_rate
58 }
59
60 pub fn target_error_rate(&self) -> f64 {
62 self.target_error_rate
63 }
64
65 pub fn is_error_rate_acceptable(&self) -> bool {
67 self.current_error_rate <= self.target_error_rate * 1.5 }
69
70 pub fn stats(&self) -> ErrorRateStats {
72 ErrorRateStats {
73 target_error_rate: self.target_error_rate,
74 current_error_rate: self.current_error_rate,
75 query_count: self.query_count,
76 false_positive_count: self.false_positive_count,
77 collision_count: self.collision_tracker.collision_count(),
78 max_chain_length: self.collision_tracker.max_chain_length(),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct CollisionTracker {
86 fingerprint_to_slots: std::collections::HashMap<u64, Vec<usize>>,
88 slot_to_fingerprints: std::collections::HashMap<usize, Vec<u64>>,
90}
91
92impl CollisionTracker {
93 pub fn new() -> Self {
95 Self {
96 fingerprint_to_slots: std::collections::HashMap::new(),
97 slot_to_fingerprints: std::collections::HashMap::new(),
98 }
99 }
100
101 pub fn record_collision(&mut self, fingerprint: u64, slot: usize) {
103 self.fingerprint_to_slots.entry(fingerprint).or_default().push(slot);
104 self.slot_to_fingerprints.entry(slot).or_default().push(fingerprint);
105 }
106
107 pub fn collision_count(&self) -> usize {
109 self.slot_to_fingerprints.values()
110 .filter(|fingerprints| fingerprints.len() > 1)
111 .count()
112 }
113
114 pub fn max_chain_length(&self) -> usize {
116 self.slot_to_fingerprints.values()
117 .map(|fingerprints| fingerprints.len())
118 .max()
119 .unwrap_or(0)
120 }
121
122 pub fn get_colliding_fingerprints(&self, fingerprint: u64) -> Vec<u64> {
124 if let Some(slots) = self.fingerprint_to_slots.get(&fingerprint) {
125 let mut colliding = Vec::new();
126 for &slot in slots {
127 if let Some(fingerprints) = self.slot_to_fingerprints.get(&slot) {
128 for &fp in fingerprints {
129 if fp != fingerprint {
130 colliding.push(fp);
131 }
132 }
133 }
134 }
135 colliding.sort_unstable();
136 colliding.dedup();
137 colliding
138 } else {
139 Vec::new()
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct StrongMapletValidator {
147 max_chain_length: usize,
149 error_threshold: f64,
151}
152
153impl StrongMapletValidator {
154 pub fn new(max_chain_length: usize, error_threshold: f64) -> Self {
156 Self {
157 max_chain_length,
158 error_threshold,
159 }
160 }
161
162 pub fn validate_strong_property<V, Op>(
168 &self,
169 collision_tracker: &CollisionTracker,
170 error_rate: f64,
171 ) -> MapletResult<StrongMapletValidation>
172 where
173 V: Clone,
174 Op: crate::operators::MergeOperator<V>,
175 {
176 let max_chain = collision_tracker.max_chain_length();
177 let collision_count = collision_tracker.collision_count();
178
179 let chain_length_ok = max_chain <= self.max_chain_length;
181
182 let error_rate_ok = error_rate <= self.error_threshold;
184
185 let prob_exceed_chain = if max_chain > 0 {
187 error_rate.powi(max_chain as i32)
188 } else {
189 0.0
190 };
191
192 let validation = StrongMapletValidation {
193 chain_length_ok,
194 error_rate_ok,
195 max_chain_length: max_chain,
196 collision_count,
197 error_rate,
198 prob_exceed_chain,
199 is_valid: chain_length_ok && error_rate_ok,
200 };
201
202 if !validation.is_valid {
203 tracing::warn!(
204 "Strong maplet property validation failed: chain_length_ok={}, error_rate_ok={}",
205 chain_length_ok, error_rate_ok
206 );
207 }
208
209 Ok(validation)
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct StrongMapletValidation {
216 pub chain_length_ok: bool,
218 pub error_rate_ok: bool,
220 pub max_chain_length: usize,
222 pub collision_count: usize,
224 pub error_rate: f64,
226 pub prob_exceed_chain: f64,
228 pub is_valid: bool,
230}
231
232#[derive(Debug, Clone)]
234pub struct ErrorRateStats {
235 pub target_error_rate: f64,
236 pub current_error_rate: f64,
237 pub query_count: u64,
238 pub false_positive_count: u64,
239 pub collision_count: usize,
240 pub max_chain_length: usize,
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_error_rate_controller() {
249 let mut controller = ErrorRateController::new(0.01).unwrap();
250
251 for _ in 0..100 {
253 controller.record_query(false);
254 }
255
256 for _ in 0..1 {
258 controller.record_query(true);
259 }
260
261 assert_eq!(controller.query_count, 101);
262 assert_eq!(controller.false_positive_count, 1);
263 assert!((controller.current_error_rate() - 1.0/101.0).abs() < 1e-10);
264 assert!(controller.is_error_rate_acceptable());
265 }
266
267 #[test]
268 fn test_collision_tracker() {
269 let mut tracker = CollisionTracker::new();
270
271 tracker.record_collision(0x1234, 0);
273 tracker.record_collision(0x5678, 0); tracker.record_collision(0x9ABC, 1);
275
276 assert_eq!(tracker.collision_count(), 1);
277 assert_eq!(tracker.max_chain_length(), 2);
278
279 let colliding = tracker.get_colliding_fingerprints(0x1234);
280 assert_eq!(colliding, vec![0x5678]);
281 }
282
283 #[test]
284 fn test_strong_maplet_validator() {
285 let validator = StrongMapletValidator::new(5, 0.01);
286 let mut tracker = CollisionTracker::new();
287
288 tracker.record_collision(0x1234, 0);
290 tracker.record_collision(0x5678, 0);
291
292 let validation = validator.validate_strong_property::<u64, crate::operators::CounterOperator>(&tracker, 0.005).unwrap();
293
294 assert!(validation.chain_length_ok);
295 assert!(validation.error_rate_ok);
296 assert!(validation.is_valid);
297 assert_eq!(validation.max_chain_length, 2);
298 assert_eq!(validation.collision_count, 1);
299 }
300}
301