Skip to main content

mnemo_core/sync/
engine.rs

1use std::sync::Arc;
2
3use chrono::Utc;
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::error::Result;
8use crate::storage::StorageBackend;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SyncResult {
12    pub pushed: usize,
13    pub pulled: usize,
14    pub conflicts: Vec<SyncConflict>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SyncConflict {
19    pub memory_id: Uuid,
20    pub local_updated_at: String,
21    pub remote_updated_at: String,
22}
23
24pub struct SyncEngine {
25    local: Arc<dyn StorageBackend>,
26    remote: Arc<dyn StorageBackend>,
27}
28
29impl SyncEngine {
30    pub fn new(local: Arc<dyn StorageBackend>, remote: Arc<dyn StorageBackend>) -> Self {
31        Self { local, remote }
32    }
33
34    /// Push local changes to remote. Returns number of memories pushed.
35    /// Uses watermark persistence to resume from last sync point.
36    pub async fn push(&self, since: &str) -> Result<usize> {
37        let watermark_key = "push_watermark";
38        let effective_since = self
39            .local
40            .get_sync_watermark(watermark_key)
41            .await?
42            .unwrap_or_else(|| since.to_string());
43        let local_memories = self
44            .local
45            .list_memories_since(&effective_since, crate::query::MAX_BATCH_QUERY_LIMIT)
46            .await?;
47        let mut pushed = 0;
48        for record in &local_memories {
49            self.remote.upsert_memory(record).await?;
50            pushed += 1;
51        }
52        if pushed > 0 {
53            let now = Utc::now().to_rfc3339();
54            self.local.set_sync_watermark(watermark_key, &now).await?;
55        }
56        Ok(pushed)
57    }
58
59    /// Pull remote changes to local. Returns number of memories pulled.
60    /// Uses watermark persistence to resume from last sync point.
61    pub async fn pull(&self, since: &str) -> Result<usize> {
62        let watermark_key = "pull_watermark";
63        let effective_since = self
64            .local
65            .get_sync_watermark(watermark_key)
66            .await?
67            .unwrap_or_else(|| since.to_string());
68        let remote_memories = self
69            .remote
70            .list_memories_since(&effective_since, crate::query::MAX_BATCH_QUERY_LIMIT)
71            .await?;
72        let mut pulled = 0;
73        for record in &remote_memories {
74            self.local.upsert_memory(record).await?;
75            pulled += 1;
76        }
77        if pulled > 0 {
78            let now = Utc::now().to_rfc3339();
79            self.local.set_sync_watermark(watermark_key, &now).await?;
80        }
81        Ok(pulled)
82    }
83
84    /// Full bidirectional sync. Pushes local changes, then pulls remote changes.
85    /// Detects conflicts where both sides have been modified since `since`.
86    pub async fn full_sync(&self, since: &str) -> Result<SyncResult> {
87        let local_memories = self
88            .local
89            .list_memories_since(since, crate::query::MAX_BATCH_QUERY_LIMIT)
90            .await?;
91        let remote_memories = self
92            .remote
93            .list_memories_since(since, crate::query::MAX_BATCH_QUERY_LIMIT)
94            .await?;
95
96        // Build a map of remote memory IDs → updated_at for conflict detection
97        let remote_map: std::collections::HashMap<Uuid, String> = remote_memories
98            .iter()
99            .map(|m| (m.id, m.updated_at.clone()))
100            .collect();
101
102        let mut conflicts = Vec::new();
103        let mut pushed = 0;
104
105        // Push local → remote, detecting conflicts
106        for record in &local_memories {
107            if let Some(remote_updated) = remote_map.get(&record.id) {
108                // Both sides modified — conflict (last-writer-wins: push local anyway)
109                if *remote_updated != record.updated_at {
110                    conflicts.push(SyncConflict {
111                        memory_id: record.id,
112                        local_updated_at: record.updated_at.clone(),
113                        remote_updated_at: remote_updated.clone(),
114                    });
115                }
116            }
117            self.remote.upsert_memory(record).await?;
118            pushed += 1;
119        }
120
121        // Pull remote → local (skip items we just pushed)
122        let local_ids: std::collections::HashSet<Uuid> =
123            local_memories.iter().map(|m| m.id).collect();
124        let mut pulled = 0;
125        for record in &remote_memories {
126            if !local_ids.contains(&record.id) {
127                self.local.upsert_memory(record).await?;
128                pulled += 1;
129            }
130        }
131
132        Ok(SyncResult {
133            pushed,
134            pulled,
135            conflicts,
136        })
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_sync_result_serde() {
146        let result = SyncResult {
147            pushed: 5,
148            pulled: 3,
149            conflicts: vec![SyncConflict {
150                memory_id: Uuid::now_v7(),
151                local_updated_at: "2025-01-01T00:00:00Z".to_string(),
152                remote_updated_at: "2025-01-01T01:00:00Z".to_string(),
153            }],
154        };
155        let json = serde_json::to_string(&result).unwrap();
156        let deserialized: SyncResult = serde_json::from_str(&json).unwrap();
157        assert_eq!(deserialized.pushed, 5);
158        assert_eq!(deserialized.pulled, 3);
159        assert_eq!(deserialized.conflicts.len(), 1);
160    }
161}