1use crate::core_bridge::CoreBridge;
4use crate::error::{CollabError, Result};
5use crate::events::{ChangeEvent, EventBus};
6use crate::workspace::WorkspaceService;
7use chrono::Utc;
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use sqlx::{Pool, Sqlite};
12use std::sync::Arc;
13use tokio::sync::broadcast;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum SyncMessage {
20 Subscribe { workspace_id: Uuid },
22 Unsubscribe { workspace_id: Uuid },
24 Change { event: ChangeEvent },
26 StateRequest { workspace_id: Uuid, version: i64 },
28 StateResponse {
30 workspace_id: Uuid,
31 version: i64,
32 state: serde_json::Value,
33 },
34 Ping,
36 Pong,
38 Error { message: String },
40}
41
42#[derive(Debug, Clone)]
44pub struct SyncState {
45 pub version: i64,
47 pub state: serde_json::Value,
49 pub last_updated: chrono::DateTime<chrono::Utc>,
51}
52
53impl SyncState {
54 pub fn new(version: i64, state: serde_json::Value) -> Self {
56 Self {
57 version,
58 state,
59 last_updated: chrono::Utc::now(),
60 }
61 }
62
63 pub fn update(&mut self, new_state: serde_json::Value) {
65 self.version += 1;
66 self.state = new_state;
67 self.last_updated = chrono::Utc::now();
68 }
69}
70
71pub struct SyncEngine {
73 event_bus: Arc<EventBus>,
75 states: DashMap<Uuid, SyncState>,
77 connections: DashMap<Uuid, Vec<Uuid>>,
79 db: Option<Pool<Sqlite>>,
81 core_bridge: Option<Arc<CoreBridge>>,
83 workspace_service: Option<Arc<WorkspaceService>>,
85}
86
87impl SyncEngine {
88 pub fn new(event_bus: Arc<EventBus>) -> Self {
90 Self {
91 event_bus,
92 states: DashMap::new(),
93 connections: DashMap::new(),
94 db: None,
95 core_bridge: None,
96 workspace_service: None,
97 }
98 }
99
100 pub fn with_db(event_bus: Arc<EventBus>, db: Pool<Sqlite>) -> Self {
102 Self {
103 event_bus,
104 states: DashMap::new(),
105 connections: DashMap::new(),
106 db: Some(db),
107 core_bridge: None,
108 workspace_service: None,
109 }
110 }
111
112 pub fn with_integration(
114 event_bus: Arc<EventBus>,
115 db: Pool<Sqlite>,
116 core_bridge: Arc<CoreBridge>,
117 workspace_service: Arc<WorkspaceService>,
118 ) -> Self {
119 Self {
120 event_bus,
121 states: DashMap::new(),
122 connections: DashMap::new(),
123 db: Some(db),
124 core_bridge: Some(core_bridge),
125 workspace_service: Some(workspace_service),
126 }
127 }
128
129 pub fn subscribe(
131 &self,
132 workspace_id: Uuid,
133 client_id: Uuid,
134 ) -> Result<broadcast::Receiver<ChangeEvent>> {
135 self.connections.entry(workspace_id).or_insert_with(Vec::new).push(client_id);
137
138 Ok(self.event_bus.subscribe())
140 }
141
142 pub fn unsubscribe(&self, workspace_id: Uuid, client_id: Uuid) -> Result<()> {
144 if let Some(mut connections) = self.connections.get_mut(&workspace_id) {
145 connections.retain(|id| *id != client_id);
146 }
147 Ok(())
148 }
149
150 pub fn publish_change(&self, event: ChangeEvent) -> Result<()> {
152 self.event_bus.publish(event)
153 }
154
155 pub fn get_state(&self, workspace_id: Uuid) -> Option<SyncState> {
157 self.states.get(&workspace_id).map(|s| s.clone())
158 }
159
160 pub fn update_state(&self, workspace_id: Uuid, new_state: serde_json::Value) -> Result<()> {
162 let version = if let Some(state) = self.states.get(&workspace_id) {
163 state.version + 1
164 } else {
165 1
166 };
167
168 if let Some(mut state) = self.states.get_mut(&workspace_id) {
169 state.update(new_state.clone());
170 } else {
171 self.states.insert(workspace_id, SyncState::new(version, new_state.clone()));
172 }
173
174 if let (Some(core_bridge), Some(workspace_service)) =
176 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
177 {
178 let core_bridge = core_bridge.clone();
179 let workspace_service = workspace_service.clone();
180 let workspace_id = workspace_id;
181 let state_data = new_state.clone();
182 tokio::spawn(async move {
183 if let Ok(mut team_workspace) = workspace_service.get_workspace(workspace_id).await
184 {
185 if let Err(e) = core_bridge
186 .update_workspace_state_from_json(&mut team_workspace, &state_data)
187 {
188 tracing::error!("Failed to update workspace state: {}", e);
189 } else {
190 if let Err(e) = core_bridge.save_workspace_to_disk(&team_workspace).await {
194 tracing::error!("Failed to save workspace to disk: {}", e);
195 }
196 }
197 }
198 });
199 }
200
201 if let Some(db) = &self.db {
203 let db = db.clone();
205 let workspace_id = workspace_id;
206 let state_data = new_state.clone();
207 tokio::spawn(async move {
208 if let Err(e) =
209 Self::save_state_snapshot(&db, workspace_id, version, &state_data).await
210 {
211 tracing::error!("Failed to save state snapshot: {}", e);
212 }
213 });
214 }
215
216 Ok(())
217 }
218
219 pub async fn get_full_workspace_state(
223 &self,
224 workspace_id: Uuid,
225 ) -> Result<Option<serde_json::Value>> {
226 if let (Some(core_bridge), Some(workspace_service)) =
227 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
228 {
229 let team_workspace = workspace_service.get_workspace(workspace_id).await?;
231
232 let state_json = core_bridge.get_workspace_state_json(&team_workspace)?;
234 Ok(Some(state_json))
235 } else {
236 Ok(self.get_state(workspace_id).map(|s| s.state))
238 }
239 }
240
241 async fn save_state_snapshot(
243 db: &Pool<Sqlite>,
244 workspace_id: Uuid,
245 version: i64,
246 state: &serde_json::Value,
247 ) -> Result<()> {
248 let state_json = serde_json::to_string(state)?;
250 let mut hasher = Sha256::new();
251 hasher.update(state_json.as_bytes());
252 let state_hash = format!("{:x}", hasher.finalize());
253
254 let existing = sqlx::query!(
256 r#"
257 SELECT id FROM workspace_state_snapshots
258 WHERE workspace_id = ? AND state_hash = ?
259 "#,
260 workspace_id,
261 state_hash
262 )
263 .fetch_optional(db)
264 .await?;
265
266 if existing.is_some() {
267 return Ok(());
269 }
270
271 let snapshot_id = Uuid::new_v4();
273 let snapshot_id_str = snapshot_id.to_string();
274 let workspace_id_str = workspace_id.to_string();
275 let now = Utc::now().to_rfc3339();
276 sqlx::query!(
277 r#"
278 INSERT INTO workspace_state_snapshots (id, workspace_id, state_hash, state_data, version, created_at)
279 VALUES (?, ?, ?, ?, ?, ?)
280 "#,
281 snapshot_id_str,
282 workspace_id_str,
283 state_hash,
284 state_json,
285 version,
286 now
287 )
288 .execute(db)
289 .await?;
290
291 Ok(())
292 }
293
294 pub async fn load_state_snapshot(
296 &self,
297 workspace_id: Uuid,
298 version: Option<i64>,
299 ) -> Result<Option<SyncState>> {
300 let db = self.db.as_ref().ok_or_else(|| {
301 CollabError::Internal("Database not available for state snapshots".to_string())
302 })?;
303
304 let workspace_id_str = workspace_id.to_string();
305 let snapshot: Option<(String, i64, String)> = if let Some(version) = version {
307 sqlx::query_as(
308 r#"
309 SELECT state_data, version, created_at
310 FROM workspace_state_snapshots
311 WHERE workspace_id = ? AND version = ?
312 ORDER BY created_at DESC
313 LIMIT 1
314 "#,
315 )
316 .bind(&workspace_id_str)
317 .bind(version)
318 .fetch_optional(db)
319 .await?
320 } else {
321 sqlx::query_as(
322 r#"
323 SELECT state_data, version, created_at
324 FROM workspace_state_snapshots
325 WHERE workspace_id = ?
326 ORDER BY version DESC, created_at DESC
327 LIMIT 1
328 "#,
329 )
330 .bind(&workspace_id_str)
331 .fetch_optional(db)
332 .await?
333 };
334
335 if let Some((state_data, snap_version, created_at_str)) = snapshot {
336 let state: serde_json::Value = serde_json::from_str(&state_data)
337 .map_err(|e| CollabError::Internal(format!("Failed to parse state: {}", e)))?;
338 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
341 .map(|dt| dt.with_timezone(&chrono::Utc))
342 .or_else(|_| {
343 chrono::NaiveDateTime::parse_from_str(&created_at_str, "%Y-%m-%d %H:%M:%S%.f")
345 .or_else(|_| {
346 chrono::NaiveDateTime::parse_from_str(
347 &created_at_str,
348 "%Y-%m-%d %H:%M:%S",
349 )
350 })
351 .map(|dt| dt.and_utc())
352 })
353 .map_err(|e| {
354 CollabError::Internal(format!(
355 "Failed to parse timestamp '{}': {}",
356 created_at_str, e
357 ))
358 })?;
359
360 Ok(Some(SyncState {
361 version: snap_version,
362 state,
363 last_updated: created_at,
364 }))
365 } else {
366 Ok(None)
367 }
368 }
369
370 pub async fn record_state_change(
372 &self,
373 workspace_id: Uuid,
374 change_type: &str,
375 change_data: serde_json::Value,
376 version: i64,
377 user_id: Uuid,
378 ) -> Result<()> {
379 let db = self.db.as_ref().ok_or_else(|| {
380 CollabError::Internal("Database not available for state changes".to_string())
381 })?;
382
383 let change_id = Uuid::new_v4();
384 let change_id_str = change_id.to_string();
385 let change_data_str = serde_json::to_string(&change_data)?;
386 let workspace_id_str = workspace_id.to_string();
387 let user_id_str = user_id.to_string();
388 let now = Utc::now().to_rfc3339();
389 sqlx::query!(
390 r#"
391 INSERT INTO workspace_state_changes (id, workspace_id, change_type, change_data, version, created_at, created_by)
392 VALUES (?, ?, ?, ?, ?, ?, ?)
393 "#,
394 change_id_str,
395 workspace_id_str,
396 change_type,
397 change_data_str,
398 version,
399 now,
400 user_id_str
401 )
402 .execute(db)
403 .await?;
404
405 Ok(())
406 }
407
408 pub async fn get_state_changes_since(
410 &self,
411 workspace_id: Uuid,
412 since_version: i64,
413 ) -> Result<Vec<serde_json::Value>> {
414 let db = self.db.as_ref().ok_or_else(|| {
415 CollabError::Internal("Database not available for state changes".to_string())
416 })?;
417
418 let workspace_id_str = workspace_id.to_string();
419 let changes = sqlx::query!(
420 r#"
421 SELECT change_data
422 FROM workspace_state_changes
423 WHERE workspace_id = ? AND version > ?
424 ORDER BY version ASC
425 "#,
426 workspace_id_str,
427 since_version
428 )
429 .fetch_all(db)
430 .await?;
431
432 let mut result = Vec::new();
433 for change in changes {
434 let data: serde_json::Value =
435 serde_json::from_str(&change.change_data).map_err(|e| {
436 CollabError::Internal(format!("Failed to parse change data: {}", e))
437 })?;
438 result.push(data);
439 }
440
441 Ok(result)
442 }
443
444 pub fn get_connections(&self, workspace_id: Uuid) -> Vec<Uuid> {
446 self.connections.get(&workspace_id).map(|c| c.clone()).unwrap_or_default()
447 }
448
449 pub fn connection_count(&self) -> usize {
451 self.connections.iter().map(|c| c.value().len()).sum()
452 }
453
454 pub fn has_connections(&self, workspace_id: Uuid) -> bool {
456 self.connections.get(&workspace_id).map(|c| !c.is_empty()).unwrap_or(false)
457 }
458
459 pub fn cleanup_inactive(&self) {
461 let inactive: Vec<Uuid> = self
462 .connections
463 .iter()
464 .filter(|entry| entry.value().is_empty())
465 .map(|entry| *entry.key())
466 .collect();
467
468 for workspace_id in inactive {
469 self.connections.remove(&workspace_id);
470 }
471 }
472}
473
474pub mod crdt {
476 use serde::{Deserialize, Serialize};
477 use uuid::Uuid;
478
479 #[derive(Debug, Clone, Serialize, Deserialize)]
481 pub struct LwwRegister<T> {
482 pub value: T,
484 pub timestamp: u64,
486 pub client_id: Uuid,
488 }
489
490 impl<T> LwwRegister<T> {
491 pub fn new(value: T, timestamp: u64, client_id: Uuid) -> Self {
493 Self {
494 value,
495 timestamp,
496 client_id,
497 }
498 }
499
500 pub fn merge(&mut self, other: Self)
502 where
503 T: Clone,
504 {
505 if other.timestamp > self.timestamp
506 || (other.timestamp == self.timestamp && other.client_id > self.client_id)
507 {
508 self.value = other.value;
509 self.timestamp = other.timestamp;
510 self.client_id = other.client_id;
511 }
512 }
513 }
514
515 #[derive(Debug, Clone, Serialize, Deserialize)]
517 pub struct TextOperation {
518 pub position: usize,
520 pub op: TextOp,
522 pub timestamp: u64,
524 pub client_id: Uuid,
526 }
527
528 #[derive(Debug, Clone, Serialize, Deserialize)]
529 #[serde(tag = "type", rename_all = "lowercase")]
530 pub enum TextOp {
531 Insert { text: String },
533 Delete { length: usize },
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_sync_state() {
544 let mut state = SyncState::new(1, serde_json::json!({"key": "value"}));
545 assert_eq!(state.version, 1);
546
547 state.update(serde_json::json!({"key": "new_value"}));
548 assert_eq!(state.version, 2);
549 }
550
551 #[test]
552 fn test_sync_engine() {
553 let event_bus = Arc::new(EventBus::new(100));
554 let engine = SyncEngine::new(event_bus);
555
556 let workspace_id = Uuid::new_v4();
557 let client_id = Uuid::new_v4();
558
559 assert_eq!(engine.connection_count(), 0);
560
561 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
562 assert_eq!(engine.connection_count(), 1);
563 assert!(engine.has_connections(workspace_id));
564
565 engine.unsubscribe(workspace_id, client_id).unwrap();
566 assert_eq!(engine.connection_count(), 0);
567 }
568
569 #[test]
570 fn test_state_management() {
571 let event_bus = Arc::new(EventBus::new(100));
572 let engine = SyncEngine::new(event_bus);
573
574 let workspace_id = Uuid::new_v4();
575 let state = serde_json::json!({"mocks": []});
576
577 engine.update_state(workspace_id, state.clone()).unwrap();
578
579 let retrieved = engine.get_state(workspace_id).unwrap();
580 assert_eq!(retrieved.version, 1);
581 assert_eq!(retrieved.state, state);
582 }
583
584 #[test]
585 fn test_crdt_lww_register() {
586 use super::crdt::LwwRegister;
587
588 let client1 = Uuid::new_v4();
589 let client2 = Uuid::new_v4();
590
591 let mut reg1 = LwwRegister::new("value1", 1, client1);
592 let reg2 = LwwRegister::new("value2", 2, client2);
593
594 reg1.merge(reg2);
595 assert_eq!(reg1.value, "value2");
596 assert_eq!(reg1.timestamp, 2);
597 }
598
599 #[test]
600 fn test_cleanup_inactive() {
601 let event_bus = Arc::new(EventBus::new(100));
602 let engine = SyncEngine::new(event_bus);
603
604 let workspace_id = Uuid::new_v4();
605 let client_id = Uuid::new_v4();
606
607 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
608 assert_eq!(engine.connection_count(), 1);
609
610 engine.unsubscribe(workspace_id, client_id).unwrap();
611 assert_eq!(engine.connection_count(), 0);
612
613 engine.cleanup_inactive();
614 assert!(!engine.has_connections(workspace_id));
615 }
616}