1use std::collections::{HashMap, HashSet};
2use std::str::FromStr;
3
4use uuid::Uuid;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8#[non_exhaustive]
9pub enum TrackingMode {
10 None,
12 #[default]
14 Table,
15}
16
17impl TrackingMode {
18 pub fn as_str(&self) -> &'static str {
20 match self {
21 Self::None => "none",
22 Self::Table => "table",
23 }
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct ParseTrackingModeError(pub String);
29
30impl std::fmt::Display for ParseTrackingModeError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "invalid tracking mode: {}", self.0)
33 }
34}
35
36impl std::error::Error for ParseTrackingModeError {}
37
38impl FromStr for TrackingMode {
39 type Err = ParseTrackingModeError;
40
41 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
42 match s.to_lowercase().as_str() {
43 "none" => Ok(Self::None),
44 "table" => Ok(Self::Table),
45 _ => Err(ParseTrackingModeError(s.to_string())),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default)]
52pub struct ReadSet {
53 pub tables: Vec<String>,
54 pub filter_columns: HashMap<String, HashSet<String>>,
55 pub mode: TrackingMode,
56}
57
58impl ReadSet {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn table_level() -> Self {
65 Self {
66 mode: TrackingMode::Table,
67 ..Default::default()
68 }
69 }
70
71 pub fn add_table(&mut self, table: impl Into<String>) {
72 let table = table.into();
73 if !self.tables.contains(&table) {
74 self.tables.push(table);
75 }
76 }
77
78 pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
79 self.filter_columns
80 .entry(table.into())
81 .or_default()
82 .insert(column.into());
83 }
84
85 pub fn includes_table(&self, table: &str) -> bool {
86 self.tables.iter().any(|t| t == table)
87 }
88
89 pub fn memory_bytes(&self) -> usize {
90 let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
91 let col_bytes = self
92 .filter_columns
93 .values()
94 .map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
95 .sum::<usize>();
96
97 table_bytes + col_bytes + 64
98 }
99
100 pub fn merge(&mut self, other: &ReadSet) {
101 for table in &other.tables {
102 if !self.tables.contains(table) {
103 self.tables.push(table.clone());
104 }
105 }
106
107 for (table, columns) in &other.filter_columns {
108 self.filter_columns
109 .entry(table.clone())
110 .or_default()
111 .extend(columns.iter().cloned());
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118#[non_exhaustive]
119pub enum ChangeOperation {
120 Insert,
121 Update,
122 Delete,
123}
124
125impl ChangeOperation {
126 pub fn as_str(&self) -> &'static str {
127 match self {
128 Self::Insert => "INSERT",
129 Self::Update => "UPDATE",
130 Self::Delete => "DELETE",
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub struct ParseChangeOperationError(pub String);
137
138impl std::fmt::Display for ParseChangeOperationError {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write!(f, "invalid change operation: {}", self.0)
141 }
142}
143
144impl std::error::Error for ParseChangeOperationError {}
145
146impl FromStr for ChangeOperation {
147 type Err = ParseChangeOperationError;
148
149 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
150 match s.to_uppercase().as_str() {
151 "INSERT" | "I" => Ok(Self::Insert),
152 "UPDATE" | "U" => Ok(Self::Update),
153 "DELETE" | "D" => Ok(Self::Delete),
154 _ => Err(ParseChangeOperationError(s.to_string())),
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct Change {
162 pub table: String,
163 pub operation: ChangeOperation,
164 pub row_id: Option<Uuid>,
165 pub changed_columns: Vec<String>,
167}
168
169impl Change {
170 pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
171 Self {
172 table: table.into(),
173 operation,
174 row_id: None,
175 changed_columns: Vec::new(),
176 }
177 }
178
179 pub fn with_row_id(mut self, row_id: Uuid) -> Self {
180 self.row_id = Some(row_id);
181 self
182 }
183
184 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
185 self.changed_columns = columns;
186 self
187 }
188
189 pub fn invalidates(&self, read_set: &ReadSet) -> bool {
192 read_set.includes_table(&self.table)
193 }
194
195 pub fn invalidates_columns(&self, selected_columns: &[&str]) -> bool {
198 if self.changed_columns.is_empty() || selected_columns.is_empty() {
200 return true;
201 }
202
203 if self.operation != ChangeOperation::Update {
205 return true;
206 }
207
208 self.changed_columns
209 .iter()
210 .any(|c| selected_columns.contains(&c.as_str()))
211 }
212}
213
214#[cfg(test)]
215#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_tracking_mode_conversion() {
221 assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
222 assert!("invalid".parse::<TrackingMode>().is_err());
223 }
224
225 #[test]
226 fn test_read_set_add_table() {
227 let mut read_set = ReadSet::new();
228 read_set.add_table("projects");
229
230 assert!(read_set.includes_table("projects"));
231 assert!(!read_set.includes_table("users"));
232 }
233
234 #[test]
235 fn test_change_invalidates_table_level() {
236 let mut read_set = ReadSet::table_level();
237 read_set.add_table("projects");
238
239 let change = Change::new("projects", ChangeOperation::Insert);
240 assert!(change.invalidates(&read_set));
241
242 let change = Change::new("users", ChangeOperation::Insert);
243 assert!(!change.invalidates(&read_set));
244 }
245
246 #[test]
247 fn test_column_invalidation() {
248 let change = Change::new("users", ChangeOperation::Update)
249 .with_columns(vec!["name".to_string(), "email".to_string()]);
250
251 assert!(change.invalidates_columns(&["name", "age"]));
252 assert!(!change.invalidates_columns(&["age", "phone"]));
253 assert!(change.invalidates_columns(&[]));
254 }
255
256 #[test]
257 fn test_column_invalidation_non_update() {
258 let change =
259 Change::new("users", ChangeOperation::Insert).with_columns(vec!["name".to_string()]);
260 assert!(change.invalidates_columns(&["age"]));
261 }
262
263 #[test]
264 fn test_read_set_merge() {
265 let mut read_set1 = ReadSet::new();
266 read_set1.add_table("projects");
267
268 let mut read_set2 = ReadSet::new();
269 read_set2.add_table("users");
270
271 read_set1.merge(&read_set2);
272
273 assert!(read_set1.includes_table("projects"));
274 assert!(read_set1.includes_table("users"));
275 }
276
277 #[test]
278 fn tracking_mode_default_is_table() {
279 assert_eq!(TrackingMode::default(), TrackingMode::Table);
282 }
283
284 #[test]
285 fn tracking_mode_as_str_round_trips() {
286 for mode in [TrackingMode::None, TrackingMode::Table] {
287 assert_eq!(mode.as_str().parse::<TrackingMode>(), Ok(mode));
288 }
289 }
290
291 #[test]
292 fn tracking_mode_parse_is_case_insensitive() {
293 assert_eq!("NONE".parse::<TrackingMode>(), Ok(TrackingMode::None));
294 assert_eq!("Table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
295 assert_eq!("TaBlE".parse::<TrackingMode>(), Ok(TrackingMode::Table));
296 }
297
298 #[test]
299 fn tracking_mode_parse_error_preserves_original_input() {
300 let err = "row".parse::<TrackingMode>().unwrap_err();
301 assert_eq!(err, ParseTrackingModeError("row".to_string()));
302 assert_eq!(err.to_string(), "invalid tracking mode: row");
303 }
304
305 #[test]
306 fn add_table_is_idempotent() {
307 let mut rs = ReadSet::new();
308 rs.add_table("projects");
309 rs.add_table("projects");
310 rs.add_table("projects");
311 assert_eq!(rs.tables, vec!["projects".to_string()]);
312 }
313
314 #[test]
315 fn add_filter_column_accumulates_per_table() {
316 let mut rs = ReadSet::new();
317 rs.add_filter_column("users", "id");
318 rs.add_filter_column("users", "email");
319 rs.add_filter_column("users", "id");
320 rs.add_filter_column("projects", "owner_id");
321
322 let users = rs.filter_columns.get("users").unwrap();
323 assert_eq!(users.len(), 2);
324 assert!(users.contains("id"));
325 assert!(users.contains("email"));
326
327 let projects = rs.filter_columns.get("projects").unwrap();
328 assert_eq!(projects.len(), 1);
329 }
330
331 #[test]
332 fn memory_bytes_grows_with_content() {
333 let empty = ReadSet::new();
334 let baseline = empty.memory_bytes();
335 assert_eq!(baseline, 64);
337
338 let mut rs = ReadSet::new();
339 rs.add_table("users");
340 rs.add_filter_column("users", "email");
341 assert!(rs.memory_bytes() > baseline);
342 }
343
344 #[test]
345 fn table_level_constructor_sets_mode() {
346 let rs = ReadSet::table_level();
347 assert_eq!(rs.mode, TrackingMode::Table);
348 assert!(rs.tables.is_empty());
349 assert!(rs.filter_columns.is_empty());
350 }
351
352 #[test]
353 fn merge_dedups_tables_and_unions_filter_columns() {
354 let mut a = ReadSet::new();
355 a.add_table("users");
356 a.add_filter_column("users", "id");
357
358 let mut b = ReadSet::new();
359 b.add_table("users");
360 b.add_table("projects");
361 b.add_filter_column("users", "email");
362 b.add_filter_column("projects", "owner_id");
363
364 a.merge(&b);
365
366 assert_eq!(a.tables, vec!["users".to_string(), "projects".to_string()]);
367 let users = a.filter_columns.get("users").unwrap();
368 assert!(users.contains("id"));
369 assert!(users.contains("email"));
370 assert_eq!(users.len(), 2);
371 }
372
373 #[test]
374 fn change_operation_as_str_round_trips() {
375 for op in [
376 ChangeOperation::Insert,
377 ChangeOperation::Update,
378 ChangeOperation::Delete,
379 ] {
380 assert_eq!(op.as_str().parse::<ChangeOperation>(), Ok(op));
381 }
382 }
383
384 #[test]
385 fn change_operation_accepts_short_codes_and_lowercase() {
386 assert_eq!("i".parse::<ChangeOperation>(), Ok(ChangeOperation::Insert));
387 assert_eq!("U".parse::<ChangeOperation>(), Ok(ChangeOperation::Update));
388 assert_eq!(
389 "delete".parse::<ChangeOperation>(),
390 Ok(ChangeOperation::Delete)
391 );
392 }
393
394 #[test]
395 fn change_operation_parse_error_preserves_input() {
396 let err = "TRUNCATE".parse::<ChangeOperation>().unwrap_err();
397 assert_eq!(err, ParseChangeOperationError("TRUNCATE".to_string()));
398 assert_eq!(err.to_string(), "invalid change operation: TRUNCATE");
399 }
400
401 #[test]
402 fn change_builders_populate_optional_fields() {
403 let row = Uuid::new_v4();
404 let change = Change::new("users", ChangeOperation::Update)
405 .with_row_id(row)
406 .with_columns(vec!["email".to_string()]);
407
408 assert_eq!(change.row_id, Some(row));
409 assert_eq!(change.changed_columns, vec!["email".to_string()]);
410 assert_eq!(change.operation, ChangeOperation::Update);
411 }
412
413 #[test]
414 fn column_invalidation_is_conservative_when_change_lacks_columns() {
415 let change = Change::new("users", ChangeOperation::Update);
418 assert!(change.invalidates_columns(&["email"]));
419 }
420}