1use std::sync::Mutex as StdMutex;
15#[cfg(feature = "sqlite")]
16use std::sync::Arc;
17use tokio::sync::broadcast;
18#[cfg(feature = "sqlite")]
19use tokio::sync::Mutex;
20#[cfg(feature = "sqlite")]
21use tokio::time::{Duration, Instant};
22use yrs::encoding::read::Read;
23use yrs::encoding::write::Write;
24use yrs::updates::decoder::{Decode, DecoderV1};
25use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
26use yrs::{Doc, ReadTxn, StateVector, Transact, Update};
27
28#[cfg(feature = "sqlite")]
29use crate::db::Database;
30
31pub const MSG_SYNC: u8 = 0;
34pub const MSG_AWARENESS: u8 = 1;
35pub const MSG_AUTH: u8 = 2; pub const MSG_QUERY_AWARENESS: u8 = 3;
37
38#[cfg(feature = "sqlite")]
40const PERSIST_DEBOUNCE_MS: u64 = 500;
41
42pub struct DocHandler {
48 pub doc_name: String,
49 doc: StdMutex<Doc>,
51 #[cfg(feature = "sqlite")]
53 db: Database,
54 pub broadcast_tx: broadcast::Sender<Vec<u8>>,
56 #[cfg(feature = "sqlite")]
58 last_persist_request: Arc<Mutex<Option<Instant>>>,
59 #[cfg(feature = "sqlite")]
61 persist_pending: Arc<Mutex<bool>>,
62}
63
64unsafe impl Send for DocHandler {}
67unsafe impl Sync for DocHandler {}
68
69impl DocHandler {
70 #[cfg(feature = "sqlite")]
71 pub async fn new(doc_name: String, db: Database) -> Self {
72 let doc = Doc::new();
73 let (broadcast_tx, _) = broadcast::channel(256);
74
75 tracing::info!("Loading document '{}' from database...", doc_name);
77 if let Ok(Some(data)) = db.get_doc(&doc_name).await {
78 tracing::info!(
79 "Found existing data for '{}': {} bytes",
80 doc_name,
81 data.len()
82 );
83 let mut txn = doc.transact_mut();
84 match Update::decode_v1(&data) {
85 Ok(update) => {
86 txn.apply_update(update);
87 tracing::debug!("Applied stored state to document '{}'", doc_name);
88 }
89 Err(e) => {
90 tracing::error!("Failed to decode stored state for '{}': {:?}", doc_name, e);
91 }
92 }
93 } else {
94 tracing::info!("No existing data found for '{}', starting fresh", doc_name);
95 }
96
97 Self {
98 doc_name,
99 doc: StdMutex::new(doc),
100 db,
101 broadcast_tx,
102 last_persist_request: Arc::new(Mutex::new(None)),
103 persist_pending: Arc::new(Mutex::new(false)),
104 }
105 }
106
107 #[cfg(not(feature = "sqlite"))]
108 pub async fn new(doc_name: String) -> Self {
109 let doc = Doc::new();
110 let (broadcast_tx, _) = broadcast::channel(256);
111
112 Self {
113 doc_name,
114 doc: StdMutex::new(doc),
115 broadcast_tx,
116 }
117 }
118
119 pub fn generate_initial_sync(&self) -> Vec<Vec<u8>> {
122 let doc = self.doc.lock().unwrap();
123 let txn = doc.transact();
124 let state_vector = txn.state_vector();
125
126 let mut encoder = EncoderV1::new();
128 encoder.write_var(0u32); let mut sv_encoder = EncoderV1::new();
132 state_vector.encode(&mut sv_encoder);
133 let sv_bytes = sv_encoder.to_vec();
134
135 encoder.write_buf(&sv_bytes);
137
138 let payload = encoder.to_vec();
139
140 let encoded = self.encode_hocuspocus_message(MSG_SYNC, &payload);
141
142 tracing::debug!(
143 "Generated initial sync message ({} bytes): {:02x?}",
144 encoded.len(),
145 encoded
146 );
147
148 vec![encoded]
149 }
150
151 pub async fn handle_message(&self, msg_data: &[u8]) -> Vec<Vec<u8>> {
155 let mut responses = Vec::new();
156
157 if msg_data.is_empty() {
158 return responses;
159 }
160
161 tracing::trace!("Received message ({} bytes)", msg_data.len());
162
163 let (content_data, _doc_name) = match DocHandler::read_and_skip_doc_name(msg_data) {
166 Some(res) => res,
167 None => {
168 tracing::warn!(
169 "Failed to parse document name from message: {:02x?}",
170 msg_data
171 );
172 return responses;
173 }
174 };
175
176 if content_data.is_empty() {
177 return responses;
178 }
179
180 let msg_type = content_data[0];
181 let payload = &content_data[1..];
182
183 match msg_type {
184 MSG_SYNC => {
185 self.handle_sync_message(payload, &mut responses).await;
186 }
187 MSG_AWARENESS => {
188 self.forward_awareness_message(payload);
191 }
192 MSG_QUERY_AWARENESS => {
193 tracing::debug!("Received QUERY_AWARENESS (no server state maintained)");
196 }
197 MSG_AUTH => {
198 tracing::debug!("Received AUTH message (accepted)");
201 }
202 _ => {
203 tracing::warn!("Unknown message type: {}", msg_type);
204 }
205 }
206
207 responses
208 }
209
210 pub fn read_and_skip_doc_name(data: &[u8]) -> Option<(&[u8], String)> {
212 let mut offset = 0;
213 let mut len: usize = 0;
214 let mut shift = 0;
215
216 loop {
218 if offset >= data.len() {
219 return None;
220 }
221 let b = data[offset];
222 offset += 1;
223 len |= ((b & 0x7F) as usize) << shift;
224 if b & 0x80 == 0 {
225 break;
226 }
227 shift += 7;
228 if shift > 64 {
229 return None;
230 }
231 }
232
233 if offset + len > data.len() {
234 return None;
235 }
236
237 let name_bytes = &data[offset..offset + len];
239 let name = String::from_utf8_lossy(name_bytes).to_string();
240
241 Some((&data[offset + len..], name))
242 }
243
244 pub fn encode_hocuspocus_message(&self, msg_type: u8, payload: &[u8]) -> Vec<u8> {
247 let mut encoder = EncoderV1::new();
248 encoder.write_string(&self.doc_name);
249 encoder.write_var(msg_type as u32);
250
251 let mut encoded = encoder.to_vec();
252 encoded.extend_from_slice(payload);
253 encoded
254 }
255
256 fn forward_awareness_message(&self, payload: &[u8]) {
258 let broadcast_msg = self.encode_hocuspocus_message(MSG_AWARENESS, payload);
260 let _ = self.broadcast_tx.send(broadcast_msg);
261 tracing::trace!("Forwarded awareness message for '{}'", self.doc_name);
262 }
263
264 async fn handle_sync_message(&self, payload: &[u8], responses: &mut Vec<Vec<u8>>) {
266 let mut decoder = DecoderV1::from(payload);
267
268 while let Ok(tag) = decoder.read_var::<u32>() {
271 match tag {
272 0 => {
273 match decoder.read_buf() {
276 Ok(sv_data) => {
277 match StateVector::decode(&mut DecoderV1::from(sv_data)) {
279 Ok(client_sv) => {
280 tracing::debug!(
281 "Handling SyncStep1 (SV len: {})",
282 client_sv.len()
283 );
284 let doc = self.doc.lock().unwrap();
285 let txn = doc.transact();
286
287 let update = txn.encode_state_as_update_v1(&client_sv);
290 let mut encoder = EncoderV1::new();
291 encoder.write_var(1u32);
292 encoder.write_buf(&update);
293 responses.push(
294 self.encode_hocuspocus_message(MSG_SYNC, &encoder.to_vec()),
295 );
296
297 let server_sv = txn.state_vector();
300
301 let mut sv_encoder = EncoderV1::new();
302 server_sv.encode(&mut sv_encoder);
303 let sv_bytes = sv_encoder.to_vec();
304
305 let mut encoder_sv = EncoderV1::new();
306 encoder_sv.write_var(0u32);
307 encoder_sv.write_buf(&sv_bytes);
308
309 responses.push(
310 self.encode_hocuspocus_message(
311 MSG_SYNC,
312 &encoder_sv.to_vec(),
313 ),
314 );
315
316 tracing::debug!(
317 "Processed SyncStep1 for '{}', sent SyncStep2 + SyncStep1",
318 self.doc_name
319 );
320 }
321 Err(e) => {
322 tracing::warn!(
323 "Failed to decode StateVector in SyncStep1: {:?}",
324 e
325 );
326 break;
327 }
328 }
329 }
330 Err(e) => {
331 tracing::warn!("Failed to read SyncStep1 payload: {:?}", e);
332 break;
333 }
334 }
335 }
336 1 => {
337 match decoder.read_buf() {
339 Ok(update_data) => {
340 tracing::debug!(
341 "Handling SyncStep2 (payload len: {})",
342 update_data.len()
343 );
344 if update_data.is_empty() {
345 tracing::debug!("Received empty SyncStep2 update, ignoring");
346 continue;
347 }
348
349 if let Err(e) = self.apply_update(update_data) {
350 tracing::error!(
351 "Failed to apply SyncStep2 update: {:?}. Payload: {:02x?}",
352 e,
353 update_data
354 );
355 } else {
356 tracing::debug!("Applied SyncStep2 update for '{}'", self.doc_name);
357 self.request_persist().await;
358 }
359 }
360 Err(e) => {
361 tracing::warn!("Failed to read SyncStep2 payload: {:?}", e);
362 break;
363 }
364 }
365 }
366 2 => {
367 match decoder.read_buf() {
369 Ok(update_data) => {
370 if let Err(e) = self.apply_update(update_data) {
371 tracing::error!("Failed to apply incremental update: {:?}", e);
372 } else {
373 tracing::debug!(
374 "Applied incremental update for '{}'",
375 self.doc_name
376 );
377
378 let mut encoder = EncoderV1::new();
381 encoder.write_var(2u32);
382 encoder.write_buf(update_data);
383 let msg =
384 self.encode_hocuspocus_message(MSG_SYNC, &encoder.to_vec());
385 let _ = self.broadcast_tx.send(msg);
386
387 self.request_persist().await;
388 }
389 }
390 Err(e) => {
391 tracing::warn!("Failed to read Update payload: {:?}", e);
392 break;
393 }
394 }
395 }
396 _ => {
397 tracing::warn!("Unknown sync message tag: {}", tag);
398 break;
399 }
400 }
401 }
402 }
403
404 pub fn apply_update(
406 &self,
407 update_data: &[u8],
408 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
409 let update = Update::decode_v1(update_data)?;
410 let doc = self.doc.lock().unwrap();
411 let mut txn = doc.transact_mut();
412 txn.apply_update(update);
413 Ok(())
414 }
415
416 pub async fn request_persist(&self) {
418 #[cfg(feature = "sqlite")]
419 {
420 let now = Instant::now();
421
422 {
423 let mut last_request = self.last_persist_request.lock().await;
424 *last_request = Some(now);
425 }
426
427 let already_pending = {
429 let pending = self.persist_pending.lock().await;
430 *pending
431 };
432
433 if !already_pending {
434 {
436 let mut pending = self.persist_pending.lock().await;
437 *pending = true;
438 }
439
440 let doc_name = self.doc_name.clone();
442 let db = self.db.clone();
443 let last_persist_request = self.last_persist_request.clone();
444 let persist_pending = self.persist_pending.clone();
445
446 let state = {
448 let doc = self.doc.lock().unwrap();
449 let txn = doc.transact();
450 txn.encode_state_as_update_v1(&StateVector::default())
451 };
452
453 tokio::spawn(async move {
454 tokio::time::sleep(Duration::from_millis(PERSIST_DEBOUNCE_MS)).await;
456
457 let should_persist = {
459 let last_request = last_persist_request.lock().await;
460 if let Some(last) = *last_request {
461 last.elapsed() >= Duration::from_millis(PERSIST_DEBOUNCE_MS - 50)
462 } else {
463 true
464 }
465 };
466
467 if should_persist {
468 if let Err(e) = db.save_doc(&doc_name, state).await {
470 tracing::error!("Failed to persist document '{}': {:?}", doc_name, e);
471 } else {
472 tracing::debug!("Persisted document '{}'", doc_name);
473 }
474 }
475
476 {
478 let mut pending = persist_pending.lock().await;
479 *pending = false;
480 }
481 });
482 }
483 }
484 }
485
486 pub async fn force_persist(&self) {
488 #[cfg(feature = "sqlite")]
489 {
490 let state = {
491 let doc = self.doc.lock().unwrap();
492 let txn = doc.transact();
493 txn.encode_state_as_update_v1(&StateVector::default())
494 };
495
496 if let Err(e) = self.db.save_doc(&self.doc_name, state).await {
497 tracing::error!(
498 "Failed to persist document '{}' on shutdown: {:?}",
499 self.doc_name,
500 e
501 );
502 } else {
503 tracing::info!("Persisted document '{}' on shutdown", self.doc_name);
504 }
505 }
506 }
507
508 pub fn subscribe(&self) -> broadcast::Receiver<Vec<u8>> {
510 self.broadcast_tx.subscribe()
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use yrs::encoding::read::Read;
518 use yrs::updates::decoder::DecoderV1;
519
520 use yrs::updates::encoder::{Encoder, EncoderV1};
521 use yrs::{GetString, Text, Transact};
522
523 #[cfg(feature = "sqlite")]
525 async fn create_test_db() -> Database {
526 Database::init_in_memory().expect("Failed to create test database")
527 }
528
529 fn encode_test_msg(doc_name: &str, msg_type: u8, payload: &[u8]) -> Vec<u8> {
530 let mut encoder = EncoderV1::new();
531 encoder.write_string(doc_name);
532 encoder.write_var(msg_type as u32);
533 let mut v = encoder.to_vec();
534 v.extend_from_slice(payload);
535 v
536 }
537
538 fn encode_sync_step1(sv: &StateVector) -> Vec<u8> {
539 let mut sv_encoder = EncoderV1::new();
540 sv.encode(&mut sv_encoder);
541 let sv_bytes = sv_encoder.to_vec();
542
543 let mut encoder = EncoderV1::new();
544 encoder.write_var(0u32); encoder.write_buf(&sv_bytes);
546 encoder.to_vec()
547 }
548
549 fn encode_update(update: &[u8]) -> Vec<u8> {
550 let mut encoder = EncoderV1::new();
551 encoder.write_var(2u32); encoder.write_buf(update);
553 encoder.to_vec()
554 }
555
556 #[tokio::test]
557 #[cfg(feature = "sqlite")]
558 async fn test_doc_handler_creation() {
559 let db = create_test_db().await;
560 let handler = DocHandler::new("test-room".to_string(), db).await;
561 assert_eq!(handler.doc_name, "test-room");
562 }
563
564 #[tokio::test]
565 #[cfg(feature = "sqlite")]
566 async fn test_initial_sync_generation() {
567 let db = create_test_db().await;
568 let handler = DocHandler::new("test-room".to_string(), db).await;
569
570 let messages = handler.generate_initial_sync();
571 assert_eq!(messages.len(), 1);
572
573 let (rest, name) =
575 DocHandler::read_and_skip_doc_name(&messages[0]).expect("Should parse doc name");
576 assert_eq!(name, "test-room");
577
578 let mut decoder = DecoderV1::from(rest);
580 let msg_type: u32 = decoder.read_var().expect("Should parse msg type");
581 assert_eq!(msg_type as u8, MSG_SYNC);
582 }
583
584 #[tokio::test]
585 #[cfg(feature = "sqlite")]
586 async fn test_sync_step1_response() {
587 let db = create_test_db().await;
588 let handler = DocHandler::new("test-room".to_string(), db).await;
589
590 let client_sv = StateVector::default();
592 let payload = encode_sync_step1(&client_sv);
593
594 let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
595
596 let responses = handler.handle_message(&msg).await;
597
598 assert_eq!(responses.len(), 2);
600
601 for resp in responses {
603 let (rest, name) = DocHandler::read_and_skip_doc_name(&resp).unwrap();
604 assert_eq!(name, "test-room");
605 let mut d = DecoderV1::from(rest);
606 let t: u32 = d.read_var().unwrap();
607 assert_eq!(t as u8, MSG_SYNC);
608 }
609 }
610
611 #[tokio::test]
612 #[cfg(feature = "sqlite")]
613 async fn test_update_application_and_broadcast() {
614 let db = create_test_db().await;
615 let handler = DocHandler::new("test-room".to_string(), db).await;
616
617 let mut rx = handler.subscribe();
619
620 let client_doc = Doc::new();
622 let update = {
623 let text = client_doc.get_or_insert_text("test");
624 let mut txn = client_doc.transact_mut();
625 text.push(&mut txn, "Hello, World!");
626 txn.encode_update_v1()
627 };
628
629 let payload = encode_update(&update);
631 let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
632
633 let _responses = handler.handle_message(&msg).await;
634
635 let broadcast = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
637 assert!(broadcast.is_ok());
638 let broadcast_data = broadcast.unwrap().unwrap();
639
640 let (_, name) = DocHandler::read_and_skip_doc_name(&broadcast_data).unwrap();
642 assert_eq!(name, "test-room");
643 }
644
645 #[tokio::test]
646 #[cfg(feature = "sqlite")]
647 async fn test_persistence_after_update() {
648 let db = create_test_db().await;
649 let handler = DocHandler::new("test-room".to_string(), db.clone()).await;
650
651 let client_doc = Doc::new();
653 let update = {
654 let text = client_doc.get_or_insert_text("test");
655 let mut txn = client_doc.transact_mut();
656 text.push(&mut txn, "Persistent data");
657 txn.encode_update_v1()
658 };
659
660 let payload = encode_update(&update);
662 let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
663
664 let _responses = handler.handle_message(&msg).await;
665
666 handler.force_persist().await;
668
669 let saved = db.get_doc("test-room").await.unwrap();
671 assert!(saved.is_some());
672 assert!(!saved.unwrap().is_empty());
673 }
674
675 #[tokio::test]
676 #[cfg(feature = "sqlite")]
677 async fn test_document_reload_from_db() {
678 let db = create_test_db().await;
679
680 let handler1 = DocHandler::new("reload-test".to_string(), db.clone()).await;
682
683 let client_doc = Doc::new();
684 let update = {
685 let text = client_doc.get_or_insert_text("content");
686 let mut txn = client_doc.transact_mut();
687 text.push(&mut txn, "Test content for reload");
688 txn.encode_update_v1()
689 };
690
691 let payload = encode_update(&update);
692 let msg = encode_test_msg("reload-test", MSG_SYNC, &payload);
693
694 handler1.handle_message(&msg).await;
695 handler1.force_persist().await;
696
697 drop(handler1);
699
700 let handler2 = DocHandler::new("reload-test".to_string(), db).await;
702
703 let doc = handler2.doc.lock().unwrap();
705 let text = doc.get_or_insert_text("content");
706 let txn = doc.transact();
707 let content = text.get_string(&txn);
708
709 assert_eq!(content, "Test content for reload");
710 }
711
712 #[tokio::test]
713 #[cfg(feature = "sqlite")]
714 async fn test_awareness_forwarding() {
715 let db = create_test_db().await;
716 let handler = DocHandler::new("test-room".to_string(), db).await;
717
718 let mut rx = handler.subscribe();
720
721 let body = vec![1, 2, 3, 4];
723 let awareness_msg = encode_test_msg("test-room", MSG_AWARENESS, &body);
724
725 let _responses = handler.handle_message(&awareness_msg).await;
726
727 let broadcast = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
729 assert!(broadcast.is_ok());
730 let received = broadcast.unwrap().unwrap();
731
732 assert_eq!(received, awareness_msg);
734 }
735
736 #[tokio::test]
737 #[cfg(feature = "sqlite")]
738 async fn test_empty_message_handling() {
739 let db = create_test_db().await;
740 let handler = DocHandler::new("test-room".to_string(), db).await;
741
742 let responses = handler.handle_message(&[]).await;
744 assert!(responses.is_empty());
745 }
746
747 #[tokio::test]
748 #[cfg(not(feature = "sqlite"))]
749 async fn test_doc_handler_no_sqlite() {
750 let handler = DocHandler::new("test-room-no-db".to_string()).await;
751 assert_eq!(handler.doc_name, "test-room-no-db");
752
753 let messages = handler.generate_initial_sync();
755 assert_eq!(messages.len(), 1);
756
757 let (_, name) = DocHandler::read_and_skip_doc_name(&messages[0]).unwrap();
758 assert_eq!(name, "test-room-no-db");
759 }
760}