Skip to main content

forge_core/realtime/
readset.rs

1use std::collections::{HashMap, HashSet};
2use std::str::FromStr;
3
4use uuid::Uuid;
5
6/// Tracking mode for read sets.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8#[non_exhaustive]
9pub enum TrackingMode {
10    /// No tracking (disabled).
11    None,
12    /// Track only tables (coarse-grained).
13    #[default]
14    Table,
15}
16
17impl TrackingMode {
18    /// Convert to string.
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            Self::None => "none",
22            Self::Table => "table",
23        }
24    }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct ParseTrackingModeError(pub String);
29
30impl std::fmt::Display for ParseTrackingModeError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "invalid tracking mode: {}", self.0)
33    }
34}
35
36impl std::error::Error for ParseTrackingModeError {}
37
38impl FromStr for TrackingMode {
39    type Err = ParseTrackingModeError;
40
41    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
42        match s.to_lowercase().as_str() {
43            "none" => Ok(Self::None),
44            "table" => Ok(Self::Table),
45            _ => Err(ParseTrackingModeError(s.to_string())),
46        }
47    }
48}
49
50/// Read set tracking tables read during query execution.
51#[derive(Debug, Clone, Default)]
52pub struct ReadSet {
53    pub tables: Vec<String>,
54    pub filter_columns: HashMap<String, HashSet<String>>,
55    pub mode: TrackingMode,
56}
57
58impl ReadSet {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Create a read set with table-level tracking.
64    pub fn table_level() -> Self {
65        Self {
66            mode: TrackingMode::Table,
67            ..Default::default()
68        }
69    }
70
71    pub fn add_table(&mut self, table: impl Into<String>) {
72        let table = table.into();
73        if !self.tables.contains(&table) {
74            self.tables.push(table);
75        }
76    }
77
78    pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
79        self.filter_columns
80            .entry(table.into())
81            .or_default()
82            .insert(column.into());
83    }
84
85    pub fn includes_table(&self, table: &str) -> bool {
86        self.tables.iter().any(|t| t == table)
87    }
88
89    pub fn memory_bytes(&self) -> usize {
90        let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
91        let col_bytes = self
92            .filter_columns
93            .values()
94            .map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
95            .sum::<usize>();
96
97        table_bytes + col_bytes + 64
98    }
99
100    pub fn merge(&mut self, other: &ReadSet) {
101        for table in &other.tables {
102            if !self.tables.contains(table) {
103                self.tables.push(table.clone());
104            }
105        }
106
107        for (table, columns) in &other.filter_columns {
108            self.filter_columns
109                .entry(table.clone())
110                .or_default()
111                .extend(columns.iter().cloned());
112        }
113    }
114}
115
116/// Change operation type.
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118#[non_exhaustive]
119pub enum ChangeOperation {
120    Insert,
121    Update,
122    Delete,
123}
124
125impl ChangeOperation {
126    pub fn as_str(&self) -> &'static str {
127        match self {
128            Self::Insert => "INSERT",
129            Self::Update => "UPDATE",
130            Self::Delete => "DELETE",
131        }
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub struct ParseChangeOperationError(pub String);
137
138impl std::fmt::Display for ParseChangeOperationError {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        write!(f, "invalid change operation: {}", self.0)
141    }
142}
143
144impl std::error::Error for ParseChangeOperationError {}
145
146impl FromStr for ChangeOperation {
147    type Err = ParseChangeOperationError;
148
149    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
150        match s.to_uppercase().as_str() {
151            "INSERT" | "I" => Ok(Self::Insert),
152            "UPDATE" | "U" => Ok(Self::Update),
153            "DELETE" | "D" => Ok(Self::Delete),
154            _ => Err(ParseChangeOperationError(s.to_string())),
155        }
156    }
157}
158
159/// A database change event.
160#[derive(Debug, Clone)]
161pub struct Change {
162    pub table: String,
163    pub operation: ChangeOperation,
164    pub row_id: Option<Uuid>,
165    /// Columns that changed (for updates).
166    pub changed_columns: Vec<String>,
167}
168
169impl Change {
170    pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
171        Self {
172            table: table.into(),
173            operation,
174            row_id: None,
175            changed_columns: Vec::new(),
176        }
177    }
178
179    pub fn with_row_id(mut self, row_id: Uuid) -> Self {
180        self.row_id = Some(row_id);
181        self
182    }
183
184    pub fn with_columns(mut self, columns: Vec<String>) -> Self {
185        self.changed_columns = columns;
186        self
187    }
188
189    /// Check if this change should invalidate a read set, optionally filtering
190    /// by compile-time selected columns from the query.
191    pub fn invalidates(&self, read_set: &ReadSet) -> bool {
192        read_set.includes_table(&self.table)
193    }
194
195    /// Column-aware invalidation: returns false if the changed columns don't
196    /// overlap with the query's selected columns.
197    pub fn invalidates_columns(&self, selected_columns: &[&str]) -> bool {
198        // If we don't know changed columns or selected columns, be conservative
199        if self.changed_columns.is_empty() || selected_columns.is_empty() {
200            return true;
201        }
202
203        // Only for updates, since inserts/deletes affect row presence
204        if self.operation != ChangeOperation::Update {
205            return true;
206        }
207
208        self.changed_columns
209            .iter()
210            .any(|c| selected_columns.contains(&c.as_str()))
211    }
212}
213
214#[cfg(test)]
215#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_tracking_mode_conversion() {
221        assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
222        assert!("invalid".parse::<TrackingMode>().is_err());
223    }
224
225    #[test]
226    fn test_read_set_add_table() {
227        let mut read_set = ReadSet::new();
228        read_set.add_table("projects");
229
230        assert!(read_set.includes_table("projects"));
231        assert!(!read_set.includes_table("users"));
232    }
233
234    #[test]
235    fn test_change_invalidates_table_level() {
236        let mut read_set = ReadSet::table_level();
237        read_set.add_table("projects");
238
239        let change = Change::new("projects", ChangeOperation::Insert);
240        assert!(change.invalidates(&read_set));
241
242        let change = Change::new("users", ChangeOperation::Insert);
243        assert!(!change.invalidates(&read_set));
244    }
245
246    #[test]
247    fn test_column_invalidation() {
248        let change = Change::new("users", ChangeOperation::Update)
249            .with_columns(vec!["name".to_string(), "email".to_string()]);
250
251        assert!(change.invalidates_columns(&["name", "age"]));
252        assert!(!change.invalidates_columns(&["age", "phone"]));
253        assert!(change.invalidates_columns(&[]));
254    }
255
256    #[test]
257    fn test_column_invalidation_non_update() {
258        let change =
259            Change::new("users", ChangeOperation::Insert).with_columns(vec!["name".to_string()]);
260        assert!(change.invalidates_columns(&["age"]));
261    }
262
263    #[test]
264    fn test_read_set_merge() {
265        let mut read_set1 = ReadSet::new();
266        read_set1.add_table("projects");
267
268        let mut read_set2 = ReadSet::new();
269        read_set2.add_table("users");
270
271        read_set1.merge(&read_set2);
272
273        assert!(read_set1.includes_table("projects"));
274        assert!(read_set1.includes_table("users"));
275    }
276
277    #[test]
278    fn tracking_mode_default_is_table() {
279        // Reactivity must default to table-level tracking; flipping this silently
280        // would disable invalidation for handlers that don't opt in.
281        assert_eq!(TrackingMode::default(), TrackingMode::Table);
282    }
283
284    #[test]
285    fn tracking_mode_as_str_round_trips() {
286        for mode in [TrackingMode::None, TrackingMode::Table] {
287            assert_eq!(mode.as_str().parse::<TrackingMode>(), Ok(mode));
288        }
289    }
290
291    #[test]
292    fn tracking_mode_parse_is_case_insensitive() {
293        assert_eq!("NONE".parse::<TrackingMode>(), Ok(TrackingMode::None));
294        assert_eq!("Table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
295        assert_eq!("TaBlE".parse::<TrackingMode>(), Ok(TrackingMode::Table));
296    }
297
298    #[test]
299    fn tracking_mode_parse_error_preserves_original_input() {
300        let err = "row".parse::<TrackingMode>().unwrap_err();
301        assert_eq!(err, ParseTrackingModeError("row".to_string()));
302        assert_eq!(err.to_string(), "invalid tracking mode: row");
303    }
304
305    #[test]
306    fn add_table_is_idempotent() {
307        let mut rs = ReadSet::new();
308        rs.add_table("projects");
309        rs.add_table("projects");
310        rs.add_table("projects");
311        assert_eq!(rs.tables, vec!["projects".to_string()]);
312    }
313
314    #[test]
315    fn add_filter_column_accumulates_per_table() {
316        let mut rs = ReadSet::new();
317        rs.add_filter_column("users", "id");
318        rs.add_filter_column("users", "email");
319        rs.add_filter_column("users", "id");
320        rs.add_filter_column("projects", "owner_id");
321
322        let users = rs.filter_columns.get("users").unwrap();
323        assert_eq!(users.len(), 2);
324        assert!(users.contains("id"));
325        assert!(users.contains("email"));
326
327        let projects = rs.filter_columns.get("projects").unwrap();
328        assert_eq!(projects.len(), 1);
329    }
330
331    #[test]
332    fn memory_bytes_grows_with_content() {
333        let empty = ReadSet::new();
334        let baseline = empty.memory_bytes();
335        // Constant 64-byte struct overhead is included for the empty case.
336        assert_eq!(baseline, 64);
337
338        let mut rs = ReadSet::new();
339        rs.add_table("users");
340        rs.add_filter_column("users", "email");
341        assert!(rs.memory_bytes() > baseline);
342    }
343
344    #[test]
345    fn table_level_constructor_sets_mode() {
346        let rs = ReadSet::table_level();
347        assert_eq!(rs.mode, TrackingMode::Table);
348        assert!(rs.tables.is_empty());
349        assert!(rs.filter_columns.is_empty());
350    }
351
352    #[test]
353    fn merge_dedups_tables_and_unions_filter_columns() {
354        let mut a = ReadSet::new();
355        a.add_table("users");
356        a.add_filter_column("users", "id");
357
358        let mut b = ReadSet::new();
359        b.add_table("users");
360        b.add_table("projects");
361        b.add_filter_column("users", "email");
362        b.add_filter_column("projects", "owner_id");
363
364        a.merge(&b);
365
366        assert_eq!(a.tables, vec!["users".to_string(), "projects".to_string()]);
367        let users = a.filter_columns.get("users").unwrap();
368        assert!(users.contains("id"));
369        assert!(users.contains("email"));
370        assert_eq!(users.len(), 2);
371    }
372
373    #[test]
374    fn change_operation_as_str_round_trips() {
375        for op in [
376            ChangeOperation::Insert,
377            ChangeOperation::Update,
378            ChangeOperation::Delete,
379        ] {
380            assert_eq!(op.as_str().parse::<ChangeOperation>(), Ok(op));
381        }
382    }
383
384    #[test]
385    fn change_operation_accepts_short_codes_and_lowercase() {
386        assert_eq!("i".parse::<ChangeOperation>(), Ok(ChangeOperation::Insert));
387        assert_eq!("U".parse::<ChangeOperation>(), Ok(ChangeOperation::Update));
388        assert_eq!(
389            "delete".parse::<ChangeOperation>(),
390            Ok(ChangeOperation::Delete)
391        );
392    }
393
394    #[test]
395    fn change_operation_parse_error_preserves_input() {
396        let err = "TRUNCATE".parse::<ChangeOperation>().unwrap_err();
397        assert_eq!(err, ParseChangeOperationError("TRUNCATE".to_string()));
398        assert_eq!(err.to_string(), "invalid change operation: TRUNCATE");
399    }
400
401    #[test]
402    fn change_builders_populate_optional_fields() {
403        let row = Uuid::new_v4();
404        let change = Change::new("users", ChangeOperation::Update)
405            .with_row_id(row)
406            .with_columns(vec!["email".to_string()]);
407
408        assert_eq!(change.row_id, Some(row));
409        assert_eq!(change.changed_columns, vec!["email".to_string()]);
410        assert_eq!(change.operation, ChangeOperation::Update);
411    }
412
413    #[test]
414    fn column_invalidation_is_conservative_when_change_lacks_columns() {
415        // An update with unknown changed_columns must invalidate — we cannot
416        // prove the selected columns are untouched.
417        let change = Change::new("users", ChangeOperation::Update);
418        assert!(change.invalidates_columns(&["email"]));
419    }
420}