1use crate::core_bridge::CoreBridge;
4use crate::error::{CollabError, Result};
5use crate::events::{ChangeEvent, EventBus};
6use crate::workspace::WorkspaceService;
7use chrono::Utc;
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use sqlx::{Pool, Sqlite};
12use std::sync::Arc;
13use tokio::sync::broadcast;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum SyncMessage {
20 Subscribe { workspace_id: Uuid },
22 Unsubscribe { workspace_id: Uuid },
24 Change { event: ChangeEvent },
26 StateRequest { workspace_id: Uuid, version: i64 },
28 StateResponse {
30 workspace_id: Uuid,
31 version: i64,
32 state: serde_json::Value,
33 },
34 Ping,
36 Pong,
38 Error { message: String },
40}
41
42#[derive(Debug, Clone)]
44pub struct SyncState {
45 pub version: i64,
47 pub state: serde_json::Value,
49 pub last_updated: chrono::DateTime<Utc>,
51}
52
53impl SyncState {
54 #[must_use]
56 pub fn new(version: i64, state: serde_json::Value) -> Self {
57 Self {
58 version,
59 state,
60 last_updated: Utc::now(),
61 }
62 }
63
64 pub fn update(&mut self, new_state: serde_json::Value) {
66 self.version += 1;
67 self.state = new_state;
68 self.last_updated = Utc::now();
69 }
70}
71
72pub struct SyncEngine {
74 event_bus: Arc<EventBus>,
76 states: DashMap<Uuid, SyncState>,
78 connections: DashMap<Uuid, Vec<Uuid>>,
80 db: Option<Pool<Sqlite>>,
82 core_bridge: Option<Arc<CoreBridge>>,
84 workspace_service: Option<Arc<WorkspaceService>>,
86}
87
88impl SyncEngine {
89 #[must_use]
91 pub fn new(event_bus: Arc<EventBus>) -> Self {
92 Self {
93 event_bus,
94 states: DashMap::new(),
95 connections: DashMap::new(),
96 db: None,
97 core_bridge: None,
98 workspace_service: None,
99 }
100 }
101
102 #[must_use]
104 pub fn with_db(event_bus: Arc<EventBus>, db: Pool<Sqlite>) -> Self {
105 Self {
106 event_bus,
107 states: DashMap::new(),
108 connections: DashMap::new(),
109 db: Some(db),
110 core_bridge: None,
111 workspace_service: None,
112 }
113 }
114
115 #[must_use]
117 pub fn with_integration(
118 event_bus: Arc<EventBus>,
119 db: Pool<Sqlite>,
120 core_bridge: Arc<CoreBridge>,
121 workspace_service: Arc<WorkspaceService>,
122 ) -> Self {
123 Self {
124 event_bus,
125 states: DashMap::new(),
126 connections: DashMap::new(),
127 db: Some(db),
128 core_bridge: Some(core_bridge),
129 workspace_service: Some(workspace_service),
130 }
131 }
132
133 pub fn subscribe(
135 &self,
136 workspace_id: Uuid,
137 client_id: Uuid,
138 ) -> Result<broadcast::Receiver<ChangeEvent>> {
139 self.connections.entry(workspace_id).or_default().push(client_id);
141
142 Ok(self.event_bus.subscribe())
144 }
145
146 pub fn unsubscribe(&self, workspace_id: Uuid, client_id: Uuid) -> Result<()> {
148 if let Some(mut connections) = self.connections.get_mut(&workspace_id) {
149 connections.retain(|id| *id != client_id);
150 }
151 Ok(())
152 }
153
154 pub fn publish_change(&self, event: ChangeEvent) -> Result<()> {
156 self.event_bus.publish(event)
157 }
158
159 #[must_use]
161 pub fn get_state(&self, workspace_id: Uuid) -> Option<SyncState> {
162 self.states.get(&workspace_id).map(|s| s.clone())
163 }
164
165 pub fn update_state(&self, workspace_id: Uuid, new_state: serde_json::Value) -> Result<()> {
167 let version = if let Some(state) = self.states.get(&workspace_id) {
168 state.version + 1
169 } else {
170 1
171 };
172
173 if let Some(mut state) = self.states.get_mut(&workspace_id) {
174 state.update(new_state.clone());
175 } else {
176 self.states.insert(workspace_id, SyncState::new(version, new_state.clone()));
177 }
178
179 if let (Some(core_bridge), Some(workspace_service)) =
181 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
182 {
183 let core_bridge = core_bridge.clone();
184 let workspace_service = workspace_service.clone();
185 let workspace_id = workspace_id;
186 let state_data = new_state.clone();
187 tokio::spawn(async move {
188 if let Ok(mut team_workspace) = workspace_service.get_workspace(workspace_id).await
189 {
190 if let Err(e) = core_bridge
191 .update_workspace_state_from_json(&mut team_workspace, &state_data)
192 {
193 tracing::error!("Failed to update workspace state: {}", e);
194 } else {
195 if let Err(e) = core_bridge.save_workspace_to_disk(&team_workspace).await {
199 tracing::error!("Failed to save workspace to disk: {}", e);
200 }
201 }
202 }
203 });
204 }
205
206 if let Some(db) = &self.db {
208 let db = db.clone();
210 let workspace_id = workspace_id;
211 let state_data = new_state;
212 tokio::spawn(async move {
213 if let Err(e) =
214 Self::save_state_snapshot(&db, workspace_id, version, &state_data).await
215 {
216 tracing::error!("Failed to save state snapshot: {}", e);
217 }
218 });
219 }
220
221 Ok(())
222 }
223
224 pub async fn get_full_workspace_state(
228 &self,
229 workspace_id: Uuid,
230 ) -> Result<Option<serde_json::Value>> {
231 if let (Some(core_bridge), Some(workspace_service)) =
232 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
233 {
234 let team_workspace = workspace_service.get_workspace(workspace_id).await?;
236
237 let state_json = core_bridge.get_workspace_state_json(&team_workspace)?;
239 Ok(Some(state_json))
240 } else {
241 Ok(self.get_state(workspace_id).map(|s| s.state))
243 }
244 }
245
246 async fn save_state_snapshot(
248 db: &Pool<Sqlite>,
249 workspace_id: Uuid,
250 version: i64,
251 state: &serde_json::Value,
252 ) -> Result<()> {
253 let state_json = serde_json::to_string(state)?;
255 let mut hasher = Sha256::new();
256 hasher.update(state_json.as_bytes());
257 let state_hash = format!("{:x}", hasher.finalize());
258
259 let existing = sqlx::query!(
261 r#"
262 SELECT id FROM workspace_state_snapshots
263 WHERE workspace_id = ? AND state_hash = ?
264 "#,
265 workspace_id,
266 state_hash
267 )
268 .fetch_optional(db)
269 .await?;
270
271 if existing.is_some() {
272 return Ok(());
274 }
275
276 let snapshot_id = Uuid::new_v4();
278 let snapshot_id_str = snapshot_id.to_string();
279 let workspace_id_str = workspace_id.to_string();
280 let now = Utc::now().to_rfc3339();
281 sqlx::query!(
282 r#"
283 INSERT INTO workspace_state_snapshots (id, workspace_id, state_hash, state_data, version, created_at)
284 VALUES (?, ?, ?, ?, ?, ?)
285 "#,
286 snapshot_id_str,
287 workspace_id_str,
288 state_hash,
289 state_json,
290 version,
291 now
292 )
293 .execute(db)
294 .await?;
295
296 Ok(())
297 }
298
299 pub async fn load_state_snapshot(
301 &self,
302 workspace_id: Uuid,
303 version: Option<i64>,
304 ) -> Result<Option<SyncState>> {
305 let db = self.db.as_ref().ok_or_else(|| {
306 CollabError::Internal("Database not available for state snapshots".to_string())
307 })?;
308
309 let workspace_id_str = workspace_id.to_string();
310 let snapshot: Option<(String, i64, String)> = if let Some(version) = version {
312 sqlx::query_as(
313 r"
314 SELECT state_data, version, created_at
315 FROM workspace_state_snapshots
316 WHERE workspace_id = ? AND version = ?
317 ORDER BY created_at DESC
318 LIMIT 1
319 ",
320 )
321 .bind(&workspace_id_str)
322 .bind(version)
323 .fetch_optional(db)
324 .await?
325 } else {
326 sqlx::query_as(
327 r"
328 SELECT state_data, version, created_at
329 FROM workspace_state_snapshots
330 WHERE workspace_id = ?
331 ORDER BY version DESC, created_at DESC
332 LIMIT 1
333 ",
334 )
335 .bind(&workspace_id_str)
336 .fetch_optional(db)
337 .await?
338 };
339
340 if let Some((state_data, snap_version, created_at_str)) = snapshot {
341 let state: serde_json::Value = serde_json::from_str(&state_data)
342 .map_err(|e| CollabError::Internal(format!("Failed to parse state: {e}")))?;
343 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
346 .map(|dt| dt.with_timezone(&Utc))
347 .or_else(|_| {
348 chrono::NaiveDateTime::parse_from_str(&created_at_str, "%Y-%m-%d %H:%M:%S%.f")
350 .or_else(|_| {
351 chrono::NaiveDateTime::parse_from_str(
352 &created_at_str,
353 "%Y-%m-%d %H:%M:%S",
354 )
355 })
356 .map(|dt| dt.and_utc())
357 })
358 .map_err(|e| {
359 CollabError::Internal(format!(
360 "Failed to parse timestamp '{created_at_str}': {e}"
361 ))
362 })?;
363
364 Ok(Some(SyncState {
365 version: snap_version,
366 state,
367 last_updated: created_at,
368 }))
369 } else {
370 Ok(None)
371 }
372 }
373
374 pub async fn record_state_change(
376 &self,
377 workspace_id: Uuid,
378 change_type: &str,
379 change_data: serde_json::Value,
380 version: i64,
381 user_id: Uuid,
382 ) -> Result<()> {
383 let db = self.db.as_ref().ok_or_else(|| {
384 CollabError::Internal("Database not available for state changes".to_string())
385 })?;
386
387 let change_id = Uuid::new_v4();
388 let change_data_str = serde_json::to_string(&change_data)?;
389 let now = Utc::now().to_rfc3339();
390 sqlx::query!(
391 r#"
392 INSERT INTO workspace_state_changes (id, workspace_id, change_type, change_data, version, created_at, created_by)
393 VALUES (?, ?, ?, ?, ?, ?, ?)
394 "#,
395 change_id,
396 workspace_id,
397 change_type,
398 change_data_str,
399 version,
400 now,
401 user_id
402 )
403 .execute(db)
404 .await?;
405
406 Ok(())
407 }
408
409 pub async fn get_state_changes_since(
411 &self,
412 workspace_id: Uuid,
413 since_version: i64,
414 ) -> Result<Vec<serde_json::Value>> {
415 let db = self.db.as_ref().ok_or_else(|| {
416 CollabError::Internal("Database not available for state changes".to_string())
417 })?;
418
419 let changes = sqlx::query!(
420 r#"
421 SELECT change_data
422 FROM workspace_state_changes
423 WHERE workspace_id = ? AND version > ?
424 ORDER BY version ASC
425 "#,
426 workspace_id,
427 since_version
428 )
429 .fetch_all(db)
430 .await?;
431
432 let mut result = Vec::new();
433 for change in changes {
434 let data: serde_json::Value = serde_json::from_str(&change.change_data)
435 .map_err(|e| CollabError::Internal(format!("Failed to parse change data: {e}")))?;
436 result.push(data);
437 }
438
439 Ok(result)
440 }
441
442 #[must_use]
444 pub fn get_connections(&self, workspace_id: Uuid) -> Vec<Uuid> {
445 self.connections.get(&workspace_id).map(|c| c.clone()).unwrap_or_default()
446 }
447
448 #[must_use]
450 pub fn connection_count(&self) -> usize {
451 self.connections.iter().map(|c| c.value().len()).sum()
452 }
453
454 #[must_use]
456 pub fn has_connections(&self, workspace_id: Uuid) -> bool {
457 self.connections.get(&workspace_id).is_some_and(|c| !c.is_empty())
458 }
459
460 pub fn cleanup_inactive(&self) {
462 let inactive: Vec<Uuid> = self
463 .connections
464 .iter()
465 .filter(|entry| entry.value().is_empty())
466 .map(|entry| *entry.key())
467 .collect();
468
469 for workspace_id in inactive {
470 self.connections.remove(&workspace_id);
471 }
472 }
473}
474
475pub mod crdt {
477 use serde::{Deserialize, Serialize};
478 use uuid::Uuid;
479
480 #[derive(Debug, Clone, Serialize, Deserialize)]
482 pub struct LwwRegister<T> {
483 pub value: T,
485 pub timestamp: u64,
487 pub client_id: Uuid,
489 }
490
491 impl<T> LwwRegister<T> {
492 pub const fn new(value: T, timestamp: u64, client_id: Uuid) -> Self {
494 Self {
495 value,
496 timestamp,
497 client_id,
498 }
499 }
500
501 pub fn merge(&mut self, other: Self)
503 where
504 T: Clone,
505 {
506 if other.timestamp > self.timestamp
507 || (other.timestamp == self.timestamp && other.client_id > self.client_id)
508 {
509 self.value = other.value;
510 self.timestamp = other.timestamp;
511 self.client_id = other.client_id;
512 }
513 }
514 }
515
516 #[derive(Debug, Clone, Serialize, Deserialize)]
518 pub struct TextOperation {
519 pub position: usize,
521 pub op: TextOp,
523 pub timestamp: u64,
525 pub client_id: Uuid,
527 }
528
529 #[derive(Debug, Clone, Serialize, Deserialize)]
530 #[serde(tag = "type", rename_all = "lowercase")]
531 pub enum TextOp {
532 Insert { text: String },
534 Delete { length: usize },
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_sync_state() {
545 let mut state = SyncState::new(1, serde_json::json!({"key": "value"}));
546 assert_eq!(state.version, 1);
547
548 state.update(serde_json::json!({"key": "new_value"}));
549 assert_eq!(state.version, 2);
550 }
551
552 #[test]
553 fn test_sync_engine() {
554 let event_bus = Arc::new(EventBus::new(100));
555 let engine = SyncEngine::new(event_bus);
556
557 let workspace_id = Uuid::new_v4();
558 let client_id = Uuid::new_v4();
559
560 assert_eq!(engine.connection_count(), 0);
561
562 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
563 assert_eq!(engine.connection_count(), 1);
564 assert!(engine.has_connections(workspace_id));
565
566 engine.unsubscribe(workspace_id, client_id).unwrap();
567 assert_eq!(engine.connection_count(), 0);
568 }
569
570 #[test]
571 fn test_state_management() {
572 let event_bus = Arc::new(EventBus::new(100));
573 let engine = SyncEngine::new(event_bus);
574
575 let workspace_id = Uuid::new_v4();
576 let state = serde_json::json!({"mocks": []});
577
578 engine.update_state(workspace_id, state.clone()).unwrap();
579
580 let retrieved = engine.get_state(workspace_id).unwrap();
581 assert_eq!(retrieved.version, 1);
582 assert_eq!(retrieved.state, state);
583 }
584
585 #[test]
586 fn test_crdt_lww_register() {
587 use super::crdt::LwwRegister;
588
589 let client1 = Uuid::new_v4();
590 let client2 = Uuid::new_v4();
591
592 let mut reg1 = LwwRegister::new("value1", 1, client1);
593 let reg2 = LwwRegister::new("value2", 2, client2);
594
595 reg1.merge(reg2);
596 assert_eq!(reg1.value, "value2");
597 assert_eq!(reg1.timestamp, 2);
598 }
599
600 #[test]
601 fn test_cleanup_inactive() {
602 let event_bus = Arc::new(EventBus::new(100));
603 let engine = SyncEngine::new(event_bus);
604
605 let workspace_id = Uuid::new_v4();
606 let client_id = Uuid::new_v4();
607
608 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
609 assert_eq!(engine.connection_count(), 1);
610
611 engine.unsubscribe(workspace_id, client_id).unwrap();
612 assert_eq!(engine.connection_count(), 0);
613
614 engine.cleanup_inactive();
615 assert!(!engine.has_connections(workspace_id));
616 }
617}