1use crate::core_bridge::CoreBridge;
4use crate::error::{CollabError, Result};
5use crate::events::{ChangeEvent, EventBus};
6use crate::workspace::WorkspaceService;
7use chrono::Utc;
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use sqlx::{Pool, Sqlite};
12use std::sync::Arc;
13use tokio::sync::broadcast;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum SyncMessage {
20 Subscribe {
22 workspace_id: Uuid,
24 },
25 Unsubscribe {
27 workspace_id: Uuid,
29 },
30 Change {
32 event: ChangeEvent,
34 },
35 StateRequest {
37 workspace_id: Uuid,
39 version: i64,
41 },
42 StateResponse {
44 workspace_id: Uuid,
46 version: i64,
48 state: serde_json::Value,
50 },
51 Ping,
53 Pong,
55 Error {
57 message: String,
59 },
60}
61
62#[derive(Debug, Clone)]
64pub struct SyncState {
65 pub version: i64,
67 pub state: serde_json::Value,
69 pub last_updated: chrono::DateTime<Utc>,
71}
72
73impl SyncState {
74 #[must_use]
76 pub fn new(version: i64, state: serde_json::Value) -> Self {
77 Self {
78 version,
79 state,
80 last_updated: Utc::now(),
81 }
82 }
83
84 pub fn update(&mut self, new_state: serde_json::Value) {
86 self.version += 1;
87 self.state = new_state;
88 self.last_updated = Utc::now();
89 }
90}
91
92pub struct SyncEngine {
94 event_bus: Arc<EventBus>,
96 states: DashMap<Uuid, SyncState>,
98 connections: DashMap<Uuid, Vec<Uuid>>,
100 db: Option<Pool<Sqlite>>,
102 core_bridge: Option<Arc<CoreBridge>>,
104 workspace_service: Option<Arc<WorkspaceService>>,
106}
107
108impl SyncEngine {
109 #[must_use]
111 pub fn new(event_bus: Arc<EventBus>) -> Self {
112 Self {
113 event_bus,
114 states: DashMap::new(),
115 connections: DashMap::new(),
116 db: None,
117 core_bridge: None,
118 workspace_service: None,
119 }
120 }
121
122 #[must_use]
124 pub fn with_db(event_bus: Arc<EventBus>, db: Pool<Sqlite>) -> Self {
125 Self {
126 event_bus,
127 states: DashMap::new(),
128 connections: DashMap::new(),
129 db: Some(db),
130 core_bridge: None,
131 workspace_service: None,
132 }
133 }
134
135 #[must_use]
137 pub fn with_integration(
138 event_bus: Arc<EventBus>,
139 db: Pool<Sqlite>,
140 core_bridge: Arc<CoreBridge>,
141 workspace_service: Arc<WorkspaceService>,
142 ) -> Self {
143 Self {
144 event_bus,
145 states: DashMap::new(),
146 connections: DashMap::new(),
147 db: Some(db),
148 core_bridge: Some(core_bridge),
149 workspace_service: Some(workspace_service),
150 }
151 }
152
153 pub fn subscribe(
159 &self,
160 workspace_id: Uuid,
161 client_id: Uuid,
162 ) -> Result<broadcast::Receiver<ChangeEvent>> {
163 self.connections.entry(workspace_id).or_default().push(client_id);
165
166 Ok(self.event_bus.subscribe())
168 }
169
170 pub fn unsubscribe(&self, workspace_id: Uuid, client_id: Uuid) -> Result<()> {
176 if let Some(mut connections) = self.connections.get_mut(&workspace_id) {
177 connections.retain(|id| *id != client_id);
178 }
179 Ok(())
180 }
181
182 pub fn publish_change(&self, event: ChangeEvent) -> Result<()> {
188 self.event_bus.publish(event)
189 }
190
191 #[must_use]
193 pub fn get_state(&self, workspace_id: Uuid) -> Option<SyncState> {
194 self.states.get(&workspace_id).map(|s| s.clone())
195 }
196
197 pub fn update_state(&self, workspace_id: Uuid, new_state: serde_json::Value) -> Result<()> {
203 let version = if let Some(state) = self.states.get(&workspace_id) {
204 state.version + 1
205 } else {
206 1
207 };
208
209 if let Some(mut state) = self.states.get_mut(&workspace_id) {
210 state.update(new_state.clone());
211 } else {
212 self.states.insert(workspace_id, SyncState::new(version, new_state.clone()));
213 }
214
215 if let (Some(core_bridge), Some(workspace_service)) =
217 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
218 {
219 let core_bridge = core_bridge.clone();
220 let workspace_service = workspace_service.clone();
221 let state_data = new_state.clone();
222 tokio::spawn(async move {
223 if let Ok(mut team_workspace) = workspace_service.get_workspace(workspace_id).await
224 {
225 if let Err(e) = core_bridge
226 .update_workspace_state_from_json(&mut team_workspace, &state_data)
227 {
228 tracing::error!("Failed to update workspace state: {}", e);
229 } else {
230 if let Err(e) = core_bridge.save_workspace_to_disk(&team_workspace).await {
234 tracing::error!("Failed to save workspace to disk: {}", e);
235 }
236 }
237 }
238 });
239 }
240
241 if let Some(db) = &self.db {
243 let db = db.clone();
245 let state_data = new_state;
246 tokio::spawn(async move {
247 if let Err(e) =
248 Self::save_state_snapshot(&db, workspace_id, version, &state_data).await
249 {
250 tracing::error!("Failed to save state snapshot: {}", e);
251 }
252 });
253 }
254
255 Ok(())
256 }
257
258 pub async fn get_full_workspace_state(
266 &self,
267 workspace_id: Uuid,
268 ) -> Result<Option<serde_json::Value>> {
269 if let (Some(core_bridge), Some(workspace_service)) =
270 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
271 {
272 let team_workspace = workspace_service.get_workspace(workspace_id).await?;
274
275 let state_json = core_bridge.get_workspace_state_json(&team_workspace)?;
277 Ok(Some(state_json))
278 } else {
279 Ok(self.get_state(workspace_id).map(|s| s.state))
281 }
282 }
283
284 async fn save_state_snapshot(
286 db: &Pool<Sqlite>,
287 workspace_id: Uuid,
288 version: i64,
289 state: &serde_json::Value,
290 ) -> Result<()> {
291 let state_json = serde_json::to_string(state)?;
293 let mut hasher = Sha256::new();
294 hasher.update(state_json.as_bytes());
295 let state_hash = format!("{:x}", hasher.finalize());
296
297 let existing = sqlx::query!(
299 r#"
300 SELECT id FROM workspace_state_snapshots
301 WHERE workspace_id = ? AND state_hash = ?
302 "#,
303 workspace_id,
304 state_hash
305 )
306 .fetch_optional(db)
307 .await?;
308
309 if existing.is_some() {
310 return Ok(());
312 }
313
314 let snapshot_id = Uuid::new_v4();
316 let snapshot_id_str = snapshot_id.to_string();
317 let workspace_id_str = workspace_id.to_string();
318 let now = Utc::now().to_rfc3339();
319 sqlx::query!(
320 r#"
321 INSERT INTO workspace_state_snapshots (id, workspace_id, state_hash, state_data, version, created_at)
322 VALUES (?, ?, ?, ?, ?, ?)
323 "#,
324 snapshot_id_str,
325 workspace_id_str,
326 state_hash,
327 state_json,
328 version,
329 now
330 )
331 .execute(db)
332 .await?;
333
334 Ok(())
335 }
336
337 pub async fn load_state_snapshot(
343 &self,
344 workspace_id: Uuid,
345 version: Option<i64>,
346 ) -> Result<Option<SyncState>> {
347 let db = self.db.as_ref().ok_or_else(|| {
348 CollabError::Internal("Database not available for state snapshots".to_string())
349 })?;
350
351 let workspace_id_str = workspace_id.to_string();
352 let snapshot: Option<(String, i64, String)> = if let Some(version) = version {
354 sqlx::query_as(
355 r"
356 SELECT state_data, version, created_at
357 FROM workspace_state_snapshots
358 WHERE workspace_id = ? AND version = ?
359 ORDER BY created_at DESC
360 LIMIT 1
361 ",
362 )
363 .bind(&workspace_id_str)
364 .bind(version)
365 .fetch_optional(db)
366 .await?
367 } else {
368 sqlx::query_as(
369 r"
370 SELECT state_data, version, created_at
371 FROM workspace_state_snapshots
372 WHERE workspace_id = ?
373 ORDER BY version DESC, created_at DESC
374 LIMIT 1
375 ",
376 )
377 .bind(&workspace_id_str)
378 .fetch_optional(db)
379 .await?
380 };
381
382 if let Some((state_data, snap_version, created_at_str)) = snapshot {
383 let state: serde_json::Value = serde_json::from_str(&state_data)
384 .map_err(|e| CollabError::Internal(format!("Failed to parse state: {e}")))?;
385 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
388 .map(|dt| dt.with_timezone(&Utc))
389 .or_else(|_| {
390 chrono::NaiveDateTime::parse_from_str(&created_at_str, "%Y-%m-%d %H:%M:%S%.f")
392 .or_else(|_| {
393 chrono::NaiveDateTime::parse_from_str(
394 &created_at_str,
395 "%Y-%m-%d %H:%M:%S",
396 )
397 })
398 .map(|dt| dt.and_utc())
399 })
400 .map_err(|e| {
401 CollabError::Internal(format!(
402 "Failed to parse timestamp '{created_at_str}': {e}"
403 ))
404 })?;
405
406 Ok(Some(SyncState {
407 version: snap_version,
408 state,
409 last_updated: created_at,
410 }))
411 } else {
412 Ok(None)
413 }
414 }
415
416 pub async fn record_state_change(
422 &self,
423 workspace_id: Uuid,
424 change_type: &str,
425 change_data: serde_json::Value,
426 version: i64,
427 user_id: Uuid,
428 ) -> Result<()> {
429 let db = self.db.as_ref().ok_or_else(|| {
430 CollabError::Internal("Database not available for state changes".to_string())
431 })?;
432
433 let change_id = Uuid::new_v4();
434 let change_data_str = serde_json::to_string(&change_data)?;
435 let now = Utc::now().to_rfc3339();
436 sqlx::query!(
437 r#"
438 INSERT INTO workspace_state_changes (id, workspace_id, change_type, change_data, version, created_at, created_by)
439 VALUES (?, ?, ?, ?, ?, ?, ?)
440 "#,
441 change_id,
442 workspace_id,
443 change_type,
444 change_data_str,
445 version,
446 now,
447 user_id
448 )
449 .execute(db)
450 .await?;
451
452 Ok(())
453 }
454
455 pub async fn get_state_changes_since(
461 &self,
462 workspace_id: Uuid,
463 since_version: i64,
464 ) -> Result<Vec<serde_json::Value>> {
465 let db = self.db.as_ref().ok_or_else(|| {
466 CollabError::Internal("Database not available for state changes".to_string())
467 })?;
468
469 let changes = sqlx::query!(
470 r#"
471 SELECT change_data
472 FROM workspace_state_changes
473 WHERE workspace_id = ? AND version > ?
474 ORDER BY version ASC
475 "#,
476 workspace_id,
477 since_version
478 )
479 .fetch_all(db)
480 .await?;
481
482 let mut result = Vec::new();
483 for change in changes {
484 let data: serde_json::Value = serde_json::from_str(&change.change_data)
485 .map_err(|e| CollabError::Internal(format!("Failed to parse change data: {e}")))?;
486 result.push(data);
487 }
488
489 Ok(result)
490 }
491
492 #[must_use]
494 pub fn get_connections(&self, workspace_id: Uuid) -> Vec<Uuid> {
495 self.connections.get(&workspace_id).map(|c| c.clone()).unwrap_or_default()
496 }
497
498 #[must_use]
500 pub fn connection_count(&self) -> usize {
501 self.connections.iter().map(|c| c.value().len()).sum()
502 }
503
504 #[must_use]
506 pub fn has_connections(&self, workspace_id: Uuid) -> bool {
507 self.connections.get(&workspace_id).is_some_and(|c| !c.is_empty())
508 }
509
510 pub fn cleanup_inactive(&self) {
512 let inactive: Vec<Uuid> = self
513 .connections
514 .iter()
515 .filter(|entry| entry.value().is_empty())
516 .map(|entry| *entry.key())
517 .collect();
518
519 for workspace_id in inactive {
520 self.connections.remove(&workspace_id);
521 }
522 }
523}
524
525pub mod crdt {
527 use serde::{Deserialize, Serialize};
528 use uuid::Uuid;
529
530 #[derive(Debug, Clone, Serialize, Deserialize)]
532 pub struct LwwRegister<T> {
533 pub value: T,
535 pub timestamp: u64,
537 pub client_id: Uuid,
539 }
540
541 impl<T> LwwRegister<T> {
542 pub const fn new(value: T, timestamp: u64, client_id: Uuid) -> Self {
544 Self {
545 value,
546 timestamp,
547 client_id,
548 }
549 }
550
551 pub fn merge(&mut self, other: Self)
553 where
554 T: Clone,
555 {
556 if other.timestamp > self.timestamp
557 || (other.timestamp == self.timestamp && other.client_id > self.client_id)
558 {
559 self.value = other.value;
560 self.timestamp = other.timestamp;
561 self.client_id = other.client_id;
562 }
563 }
564 }
565
566 #[derive(Debug, Clone, Serialize, Deserialize)]
568 pub struct TextOperation {
569 pub position: usize,
571 pub op: TextOp,
573 pub timestamp: u64,
575 pub client_id: Uuid,
577 }
578
579 #[derive(Debug, Clone, Serialize, Deserialize)]
581 #[serde(tag = "type", rename_all = "lowercase")]
582 pub enum TextOp {
583 Insert {
585 text: String,
587 },
588 Delete {
590 length: usize,
592 },
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_sync_state() {
602 let mut state = SyncState::new(1, serde_json::json!({"key": "value"}));
603 assert_eq!(state.version, 1);
604
605 state.update(serde_json::json!({"key": "new_value"}));
606 assert_eq!(state.version, 2);
607 }
608
609 #[test]
610 fn test_sync_engine() {
611 let event_bus = Arc::new(EventBus::new(100));
612 let engine = SyncEngine::new(event_bus);
613
614 let workspace_id = Uuid::new_v4();
615 let client_id = Uuid::new_v4();
616
617 assert_eq!(engine.connection_count(), 0);
618
619 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
620 assert_eq!(engine.connection_count(), 1);
621 assert!(engine.has_connections(workspace_id));
622
623 engine.unsubscribe(workspace_id, client_id).unwrap();
624 assert_eq!(engine.connection_count(), 0);
625 }
626
627 #[test]
628 fn test_state_management() {
629 let event_bus = Arc::new(EventBus::new(100));
630 let engine = SyncEngine::new(event_bus);
631
632 let workspace_id = Uuid::new_v4();
633 let state = serde_json::json!({"mocks": []});
634
635 engine.update_state(workspace_id, state.clone()).unwrap();
636
637 let retrieved = engine.get_state(workspace_id).unwrap();
638 assert_eq!(retrieved.version, 1);
639 assert_eq!(retrieved.state, state);
640 }
641
642 #[test]
643 fn test_crdt_lww_register() {
644 use super::crdt::LwwRegister;
645
646 let client1 = Uuid::new_v4();
647 let client2 = Uuid::new_v4();
648
649 let mut reg1 = LwwRegister::new("value1", 1, client1);
650 let reg2 = LwwRegister::new("value2", 2, client2);
651
652 reg1.merge(reg2);
653 assert_eq!(reg1.value, "value2");
654 assert_eq!(reg1.timestamp, 2);
655 }
656
657 #[test]
658 fn test_cleanup_inactive() {
659 let event_bus = Arc::new(EventBus::new(100));
660 let engine = SyncEngine::new(event_bus);
661
662 let workspace_id = Uuid::new_v4();
663 let client_id = Uuid::new_v4();
664
665 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
666 assert_eq!(engine.connection_count(), 1);
667
668 engine.unsubscribe(workspace_id, client_id).unwrap();
669 assert_eq!(engine.connection_count(), 0);
670
671 engine.cleanup_inactive();
672 assert!(!engine.has_connections(workspace_id));
673 }
674}