1use crate::core_bridge::CoreBridge;
4use crate::error::{CollabError, Result};
5use crate::events::{ChangeEvent, EventBus};
6use crate::workspace::WorkspaceService;
7use chrono::Utc;
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use sqlx::{Pool, Sqlite};
12use std::sync::Arc;
13use tokio::sync::broadcast;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum SyncMessage {
20 Subscribe { workspace_id: Uuid },
22 Unsubscribe { workspace_id: Uuid },
24 Change { event: ChangeEvent },
26 StateRequest { workspace_id: Uuid, version: i64 },
28 StateResponse {
30 workspace_id: Uuid,
31 version: i64,
32 state: serde_json::Value,
33 },
34 Ping,
36 Pong,
38 Error { message: String },
40}
41
42#[derive(Debug, Clone)]
44pub struct SyncState {
45 pub version: i64,
47 pub state: serde_json::Value,
49 pub last_updated: chrono::DateTime<Utc>,
51}
52
53impl SyncState {
54 #[must_use]
56 pub fn new(version: i64, state: serde_json::Value) -> Self {
57 Self {
58 version,
59 state,
60 last_updated: Utc::now(),
61 }
62 }
63
64 pub fn update(&mut self, new_state: serde_json::Value) {
66 self.version += 1;
67 self.state = new_state;
68 self.last_updated = Utc::now();
69 }
70}
71
72pub struct SyncEngine {
74 event_bus: Arc<EventBus>,
76 states: DashMap<Uuid, SyncState>,
78 connections: DashMap<Uuid, Vec<Uuid>>,
80 db: Option<Pool<Sqlite>>,
82 core_bridge: Option<Arc<CoreBridge>>,
84 workspace_service: Option<Arc<WorkspaceService>>,
86}
87
88impl SyncEngine {
89 #[must_use]
91 pub fn new(event_bus: Arc<EventBus>) -> Self {
92 Self {
93 event_bus,
94 states: DashMap::new(),
95 connections: DashMap::new(),
96 db: None,
97 core_bridge: None,
98 workspace_service: None,
99 }
100 }
101
102 #[must_use]
104 pub fn with_db(event_bus: Arc<EventBus>, db: Pool<Sqlite>) -> Self {
105 Self {
106 event_bus,
107 states: DashMap::new(),
108 connections: DashMap::new(),
109 db: Some(db),
110 core_bridge: None,
111 workspace_service: None,
112 }
113 }
114
115 #[must_use]
117 pub fn with_integration(
118 event_bus: Arc<EventBus>,
119 db: Pool<Sqlite>,
120 core_bridge: Arc<CoreBridge>,
121 workspace_service: Arc<WorkspaceService>,
122 ) -> Self {
123 Self {
124 event_bus,
125 states: DashMap::new(),
126 connections: DashMap::new(),
127 db: Some(db),
128 core_bridge: Some(core_bridge),
129 workspace_service: Some(workspace_service),
130 }
131 }
132
133 pub fn subscribe(
135 &self,
136 workspace_id: Uuid,
137 client_id: Uuid,
138 ) -> Result<broadcast::Receiver<ChangeEvent>> {
139 self.connections.entry(workspace_id).or_default().push(client_id);
141
142 Ok(self.event_bus.subscribe())
144 }
145
146 pub fn unsubscribe(&self, workspace_id: Uuid, client_id: Uuid) -> Result<()> {
148 if let Some(mut connections) = self.connections.get_mut(&workspace_id) {
149 connections.retain(|id| *id != client_id);
150 }
151 Ok(())
152 }
153
154 pub fn publish_change(&self, event: ChangeEvent) -> Result<()> {
156 self.event_bus.publish(event)
157 }
158
159 #[must_use]
161 pub fn get_state(&self, workspace_id: Uuid) -> Option<SyncState> {
162 self.states.get(&workspace_id).map(|s| s.clone())
163 }
164
165 pub fn update_state(&self, workspace_id: Uuid, new_state: serde_json::Value) -> Result<()> {
167 let version = if let Some(state) = self.states.get(&workspace_id) {
168 state.version + 1
169 } else {
170 1
171 };
172
173 if let Some(mut state) = self.states.get_mut(&workspace_id) {
174 state.update(new_state.clone());
175 } else {
176 self.states.insert(workspace_id, SyncState::new(version, new_state.clone()));
177 }
178
179 if let (Some(core_bridge), Some(workspace_service)) =
181 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
182 {
183 let core_bridge = core_bridge.clone();
184 let workspace_service = workspace_service.clone();
185 let workspace_id = workspace_id;
186 let state_data = new_state.clone();
187 tokio::spawn(async move {
188 if let Ok(mut team_workspace) = workspace_service.get_workspace(workspace_id).await
189 {
190 if let Err(e) = core_bridge
191 .update_workspace_state_from_json(&mut team_workspace, &state_data)
192 {
193 tracing::error!("Failed to update workspace state: {}", e);
194 } else {
195 if let Err(e) = core_bridge.save_workspace_to_disk(&team_workspace).await {
199 tracing::error!("Failed to save workspace to disk: {}", e);
200 }
201 }
202 }
203 });
204 }
205
206 if let Some(db) = &self.db {
208 let db = db.clone();
210 let workspace_id = workspace_id;
211 let state_data = new_state;
212 tokio::spawn(async move {
213 if let Err(e) =
214 Self::save_state_snapshot(&db, workspace_id, version, &state_data).await
215 {
216 tracing::error!("Failed to save state snapshot: {}", e);
217 }
218 });
219 }
220
221 Ok(())
222 }
223
224 pub async fn get_full_workspace_state(
228 &self,
229 workspace_id: Uuid,
230 ) -> Result<Option<serde_json::Value>> {
231 if let (Some(core_bridge), Some(workspace_service)) =
232 (self.core_bridge.as_ref(), self.workspace_service.as_ref())
233 {
234 let team_workspace = workspace_service.get_workspace(workspace_id).await?;
236
237 let state_json = core_bridge.get_workspace_state_json(&team_workspace)?;
239 Ok(Some(state_json))
240 } else {
241 Ok(self.get_state(workspace_id).map(|s| s.state))
243 }
244 }
245
246 async fn save_state_snapshot(
248 db: &Pool<Sqlite>,
249 workspace_id: Uuid,
250 version: i64,
251 state: &serde_json::Value,
252 ) -> Result<()> {
253 let state_json = serde_json::to_string(state)?;
255 let mut hasher = Sha256::new();
256 hasher.update(state_json.as_bytes());
257 let state_hash = format!("{:x}", hasher.finalize());
258
259 let existing = sqlx::query!(
261 r#"
262 SELECT id FROM workspace_state_snapshots
263 WHERE workspace_id = ? AND state_hash = ?
264 "#,
265 workspace_id,
266 state_hash
267 )
268 .fetch_optional(db)
269 .await?;
270
271 if existing.is_some() {
272 return Ok(());
274 }
275
276 let snapshot_id = Uuid::new_v4();
278 let snapshot_id_str = snapshot_id.to_string();
279 let workspace_id_str = workspace_id.to_string();
280 let now = Utc::now().to_rfc3339();
281 sqlx::query!(
282 r#"
283 INSERT INTO workspace_state_snapshots (id, workspace_id, state_hash, state_data, version, created_at)
284 VALUES (?, ?, ?, ?, ?, ?)
285 "#,
286 snapshot_id_str,
287 workspace_id_str,
288 state_hash,
289 state_json,
290 version,
291 now
292 )
293 .execute(db)
294 .await?;
295
296 Ok(())
297 }
298
299 pub async fn load_state_snapshot(
301 &self,
302 workspace_id: Uuid,
303 version: Option<i64>,
304 ) -> Result<Option<SyncState>> {
305 let db = self.db.as_ref().ok_or_else(|| {
306 CollabError::Internal("Database not available for state snapshots".to_string())
307 })?;
308
309 let workspace_id_str = workspace_id.to_string();
310 let snapshot: Option<(String, i64, String)> = if let Some(version) = version {
312 sqlx::query_as(
313 r"
314 SELECT state_data, version, created_at
315 FROM workspace_state_snapshots
316 WHERE workspace_id = ? AND version = ?
317 ORDER BY created_at DESC
318 LIMIT 1
319 ",
320 )
321 .bind(&workspace_id_str)
322 .bind(version)
323 .fetch_optional(db)
324 .await?
325 } else {
326 sqlx::query_as(
327 r"
328 SELECT state_data, version, created_at
329 FROM workspace_state_snapshots
330 WHERE workspace_id = ?
331 ORDER BY version DESC, created_at DESC
332 LIMIT 1
333 ",
334 )
335 .bind(&workspace_id_str)
336 .fetch_optional(db)
337 .await?
338 };
339
340 if let Some((state_data, snap_version, created_at_str)) = snapshot {
341 let state: serde_json::Value = serde_json::from_str(&state_data)
342 .map_err(|e| CollabError::Internal(format!("Failed to parse state: {e}")))?;
343 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
346 .map(|dt| dt.with_timezone(&Utc))
347 .or_else(|_| {
348 chrono::NaiveDateTime::parse_from_str(&created_at_str, "%Y-%m-%d %H:%M:%S%.f")
350 .or_else(|_| {
351 chrono::NaiveDateTime::parse_from_str(
352 &created_at_str,
353 "%Y-%m-%d %H:%M:%S",
354 )
355 })
356 .map(|dt| dt.and_utc())
357 })
358 .map_err(|e| {
359 CollabError::Internal(format!(
360 "Failed to parse timestamp '{created_at_str}': {e}"
361 ))
362 })?;
363
364 Ok(Some(SyncState {
365 version: snap_version,
366 state,
367 last_updated: created_at,
368 }))
369 } else {
370 Ok(None)
371 }
372 }
373
374 pub async fn record_state_change(
376 &self,
377 workspace_id: Uuid,
378 change_type: &str,
379 change_data: serde_json::Value,
380 version: i64,
381 user_id: Uuid,
382 ) -> Result<()> {
383 let db = self.db.as_ref().ok_or_else(|| {
384 CollabError::Internal("Database not available for state changes".to_string())
385 })?;
386
387 let change_id = Uuid::new_v4();
388 let change_id_str = change_id.to_string();
389 let change_data_str = serde_json::to_string(&change_data)?;
390 let workspace_id_str = workspace_id.to_string();
391 let user_id_str = user_id.to_string();
392 let now = Utc::now().to_rfc3339();
393 sqlx::query!(
394 r#"
395 INSERT INTO workspace_state_changes (id, workspace_id, change_type, change_data, version, created_at, created_by)
396 VALUES (?, ?, ?, ?, ?, ?, ?)
397 "#,
398 change_id_str,
399 workspace_id_str,
400 change_type,
401 change_data_str,
402 version,
403 now,
404 user_id_str
405 )
406 .execute(db)
407 .await?;
408
409 Ok(())
410 }
411
412 pub async fn get_state_changes_since(
414 &self,
415 workspace_id: Uuid,
416 since_version: i64,
417 ) -> Result<Vec<serde_json::Value>> {
418 let db = self.db.as_ref().ok_or_else(|| {
419 CollabError::Internal("Database not available for state changes".to_string())
420 })?;
421
422 let workspace_id_str = workspace_id.to_string();
423 let changes = sqlx::query!(
424 r#"
425 SELECT change_data
426 FROM workspace_state_changes
427 WHERE workspace_id = ? AND version > ?
428 ORDER BY version ASC
429 "#,
430 workspace_id_str,
431 since_version
432 )
433 .fetch_all(db)
434 .await?;
435
436 let mut result = Vec::new();
437 for change in changes {
438 let data: serde_json::Value = serde_json::from_str(&change.change_data)
439 .map_err(|e| CollabError::Internal(format!("Failed to parse change data: {e}")))?;
440 result.push(data);
441 }
442
443 Ok(result)
444 }
445
446 #[must_use]
448 pub fn get_connections(&self, workspace_id: Uuid) -> Vec<Uuid> {
449 self.connections.get(&workspace_id).map(|c| c.clone()).unwrap_or_default()
450 }
451
452 #[must_use]
454 pub fn connection_count(&self) -> usize {
455 self.connections.iter().map(|c| c.value().len()).sum()
456 }
457
458 #[must_use]
460 pub fn has_connections(&self, workspace_id: Uuid) -> bool {
461 self.connections.get(&workspace_id).is_some_and(|c| !c.is_empty())
462 }
463
464 pub fn cleanup_inactive(&self) {
466 let inactive: Vec<Uuid> = self
467 .connections
468 .iter()
469 .filter(|entry| entry.value().is_empty())
470 .map(|entry| *entry.key())
471 .collect();
472
473 for workspace_id in inactive {
474 self.connections.remove(&workspace_id);
475 }
476 }
477}
478
479pub mod crdt {
481 use serde::{Deserialize, Serialize};
482 use uuid::Uuid;
483
484 #[derive(Debug, Clone, Serialize, Deserialize)]
486 pub struct LwwRegister<T> {
487 pub value: T,
489 pub timestamp: u64,
491 pub client_id: Uuid,
493 }
494
495 impl<T> LwwRegister<T> {
496 pub const fn new(value: T, timestamp: u64, client_id: Uuid) -> Self {
498 Self {
499 value,
500 timestamp,
501 client_id,
502 }
503 }
504
505 pub fn merge(&mut self, other: Self)
507 where
508 T: Clone,
509 {
510 if other.timestamp > self.timestamp
511 || (other.timestamp == self.timestamp && other.client_id > self.client_id)
512 {
513 self.value = other.value;
514 self.timestamp = other.timestamp;
515 self.client_id = other.client_id;
516 }
517 }
518 }
519
520 #[derive(Debug, Clone, Serialize, Deserialize)]
522 pub struct TextOperation {
523 pub position: usize,
525 pub op: TextOp,
527 pub timestamp: u64,
529 pub client_id: Uuid,
531 }
532
533 #[derive(Debug, Clone, Serialize, Deserialize)]
534 #[serde(tag = "type", rename_all = "lowercase")]
535 pub enum TextOp {
536 Insert { text: String },
538 Delete { length: usize },
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_sync_state() {
549 let mut state = SyncState::new(1, serde_json::json!({"key": "value"}));
550 assert_eq!(state.version, 1);
551
552 state.update(serde_json::json!({"key": "new_value"}));
553 assert_eq!(state.version, 2);
554 }
555
556 #[test]
557 fn test_sync_engine() {
558 let event_bus = Arc::new(EventBus::new(100));
559 let engine = SyncEngine::new(event_bus);
560
561 let workspace_id = Uuid::new_v4();
562 let client_id = Uuid::new_v4();
563
564 assert_eq!(engine.connection_count(), 0);
565
566 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
567 assert_eq!(engine.connection_count(), 1);
568 assert!(engine.has_connections(workspace_id));
569
570 engine.unsubscribe(workspace_id, client_id).unwrap();
571 assert_eq!(engine.connection_count(), 0);
572 }
573
574 #[test]
575 fn test_state_management() {
576 let event_bus = Arc::new(EventBus::new(100));
577 let engine = SyncEngine::new(event_bus);
578
579 let workspace_id = Uuid::new_v4();
580 let state = serde_json::json!({"mocks": []});
581
582 engine.update_state(workspace_id, state.clone()).unwrap();
583
584 let retrieved = engine.get_state(workspace_id).unwrap();
585 assert_eq!(retrieved.version, 1);
586 assert_eq!(retrieved.state, state);
587 }
588
589 #[test]
590 fn test_crdt_lww_register() {
591 use super::crdt::LwwRegister;
592
593 let client1 = Uuid::new_v4();
594 let client2 = Uuid::new_v4();
595
596 let mut reg1 = LwwRegister::new("value1", 1, client1);
597 let reg2 = LwwRegister::new("value2", 2, client2);
598
599 reg1.merge(reg2);
600 assert_eq!(reg1.value, "value2");
601 assert_eq!(reg1.timestamp, 2);
602 }
603
604 #[test]
605 fn test_cleanup_inactive() {
606 let event_bus = Arc::new(EventBus::new(100));
607 let engine = SyncEngine::new(event_bus);
608
609 let workspace_id = Uuid::new_v4();
610 let client_id = Uuid::new_v4();
611
612 let _rx = engine.subscribe(workspace_id, client_id).unwrap();
613 assert_eq!(engine.connection_count(), 1);
614
615 engine.unsubscribe(workspace_id, client_id).unwrap();
616 assert_eq!(engine.connection_count(), 0);
617
618 engine.cleanup_inactive();
619 assert!(!engine.has_connections(workspace_id));
620 }
621}