1use std::collections::{HashMap, HashSet};
2use std::str::FromStr;
3
4use uuid::Uuid;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum TrackingMode {
9 None,
11 Table,
13 Row,
15 #[default]
17 Adaptive,
18}
19
20impl TrackingMode {
21 pub fn as_str(&self) -> &'static str {
23 match self {
24 Self::None => "none",
25 Self::Table => "table",
26 Self::Row => "row",
27 Self::Adaptive => "adaptive",
28 }
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ParseTrackingModeError(pub String);
35
36impl std::fmt::Display for ParseTrackingModeError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "invalid tracking mode: {}", self.0)
39 }
40}
41
42impl std::error::Error for ParseTrackingModeError {}
43
44impl FromStr for TrackingMode {
45 type Err = ParseTrackingModeError;
46
47 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
48 match s.to_lowercase().as_str() {
49 "none" => Ok(Self::None),
50 "table" => Ok(Self::Table),
51 "row" => Ok(Self::Row),
52 "adaptive" => Ok(Self::Adaptive),
53 _ => Err(ParseTrackingModeError(s.to_string())),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default)]
60pub struct ReadSet {
61 pub tables: HashSet<String>,
63 pub rows: HashMap<String, HashSet<Uuid>>,
65 pub filter_columns: HashMap<String, HashSet<String>>,
67 pub mode: TrackingMode,
69}
70
71impl ReadSet {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn table_level() -> Self {
79 Self {
80 mode: TrackingMode::Table,
81 ..Default::default()
82 }
83 }
84
85 pub fn row_level() -> Self {
87 Self {
88 mode: TrackingMode::Row,
89 ..Default::default()
90 }
91 }
92
93 pub fn add_table(&mut self, table: impl Into<String>) {
95 self.tables.insert(table.into());
96 }
97
98 pub fn add_row(&mut self, table: impl Into<String>, row_id: Uuid) {
100 let table = table.into();
101 self.tables.insert(table.clone());
102 self.rows.entry(table).or_default().insert(row_id);
103 }
104
105 pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
107 self.filter_columns
108 .entry(table.into())
109 .or_default()
110 .insert(column.into());
111 }
112
113 pub fn includes_table(&self, table: &str) -> bool {
115 self.tables.contains(table)
116 }
117
118 pub fn includes_row(&self, table: &str, row_id: Uuid) -> bool {
120 if !self.tables.contains(table) {
121 return false;
122 }
123
124 if self.mode == TrackingMode::Table {
126 return true;
127 }
128
129 if let Some(rows) = self.rows.get(table) {
131 rows.contains(&row_id)
132 } else {
133 true
135 }
136 }
137
138 pub fn memory_bytes(&self) -> usize {
140 let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
141 let row_bytes = self
142 .rows
143 .values()
144 .map(|set| set.len() * 16 + 24)
145 .sum::<usize>();
146 let filter_bytes = self
147 .filter_columns
148 .values()
149 .map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
150 .sum::<usize>();
151
152 table_bytes + row_bytes + filter_bytes + 64 }
154
155 pub fn row_count(&self) -> usize {
157 self.rows.values().map(|set| set.len()).sum()
158 }
159
160 pub fn merge(&mut self, other: &ReadSet) {
162 self.tables.extend(other.tables.iter().cloned());
163
164 for (table, rows) in &other.rows {
165 self.rows
166 .entry(table.clone())
167 .or_default()
168 .extend(rows.iter().cloned());
169 }
170
171 for (table, columns) in &other.filter_columns {
172 self.filter_columns
173 .entry(table.clone())
174 .or_default()
175 .extend(columns.iter().cloned());
176 }
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum ChangeOperation {
183 Insert,
185 Update,
187 Delete,
189}
190
191impl ChangeOperation {
192 pub fn as_str(&self) -> &'static str {
194 match self {
195 Self::Insert => "INSERT",
196 Self::Update => "UPDATE",
197 Self::Delete => "DELETE",
198 }
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct ParseChangeOperationError(pub String);
205
206impl std::fmt::Display for ParseChangeOperationError {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(f, "invalid change operation: {}", self.0)
209 }
210}
211
212impl std::error::Error for ParseChangeOperationError {}
213
214impl FromStr for ChangeOperation {
215 type Err = ParseChangeOperationError;
216
217 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
218 match s.to_uppercase().as_str() {
219 "INSERT" | "I" => Ok(Self::Insert),
220 "UPDATE" | "U" => Ok(Self::Update),
221 "DELETE" | "D" => Ok(Self::Delete),
222 _ => Err(ParseChangeOperationError(s.to_string())),
223 }
224 }
225}
226
227#[derive(Debug, Clone)]
229pub struct Change {
230 pub table: String,
232 pub operation: ChangeOperation,
234 pub row_id: Option<Uuid>,
236 pub changed_columns: Vec<String>,
238}
239
240impl Change {
241 pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
243 Self {
244 table: table.into(),
245 operation,
246 row_id: None,
247 changed_columns: Vec::new(),
248 }
249 }
250
251 pub fn with_row_id(mut self, row_id: Uuid) -> Self {
253 self.row_id = Some(row_id);
254 self
255 }
256
257 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
259 self.changed_columns = columns;
260 self
261 }
262
263 pub fn invalidates(&self, read_set: &ReadSet) -> bool {
265 if !read_set.includes_table(&self.table) {
267 return false;
268 }
269
270 if read_set.mode == TrackingMode::Row {
272 if let Some(row_id) = self.row_id {
273 match self.operation {
274 ChangeOperation::Update | ChangeOperation::Delete => {
276 return read_set.includes_row(&self.table, row_id);
277 }
278 ChangeOperation::Insert => {}
280 }
281 }
282 }
283
284 true
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_tracking_mode_conversion() {
295 assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
296 assert_eq!("row".parse::<TrackingMode>(), Ok(TrackingMode::Row));
297 assert_eq!(
298 "adaptive".parse::<TrackingMode>(),
299 Ok(TrackingMode::Adaptive)
300 );
301 assert!("invalid".parse::<TrackingMode>().is_err());
302 }
303
304 #[test]
305 fn test_read_set_add_table() {
306 let mut read_set = ReadSet::new();
307 read_set.add_table("projects");
308
309 assert!(read_set.includes_table("projects"));
310 assert!(!read_set.includes_table("users"));
311 }
312
313 #[test]
314 fn test_read_set_add_row() {
315 let mut read_set = ReadSet::row_level();
316 let row_id = Uuid::new_v4();
317 read_set.add_row("projects", row_id);
318
319 assert!(read_set.includes_table("projects"));
320 assert!(read_set.includes_row("projects", row_id));
321 assert!(!read_set.includes_row("projects", Uuid::new_v4()));
322 }
323
324 #[test]
325 fn test_change_invalidates_table_level() {
326 let mut read_set = ReadSet::table_level();
327 read_set.add_table("projects");
328
329 let change = Change::new("projects", ChangeOperation::Insert);
330 assert!(change.invalidates(&read_set));
331
332 let change = Change::new("users", ChangeOperation::Insert);
333 assert!(!change.invalidates(&read_set));
334 }
335
336 #[test]
337 fn test_change_invalidates_row_level() {
338 let mut read_set = ReadSet::row_level();
339 let tracked_id = Uuid::new_v4();
340 let other_id = Uuid::new_v4();
341 read_set.add_row("projects", tracked_id);
342
343 let change = Change::new("projects", ChangeOperation::Update).with_row_id(tracked_id);
345 assert!(change.invalidates(&read_set));
346
347 let change = Change::new("projects", ChangeOperation::Update).with_row_id(other_id);
349 assert!(!change.invalidates(&read_set));
350
351 let change = Change::new("projects", ChangeOperation::Insert).with_row_id(other_id);
353 assert!(change.invalidates(&read_set));
354 }
355
356 #[test]
357 fn test_read_set_merge() {
358 let mut read_set1 = ReadSet::new();
359 read_set1.add_table("projects");
360
361 let mut read_set2 = ReadSet::new();
362 read_set2.add_table("users");
363
364 read_set1.merge(&read_set2);
365
366 assert!(read_set1.includes_table("projects"));
367 assert!(read_set1.includes_table("users"));
368 }
369}