1use std::collections::{HashMap, HashSet};
2use std::str::FromStr;
3
4use uuid::Uuid;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum TrackingMode {
9 None,
11 Table,
13 Row,
15 #[default]
17 Adaptive,
18}
19
20impl TrackingMode {
21 pub fn as_str(&self) -> &'static str {
23 match self {
24 Self::None => "none",
25 Self::Table => "table",
26 Self::Row => "row",
27 Self::Adaptive => "adaptive",
28 }
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct ParseTrackingModeError(pub String);
34
35impl std::fmt::Display for ParseTrackingModeError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "invalid tracking mode: {}", self.0)
38 }
39}
40
41impl std::error::Error for ParseTrackingModeError {}
42
43impl FromStr for TrackingMode {
44 type Err = ParseTrackingModeError;
45
46 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
47 match s.to_lowercase().as_str() {
48 "none" => Ok(Self::None),
49 "table" => Ok(Self::Table),
50 "row" => Ok(Self::Row),
51 "adaptive" => Ok(Self::Adaptive),
52 _ => Err(ParseTrackingModeError(s.to_string())),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
61pub struct BloomFilter {
62 bits: Vec<u64>,
63 num_hashes: u32,
64 num_bits: u64,
65}
66
67impl BloomFilter {
68 pub fn new(expected_items: usize) -> Self {
70 let num_bits = (expected_items as u64 * 10).max(64);
72 let num_words = num_bits.div_ceil(64) as usize;
73 let num_hashes = 7;
75
76 Self {
77 bits: vec![0u64; num_words],
78 num_hashes,
79 num_bits,
80 }
81 }
82
83 pub fn insert(&mut self, item: Uuid) {
85 let bytes = item.as_bytes();
86 for i in 0..self.num_hashes {
87 let idx = self.hash(bytes, i);
88 let word = (idx / 64) as usize;
89 let bit = idx % 64;
90 if let Some(w) = self.bits.get_mut(word) {
91 *w |= 1u64 << bit;
92 }
93 }
94 }
95
96 pub fn might_contain(&self, item: Uuid) -> bool {
98 let bytes = item.as_bytes();
99 for i in 0..self.num_hashes {
100 let idx = self.hash(bytes, i);
101 let word = (idx / 64) as usize;
102 let bit = idx % 64;
103 match self.bits.get(word) {
104 Some(w) if (w >> bit) & 1 == 1 => continue,
105 _ => return false,
106 }
107 }
108 true
109 }
110
111 fn hash(&self, bytes: &[u8; 16], seed: u32) -> u64 {
112 let h1 = u64::from_le_bytes([
114 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
115 ]);
116 let h2 = u64::from_le_bytes([
117 bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
118 ]);
119 h1.wrapping_add((seed as u64).wrapping_mul(h2)) % self.num_bits
120 }
121
122 pub fn memory_bytes(&self) -> usize {
124 self.bits.len() * 8 + 16
125 }
126}
127
128#[derive(Debug, Clone, Default)]
130pub struct ReadSet {
131 pub tables: Vec<String>,
133 pub row_filter: HashMap<String, BloomFilter>,
135 pub row_counts: HashMap<String, usize>,
137 pub filter_columns: HashMap<String, HashSet<String>>,
139 pub mode: TrackingMode,
141}
142
143impl ReadSet {
144 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn table_level() -> Self {
151 Self {
152 mode: TrackingMode::Table,
153 ..Default::default()
154 }
155 }
156
157 pub fn row_level() -> Self {
159 Self {
160 mode: TrackingMode::Row,
161 ..Default::default()
162 }
163 }
164
165 pub fn add_table(&mut self, table: impl Into<String>) {
167 let table = table.into();
168 if !self.tables.contains(&table) {
169 self.tables.push(table);
170 }
171 }
172
173 pub fn add_row(&mut self, table: impl Into<String>, row_id: Uuid) {
175 let table = table.into();
176 if !self.tables.contains(&table) {
177 self.tables.push(table.clone());
178 }
179 let filter = self
180 .row_filter
181 .entry(table.clone())
182 .or_insert_with(|| BloomFilter::new(1000));
183 filter.insert(row_id);
184 *self.row_counts.entry(table).or_insert(0) += 1;
185 }
186
187 pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
189 self.filter_columns
190 .entry(table.into())
191 .or_default()
192 .insert(column.into());
193 }
194
195 pub fn includes_table(&self, table: &str) -> bool {
197 self.tables.iter().any(|t| t == table)
198 }
199
200 pub fn includes_row(&self, table: &str, row_id: Uuid) -> bool {
202 if !self.includes_table(table) {
203 return false;
204 }
205
206 if self.mode == TrackingMode::Table {
207 return true;
208 }
209
210 if let Some(filter) = self.row_filter.get(table) {
211 filter.might_contain(row_id)
212 } else {
213 true
215 }
216 }
217
218 pub fn memory_bytes(&self) -> usize {
220 let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
221 let filter_bytes: usize = self.row_filter.values().map(|f| f.memory_bytes()).sum();
222 let col_bytes = self
223 .filter_columns
224 .values()
225 .map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
226 .sum::<usize>();
227
228 table_bytes + filter_bytes + col_bytes + 64
229 }
230
231 pub fn row_count(&self) -> usize {
233 self.row_counts.values().sum()
234 }
235
236 pub fn merge(&mut self, other: &ReadSet) {
238 for table in &other.tables {
239 if !self.tables.contains(table) {
240 self.tables.push(table.clone());
241 }
242 }
243
244 for (table, columns) in &other.filter_columns {
245 self.filter_columns
246 .entry(table.clone())
247 .or_default()
248 .extend(columns.iter().cloned());
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255pub enum ChangeOperation {
256 Insert,
258 Update,
260 Delete,
262}
263
264impl ChangeOperation {
265 pub fn as_str(&self) -> &'static str {
267 match self {
268 Self::Insert => "INSERT",
269 Self::Update => "UPDATE",
270 Self::Delete => "DELETE",
271 }
272 }
273}
274
275#[derive(Debug, Clone, PartialEq, Eq)]
276pub struct ParseChangeOperationError(pub String);
277
278impl std::fmt::Display for ParseChangeOperationError {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 write!(f, "invalid change operation: {}", self.0)
281 }
282}
283
284impl std::error::Error for ParseChangeOperationError {}
285
286impl FromStr for ChangeOperation {
287 type Err = ParseChangeOperationError;
288
289 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
290 match s.to_uppercase().as_str() {
291 "INSERT" | "I" => Ok(Self::Insert),
292 "UPDATE" | "U" => Ok(Self::Update),
293 "DELETE" | "D" => Ok(Self::Delete),
294 _ => Err(ParseChangeOperationError(s.to_string())),
295 }
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct Change {
302 pub table: String,
304 pub operation: ChangeOperation,
306 pub row_id: Option<Uuid>,
308 pub changed_columns: Vec<String>,
310}
311
312impl Change {
313 pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
315 Self {
316 table: table.into(),
317 operation,
318 row_id: None,
319 changed_columns: Vec::new(),
320 }
321 }
322
323 pub fn with_row_id(mut self, row_id: Uuid) -> Self {
325 self.row_id = Some(row_id);
326 self
327 }
328
329 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
331 self.changed_columns = columns;
332 self
333 }
334
335 pub fn invalidates(&self, read_set: &ReadSet) -> bool {
338 if !read_set.includes_table(&self.table) {
339 return false;
340 }
341
342 if read_set.mode == TrackingMode::Row
344 && let Some(row_id) = self.row_id
345 {
346 match self.operation {
347 ChangeOperation::Update | ChangeOperation::Delete => {
348 return read_set.includes_row(&self.table, row_id);
349 }
350 ChangeOperation::Insert => {}
352 }
353 }
354
355 true
356 }
357
358 pub fn invalidates_columns(&self, selected_columns: &[&str]) -> bool {
361 if self.changed_columns.is_empty() || selected_columns.is_empty() {
363 return true;
364 }
365
366 if self.operation != ChangeOperation::Update {
368 return true;
369 }
370
371 self.changed_columns
372 .iter()
373 .any(|c| selected_columns.contains(&c.as_str()))
374 }
375}
376
377#[cfg(test)]
378#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_tracking_mode_conversion() {
384 assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
385 assert_eq!("row".parse::<TrackingMode>(), Ok(TrackingMode::Row));
386 assert_eq!(
387 "adaptive".parse::<TrackingMode>(),
388 Ok(TrackingMode::Adaptive)
389 );
390 assert!("invalid".parse::<TrackingMode>().is_err());
391 }
392
393 #[test]
394 fn test_read_set_add_table() {
395 let mut read_set = ReadSet::new();
396 read_set.add_table("projects");
397
398 assert!(read_set.includes_table("projects"));
399 assert!(!read_set.includes_table("users"));
400 }
401
402 #[test]
403 fn test_read_set_add_row() {
404 let mut read_set = ReadSet::row_level();
405 let row_id = Uuid::new_v4();
406 read_set.add_row("projects", row_id);
407
408 assert!(read_set.includes_table("projects"));
409 assert!(read_set.includes_row("projects", row_id));
410 }
413
414 #[test]
415 fn test_bloom_filter_no_false_negatives() {
416 let mut filter = BloomFilter::new(100);
417 let ids: Vec<Uuid> = (0..100).map(|_| Uuid::new_v4()).collect();
418
419 for id in &ids {
420 filter.insert(*id);
421 }
422
423 for id in &ids {
424 assert!(
425 filter.might_contain(*id),
426 "bloom filter should never miss an inserted item"
427 );
428 }
429 }
430
431 #[test]
432 fn test_bloom_filter_false_positive_rate() {
433 let mut filter = BloomFilter::new(1000);
434 let inserted: Vec<Uuid> = (0..1000).map(|_| Uuid::new_v4()).collect();
435 for id in &inserted {
436 filter.insert(*id);
437 }
438
439 let not_inserted: Vec<Uuid> = (0..10000).map(|_| Uuid::new_v4()).collect();
440 let false_positives = not_inserted
441 .iter()
442 .filter(|id| filter.might_contain(**id))
443 .count();
444
445 assert!(
447 false_positives < 200,
448 "false positive rate too high: {}/10000",
449 false_positives
450 );
451 }
452
453 #[test]
454 fn test_change_invalidates_table_level() {
455 let mut read_set = ReadSet::table_level();
456 read_set.add_table("projects");
457
458 let change = Change::new("projects", ChangeOperation::Insert);
459 assert!(change.invalidates(&read_set));
460
461 let change = Change::new("users", ChangeOperation::Insert);
462 assert!(!change.invalidates(&read_set));
463 }
464
465 #[test]
466 fn test_change_invalidates_row_level() {
467 let mut read_set = ReadSet::row_level();
468 let tracked_id = Uuid::new_v4();
469 read_set.add_row("projects", tracked_id);
470
471 let change = Change::new("projects", ChangeOperation::Update).with_row_id(tracked_id);
473 assert!(change.invalidates(&read_set));
474
475 let other_id = Uuid::new_v4();
477 let change = Change::new("projects", ChangeOperation::Insert).with_row_id(other_id);
478 assert!(change.invalidates(&read_set));
479 }
480
481 #[test]
482 fn test_column_invalidation() {
483 let change = Change::new("users", ChangeOperation::Update)
484 .with_columns(vec!["name".to_string(), "email".to_string()]);
485
486 assert!(change.invalidates_columns(&["name", "age"]));
488
489 assert!(!change.invalidates_columns(&["age", "phone"]));
491
492 assert!(change.invalidates_columns(&[]));
494 }
495
496 #[test]
497 fn test_column_invalidation_non_update() {
498 let change =
500 Change::new("users", ChangeOperation::Insert).with_columns(vec!["name".to_string()]);
501 assert!(change.invalidates_columns(&["age"]));
502 }
503
504 #[test]
505 fn test_read_set_merge() {
506 let mut read_set1 = ReadSet::new();
507 read_set1.add_table("projects");
508
509 let mut read_set2 = ReadSet::new();
510 read_set2.add_table("users");
511
512 read_set1.merge(&read_set2);
513
514 assert!(read_set1.includes_table("projects"));
515 assert!(read_set1.includes_table("users"));
516 }
517}