1use crate::error::Result;
4use crate::events::{ChangeEvent, EventBus};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum SyncMessage {
15 Subscribe { workspace_id: Uuid },
17 Unsubscribe { workspace_id: Uuid },
19 Change { event: ChangeEvent },
21 StateRequest { workspace_id: Uuid, version: i64 },
23 StateResponse {
25 workspace_id: Uuid,
26 version: i64,
27 state: serde_json::Value,
28 },
29 Ping,
31 Pong,
33 Error { message: String },
35}
36
37#[derive(Debug, Clone)]
39pub struct SyncState {
40 pub version: i64,
42 pub state: serde_json::Value,
44 pub last_updated: chrono::DateTime<chrono::Utc>,
46}
47
48impl SyncState {
49 pub fn new(version: i64, state: serde_json::Value) -> Self {
51 Self {
52 version,
53 state,
54 last_updated: chrono::Utc::now(),
55 }
56 }
57
58 pub fn update(&mut self, new_state: serde_json::Value) {
60 self.version += 1;
61 self.state = new_state;
62 self.last_updated = chrono::Utc::now();
63 }
64}
65
66pub struct SyncEngine {
68 event_bus: Arc<EventBus>,
70 states: DashMap<Uuid, SyncState>,
72 connections: DashMap<Uuid, Vec<Uuid>>,
74}
75
76impl SyncEngine {
77 pub fn new(event_bus: Arc<EventBus>) -> Self {
79 Self {
80 event_bus,
81 states: DashMap::new(),
82 connections: DashMap::new(),
83 }
84 }
85
86 pub fn subscribe(
88 &self,
89 workspace_id: Uuid,
90 client_id: Uuid,
91 ) -> Result<broadcast::Receiver<ChangeEvent>> {
92 self.connections.entry(workspace_id).or_insert_with(Vec::new).push(client_id);
94
95 Ok(self.event_bus.subscribe())
97 }
98
99 pub fn unsubscribe(&self, workspace_id: Uuid, client_id: Uuid) -> Result<()> {
101 if let Some(mut connections) = self.connections.get_mut(&workspace_id) {
102 connections.retain(|id| *id != client_id);
103 }
104 Ok(())
105 }
106
107 pub fn publish_change(&self, event: ChangeEvent) -> Result<()> {
109 self.event_bus.publish(event)
110 }
111
112 pub fn get_state(&self, workspace_id: Uuid) -> Option<SyncState> {
114 self.states.get(&workspace_id).map(|s| s.clone())
115 }
116
117 pub fn update_state(&self, workspace_id: Uuid, new_state: serde_json::Value) -> Result<()> {
119 if let Some(mut state) = self.states.get_mut(&workspace_id) {
120 state.update(new_state);
121 } else {
122 self.states.insert(workspace_id, SyncState::new(1, new_state));
123 }
124 Ok(())
125 }
126
127 pub fn get_connections(&self, workspace_id: Uuid) -> Vec<Uuid> {
129 self.connections.get(&workspace_id).map(|c| c.clone()).unwrap_or_default()
130 }
131
132 pub fn connection_count(&self) -> usize {
134 self.connections.iter().map(|c| c.value().len()).sum()
135 }
136
137 pub fn has_connections(&self, workspace_id: Uuid) -> bool {
139 self.connections.get(&workspace_id).map(|c| !c.is_empty()).unwrap_or(false)
140 }
141
142 pub fn cleanup_inactive(&self) {
144 let inactive: Vec<Uuid> = self
145 .connections
146 .iter()
147 .filter(|entry| entry.value().is_empty())
148 .map(|entry| *entry.key())
149 .collect();
150
151 for workspace_id in inactive {
152 self.connections.remove(&workspace_id);
153 }
154 }
155}
156
157pub mod crdt {
159 use serde::{Deserialize, Serialize};
160 use uuid::Uuid;
161
162 #[derive(Debug, Clone, Serialize, Deserialize)]
164 pub struct LwwRegister<T> {
165 pub value: T,
167 pub timestamp: u64,
169 pub client_id: Uuid,
171 }
172
173 impl<T> LwwRegister<T> {
174 pub fn new(value: T, timestamp: u64, client_id: Uuid) -> Self {
176 Self {
177 value,
178 timestamp,
179 client_id,
180 }
181 }
182
183 pub fn merge(&mut self, other: Self)
185 where
186 T: Clone,
187 {
188 if other.timestamp > self.timestamp
189 || (other.timestamp == self.timestamp && other.client_id > self.client_id)
190 {
191 self.value = other.value;
192 self.timestamp = other.timestamp;
193 self.client_id = other.client_id;
194 }
195 }
196 }
197
198 #[derive(Debug, Clone, Serialize, Deserialize)]
200 pub struct TextOperation {
201 pub position: usize,
203 pub op: TextOp,
205 pub timestamp: u64,
207 pub client_id: Uuid,
209 }
210
211 #[derive(Debug, Clone, Serialize, Deserialize)]
212 #[serde(tag = "type", rename_all = "lowercase")]
213 pub enum TextOp {
214 Insert { text: String },
216 Delete { length: usize },
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_sync_state() {
227 let mut state = SyncState::new(1, serde_json::json!({"key": "value"}));
228 assert_eq!(state.version, 1);
229
230 state.update(serde_json::json!({"key": "new_value"}));
231 assert_eq!(state.version, 2);
232 }
233
234 #[test]
235 fn test_sync_engine() {
236 let event_bus = Arc::new(EventBus::new(100));
237 let engine = SyncEngine::new(event_bus);
238
239 let workspace_id = Uuid::new_v4();
240 let client_id = Uuid::new_v4();
241
242 assert_eq!(engine.connection_count(), 0);
243
244 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
245 assert_eq!(engine.connection_count(), 1);
246 assert!(engine.has_connections(workspace_id));
247
248 engine.unsubscribe(workspace_id, client_id).unwrap();
249 assert_eq!(engine.connection_count(), 0);
250 }
251
252 #[test]
253 fn test_state_management() {
254 let event_bus = Arc::new(EventBus::new(100));
255 let engine = SyncEngine::new(event_bus);
256
257 let workspace_id = Uuid::new_v4();
258 let state = serde_json::json!({"mocks": []});
259
260 engine.update_state(workspace_id, state.clone()).unwrap();
261
262 let retrieved = engine.get_state(workspace_id).unwrap();
263 assert_eq!(retrieved.version, 1);
264 assert_eq!(retrieved.state, state);
265 }
266
267 #[test]
268 fn test_crdt_lww_register() {
269 use super::crdt::LwwRegister;
270
271 let client1 = Uuid::new_v4();
272 let client2 = Uuid::new_v4();
273
274 let mut reg1 = LwwRegister::new("value1", 1, client1);
275 let reg2 = LwwRegister::new("value2", 2, client2);
276
277 reg1.merge(reg2);
278 assert_eq!(reg1.value, "value2");
279 assert_eq!(reg1.timestamp, 2);
280 }
281
282 #[test]
283 fn test_cleanup_inactive() {
284 let event_bus = Arc::new(EventBus::new(100));
285 let engine = SyncEngine::new(event_bus);
286
287 let workspace_id = Uuid::new_v4();
288 let client_id = Uuid::new_v4();
289
290 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
291 assert_eq!(engine.connection_count(), 1);
292
293 engine.unsubscribe(workspace_id, client_id).unwrap();
294 assert_eq!(engine.connection_count(), 0);
295
296 engine.cleanup_inactive();
297 assert!(!engine.has_connections(workspace_id));
298 }
299}