1use std::collections::{HashMap, HashSet};
2use std::str::FromStr;
3
4use uuid::Uuid;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum TrackingMode {
9 None,
11 Table,
13 Row,
15 #[default]
17 Adaptive,
18}
19
20impl TrackingMode {
21 pub fn as_str(&self) -> &'static str {
23 match self {
24 Self::None => "none",
25 Self::Table => "table",
26 Self::Row => "row",
27 Self::Adaptive => "adaptive",
28 }
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct ParseTrackingModeError(pub String);
34
35impl std::fmt::Display for ParseTrackingModeError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "invalid tracking mode: {}", self.0)
38 }
39}
40
41impl std::error::Error for ParseTrackingModeError {}
42
43impl FromStr for TrackingMode {
44 type Err = ParseTrackingModeError;
45
46 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
47 match s.to_lowercase().as_str() {
48 "none" => Ok(Self::None),
49 "table" => Ok(Self::Table),
50 "row" => Ok(Self::Row),
51 "adaptive" => Ok(Self::Adaptive),
52 _ => Err(ParseTrackingModeError(s.to_string())),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct ReadSet {
60 pub tables: HashSet<String>,
62 pub rows: HashMap<String, HashSet<Uuid>>,
64 pub filter_columns: HashMap<String, HashSet<String>>,
66 pub mode: TrackingMode,
68}
69
70impl ReadSet {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn table_level() -> Self {
78 Self {
79 mode: TrackingMode::Table,
80 ..Default::default()
81 }
82 }
83
84 pub fn row_level() -> Self {
86 Self {
87 mode: TrackingMode::Row,
88 ..Default::default()
89 }
90 }
91
92 pub fn add_table(&mut self, table: impl Into<String>) {
94 self.tables.insert(table.into());
95 }
96
97 pub fn add_row(&mut self, table: impl Into<String>, row_id: Uuid) {
99 let table = table.into();
100 self.tables.insert(table.clone());
101 self.rows.entry(table).or_default().insert(row_id);
102 }
103
104 pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
106 self.filter_columns
107 .entry(table.into())
108 .or_default()
109 .insert(column.into());
110 }
111
112 pub fn includes_table(&self, table: &str) -> bool {
114 self.tables.contains(table)
115 }
116
117 pub fn includes_row(&self, table: &str, row_id: Uuid) -> bool {
119 if !self.tables.contains(table) {
120 return false;
121 }
122
123 if self.mode == TrackingMode::Table {
125 return true;
126 }
127
128 if let Some(rows) = self.rows.get(table) {
130 rows.contains(&row_id)
131 } else {
132 true
134 }
135 }
136
137 pub fn memory_bytes(&self) -> usize {
139 let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
140 let row_bytes = self
141 .rows
142 .values()
143 .map(|set| set.len() * 16 + 24)
144 .sum::<usize>();
145 let filter_bytes = self
146 .filter_columns
147 .values()
148 .map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
149 .sum::<usize>();
150
151 table_bytes + row_bytes + filter_bytes + 64 }
153
154 pub fn row_count(&self) -> usize {
156 self.rows.values().map(|set| set.len()).sum()
157 }
158
159 pub fn merge(&mut self, other: &ReadSet) {
161 self.tables.extend(other.tables.iter().cloned());
162
163 for (table, rows) in &other.rows {
164 self.rows
165 .entry(table.clone())
166 .or_default()
167 .extend(rows.iter().cloned());
168 }
169
170 for (table, columns) in &other.filter_columns {
171 self.filter_columns
172 .entry(table.clone())
173 .or_default()
174 .extend(columns.iter().cloned());
175 }
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub enum ChangeOperation {
182 Insert,
184 Update,
186 Delete,
188}
189
190impl ChangeOperation {
191 pub fn as_str(&self) -> &'static str {
193 match self {
194 Self::Insert => "INSERT",
195 Self::Update => "UPDATE",
196 Self::Delete => "DELETE",
197 }
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq)]
202pub struct ParseChangeOperationError(pub String);
203
204impl std::fmt::Display for ParseChangeOperationError {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 write!(f, "invalid change operation: {}", self.0)
207 }
208}
209
210impl std::error::Error for ParseChangeOperationError {}
211
212impl FromStr for ChangeOperation {
213 type Err = ParseChangeOperationError;
214
215 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
216 match s.to_uppercase().as_str() {
217 "INSERT" | "I" => Ok(Self::Insert),
218 "UPDATE" | "U" => Ok(Self::Update),
219 "DELETE" | "D" => Ok(Self::Delete),
220 _ => Err(ParseChangeOperationError(s.to_string())),
221 }
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct Change {
228 pub table: String,
230 pub operation: ChangeOperation,
232 pub row_id: Option<Uuid>,
234 pub changed_columns: Vec<String>,
236}
237
238impl Change {
239 pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
241 Self {
242 table: table.into(),
243 operation,
244 row_id: None,
245 changed_columns: Vec::new(),
246 }
247 }
248
249 pub fn with_row_id(mut self, row_id: Uuid) -> Self {
251 self.row_id = Some(row_id);
252 self
253 }
254
255 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
257 self.changed_columns = columns;
258 self
259 }
260
261 pub fn invalidates(&self, read_set: &ReadSet) -> bool {
263 if !read_set.includes_table(&self.table) {
265 return false;
266 }
267
268 if read_set.mode == TrackingMode::Row
270 && let Some(row_id) = self.row_id
271 {
272 match self.operation {
273 ChangeOperation::Update | ChangeOperation::Delete => {
275 return read_set.includes_row(&self.table, row_id);
276 }
277 ChangeOperation::Insert => {}
279 }
280 }
281
282 true
284 }
285}
286
287#[cfg(test)]
288#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_tracking_mode_conversion() {
294 assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
295 assert_eq!("row".parse::<TrackingMode>(), Ok(TrackingMode::Row));
296 assert_eq!(
297 "adaptive".parse::<TrackingMode>(),
298 Ok(TrackingMode::Adaptive)
299 );
300 assert!("invalid".parse::<TrackingMode>().is_err());
301 }
302
303 #[test]
304 fn test_read_set_add_table() {
305 let mut read_set = ReadSet::new();
306 read_set.add_table("projects");
307
308 assert!(read_set.includes_table("projects"));
309 assert!(!read_set.includes_table("users"));
310 }
311
312 #[test]
313 fn test_read_set_add_row() {
314 let mut read_set = ReadSet::row_level();
315 let row_id = Uuid::new_v4();
316 read_set.add_row("projects", row_id);
317
318 assert!(read_set.includes_table("projects"));
319 assert!(read_set.includes_row("projects", row_id));
320 assert!(!read_set.includes_row("projects", Uuid::new_v4()));
321 }
322
323 #[test]
324 fn test_change_invalidates_table_level() {
325 let mut read_set = ReadSet::table_level();
326 read_set.add_table("projects");
327
328 let change = Change::new("projects", ChangeOperation::Insert);
329 assert!(change.invalidates(&read_set));
330
331 let change = Change::new("users", ChangeOperation::Insert);
332 assert!(!change.invalidates(&read_set));
333 }
334
335 #[test]
336 fn test_change_invalidates_row_level() {
337 let mut read_set = ReadSet::row_level();
338 let tracked_id = Uuid::new_v4();
339 let other_id = Uuid::new_v4();
340 read_set.add_row("projects", tracked_id);
341
342 let change = Change::new("projects", ChangeOperation::Update).with_row_id(tracked_id);
344 assert!(change.invalidates(&read_set));
345
346 let change = Change::new("projects", ChangeOperation::Update).with_row_id(other_id);
348 assert!(!change.invalidates(&read_set));
349
350 let change = Change::new("projects", ChangeOperation::Insert).with_row_id(other_id);
352 assert!(change.invalidates(&read_set));
353 }
354
355 #[test]
356 fn test_read_set_merge() {
357 let mut read_set1 = ReadSet::new();
358 read_set1.add_table("projects");
359
360 let mut read_set2 = ReadSet::new();
361 read_set2.add_table("users");
362
363 read_set1.merge(&read_set2);
364
365 assert!(read_set1.includes_table("projects"));
366 assert!(read_set1.includes_table("users"));
367 }
368}