Skip to main content

sqlmodel_session/
change_tracker.rs

1//! Change tracking and dirty detection for SQLModel Session.
2//!
3//! This module provides snapshot-based change tracking to detect when objects
4//! have been modified since they were loaded from the database.
5
6use crate::ObjectKey;
7use serde::Serialize;
8use sqlmodel_core::Model;
9use std::collections::HashMap;
10use std::time::Instant;
11
12/// Snapshot of an object's state at a point in time.
13#[derive(Debug)]
14pub struct ObjectSnapshot {
15    /// Serialized original state (JSON bytes).
16    data: Vec<u8>,
17    /// Timestamp when snapshot was taken.
18    taken_at: Instant,
19}
20
21impl ObjectSnapshot {
22    /// Create a new snapshot from serialized data.
23    pub fn new(data: Vec<u8>) -> Self {
24        Self {
25            data,
26            taken_at: Instant::now(),
27        }
28    }
29
30    /// Get the snapshot data.
31    pub fn data(&self) -> &[u8] {
32        &self.data
33    }
34
35    /// Get the timestamp when the snapshot was taken.
36    pub fn taken_at(&self) -> Instant {
37        self.taken_at
38    }
39}
40
41/// Tracks changes to objects in the session.
42///
43/// Uses snapshot comparison to detect when objects have been modified.
44pub struct ChangeTracker {
45    /// Original snapshots by object key.
46    snapshots: HashMap<ObjectKey, ObjectSnapshot>,
47}
48
49impl ChangeTracker {
50    /// Create a new empty change tracker.
51    pub fn new() -> Self {
52        Self {
53            snapshots: HashMap::new(),
54        }
55    }
56
57    /// Take a snapshot of an object.
58    ///
59    /// This stores the serialized state of the object for later comparison.
60    #[tracing::instrument(level = "trace", skip(self, obj))]
61    pub fn snapshot<T: Model + Serialize>(&mut self, key: ObjectKey, obj: &T) {
62        let data = match serde_json::to_vec(obj) {
63            Ok(d) => d,
64            Err(e) => {
65                tracing::warn!(
66                    model = std::any::type_name::<T>(),
67                    error = %e,
68                    "Snapshot serialization failed, storing empty snapshot"
69                );
70                Vec::new()
71            }
72        };
73        tracing::trace!(
74            model = std::any::type_name::<T>(),
75            pk_hash = key.pk_hash(),
76            snapshot_bytes = data.len(),
77            "Taking object snapshot"
78        );
79        self.snapshots.insert(key, ObjectSnapshot::new(data));
80    }
81
82    /// Take a snapshot from raw bytes.
83    pub fn snapshot_raw(&mut self, key: ObjectKey, data: Vec<u8>) {
84        self.snapshots.insert(key, ObjectSnapshot::new(data));
85    }
86
87    /// Check if an object has changed since its snapshot.
88    ///
89    /// Returns `true` if:
90    /// - The object has no snapshot (treated as dirty)
91    /// - The current state differs from the snapshot
92    #[tracing::instrument(level = "trace", skip(self, obj))]
93    pub fn is_dirty<T: Model + Serialize>(&self, key: &ObjectKey, obj: &T) -> bool {
94        let Some(snapshot) = self.snapshots.get(key) else {
95            tracing::trace!(
96                pk_hash = key.pk_hash(),
97                dirty = true,
98                "No snapshot - treating as dirty"
99            );
100            return true;
101        };
102
103        let current = match serde_json::to_vec(obj) {
104            Ok(d) => d,
105            Err(e) => {
106                tracing::warn!(
107                    model = std::any::type_name::<T>(),
108                    error = %e,
109                    "Dirty check serialization failed, treating as dirty"
110                );
111                return true;
112            }
113        };
114        let dirty = current != snapshot.data;
115        tracing::trace!(pk_hash = key.pk_hash(), dirty = dirty, "Dirty check result");
116        dirty
117    }
118
119    /// Check if raw bytes match the snapshot.
120    pub fn is_dirty_raw(&self, key: &ObjectKey, current: &[u8]) -> bool {
121        let Some(snapshot) = self.snapshots.get(key) else {
122            return true;
123        };
124        current != snapshot.data
125    }
126
127    /// Get changed fields between snapshot and current state.
128    ///
129    /// Returns a list of field names that have different values.
130    #[tracing::instrument(level = "debug", skip(self, obj))]
131    pub fn changed_fields<T: Model + Serialize>(
132        &self,
133        key: &ObjectKey,
134        obj: &T,
135    ) -> Vec<&'static str> {
136        let Some(snapshot) = self.snapshots.get(key) else {
137            // No snapshot = all fields are "changed"
138            let fields: Vec<&'static str> = T::fields().iter().map(|f| f.name).collect();
139            tracing::debug!(
140                model = std::any::type_name::<T>(),
141                changed_count = fields.len(),
142                "No snapshot - all fields considered changed"
143            );
144            return fields;
145        };
146
147        // Parse both as JSON objects and compare fields
148        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
149            Ok(v) => v,
150            Err(e) => {
151                tracing::warn!(
152                    model = std::any::type_name::<T>(),
153                    error = %e,
154                    "Snapshot deserialization failed in changed_fields, treating all as changed"
155                );
156                serde_json::Value::Null
157            }
158        };
159        let current: serde_json::Value = match serde_json::to_value(obj) {
160            Ok(v) => v,
161            Err(e) => {
162                tracing::warn!(
163                    model = std::any::type_name::<T>(),
164                    error = %e,
165                    "Current serialization failed in changed_fields, treating all as changed"
166                );
167                serde_json::Value::Null
168            }
169        };
170
171        let mut changed = Vec::new();
172        for field in T::fields() {
173            let orig_val = original.get(field.name);
174            let curr_val = current.get(field.name);
175            if orig_val != curr_val {
176                changed.push(field.name);
177            }
178        }
179
180        tracing::debug!(
181            model = std::any::type_name::<T>(),
182            changed_count = changed.len(),
183            fields = ?changed,
184            "Detected changed fields"
185        );
186        changed
187    }
188
189    /// Get changed fields from raw JSON bytes.
190    pub fn changed_fields_raw(
191        &self,
192        key: &ObjectKey,
193        current_bytes: &[u8],
194        field_names: &[&'static str],
195    ) -> Vec<&'static str> {
196        let Some(snapshot) = self.snapshots.get(key) else {
197            return field_names.to_vec();
198        };
199
200        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
201            Ok(v) => v,
202            Err(e) => {
203                tracing::warn!(
204                    error = %e,
205                    "Snapshot deserialization failed in changed_fields_raw, treating all as changed"
206                );
207                serde_json::Value::Null
208            }
209        };
210        let current: serde_json::Value = match serde_json::from_slice(current_bytes) {
211            Ok(v) => v,
212            Err(e) => {
213                tracing::warn!(
214                    error = %e,
215                    "Current deserialization failed in changed_fields_raw, treating all as changed"
216                );
217                serde_json::Value::Null
218            }
219        };
220
221        let mut changed = Vec::new();
222        for name in field_names {
223            let orig_val = original.get(*name);
224            let curr_val = current.get(*name);
225            if orig_val != curr_val {
226                changed.push(*name);
227            }
228        }
229        changed
230    }
231
232    /// Get detailed attribute changes between snapshot and current state.
233    ///
234    /// Returns `AttributeChange` structs with field name, old value, and new value.
235    pub fn attribute_changes<T: Model + Serialize>(
236        &self,
237        key: &ObjectKey,
238        obj: &T,
239    ) -> Vec<sqlmodel_core::AttributeChange> {
240        let Some(snapshot) = self.snapshots.get(key) else {
241            return Vec::new();
242        };
243
244        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
245            Ok(v) => v,
246            Err(e) => {
247                tracing::warn!(
248                    model = std::any::type_name::<T>(),
249                    error = %e,
250                    "Snapshot deserialization failed in attribute_changes, treating as empty"
251                );
252                serde_json::Value::Null
253            }
254        };
255        let current: serde_json::Value = match serde_json::to_value(obj) {
256            Ok(v) => v,
257            Err(e) => {
258                tracing::warn!(
259                    model = std::any::type_name::<T>(),
260                    error = %e,
261                    "Current serialization failed in attribute_changes, treating as empty"
262                );
263                serde_json::Value::Null
264            }
265        };
266
267        let mut changes = Vec::new();
268        for field in T::fields() {
269            let orig_val = original
270                .get(field.name)
271                .cloned()
272                .unwrap_or(serde_json::Value::Null);
273            let curr_val = current
274                .get(field.name)
275                .cloned()
276                .unwrap_or(serde_json::Value::Null);
277            if orig_val != curr_val {
278                changes.push(sqlmodel_core::AttributeChange {
279                    field_name: field.name,
280                    old_value: orig_val,
281                    new_value: curr_val,
282                });
283            }
284        }
285        changes
286    }
287
288    /// Check if a snapshot exists for the given key.
289    pub fn has_snapshot(&self, key: &ObjectKey) -> bool {
290        self.snapshots.contains_key(key)
291    }
292
293    /// Get the snapshot for a key.
294    pub fn get_snapshot(&self, key: &ObjectKey) -> Option<&ObjectSnapshot> {
295        self.snapshots.get(key)
296    }
297
298    /// Clear snapshot for a specific object.
299    ///
300    /// Call this after commit or when discarding changes.
301    pub fn clear(&mut self, key: &ObjectKey) {
302        self.snapshots.remove(key);
303    }
304
305    /// Clear all snapshots.
306    ///
307    /// Call this after commit or rollback to reset tracking state.
308    pub fn clear_all(&mut self) {
309        self.snapshots.clear();
310    }
311
312    /// Update snapshot after flush (new baseline).
313    ///
314    /// Call this after a successful flush to set the current state as the new baseline.
315    #[tracing::instrument(level = "trace", skip(self, obj))]
316    pub fn refresh<T: Model + Serialize>(&mut self, key: ObjectKey, obj: &T) {
317        tracing::trace!(pk_hash = key.pk_hash(), "Refreshing snapshot");
318        self.snapshot(key, obj);
319    }
320
321    /// Number of tracked snapshots.
322    pub fn len(&self) -> usize {
323        self.snapshots.len()
324    }
325
326    /// Check if there are no snapshots.
327    pub fn is_empty(&self) -> bool {
328        self.snapshots.is_empty()
329    }
330}
331
332impl Default for ChangeTracker {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use serde::{Deserialize, Serialize};
342    use sqlmodel_core::{FieldInfo, Row, Value};
343
344    // Mock model for testing
345    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
346    struct TestHero {
347        id: i64,
348        name: String,
349        age: Option<i32>,
350    }
351
352    impl Model for TestHero {
353        const TABLE_NAME: &'static str = "hero";
354        const PRIMARY_KEY: &'static [&'static str] = &["id"];
355
356        fn fields() -> &'static [FieldInfo] {
357            static FIELDS: [FieldInfo; 3] = [
358                FieldInfo::new("id", "id", sqlmodel_core::SqlType::BigInt)
359                    .primary_key(true)
360                    .auto_increment(true),
361                FieldInfo::new("name", "name", sqlmodel_core::SqlType::Text),
362                FieldInfo::new("age", "age", sqlmodel_core::SqlType::Integer).nullable(true),
363            ];
364            &FIELDS
365        }
366
367        fn primary_key_value(&self) -> Vec<Value> {
368            vec![Value::BigInt(self.id)]
369        }
370
371        fn from_row(_row: &Row) -> Result<Self, sqlmodel_core::Error> {
372            unimplemented!("Not needed for these tests")
373        }
374
375        fn to_row(&self) -> Vec<(&'static str, Value)> {
376            vec![
377                ("id", Value::BigInt(self.id)),
378                ("name", Value::Text(self.name.clone())),
379                ("age", self.age.map_or(Value::Null, Value::Int)),
380            ]
381        }
382
383        fn is_new(&self) -> bool {
384            false
385        }
386    }
387
388    fn make_key(id: i64) -> ObjectKey {
389        ObjectKey::from_pk::<TestHero>(&[Value::BigInt(id)])
390    }
391
392    #[test]
393    fn test_snapshot_captures_current_state() {
394        let mut tracker = ChangeTracker::new();
395        let hero = TestHero {
396            id: 1,
397            name: "Spider-Man".to_string(),
398            age: Some(25),
399        };
400        let key = make_key(1);
401
402        tracker.snapshot(key, &hero);
403
404        assert!(tracker.has_snapshot(&key));
405        let snapshot = tracker.get_snapshot(&key).unwrap();
406        assert!(!snapshot.data().is_empty());
407    }
408
409    #[test]
410    fn test_snapshot_overwrites_previous() {
411        let mut tracker = ChangeTracker::new();
412        let key = make_key(1);
413
414        let hero1 = TestHero {
415            id: 1,
416            name: "Spider-Man".to_string(),
417            age: Some(25),
418        };
419        tracker.snapshot(key, &hero1);
420        let first_data = tracker.get_snapshot(&key).unwrap().data().to_vec();
421
422        let hero2 = TestHero {
423            id: 1,
424            name: "Peter Parker".to_string(),
425            age: Some(26),
426        };
427        tracker.snapshot(key, &hero2);
428        let second_data = tracker.get_snapshot(&key).unwrap().data().to_vec();
429
430        assert_ne!(first_data, second_data);
431    }
432
433    #[test]
434    fn test_is_dirty_false_if_unchanged() {
435        let mut tracker = ChangeTracker::new();
436        let hero = TestHero {
437            id: 1,
438            name: "Spider-Man".to_string(),
439            age: Some(25),
440        };
441        let key = make_key(1);
442
443        tracker.snapshot(key, &hero);
444
445        // Same object = not dirty
446        assert!(!tracker.is_dirty(&key, &hero));
447    }
448
449    #[test]
450    fn test_is_dirty_true_if_field_changed() {
451        let mut tracker = ChangeTracker::new();
452        let hero = TestHero {
453            id: 1,
454            name: "Spider-Man".to_string(),
455            age: Some(25),
456        };
457        let key = make_key(1);
458
459        tracker.snapshot(key, &hero);
460
461        // Modify the hero
462        let modified_hero = TestHero {
463            id: 1,
464            name: "Peter Parker".to_string(),
465            age: Some(25),
466        };
467
468        assert!(tracker.is_dirty(&key, &modified_hero));
469    }
470
471    #[test]
472    fn test_is_dirty_true_if_no_snapshot() {
473        let tracker = ChangeTracker::new();
474        let hero = TestHero {
475            id: 1,
476            name: "Spider-Man".to_string(),
477            age: Some(25),
478        };
479        let key = make_key(1);
480
481        // No snapshot = dirty
482        assert!(tracker.is_dirty(&key, &hero));
483    }
484
485    #[test]
486    fn test_changed_fields_empty_if_unchanged() {
487        let mut tracker = ChangeTracker::new();
488        let hero = TestHero {
489            id: 1,
490            name: "Spider-Man".to_string(),
491            age: Some(25),
492        };
493        let key = make_key(1);
494
495        tracker.snapshot(key, &hero);
496
497        let changed = tracker.changed_fields(&key, &hero);
498        assert!(changed.is_empty());
499    }
500
501    #[test]
502    fn test_changed_fields_lists_modified() {
503        let mut tracker = ChangeTracker::new();
504        let hero = TestHero {
505            id: 1,
506            name: "Spider-Man".to_string(),
507            age: Some(25),
508        };
509        let key = make_key(1);
510
511        tracker.snapshot(key, &hero);
512
513        let modified_hero = TestHero {
514            id: 1,
515            name: "Peter Parker".to_string(),
516            age: Some(25),
517        };
518
519        let changed = tracker.changed_fields(&key, &modified_hero);
520        assert_eq!(changed, vec!["name"]);
521    }
522
523    #[test]
524    fn test_changed_fields_multiple_changes() {
525        let mut tracker = ChangeTracker::new();
526        let hero = TestHero {
527            id: 1,
528            name: "Spider-Man".to_string(),
529            age: Some(25),
530        };
531        let key = make_key(1);
532
533        tracker.snapshot(key, &hero);
534
535        let modified_hero = TestHero {
536            id: 1,
537            name: "Peter Parker".to_string(),
538            age: Some(30),
539        };
540
541        let changed = tracker.changed_fields(&key, &modified_hero);
542        assert!(changed.contains(&"name"));
543        assert!(changed.contains(&"age"));
544        assert!(!changed.contains(&"id"));
545    }
546
547    #[test]
548    fn test_clear_removes_snapshot() {
549        let mut tracker = ChangeTracker::new();
550        let hero = TestHero {
551            id: 1,
552            name: "Spider-Man".to_string(),
553            age: Some(25),
554        };
555        let key = make_key(1);
556
557        tracker.snapshot(key, &hero);
558        assert!(tracker.has_snapshot(&key));
559
560        tracker.clear(&key);
561        assert!(!tracker.has_snapshot(&key));
562    }
563
564    #[test]
565    fn test_clear_all_removes_all() {
566        let mut tracker = ChangeTracker::new();
567
568        let hero1 = TestHero {
569            id: 1,
570            name: "Spider-Man".to_string(),
571            age: Some(25),
572        };
573        let hero2 = TestHero {
574            id: 2,
575            name: "Iron Man".to_string(),
576            age: Some(40),
577        };
578
579        tracker.snapshot(make_key(1), &hero1);
580        tracker.snapshot(make_key(2), &hero2);
581
582        assert_eq!(tracker.len(), 2);
583
584        tracker.clear_all();
585
586        assert!(tracker.is_empty());
587    }
588
589    #[test]
590    fn test_refresh_updates_baseline() {
591        let mut tracker = ChangeTracker::new();
592        let hero = TestHero {
593            id: 1,
594            name: "Spider-Man".to_string(),
595            age: Some(25),
596        };
597        let key = make_key(1);
598
599        tracker.snapshot(key, &hero);
600
601        let modified_hero = TestHero {
602            id: 1,
603            name: "Peter Parker".to_string(),
604            age: Some(25),
605        };
606
607        // Should be dirty before refresh
608        assert!(tracker.is_dirty(&key, &modified_hero));
609
610        // Refresh the baseline
611        tracker.refresh(key, &modified_hero);
612
613        // No longer dirty
614        assert!(!tracker.is_dirty(&key, &modified_hero));
615    }
616
617    #[test]
618    fn test_attribute_changes_empty_when_unchanged() {
619        let mut tracker = ChangeTracker::new();
620        let hero = TestHero {
621            id: 1,
622            name: "Spider-Man".to_string(),
623            age: Some(25),
624        };
625        let key = ObjectKey::from_model(&hero);
626        tracker.snapshot(key, &hero);
627
628        let changes = tracker.attribute_changes(&key, &hero);
629        assert!(changes.is_empty());
630    }
631
632    #[test]
633    fn test_attribute_changes_detects_field_change() {
634        let mut tracker = ChangeTracker::new();
635        let hero = TestHero {
636            id: 1,
637            name: "Spider-Man".to_string(),
638            age: Some(25),
639        };
640        let key = ObjectKey::from_model(&hero);
641        tracker.snapshot(key, &hero);
642
643        let modified = TestHero {
644            id: 1,
645            name: "Peter Parker".to_string(),
646            age: Some(26),
647        };
648
649        let changes = tracker.attribute_changes(&key, &modified);
650        assert_eq!(changes.len(), 2);
651        assert_eq!(changes[0].field_name, "name");
652        assert_eq!(changes[0].old_value, serde_json::json!("Spider-Man"));
653        assert_eq!(changes[0].new_value, serde_json::json!("Peter Parker"));
654        assert_eq!(changes[1].field_name, "age");
655        assert_eq!(changes[1].old_value, serde_json::json!(25));
656        assert_eq!(changes[1].new_value, serde_json::json!(26));
657    }
658
659    #[test]
660    fn test_attribute_changes_empty_without_snapshot() {
661        let tracker = ChangeTracker::new();
662        let hero = TestHero {
663            id: 1,
664            name: "Spider-Man".to_string(),
665            age: Some(25),
666        };
667        let key = ObjectKey::from_model(&hero);
668
669        // No snapshot → empty changes (not all fields)
670        let changes = tracker.attribute_changes(&key, &hero);
671        assert!(changes.is_empty());
672    }
673}