Skip to main content

hocuspocus_rs/
sync.rs

1//! Yjs Sync Protocol Implementation
2//!
3//! This module implements the y-websocket sync protocol for compatibility with
4//! Hocuspocus/y-websocket clients. The protocol flow is:
5//!
6//! 1. Client connects and sends SyncStep1(client_state_vector)
7//! 2. Server responds with SyncStep2(server_diff) + SyncStep1(server_state_vector)
8//! 3. Client sends SyncStep2(client_diff)
9//! 4. Ongoing: Both sides exchange Update messages
10//!
11//! Note: Awareness protocol (for cursors/presence) is handled by forwarding
12//! messages between clients without server-side state.
13
14use std::sync::Mutex as StdMutex;
15#[cfg(feature = "sqlite")]
16use std::sync::Arc;
17use tokio::sync::broadcast;
18#[cfg(feature = "sqlite")]
19use tokio::sync::Mutex;
20#[cfg(feature = "sqlite")]
21use tokio::time::{Duration, Instant};
22use yrs::encoding::read::Read;
23use yrs::encoding::write::Write;
24use yrs::updates::decoder::{Decode, DecoderV1};
25use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
26use yrs::{Doc, ReadTxn, StateVector, Transact, Update};
27
28#[cfg(feature = "sqlite")]
29use crate::db::Database;
30
31/// Message types for y-websocket protocol
32/// The first byte of each message indicates its type
33pub const MSG_SYNC: u8 = 0;
34pub const MSG_AWARENESS: u8 = 1;
35pub const MSG_AUTH: u8 = 2; // Hocuspocus-specific
36pub const MSG_QUERY_AWARENESS: u8 = 3;
37
38/// Debounce interval for persistence (milliseconds)
39#[cfg(feature = "sqlite")]
40const PERSIST_DEBOUNCE_MS: u64 = 500;
41
42/// A handler for a single Yjs document/room
43/// Manages the document state, persistence, and broadcasting
44///
45/// Note: We use std::sync::Mutex for the Doc because yrs::Doc operations are
46/// synchronous and fast. tokio::sync::Mutex would cause unnecessary overhead.
47pub struct DocHandler {
48    pub doc_name: String,
49    /// Thread-safe document access using std Mutex (Doc ops are sync & fast)
50    doc: StdMutex<Doc>,
51    /// Database for persistence
52    #[cfg(feature = "sqlite")]
53    db: Database,
54    /// Broadcast channel for sending updates to other clients
55    pub broadcast_tx: broadcast::Sender<Vec<u8>>,
56    /// Track when persistence was last requested for debouncing
57    #[cfg(feature = "sqlite")]
58    last_persist_request: Arc<Mutex<Option<Instant>>>,
59    /// Flag to indicate persistence is pending
60    #[cfg(feature = "sqlite")]
61    persist_pending: Arc<Mutex<bool>>,
62}
63
64// Explicitly mark DocHandler as Send + Sync since we use std::sync::Mutex
65// and all fields are thread-safe
66unsafe impl Send for DocHandler {}
67unsafe impl Sync for DocHandler {}
68
69impl DocHandler {
70    #[cfg(feature = "sqlite")]
71    pub async fn new(doc_name: String, db: Database) -> Self {
72        let doc = Doc::new();
73        let (broadcast_tx, _) = broadcast::channel(256);
74
75        // Load existing state from DB
76        tracing::info!("Loading document '{}' from database...", doc_name);
77        if let Ok(Some(data)) = db.get_doc(&doc_name).await {
78            tracing::info!(
79                "Found existing data for '{}': {} bytes",
80                doc_name,
81                data.len()
82            );
83            let mut txn = doc.transact_mut();
84            match Update::decode_v1(&data) {
85                Ok(update) => {
86                    txn.apply_update(update);
87                    tracing::debug!("Applied stored state to document '{}'", doc_name);
88                }
89                Err(e) => {
90                    tracing::error!("Failed to decode stored state for '{}': {:?}", doc_name, e);
91                }
92            }
93        } else {
94            tracing::info!("No existing data found for '{}', starting fresh", doc_name);
95        }
96
97        Self {
98            doc_name,
99            doc: StdMutex::new(doc),
100            db,
101            broadcast_tx,
102            last_persist_request: Arc::new(Mutex::new(None)),
103            persist_pending: Arc::new(Mutex::new(false)),
104        }
105    }
106
107    #[cfg(not(feature = "sqlite"))]
108    pub async fn new(doc_name: String) -> Self {
109        let doc = Doc::new();
110        let (broadcast_tx, _) = broadcast::channel(256);
111
112        Self {
113            doc_name,
114            doc: StdMutex::new(doc),
115            broadcast_tx,
116        }
117    }
118
119    /// Generate the initial sync messages to send when a client connects
120    /// Returns: [SyncStep1(server_state_vector)]
121    pub fn generate_initial_sync(&self) -> Vec<Vec<u8>> {
122        let doc = self.doc.lock().unwrap();
123        let txn = doc.transact();
124        let state_vector = txn.state_vector();
125
126        // Encode SyncStep1: [Tag 0] [Len] [SV]
127        let mut encoder = EncoderV1::new();
128        encoder.write_var(0u32); // Tag SyncStep1
129
130        // Encode SV to bytes first
131        let mut sv_encoder = EncoderV1::new();
132        state_vector.encode(&mut sv_encoder);
133        let sv_bytes = sv_encoder.to_vec();
134
135        // Write as buffer (Length + Bytes)
136        encoder.write_buf(&sv_bytes);
137
138        let payload = encoder.to_vec();
139
140        let encoded = self.encode_hocuspocus_message(MSG_SYNC, &payload);
141
142        tracing::debug!(
143            "Generated initial sync message ({} bytes): {:02x?}",
144            encoded.len(),
145            encoded
146        );
147
148        vec![encoded]
149    }
150
151    /// Process an incoming message from a client
152    /// Returns a list of response messages to send back to this client
153    /// Also broadcasts updates to other clients via the broadcast channel
154    pub async fn handle_message(&self, msg_data: &[u8]) -> Vec<Vec<u8>> {
155        let mut responses = Vec::new();
156
157        if msg_data.is_empty() {
158            return responses;
159        }
160
161        tracing::trace!("Received message ({} bytes)", msg_data.len());
162
163        // Hocuspocus Protocol V2: [DocName (VarString)] [MessageType (VarUint)] [Payload]
164        // We first need to skip the document name since we already know context from the room connection
165        let (content_data, _doc_name) = match DocHandler::read_and_skip_doc_name(msg_data) {
166            Some(res) => res,
167            None => {
168                tracing::warn!(
169                    "Failed to parse document name from message: {:02x?}",
170                    msg_data
171                );
172                return responses;
173            }
174        };
175
176        if content_data.is_empty() {
177            return responses;
178        }
179
180        let msg_type = content_data[0];
181        let payload = &content_data[1..];
182
183        match msg_type {
184            MSG_SYNC => {
185                self.handle_sync_message(payload, &mut responses).await;
186            }
187            MSG_AWARENESS => {
188                // Awareness messages are forwarded to other clients
189                // We re-wrap them with our doc_name to ensure clients route them correctly
190                self.forward_awareness_message(payload);
191            }
192            MSG_QUERY_AWARENESS => {
193                // Query awareness - we don't maintain server-side awareness state
194                // Clients will receive awareness updates from other clients directly
195                tracing::debug!("Received QUERY_AWARENESS (no server state maintained)");
196            }
197            MSG_AUTH => {
198                // Auth messages are handled at the WebSocket layer
199                // For now, we accept all connections
200                tracing::debug!("Received AUTH message (accepted)");
201            }
202            _ => {
203                tracing::warn!("Unknown message type: {}", msg_type);
204            }
205        }
206
207        responses
208    }
209
210    /// Helper to read the VarString document name and return the rest of the buffer
211    pub fn read_and_skip_doc_name(data: &[u8]) -> Option<(&[u8], String)> {
212        let mut offset = 0;
213        let mut len: usize = 0;
214        let mut shift = 0;
215
216        // Decode VarUint length
217        loop {
218            if offset >= data.len() {
219                return None;
220            }
221            let b = data[offset];
222            offset += 1;
223            len |= ((b & 0x7F) as usize) << shift;
224            if b & 0x80 == 0 {
225                break;
226            }
227            shift += 7;
228            if shift > 64 {
229                return None;
230            }
231        }
232
233        if offset + len > data.len() {
234            return None;
235        }
236
237        // Decode string for debug/verification (optional but good for logging)
238        let name_bytes = &data[offset..offset + len];
239        let name = String::from_utf8_lossy(name_bytes).to_string();
240
241        Some((&data[offset + len..], name))
242    }
243
244    /// Wraps a raw payload in the Hocuspocus V2 protocol structure:
245    /// [DocName : VarString] [MsgType : VarUint] [Payload : Bytes]
246    pub fn encode_hocuspocus_message(&self, msg_type: u8, payload: &[u8]) -> Vec<u8> {
247        let mut encoder = EncoderV1::new();
248        encoder.write_string(&self.doc_name);
249        encoder.write_var(msg_type as u32);
250
251        let mut encoded = encoder.to_vec();
252        encoded.extend_from_slice(payload);
253        encoded
254    }
255
256    /// Forward awareness message to other clients
257    fn forward_awareness_message(&self, payload: &[u8]) {
258        // Re-wrap the awareness message with the Hocuspocus protocol V2 prefix
259        let broadcast_msg = self.encode_hocuspocus_message(MSG_AWARENESS, payload);
260        let _ = self.broadcast_tx.send(broadcast_msg);
261        tracing::trace!("Forwarded awareness message for '{}'", self.doc_name);
262    }
263
264    /// Handle sync protocol messages (SyncStep1, SyncStep2, Update)
265    async fn handle_sync_message(&self, payload: &[u8], responses: &mut Vec<Vec<u8>>) {
266        let mut decoder = DecoderV1::from(payload);
267
268        // Loop over the payload to decode multiple messages (Hocuspocus/y-protocols stream)
269        // We check loop by trying to read the next tag
270        while let Ok(tag) = decoder.read_var::<u32>() {
271            match tag {
272                0 => {
273                    // SyncStep1: [Tag 0] [Len] [StateVector]
274                    // First read the length-prefixed buffer
275                    match decoder.read_buf() {
276                        Ok(sv_data) => {
277                            // Then decode SV from the buffer
278                            match StateVector::decode(&mut DecoderV1::from(sv_data)) {
279                                Ok(client_sv) => {
280                                    tracing::debug!(
281                                        "Handling SyncStep1 (SV len: {})",
282                                        client_sv.len()
283                                    );
284                                    let doc = self.doc.lock().unwrap();
285                                    let txn = doc.transact();
286
287                                    // Reply with SyncStep2 (updates client needs)
288                                    // [Tag 1] [Len] [Bytes]
289                                    let update = txn.encode_state_as_update_v1(&client_sv);
290                                    let mut encoder = EncoderV1::new();
291                                    encoder.write_var(1u32);
292                                    encoder.write_buf(&update);
293                                    responses.push(
294                                        self.encode_hocuspocus_message(MSG_SYNC, &encoder.to_vec()),
295                                    );
296
297                                    // Also send our SyncStep1 (server SV) so client can sync vs us
298                                    // [Tag 0] [Len] [StateVector]
299                                    let server_sv = txn.state_vector();
300
301                                    let mut sv_encoder = EncoderV1::new();
302                                    server_sv.encode(&mut sv_encoder);
303                                    let sv_bytes = sv_encoder.to_vec();
304
305                                    let mut encoder_sv = EncoderV1::new();
306                                    encoder_sv.write_var(0u32);
307                                    encoder_sv.write_buf(&sv_bytes);
308
309                                    responses.push(
310                                        self.encode_hocuspocus_message(
311                                            MSG_SYNC,
312                                            &encoder_sv.to_vec(),
313                                        ),
314                                    );
315
316                                    tracing::debug!(
317                                        "Processed SyncStep1 for '{}', sent SyncStep2 + SyncStep1",
318                                        self.doc_name
319                                    );
320                                }
321                                Err(e) => {
322                                    tracing::warn!(
323                                        "Failed to decode StateVector in SyncStep1: {:?}",
324                                        e
325                                    );
326                                    break;
327                                }
328                            }
329                        }
330                        Err(e) => {
331                            tracing::warn!("Failed to read SyncStep1 payload: {:?}", e);
332                            break;
333                        }
334                    }
335                }
336                1 => {
337                    // SyncStep2: [Tag 1] [Len] [Bytes]
338                    match decoder.read_buf() {
339                        Ok(update_data) => {
340                            tracing::debug!(
341                                "Handling SyncStep2 (payload len: {})",
342                                update_data.len()
343                            );
344                            if update_data.is_empty() {
345                                tracing::debug!("Received empty SyncStep2 update, ignoring");
346                                continue;
347                            }
348
349                            if let Err(e) = self.apply_update(update_data) {
350                                tracing::error!(
351                                    "Failed to apply SyncStep2 update: {:?}. Payload: {:02x?}",
352                                    e,
353                                    update_data
354                                );
355                            } else {
356                                tracing::debug!("Applied SyncStep2 update for '{}'", self.doc_name);
357                                self.request_persist().await;
358                            }
359                        }
360                        Err(e) => {
361                            tracing::warn!("Failed to read SyncStep2 payload: {:?}", e);
362                            break;
363                        }
364                    }
365                }
366                2 => {
367                    // Update: [Tag 2] [Len] [Bytes]
368                    match decoder.read_buf() {
369                        Ok(update_data) => {
370                            if let Err(e) = self.apply_update(update_data) {
371                                tracing::error!("Failed to apply incremental update: {:?}", e);
372                            } else {
373                                tracing::debug!(
374                                    "Applied incremental update for '{}'",
375                                    self.doc_name
376                                );
377
378                                // Broadcast to other clients - MUST include Hocuspocus prefix
379                                // Broadcast format: [Tag 2] [Len] [Bytes]
380                                let mut encoder = EncoderV1::new();
381                                encoder.write_var(2u32);
382                                encoder.write_buf(update_data);
383                                let msg =
384                                    self.encode_hocuspocus_message(MSG_SYNC, &encoder.to_vec());
385                                let _ = self.broadcast_tx.send(msg);
386
387                                self.request_persist().await;
388                            }
389                        }
390                        Err(e) => {
391                            tracing::warn!("Failed to read Update payload: {:?}", e);
392                            break;
393                        }
394                    }
395                }
396                _ => {
397                    tracing::warn!("Unknown sync message tag: {}", tag);
398                    break;
399                }
400            }
401        }
402    }
403
404    /// Apply a Yjs update to the document
405    pub fn apply_update(
406        &self,
407        update_data: &[u8],
408    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
409        let update = Update::decode_v1(update_data)?;
410        let doc = self.doc.lock().unwrap();
411        let mut txn = doc.transact_mut();
412        txn.apply_update(update);
413        Ok(())
414    }
415
416    /// Request persistence with debouncing
417    pub async fn request_persist(&self) {
418        #[cfg(feature = "sqlite")]
419        {
420            let now = Instant::now();
421
422            {
423                let mut last_request = self.last_persist_request.lock().await;
424                *last_request = Some(now);
425            }
426
427            // Check if persistence is already pending
428            let already_pending = {
429                let pending = self.persist_pending.lock().await;
430                *pending
431            };
432
433            if !already_pending {
434                // Mark as pending
435                {
436                    let mut pending = self.persist_pending.lock().await;
437                    *pending = true;
438                }
439
440                // Clone what we need for the spawned task
441                let doc_name = self.doc_name.clone();
442                let db = self.db.clone();
443                let last_persist_request = self.last_persist_request.clone();
444                let persist_pending = self.persist_pending.clone();
445
446                // Encode the state before spawning since Doc isn't Send
447                let state = {
448                    let doc = self.doc.lock().unwrap();
449                    let txn = doc.transact();
450                    txn.encode_state_as_update_v1(&StateVector::default())
451                };
452
453                tokio::spawn(async move {
454                    // Wait for debounce interval
455                    tokio::time::sleep(Duration::from_millis(PERSIST_DEBOUNCE_MS)).await;
456
457                    // Check if there were more updates during the wait
458                    let should_persist = {
459                        let last_request = last_persist_request.lock().await;
460                        if let Some(last) = *last_request {
461                            last.elapsed() >= Duration::from_millis(PERSIST_DEBOUNCE_MS - 50)
462                        } else {
463                            true
464                        }
465                    };
466
467                    if should_persist {
468                        // Save to database
469                        if let Err(e) = db.save_doc(&doc_name, state).await {
470                            tracing::error!("Failed to persist document '{}': {:?}", doc_name, e);
471                        } else {
472                            tracing::debug!("Persisted document '{}'", doc_name);
473                        }
474                    }
475
476                    // Clear pending flag
477                    {
478                        let mut pending = persist_pending.lock().await;
479                        *pending = false;
480                    }
481                });
482            }
483        }
484    }
485
486    /// Force immediate persistence (for graceful shutdown)
487    pub async fn force_persist(&self) {
488        #[cfg(feature = "sqlite")]
489        {
490            let state = {
491                let doc = self.doc.lock().unwrap();
492                let txn = doc.transact();
493                txn.encode_state_as_update_v1(&StateVector::default())
494            };
495
496            if let Err(e) = self.db.save_doc(&self.doc_name, state).await {
497                tracing::error!(
498                    "Failed to persist document '{}' on shutdown: {:?}",
499                    self.doc_name,
500                    e
501                );
502            } else {
503                tracing::info!("Persisted document '{}' on shutdown", self.doc_name);
504            }
505        }
506    }
507
508    /// Get a subscription to broadcast messages
509    pub fn subscribe(&self) -> broadcast::Receiver<Vec<u8>> {
510        self.broadcast_tx.subscribe()
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use yrs::encoding::read::Read;
518    use yrs::updates::decoder::DecoderV1;
519
520    use yrs::updates::encoder::{Encoder, EncoderV1};
521    use yrs::{GetString, Text, Transact};
522
523    /// Creates an in-memory test database
524    #[cfg(feature = "sqlite")]
525    async fn create_test_db() -> Database {
526        Database::init_in_memory().expect("Failed to create test database")
527    }
528
529    fn encode_test_msg(doc_name: &str, msg_type: u8, payload: &[u8]) -> Vec<u8> {
530        let mut encoder = EncoderV1::new();
531        encoder.write_string(doc_name);
532        encoder.write_var(msg_type as u32);
533        let mut v = encoder.to_vec();
534        v.extend_from_slice(payload);
535        v
536    }
537
538    fn encode_sync_step1(sv: &StateVector) -> Vec<u8> {
539        let mut sv_encoder = EncoderV1::new();
540        sv.encode(&mut sv_encoder);
541        let sv_bytes = sv_encoder.to_vec();
542
543        let mut encoder = EncoderV1::new();
544        encoder.write_var(0u32); // Tag 0
545        encoder.write_buf(&sv_bytes);
546        encoder.to_vec()
547    }
548
549    fn encode_update(update: &[u8]) -> Vec<u8> {
550        let mut encoder = EncoderV1::new();
551        encoder.write_var(2u32); // Tag 2
552        encoder.write_buf(update);
553        encoder.to_vec()
554    }
555
556    #[tokio::test]
557    #[cfg(feature = "sqlite")]
558    async fn test_doc_handler_creation() {
559        let db = create_test_db().await;
560        let handler = DocHandler::new("test-room".to_string(), db).await;
561        assert_eq!(handler.doc_name, "test-room");
562    }
563
564    #[tokio::test]
565    #[cfg(feature = "sqlite")]
566    async fn test_initial_sync_generation() {
567        let db = create_test_db().await;
568        let handler = DocHandler::new("test-room".to_string(), db).await;
569
570        let messages = handler.generate_initial_sync();
571        assert_eq!(messages.len(), 1);
572
573        // Should start with doc name "test-room"
574        let (rest, name) =
575            DocHandler::read_and_skip_doc_name(&messages[0]).expect("Should parse doc name");
576        assert_eq!(name, "test-room");
577
578        // Next should be MSG_SYNC
579        let mut decoder = DecoderV1::from(rest);
580        let msg_type: u32 = decoder.read_var().expect("Should parse msg type");
581        assert_eq!(msg_type as u8, MSG_SYNC);
582    }
583
584    #[tokio::test]
585    #[cfg(feature = "sqlite")]
586    async fn test_sync_step1_response() {
587        let db = create_test_db().await;
588        let handler = DocHandler::new("test-room".to_string(), db).await;
589
590        // Create a client state vector (empty = requesting all updates)
591        let client_sv = StateVector::default();
592        let payload = encode_sync_step1(&client_sv);
593
594        let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
595
596        let responses = handler.handle_message(&msg).await;
597
598        // Should get SyncStep2 + SyncStep1 back
599        assert_eq!(responses.len(), 2);
600
601        // Verify response structure
602        for resp in responses {
603            let (rest, name) = DocHandler::read_and_skip_doc_name(&resp).unwrap();
604            assert_eq!(name, "test-room");
605            let mut d = DecoderV1::from(rest);
606            let t: u32 = d.read_var().unwrap();
607            assert_eq!(t as u8, MSG_SYNC);
608        }
609    }
610
611    #[tokio::test]
612    #[cfg(feature = "sqlite")]
613    async fn test_update_application_and_broadcast() {
614        let db = create_test_db().await;
615        let handler = DocHandler::new("test-room".to_string(), db).await;
616
617        // Subscribe to broadcasts
618        let mut rx = handler.subscribe();
619
620        // Create an update from a client doc
621        let client_doc = Doc::new();
622        let update = {
623            let text = client_doc.get_or_insert_text("test");
624            let mut txn = client_doc.transact_mut();
625            text.push(&mut txn, "Hello, World!");
626            txn.encode_update_v1()
627        };
628
629        // Send as SyncMessage::Update
630        let payload = encode_update(&update);
631        let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
632
633        let _responses = handler.handle_message(&msg).await;
634
635        // Should have broadcast the update
636        let broadcast = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
637        assert!(broadcast.is_ok());
638        let broadcast_data = broadcast.unwrap().unwrap();
639
640        // Verify broadcast format
641        let (_, name) = DocHandler::read_and_skip_doc_name(&broadcast_data).unwrap();
642        assert_eq!(name, "test-room");
643    }
644
645    #[tokio::test]
646    #[cfg(feature = "sqlite")]
647    async fn test_persistence_after_update() {
648        let db = create_test_db().await;
649        let handler = DocHandler::new("test-room".to_string(), db.clone()).await;
650
651        // Create an update
652        let client_doc = Doc::new();
653        let update = {
654            let text = client_doc.get_or_insert_text("test");
655            let mut txn = client_doc.transact_mut();
656            text.push(&mut txn, "Persistent data");
657            txn.encode_update_v1()
658        };
659
660        // Send update
661        let payload = encode_update(&update);
662        let msg = encode_test_msg("test-room", MSG_SYNC, &payload);
663
664        let _responses = handler.handle_message(&msg).await;
665
666        // Force persist (normally debounced)
667        handler.force_persist().await;
668
669        // Verify data was saved
670        let saved = db.get_doc("test-room").await.unwrap();
671        assert!(saved.is_some());
672        assert!(!saved.unwrap().is_empty());
673    }
674
675    #[tokio::test]
676    #[cfg(feature = "sqlite")]
677    async fn test_document_reload_from_db() {
678        let db = create_test_db().await;
679
680        // Create a handler and add some data
681        let handler1 = DocHandler::new("reload-test".to_string(), db.clone()).await;
682
683        let client_doc = Doc::new();
684        let update = {
685            let text = client_doc.get_or_insert_text("content");
686            let mut txn = client_doc.transact_mut();
687            text.push(&mut txn, "Test content for reload");
688            txn.encode_update_v1()
689        };
690
691        let payload = encode_update(&update);
692        let msg = encode_test_msg("reload-test", MSG_SYNC, &payload);
693
694        handler1.handle_message(&msg).await;
695        handler1.force_persist().await;
696
697        // Drop the first handler
698        drop(handler1);
699
700        // Create a new handler for the same room - should load from DB
701        let handler2 = DocHandler::new("reload-test".to_string(), db).await;
702
703        // Verify the document has the content
704        let doc = handler2.doc.lock().unwrap();
705        let text = doc.get_or_insert_text("content");
706        let txn = doc.transact();
707        let content = text.get_string(&txn);
708
709        assert_eq!(content, "Test content for reload");
710    }
711
712    #[tokio::test]
713    #[cfg(feature = "sqlite")]
714    async fn test_awareness_forwarding() {
715        let db = create_test_db().await;
716        let handler = DocHandler::new("test-room".to_string(), db).await;
717
718        // Subscribe to broadcasts
719        let mut rx = handler.subscribe();
720
721        // Create a fake awareness message
722        let body = vec![1, 2, 3, 4];
723        let awareness_msg = encode_test_msg("test-room", MSG_AWARENESS, &body);
724
725        let _responses = handler.handle_message(&awareness_msg).await;
726
727        // Should have broadcast the awareness message
728        let broadcast = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
729        assert!(broadcast.is_ok());
730        let received = broadcast.unwrap().unwrap();
731
732        // Should effectively be identical to input since we re-wrap with same doc name
733        assert_eq!(received, awareness_msg);
734    }
735
736    #[tokio::test]
737    #[cfg(feature = "sqlite")]
738    async fn test_empty_message_handling() {
739        let db = create_test_db().await;
740        let handler = DocHandler::new("test-room".to_string(), db).await;
741
742        // Empty message should return empty responses
743        let responses = handler.handle_message(&[]).await;
744        assert!(responses.is_empty());
745    }
746
747    #[tokio::test]
748    #[cfg(not(feature = "sqlite"))]
749    async fn test_doc_handler_no_sqlite() {
750        let handler = DocHandler::new("test-room-no-db".to_string()).await;
751        assert_eq!(handler.doc_name, "test-room-no-db");
752        
753        // Basic sync generation should still work
754        let messages = handler.generate_initial_sync();
755        assert_eq!(messages.len(), 1);
756        
757        let (_, name) = DocHandler::read_and_skip_doc_name(&messages[0]).unwrap();
758        assert_eq!(name, "test-room-no-db");
759    }
760}