loro_websocket_server/
lib.rs

1//! Loro WebSocket Server (simple skeleton)
2//!
3//! Minimal async WebSocket server that accepts connections and echoes binary
4//! protocol frames back to clients. It also responds to text "ping" with
5//! text "pong" as described in protocol.md keepalive section.
6//!
7//! This is intentionally simple and is meant as a starting point. Application
8//! logic (authorization, room routing, broadcasting, etc.) should be layered
9//! on top using the `loro_protocol` crate for message encoding/decoding.
10//!
11//! Example (not run here because it binds a socket):
12//! ```no_run
13//! # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
14//! #   let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
15//! #   rt.block_on(async move {
16//! loro_websocket_server::serve("127.0.0.1:9000").await?;
17//! #   Ok(())
18//! # })
19//! # }
20//! ```
21
22use futures_util::{SinkExt, StreamExt};
23use std::{
24    collections::{HashMap, HashSet},
25    future::Future,
26    hash::{Hash, Hasher},
27    pin::Pin,
28    sync::{
29        atomic::{AtomicU64, Ordering},
30        Arc,
31    },
32    time::Duration,
33};
34use tokio::{
35    net::{TcpListener, TcpStream},
36    sync::mpsc,
37};
38use tokio_tungstenite::accept_hdr_async;
39use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
40use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
41use tokio_tungstenite::tungstenite::{self, Message};
42
43use loro::awareness::EphemeralStore;
44use loro::{ExportMode, LoroDoc};
45pub use loro_protocol as protocol;
46use protocol::{try_decode, BytesReader, CrdtType, JoinErrorCode, Permission, ProtocolMessage};
47use tracing::{debug, error, info, warn};
48
49// Limits to protect server memory from abusive fragment headers
50const MAX_FRAGMENTS: u64 = 4096; // hard cap on number of fragments per batch
51const MAX_BATCH_BYTES: u64 = 64 * 1024 * 1024; // 64 MiB per batch
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54struct RoomKey {
55    crdt: CrdtType,
56    room: String,
57}
58impl Hash for RoomKey {
59    fn hash<H: Hasher>(&self, state: &mut H) {
60        // CrdtType is repr as enum with a few variants; map to u8 for hashing
61        let tag = match self.crdt {
62            CrdtType::Loro => 0u8,
63            CrdtType::LoroEphemeralStore => 1,
64            CrdtType::LoroEphemeralStorePersisted => 2,
65            CrdtType::Yjs => 3,
66            CrdtType::YjsAwareness => 4,
67            CrdtType::Elo => 5,
68        };
69        tag.hash(state);
70        self.room.hash(state);
71    }
72}
73
74type Sender = mpsc::UnboundedSender<Message>;
75
76// Hook types
77/// Snapshot payload returned by `on_load_document` alongside optional metadata
78/// that will be passed through to `on_save_document`.
79pub struct LoadedDoc<DocCtx> {
80    pub snapshot: Option<Vec<u8>>,
81    pub ctx: Option<DocCtx>,
82}
83
84/// Arguments provided to `on_load_document`.
85pub struct LoadDocArgs {
86    pub workspace: String,
87    pub room: String,
88    pub crdt: CrdtType,
89}
90
91/// Arguments provided to `on_save_document`.
92pub struct SaveDocArgs<DocCtx> {
93    pub workspace: String,
94    pub room: String,
95    pub crdt: CrdtType,
96    pub data: Vec<u8>,
97    pub ctx: Option<DocCtx>,
98}
99
100type LoadFuture<DocCtx> =
101    Pin<Box<dyn Future<Output = Result<LoadedDoc<DocCtx>, String>> + Send + 'static>>;
102type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
103type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
104type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;
105type AuthFuture =
106    Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
107type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
108
109type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync;
110
111#[derive(Clone)]
112pub struct ServerConfig<DocCtx = ()> {
113    pub on_load_document: Option<LoadFn<DocCtx>>,
114    pub on_save_document: Option<SaveFn<DocCtx>>,
115    pub save_interval_ms: Option<u64>,
116    pub default_permission: Permission,
117    pub authenticate: Option<AuthFn>,
118    /// Optional handshake auth: called during WS HTTP upgrade.
119    ///
120    /// Parameters:
121    /// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing)
122    /// - `token`: `token` query parameter if present
123    ///
124    /// Return true to accept, false to reject with 401.
125    pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
126}
127
128// CRDT document abstraction to reduce match-based branching
129trait CrdtDoc: Send {
130    fn get_version(&self) -> Vec<u8> {
131        Vec::new()
132    }
133    fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
134        Vec::new()
135    }
136    fn apply_updates(&mut self, _updates: &[Vec<u8>]) -> Result<(), String> {
137        Ok(())
138    }
139    fn should_persist(&self) -> bool {
140        false
141    }
142    fn export_snapshot(&self) -> Option<Vec<u8>> {
143        None
144    }
145    fn import_snapshot(&mut self, _data: &[u8]) {}
146    fn allow_backfill_when_no_other_clients(&self) -> bool {
147        false
148    }
149    fn remove_when_last_subscriber_leaves(&self) -> bool {
150        false
151    }
152}
153
154struct LoroRoomDoc {
155    doc: LoroDoc,
156}
157impl LoroRoomDoc {
158    fn new() -> Self {
159        Self {
160            doc: LoroDoc::new(),
161        }
162    }
163}
164impl CrdtDoc for LoroRoomDoc {
165    fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
166        for u in updates {
167            let _ = self.doc.import(u);
168        }
169        Ok(())
170    }
171    fn should_persist(&self) -> bool {
172        true
173    }
174    fn export_snapshot(&self) -> Option<Vec<u8>> {
175        self.doc.export(ExportMode::Snapshot).ok()
176    }
177    fn import_snapshot(&mut self, data: &[u8]) {
178        let _ = self.doc.import(data);
179    }
180}
181
182struct EphemeralRoomDoc {
183    store: EphemeralStore,
184}
185impl EphemeralRoomDoc {
186    fn new(timeout_ms: i64) -> Self {
187        Self {
188            store: EphemeralStore::new(timeout_ms),
189        }
190    }
191}
192impl CrdtDoc for EphemeralRoomDoc {
193    fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
194        let data = self.store.encode_all();
195        if data.is_empty() {
196            Vec::new()
197        } else {
198            vec![data]
199        }
200    }
201    fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
202        for u in updates {
203            if !u.is_empty() {
204                self.store.apply(u);
205            }
206        }
207        Ok(())
208    }
209    fn remove_when_last_subscriber_leaves(&self) -> bool {
210        true
211    }
212}
213
214struct PersistentEphemeralRoomDoc {
215    store: EphemeralStore,
216    timeout_ms: i64,
217}
218impl PersistentEphemeralRoomDoc {
219    fn new(timeout_ms: i64) -> Self {
220        Self {
221            store: EphemeralStore::new(timeout_ms),
222            timeout_ms,
223        }
224    }
225}
226impl CrdtDoc for PersistentEphemeralRoomDoc {
227    fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
228        let data = self.store.encode_all();
229        if data.is_empty() {
230            Vec::new()
231        } else {
232            vec![data]
233        }
234    }
235    fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
236        for u in updates {
237            if !u.is_empty() {
238                self.store.apply(u);
239            }
240        }
241        Ok(())
242    }
243    fn should_persist(&self) -> bool {
244        true
245    }
246    fn export_snapshot(&self) -> Option<Vec<u8>> {
247        Some(self.store.encode_all())
248    }
249    fn import_snapshot(&mut self, data: &[u8]) {
250        self.store = EphemeralStore::new(self.timeout_ms);
251        if !data.is_empty() {
252            self.store.apply(data);
253        }
254    }
255    fn allow_backfill_when_no_other_clients(&self) -> bool {
256        true
257    }
258}
259
260// ELO header index entries
261struct EloDeltaSpanIndexEntry {
262    start: u64,
263    end: u64,
264    key_id: String,
265    record: Vec<u8>,
266}
267
268struct EloRoomDoc {
269    spans_by_peer: std::collections::HashMap<String, Vec<EloDeltaSpanIndexEntry>>,
270}
271impl EloRoomDoc {
272    fn new() -> Self {
273        Self {
274            spans_by_peer: std::collections::HashMap::new(),
275        }
276    }
277
278    fn peer_key_from_bytes(bytes: &[u8]) -> String {
279        // Prefer UTF-8 if valid, else hex
280        match std::str::from_utf8(bytes) {
281            Ok(s) => s.to_string(),
282            Err(_) => {
283                let mut out = String::with_capacity(bytes.len() * 2);
284                for b in bytes {
285                    use std::fmt::Write as _;
286                    let _ = write!(&mut out, "{:02x}", b);
287                }
288                out
289            }
290        }
291    }
292
293    fn decode_version_vector(&self, buf: &[u8]) -> Option<std::collections::HashMap<String, u64>> {
294        use loro_protocol::bytes::BytesReader;
295        let mut r = BytesReader::new(buf);
296        let count = usize::try_from(r.read_uleb128().ok()?).ok()?;
297        let mut map: std::collections::HashMap<String, u64> =
298            std::collections::HashMap::with_capacity(count);
299        for _ in 0..count {
300            let peer_bytes = r.read_var_bytes().ok()?;
301            let ctr = r.read_uleb128().ok()?;
302            map.insert(Self::peer_key_from_bytes(peer_bytes), ctr);
303        }
304        Some(map)
305    }
306
307    fn encode_current_vv(&self) -> Vec<u8> {
308        use loro_protocol::bytes::BytesWriter;
309        let mut entries: Vec<(String, u64)> = Vec::new();
310        for (peer, spans) in self.spans_by_peer.iter() {
311            if !peer.as_bytes().iter().all(|b| b.is_ascii_digit()) {
312                continue;
313            }
314            let mut max_end = 0u64;
315            for s in spans.iter() {
316                if s.end > max_end {
317                    max_end = s.end;
318                }
319            }
320            if max_end > 0 {
321                entries.push((peer.clone(), max_end));
322            }
323        }
324        let mut w = BytesWriter::new();
325        w.push_uleb128(entries.len() as u64);
326        for (peer, ctr) in entries.iter() {
327            w.push_var_bytes(peer.as_bytes());
328            w.push_uleb128(*ctr);
329        }
330        w.finalize()
331    }
332}
333impl CrdtDoc for EloRoomDoc {
334    fn get_version(&self) -> Vec<u8> {
335        // If we have no indexed entries yet, return an empty version to signal
336        // "unknown/empty" baseline so clients may choose to send a snapshot.
337        if self.spans_by_peer.is_empty() {
338            return Vec::new();
339        }
340        self.encode_current_vv()
341    }
342    fn compute_backfill(&self, client_version: &[u8]) -> Vec<Vec<u8>> {
343        let known = self
344            .decode_version_vector(client_version)
345            .unwrap_or_default();
346        let mut records: Vec<Vec<u8>> = Vec::new();
347        for (peer, spans) in self.spans_by_peer.iter() {
348            let k = known.get(peer).copied().unwrap_or(0);
349            for e in spans {
350                if e.end > k {
351                    records.push(e.record.clone());
352                }
353            }
354        }
355        if records.is_empty() {
356            return Vec::new();
357        }
358        let mut w = loro_protocol::bytes::BytesWriter::new();
359        w.push_uleb128(records.len() as u64);
360        for rec in records.iter() {
361            w.push_var_bytes(rec);
362        }
363        vec![w.finalize()]
364    }
365    fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
366        use loro_protocol::elo::{
367            decode_elo_container, parse_elo_record_header, EloHeader, EloRecordKind,
368        };
369        for u in updates {
370            let records = decode_elo_container(u.as_slice())?;
371            for rec in records {
372                let parsed = parse_elo_record_header(rec)?;
373                match parsed.kind {
374                    EloRecordKind::DeltaSpan => {
375                        if let EloHeader::Delta(h) = parsed.header {
376                            if !(h.end > h.start) {
377                                return Err("invalid ELO delta span: end must be > start".into());
378                            }
379                            if h.iv.len() != 12 {
380                                return Err("invalid ELO delta span: IV must be 12 bytes".into());
381                            }
382                            let peer = Self::peer_key_from_bytes(&h.peer_id);
383                            let list = self.spans_by_peer.entry(peer).or_default();
384                            // Insert keeping order by start; remove fully covered entries [start, end]
385                            let mut kept: Vec<EloDeltaSpanIndexEntry> =
386                                Vec::with_capacity(list.len() + 1);
387                            let mut inserted = false;
388                            for e in list.iter() {
389                                if !inserted && e.start >= h.start {
390                                    kept.push(EloDeltaSpanIndexEntry {
391                                        start: h.start,
392                                        end: h.end,
393                                        key_id: h.key_id.clone(),
394                                        record: rec.to_vec(),
395                                    });
396                                    inserted = true;
397                                }
398                                // keep entries not fully covered by [start, end]
399                                let covered = e.start >= h.start && e.end <= h.end;
400                                if !covered {
401                                    kept.push(EloDeltaSpanIndexEntry {
402                                        start: e.start,
403                                        end: e.end,
404                                        key_id: e.key_id.clone(),
405                                        record: e.record.clone(),
406                                    });
407                                }
408                            }
409                            if !inserted {
410                                kept.push(EloDeltaSpanIndexEntry {
411                                    start: h.start,
412                                    end: h.end,
413                                    key_id: h.key_id.clone(),
414                                    record: rec.to_vec(),
415                                });
416                            }
417                            *list = kept;
418                        }
419                    }
420                    EloRecordKind::Snapshot => {
421                        // Snapshot header validation already done by parser; no indexing needed
422                    }
423                }
424            }
425        }
426        Ok(())
427    }
428    fn allow_backfill_when_no_other_clients(&self) -> bool {
429        true
430    }
431}
432impl<DocCtx> Default for ServerConfig<DocCtx> {
433    fn default() -> Self {
434        Self {
435            on_load_document: None,
436            on_save_document: None,
437            save_interval_ms: None,
438            default_permission: Permission::Write,
439            authenticate: None,
440            handshake_auth: None,
441        }
442    }
443}
444
445struct RoomDocState<DocCtx> {
446    doc: Box<dyn CrdtDoc>,
447    dirty: bool,
448    ctx: Option<DocCtx>,
449}
450
451struct Hub<DocCtx> {
452    // room -> vec of (conn_id, sender)
453    subs: HashMap<RoomKey, Vec<(u64, Sender)>>,
454    // room -> document state (Loro persistent, Ephemeral in-memory, Elo index)
455    docs: HashMap<RoomKey, RoomDocState<DocCtx>>,
456    config: ServerConfig<DocCtx>,
457    // (conn_id, room) -> permission
458    perms: HashMap<(u64, RoomKey), Permission>,
459    workspace: String,
460    // Fragment reassembly state: per room + batch id
461    fragments: HashMap<(RoomKey, protocol::BatchId), FragmentBatch>,
462}
463
464impl<DocCtx> Hub<DocCtx>
465where
466    DocCtx: Clone + Send + Sync + 'static,
467{
468    fn new(config: ServerConfig<DocCtx>, workspace: String) -> Self {
469        Self {
470            subs: HashMap::new(),
471            docs: HashMap::new(),
472            config,
473            perms: HashMap::new(),
474            workspace,
475            fragments: HashMap::new(),
476        }
477    }
478
479    const EPHEMERAL_TIMEOUT_MS: i64 = 60_000;
480
481    fn join(&mut self, conn_id: u64, room: RoomKey, tx: &Sender) {
482        let entry = self.subs.entry(room).or_default();
483        if !entry.iter().any(|(id, _)| *id == conn_id) {
484            entry.push((conn_id, tx.clone()));
485        }
486    }
487
488    fn leave_all(&mut self, conn_id: u64) {
489        let mut emptied: Vec<RoomKey> = Vec::new();
490        for (k, vec) in self.subs.iter_mut() {
491            vec.retain(|(id, _)| *id != conn_id);
492            if vec.is_empty() {
493                emptied.push(k.clone());
494            }
495        }
496        // Drop empty rooms from subscription map
497        for k in &emptied {
498            let _ = self.subs.remove(k);
499        }
500
501        // Remove permissions for this connection
502        self.perms.retain(|(id, _), _| *id != conn_id);
503
504        // Clean up ephemeral state for rooms that no longer have subscribers
505        for k in emptied.clone() {
506            if let Some(state) = self.docs.get(&k) {
507                if state.doc.remove_when_last_subscriber_leaves() {
508                    self.docs.remove(&k);
509                    debug!(room=?k.room, "cleaned up ephemeral doc after last subscriber left");
510                }
511            }
512        }
513
514        // Clean up in-flight fragment batches started by this connection, or for rooms now emptied
515        if !self.fragments.is_empty() {
516            use std::collections::HashSet;
517            let emptied_set: HashSet<RoomKey> = emptied.into_iter().collect();
518            self.fragments
519                .retain(|(rk, _), b| b.from_conn != conn_id && !emptied_set.contains(rk));
520        }
521    }
522
523    fn broadcast(&mut self, room: &RoomKey, from: u64, msg: Message) {
524        if let Some(list) = self.subs.get_mut(room) {
525            // drop dead senders
526            let mut dead: HashSet<u64> = HashSet::new();
527            for (id, tx) in list.iter() {
528                if *id == from {
529                    continue;
530                }
531                if tx.send(msg.clone()).is_err() {
532                    dead.insert(*id);
533                }
534            }
535            if !dead.is_empty() {
536                list.retain(|(id, _)| !dead.contains(id));
537                debug!(room=?room.room, removed=%dead.len(), "removed dead subscribers");
538            }
539        }
540    }
541
542    async fn ensure_room_loaded(&mut self, room: &RoomKey) {
543        if self.docs.contains_key(room) {
544            return;
545        }
546        match room.crdt {
547            CrdtType::Loro => {
548                let mut d = LoroRoomDoc::new();
549                let mut ctx = None;
550                if let Some(loader) = &self.config.on_load_document {
551                    let args = LoadDocArgs {
552                        workspace: self.workspace.clone(),
553                        room: room.room.clone(),
554                        crdt: room.crdt,
555                    };
556                    match (loader)(args).await {
557                        Ok(loaded) => {
558                            if let Some(bytes) = loaded.snapshot {
559                                d.import_snapshot(&bytes);
560                            }
561                            ctx = loaded.ctx;
562                        }
563                        Err(e) => {
564                            warn!(room=?room.room, %e, "load document failed");
565                        }
566                    }
567                }
568                self.docs.insert(
569                    room.clone(),
570                    RoomDocState {
571                        doc: Box::new(d),
572                        dirty: false,
573                        ctx,
574                    },
575                );
576            }
577            CrdtType::LoroEphemeralStore => {
578                let d = EphemeralRoomDoc::new(Self::EPHEMERAL_TIMEOUT_MS);
579                self.docs.insert(
580                    room.clone(),
581                    RoomDocState {
582                        doc: Box::new(d),
583                        dirty: false,
584                        ctx: None,
585                    },
586                );
587            }
588            CrdtType::LoroEphemeralStorePersisted => {
589                let mut d = PersistentEphemeralRoomDoc::new(Self::EPHEMERAL_TIMEOUT_MS);
590                let mut ctx = None;
591                if let Some(loader) = &self.config.on_load_document {
592                    let args = LoadDocArgs {
593                        workspace: self.workspace.clone(),
594                        room: room.room.clone(),
595                        crdt: room.crdt,
596                    };
597                    match (loader)(args).await {
598                        Ok(loaded) => {
599                            if let Some(bytes) = loaded.snapshot {
600                                d.import_snapshot(&bytes);
601                            }
602                            ctx = loaded.ctx;
603                        }
604                        Err(e) => {
605                            warn!(room=?room.room, %e, "load persisted ephemeral store failed");
606                        }
607                    }
608                }
609                self.docs.insert(
610                    room.clone(),
611                    RoomDocState {
612                        doc: Box::new(d),
613                        dirty: false,
614                        ctx,
615                    },
616                );
617            }
618            CrdtType::Elo => {
619                let d = EloRoomDoc::new();
620                self.docs.insert(
621                    room.clone(),
622                    RoomDocState {
623                        doc: Box::new(d),
624                        dirty: false,
625                        ctx: None,
626                    },
627                );
628            }
629            _ => {}
630        }
631    }
632
633    fn current_version_bytes(&self, room: &RoomKey) -> Vec<u8> {
634        match self.docs.get(room) {
635            Some(state) => state.doc.get_version(),
636            None => Vec::new(),
637        }
638    }
639
640    fn apply_updates(&mut self, room: &RoomKey, updates: &[Vec<u8>]) {
641        if let Some(state) = self.docs.get_mut(room) {
642            if let Err(e) = state.doc.apply_updates(updates) {
643                warn!(room=?room.room, %e, "apply_updates failed");
644            } else if state.doc.should_persist() {
645                state.dirty = true;
646            }
647        }
648    }
649
650    fn snapshot_bytes(&self, room: &RoomKey) -> Option<Vec<u8>> {
651        let Some(data) = self.docs.get(room).and_then(|s| s.doc.export_snapshot()) else {
652            return None;
653        };
654        if data.is_empty() {
655            None
656        } else {
657            Some(data)
658        }
659    }
660}
661
662struct FragmentBatch {
663    from_conn: u64,
664    fragment_count: u64,
665    total_size: u64,
666    received: u64,
667    chunks: Vec<Option<Vec<u8>>>,
668}
669
670impl<DocCtx> Hub<DocCtx>
671where
672    DocCtx: Clone + Send + Sync + 'static,
673{
674    fn start_fragment_batch(
675        &mut self,
676        room: &RoomKey,
677        from_conn: u64,
678        batch_id: protocol::BatchId,
679        fragment_count: u64,
680        total_size: u64,
681    ) {
682        let key = (room.clone(), batch_id);
683        let chunks_len = usize::try_from(fragment_count).unwrap_or(0);
684        let batch = FragmentBatch {
685            from_conn,
686            fragment_count,
687            total_size,
688            received: 0,
689            chunks: vec![None; chunks_len],
690        };
691        self.fragments.insert(key, batch);
692    }
693
694    /// Returns Some(reassembled) when complete; removes batch.
695    fn add_fragment_and_maybe_finish(
696        &mut self,
697        room: &RoomKey,
698        batch_id: protocol::BatchId,
699        index: u64,
700        fragment: Vec<u8>,
701    ) -> Option<Vec<u8>> {
702        let key = (room.clone(), batch_id);
703        if let Some(b) = self.fragments.get_mut(&key) {
704            let idx = match usize::try_from(index) {
705                Ok(i) => i,
706                Err(_) => return None,
707            };
708            if idx >= b.chunks.len() {
709                return None;
710            }
711            if b.chunks[idx].is_none() {
712                b.chunks[idx] = Some(fragment);
713                b.received += 1;
714            }
715            if b.received == b.fragment_count {
716                let mut out = Vec::with_capacity(b.total_size as usize);
717                for ch in b.chunks.iter() {
718                    if let Some(bytes) = ch.as_ref() {
719                        out.extend_from_slice(bytes);
720                    }
721                }
722                self.fragments.remove(&key);
723                return Some(out);
724            }
725        }
726        None
727    }
728}
729
730static NEXT_ID: AtomicU64 = AtomicU64::new(1);
731
732struct HubRegistry<DocCtx> {
733    config: ServerConfig<DocCtx>,
734    hubs: tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<Hub<DocCtx>>>>>,
735}
736
737impl<DocCtx> HubRegistry<DocCtx>
738where
739    DocCtx: Clone + Send + Sync + 'static,
740{
741    fn new(config: ServerConfig<DocCtx>) -> Self {
742        Self {
743            config,
744            hubs: tokio::sync::Mutex::new(HashMap::new()),
745        }
746    }
747
748    async fn get_or_create(&self, workspace: &str) -> Arc<tokio::sync::Mutex<Hub<DocCtx>>> {
749        let mut map = self.hubs.lock().await;
750        if let Some(h) = map.get(workspace) {
751            return h.clone();
752        }
753        let hub = Arc::new(tokio::sync::Mutex::new(Hub::new(
754            self.config.clone(),
755            workspace.to_string(),
756        )));
757        // Spawn saver task for this hub if configured
758        if let (Some(ms), Some(saver)) = (
759            self.config.save_interval_ms,
760            self.config.on_save_document.clone(),
761        ) {
762            let hub_clone = hub.clone();
763            tokio::spawn(async move {
764                let mut interval = tokio::time::interval(Duration::from_millis(ms));
765                loop {
766                    interval.tick().await;
767                    let mut guard = hub_clone.lock().await;
768                    let ws = guard.workspace.clone();
769                    let rooms: Vec<RoomKey> = guard.docs.keys().cloned().collect();
770                    for room in rooms {
771                        if let Some(state) = guard.docs.get_mut(&room) {
772                            if state.dirty && state.doc.should_persist() {
773                                let start = std::time::Instant::now();
774                                if let Some(snapshot) = state.doc.export_snapshot() {
775                                    let room_str = room.room.clone();
776                                    let ctx = state.ctx.clone();
777                                    let args = SaveDocArgs {
778                                        workspace: ws.clone(),
779                                        room: room_str.clone(),
780                                        crdt: room.crdt,
781                                        data: snapshot,
782                                        ctx,
783                                    };
784                                    match (saver)(args).await {
785                                        Ok(()) => {
786                                            state.dirty = false;
787                                            let elapsed = start.elapsed();
788                                            debug!(workspace=%ws, room=%room_str, ms=%elapsed.as_millis(), "snapshot saved");
789                                        }
790                                        Err(e) => {
791                                            warn!(workspace=%ws, room=%room_str, %e, "snapshot save failed");
792                                        }
793                                    }
794                                }
795                            }
796                        }
797                    }
798                }
799            });
800        }
801        map.insert(workspace.to_string(), hub.clone());
802        hub
803    }
804}
805
806/// Start a simple broadcast server on the given socket address.
807pub async fn serve(addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
808    info!(%addr, "binding TCP listener");
809    let listener = TcpListener::bind(addr).await?;
810    serve_incoming_with_config::<()>(listener, ServerConfig::default()).await
811}
812
813/// Serve a pre-bound listener. Useful for tests to bind on port 0.
814pub async fn serve_incoming(
815    listener: TcpListener,
816) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
817    serve_incoming_with_config::<()>(listener, ServerConfig::default()).await
818}
819
820pub async fn serve_incoming_with_config<DocCtx>(
821    listener: TcpListener,
822    config: ServerConfig<DocCtx>,
823) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
824where
825    DocCtx: Clone + Send + Sync + 'static,
826{
827    let registry = Arc::new(HubRegistry::new(config.clone()));
828
829    loop {
830        match listener.accept().await {
831            Ok((stream, peer)) => {
832                debug!(remote=%peer, "accepted TCP connection");
833                let registry = registry.clone();
834                tokio::spawn(async move {
835                    if let Err(e) = handle_conn(stream, registry).await {
836                        warn!(%e, "connection task ended with error");
837                    }
838                });
839            }
840            Err(e) => {
841                error!(%e, "accept failed; continuing");
842                continue;
843            }
844        }
845    }
846}
847
848async fn handle_conn<DocCtx>(
849    stream: TcpStream,
850    registry: Arc<HubRegistry<DocCtx>>,
851) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
852where
853    DocCtx: Clone + Send + Sync + 'static,
854{
855    // Capture config outside of non-async closure
856    let handshake_auth = registry.config.handshake_auth.clone();
857    let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
858        Arc::new(std::sync::Mutex::new(None));
859    let workspace_holder_c = workspace_holder.clone();
860
861    let ws = accept_hdr_async(
862        stream,
863        move |req: &tungstenite::handshake::server::Request,
864              resp: tungstenite::handshake::server::Response| {
865            if let Some(check) = &handshake_auth {
866                // Parse path: expect "/{workspace}" (workspace may be empty)
867                let uri = req.uri();
868                let path = uri.path();
869                let mut workspace_id = "";
870                if let Some(rest) = path.strip_prefix('/') {
871                    if !rest.is_empty() {
872                        // take first segment as workspace id
873                        workspace_id = rest.split('/').next().unwrap_or("");
874                    }
875                }
876                // Save for later
877                {
878                    if let Ok(mut guard) = workspace_holder_c.lock() {
879                        *guard = Some(workspace_id.to_string());
880                    }
881                }
882
883                // Parse query token parameter (no external deps)
884                let token = uri.query().and_then(|q| {
885                    for pair in q.split('&') {
886                        let mut it = pair.splitn(2, '=');
887                        let k = it.next().unwrap_or("");
888                        let v = it.next();
889                        if k == "token" {
890                            return Some(v.unwrap_or(""));
891                        }
892                    }
893                    None
894                });
895
896                let allowed = (check)(workspace_id, token);
897                if !allowed {
898                    warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
899                    // Build a 401 Unauthorized response
900                    let builder = tungstenite::http::Response::builder()
901                        .status(tungstenite::http::StatusCode::UNAUTHORIZED);
902                    // Provide a small body for clarity
903                    let response = builder
904                        .body(Some("Unauthorized".to_string()))
905                        .unwrap_or_else(|e| {
906                            warn!(?e, "failed to build unauthorized response");
907                            let mut fallback =
908                                tungstenite::http::Response::new(Some("Unauthorized".to_string()));
909                            *fallback.status_mut() = tungstenite::http::StatusCode::UNAUTHORIZED;
910                            fallback
911                        });
912                    return Err(response);
913                }
914                debug!(workspace=%workspace_id, token=?token, "handshake auth accepted");
915            }
916            Ok(resp)
917        },
918    )
919    .await?;
920
921    // Determine workspace id (default to empty string)
922    let workspace_id = workspace_holder
923        .lock()
924        .ok()
925        .and_then(|g| g.clone())
926        .unwrap_or_default();
927    let hub = registry.get_or_create(&workspace_id).await;
928
929    // writer task channel
930    let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
931    let (mut sink, mut stream) = ws.split();
932    // writer
933    let sink_task = tokio::spawn(async move {
934        while let Some(msg) = rx.recv().await {
935            if sink.send(msg).await.is_err() {
936                debug!("sink send error; writer task exiting");
937                break;
938            }
939        }
940    });
941
942    let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
943    let mut joined_rooms: HashSet<RoomKey> = HashSet::new();
944
945    while let Some(msg) = stream.next().await {
946        match msg? {
947            Message::Text(txt) => {
948                if txt == "ping" {
949                    let _ = tx.send(Message::Text("pong".into()));
950                }
951            }
952            Message::Binary(data) => {
953                if let Some(proto) = try_decode(data.as_ref()) {
954                    match proto {
955                        ProtocolMessage::JoinRequest {
956                            crdt,
957                            room_id,
958                            auth,
959                            version,
960                        } => {
961                            let room = RoomKey {
962                                crdt,
963                                room: room_id.clone(),
964                            };
965                            let mut h = hub.lock().await;
966                            // ensure doc exists / load
967                            h.ensure_room_loaded(&room).await;
968                            // authenticate
969                            let mut permission = h.config.default_permission;
970                            if let Some(auth_fn) = &h.config.authenticate {
971                                let room_str = room.room.clone();
972                                match (auth_fn)(room_str, room.crdt, auth.clone()).await {
973                                    Ok(Some(p)) => {
974                                        permission = p;
975                                    }
976                                    Ok(None) => {
977                                        let err = ProtocolMessage::JoinError {
978                                            crdt,
979                                            room_id: room.room.clone(),
980                                            code: JoinErrorCode::AuthFailed,
981                                            message: "Authentication failed".into(),
982                                            receiver_version: None,
983                                            app_code: None,
984                                        };
985                                        if let Ok(bytes) = loro_protocol::encode(&err) {
986                                            let _ = tx.send(Message::Binary(bytes.into()));
987                                        }
988                                        warn!(room=?room.room, "join denied by authenticate() returning None");
989                                        continue;
990                                    }
991                                    Err(e) => {
992                                        let err = ProtocolMessage::JoinError {
993                                            crdt,
994                                            room_id: room.room.clone(),
995                                            code: JoinErrorCode::Unknown,
996                                            message: e,
997                                            receiver_version: None,
998                                            app_code: None,
999                                        };
1000                                        if let Ok(bytes) = loro_protocol::encode(&err) {
1001                                            let _ = tx.send(Message::Binary(bytes.into()));
1002                                        }
1003                                        warn!(room=?room.room, "join denied due to authenticate() error");
1004                                        continue;
1005                                    }
1006                                }
1007                            }
1008                            // register subscriber and record permission
1009                            h.join(conn_id, room.clone(), &tx);
1010                            h.perms.insert((conn_id, room.clone()), permission);
1011                            joined_rooms.insert(room.clone());
1012                            info!(workspace=%h.workspace, room=?room.room, ?permission, "join ok");
1013                            // respond ok with current version and empty extra
1014                            let current_version = h.current_version_bytes(&room);
1015                            let ok = ProtocolMessage::JoinResponseOk {
1016                                crdt,
1017                                room_id: room.room.clone(),
1018                                permission,
1019                                version: current_version,
1020                                extra: Some(Vec::new()),
1021                            };
1022                            if let Ok(bytes) = loro_protocol::encode(&ok) {
1023                                let _ = tx.send(Message::Binary(bytes.into()));
1024                            }
1025                            // send initial state:
1026                            // - If snapshot available (Loro), send as a DocUpdate.
1027                            if let Some(snap) = h.snapshot_bytes(&room) {
1028                                let du = ProtocolMessage::DocUpdate {
1029                                    crdt,
1030                                    room_id: room.room.clone(),
1031                                    updates: vec![snap],
1032                                };
1033                                if let Ok(bytes) = loro_protocol::encode(&du) {
1034                                    let _ = tx.send(Message::Binary(bytes.into()));
1035                                    debug!(room=?room.room, "sent initial snapshot after join");
1036                                }
1037                            } else {
1038                                // Otherwise, attempt backfill if other clients present or the CRDT allows
1039                                let others_in_room =
1040                                    h.subs.get(&room).map(|v| v.len()).unwrap_or(0) > 1;
1041                                let allow_when_empty = h
1042                                    .docs
1043                                    .get(&room)
1044                                    .map(|s| s.doc.allow_backfill_when_no_other_clients())
1045                                    .unwrap_or(false);
1046                                if others_in_room || allow_when_empty {
1047                                    let backfill = h
1048                                        .docs
1049                                        .get(&room)
1050                                        .map(|s| s.doc.compute_backfill(&version))
1051                                        .unwrap_or_default();
1052                                    let backfill_cnt = backfill.len();
1053                                    for u in backfill {
1054                                        let du = ProtocolMessage::DocUpdate {
1055                                            crdt,
1056                                            room_id: room.room.clone(),
1057                                            updates: vec![u],
1058                                        };
1059                                        if let Ok(bytes) = loro_protocol::encode(&du) {
1060                                            let _ = tx.send(Message::Binary(bytes.into()));
1061                                        }
1062                                    }
1063                                    if backfill_cnt > 0 {
1064                                        debug!(room=?room.room, cnt=%backfill_cnt, "sent backfill after join");
1065                                    }
1066                                }
1067                            }
1068                        }
1069                        ProtocolMessage::DocUpdateFragmentHeader {
1070                            crdt,
1071                            room_id,
1072                            batch_id,
1073                            fragment_count,
1074                            total_size_bytes,
1075                        } => {
1076                            let room = RoomKey {
1077                                crdt,
1078                                room: room_id.clone(),
1079                            };
1080                            if !joined_rooms.contains(&room) {
1081                                let err = ProtocolMessage::UpdateError {
1082                                    crdt,
1083                                    room_id: room.room.clone(),
1084                                    code: protocol::UpdateErrorCode::PermissionDenied,
1085                                    message: "Must join room before sending updates".into(),
1086                                    batch_id: Some(batch_id),
1087                                    app_code: None,
1088                                };
1089                                if let Ok(bytes) = loro_protocol::encode(&err) {
1090                                    let _ = tx.send(Message::Binary(bytes.into()));
1091                                }
1092                                continue;
1093                            }
1094                            // Permission check
1095                            let perm = hub
1096                                .lock()
1097                                .await
1098                                .perms
1099                                .get(&(conn_id, room.clone()))
1100                                .copied();
1101                            if !matches!(perm, Some(Permission::Write)) {
1102                                let err = ProtocolMessage::UpdateError {
1103                                    crdt,
1104                                    room_id: room.room.clone(),
1105                                    code: protocol::UpdateErrorCode::PermissionDenied,
1106                                    message: "Write permission required to update document".into(),
1107                                    batch_id: Some(batch_id),
1108                                    app_code: None,
1109                                };
1110                                if let Ok(bytes) = loro_protocol::encode(&err) {
1111                                    let _ = tx.send(Message::Binary(bytes.into()));
1112                                }
1113                                continue;
1114                            }
1115                            // Bounds checks
1116                            if fragment_count == 0
1117                                || fragment_count > MAX_FRAGMENTS
1118                                || total_size_bytes > MAX_BATCH_BYTES
1119                            {
1120                                let err = ProtocolMessage::UpdateError {
1121                                    crdt,
1122                                    room_id: room.room.clone(),
1123                                    code: protocol::UpdateErrorCode::PayloadTooLarge,
1124                                    message: "Fragmented batch exceeds server limits".into(),
1125                                    batch_id: Some(batch_id),
1126                                    app_code: None,
1127                                };
1128                                if let Ok(bytes) = loro_protocol::encode(&err) {
1129                                    let _ = tx.send(Message::Binary(bytes.into()));
1130                                }
1131                                continue;
1132                            }
1133                            // Initialize batch (guard against hijack by another sender)
1134                            let mut h = hub.lock().await;
1135                            let key = (room.clone(), batch_id);
1136                            if let Some(existing) = h.fragments.get(&key) {
1137                                if existing.from_conn != conn_id {
1138                                    let err = ProtocolMessage::UpdateError {
1139                                        crdt,
1140                                        room_id: room.room.clone(),
1141                                        code: protocol::UpdateErrorCode::InvalidUpdate,
1142                                        message: "Batch ID already in use by another sender".into(),
1143                                        batch_id: Some(batch_id),
1144                                        app_code: None,
1145                                    };
1146                                    if let Ok(bytes) = loro_protocol::encode(&err) {
1147                                        let _ = tx.send(Message::Binary(bytes.into()));
1148                                    }
1149                                    continue;
1150                                }
1151                                // else: duplicate header from same sender -> accept and broadcast as-is
1152                            } else {
1153                                h.start_fragment_batch(
1154                                    &room,
1155                                    conn_id,
1156                                    batch_id,
1157                                    fragment_count,
1158                                    total_size_bytes,
1159                                );
1160                            }
1161                            // Broadcast header as-is
1162                            h.broadcast(&room, conn_id, Message::Binary(data));
1163                        }
1164                        ProtocolMessage::DocUpdateFragment {
1165                            crdt,
1166                            room_id,
1167                            batch_id,
1168                            index,
1169                            fragment,
1170                        } => {
1171                            let room = RoomKey {
1172                                crdt,
1173                                room: room_id.clone(),
1174                            };
1175                            if !joined_rooms.contains(&room) {
1176                                let err = ProtocolMessage::UpdateError {
1177                                    crdt,
1178                                    room_id: room.room.clone(),
1179                                    code: protocol::UpdateErrorCode::PermissionDenied,
1180                                    message: "Must join room before sending updates".into(),
1181                                    batch_id: Some(batch_id),
1182                                    app_code: None,
1183                                };
1184                                if let Ok(bytes) = loro_protocol::encode(&err) {
1185                                    let _ = tx.send(Message::Binary(bytes.into()));
1186                                }
1187                                continue;
1188                            }
1189                            // Validate batch existence and sender binding; also index bounds
1190                            let mut h = hub.lock().await;
1191                            let key = (room.clone(), batch_id);
1192                            if let Some(b) = h.fragments.get(&key) {
1193                                if b.from_conn != conn_id {
1194                                    let err = ProtocolMessage::UpdateError {
1195                                        crdt,
1196                                        room_id: room.room.clone(),
1197                                        code: protocol::UpdateErrorCode::InvalidUpdate,
1198                                        message: "Fragment from wrong sender for batch".into(),
1199                                        batch_id: Some(batch_id),
1200                                        app_code: None,
1201                                    };
1202                                    if let Ok(bytes) = loro_protocol::encode(&err) {
1203                                        let _ = tx.send(Message::Binary(bytes.into()));
1204                                    }
1205                                    // do not broadcast
1206                                    continue;
1207                                }
1208                                if !usize::try_from(index)
1209                                    .ok()
1210                                    .map(|i| i < b.chunks.len())
1211                                    .unwrap_or(false)
1212                                {
1213                                    let err = ProtocolMessage::UpdateError {
1214                                        crdt,
1215                                        room_id: room.room.clone(),
1216                                        code: protocol::UpdateErrorCode::InvalidUpdate,
1217                                        message: "Fragment index out of range".into(),
1218                                        batch_id: Some(batch_id),
1219                                        app_code: None,
1220                                    };
1221                                    if let Ok(bytes) = loro_protocol::encode(&err) {
1222                                        let _ = tx.send(Message::Binary(bytes.into()));
1223                                    }
1224                                    continue;
1225                                }
1226                            } else {
1227                                let err = ProtocolMessage::UpdateError {
1228                                    crdt,
1229                                    room_id: room.room.clone(),
1230                                    code: protocol::UpdateErrorCode::InvalidUpdate,
1231                                    message: "Unknown fragment batch".into(),
1232                                    batch_id: Some(batch_id),
1233                                    app_code: None,
1234                                };
1235                                if let Ok(bytes) = loro_protocol::encode(&err) {
1236                                    let _ = tx.send(Message::Binary(bytes.into()));
1237                                }
1238                                continue;
1239                            }
1240                            // Broadcast this fragment as-is to others (only after validation)
1241                            h.broadcast(&room, conn_id, Message::Binary(data.clone()));
1242                            // Accumulate and possibly finish
1243                            if let Some(buf) =
1244                                h.add_fragment_and_maybe_finish(&room, batch_id, index, fragment)
1245                            {
1246                                // On completion: parse and apply to stored doc state if applicable
1247                                match crdt {
1248                                    CrdtType::Loro
1249                                    | CrdtType::LoroEphemeralStore
1250                                    | CrdtType::LoroEphemeralStorePersisted => {
1251                                        if let Ok(updates) = parse_docupdate_payload(&buf) {
1252                                            let start = std::time::Instant::now();
1253                                            h.apply_updates(&room, &updates);
1254                                            let elapsed_ms = start.elapsed().as_millis();
1255                                            debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied reassembled updates");
1256                                        }
1257                                    }
1258                                    CrdtType::Elo => {
1259                                        // Apply as indexing-only
1260                                        h.apply_updates(&room, &[buf]);
1261                                    }
1262                                    _ => {}
1263                                }
1264                            }
1265                        }
1266                        ProtocolMessage::DocUpdate {
1267                            crdt,
1268                            room_id,
1269                            updates,
1270                        } => {
1271                            let room = RoomKey {
1272                                crdt,
1273                                room: room_id.clone(),
1274                            };
1275                            if !joined_rooms.contains(&room) {
1276                                // Not joined: reject with PermissionDenied
1277                                let err = ProtocolMessage::UpdateError {
1278                                    crdt,
1279                                    room_id: room.room.clone(),
1280                                    code: protocol::UpdateErrorCode::PermissionDenied,
1281                                    message: "Must join room before sending updates".into(),
1282                                    batch_id: None,
1283                                    app_code: None,
1284                                };
1285                                if let Ok(bytes) = loro_protocol::encode(&err) {
1286                                    let _ = tx.send(Message::Binary(bytes.into()));
1287                                }
1288                                warn!(room=?room.room, "update rejected: not joined");
1289                            } else {
1290                                // Check permission
1291                                let perm = hub
1292                                    .lock()
1293                                    .await
1294                                    .perms
1295                                    .get(&(conn_id, room.clone()))
1296                                    .copied();
1297                                if !matches!(perm, Some(Permission::Write)) {
1298                                    let err = ProtocolMessage::UpdateError {
1299                                        crdt,
1300                                        room_id: room.room.clone(),
1301                                        code: protocol::UpdateErrorCode::PermissionDenied,
1302                                        message: "Write permission required to update document"
1303                                            .into(),
1304                                        batch_id: None,
1305                                        app_code: None,
1306                                    };
1307                                    if let Ok(bytes) = loro_protocol::encode(&err) {
1308                                        let _ = tx.send(Message::Binary(bytes.into()));
1309                                    }
1310                                    continue;
1311                                }
1312                                let mut h = hub.lock().await;
1313                                match crdt {
1314                                    CrdtType::Loro
1315                                    | CrdtType::LoroEphemeralStore
1316                                    | CrdtType::LoroEphemeralStorePersisted => {
1317                                        let start = std::time::Instant::now();
1318                                        h.apply_updates(&room, &updates);
1319                                        let elapsed_ms = start.elapsed().as_millis();
1320                                        h.broadcast(&room, conn_id, Message::Binary(data));
1321                                        debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied and broadcast updates");
1322                                    }
1323                                    CrdtType::Elo => {
1324                                        // Index headers only; payload remains opaque to server.
1325                                        h.apply_updates(&room, &updates);
1326                                        h.broadcast(&room, conn_id, Message::Binary(data));
1327                                        debug!(room=?room.room, updates=%updates.len(), "indexed and broadcast ELO updates");
1328                                    }
1329                                    _ => {
1330                                        h.broadcast(&room, conn_id, Message::Binary(data));
1331                                    }
1332                                }
1333                            }
1334                        }
1335                        _ => {
1336                            // For simplicity, ignore other messages in minimal server.
1337                        }
1338                    }
1339                } else {
1340                    // Invalid frame: close with Protocol error, but keep server running
1341                    warn!("invalid protocol frame; closing connection");
1342                    let _ = tx.send(Message::Close(Some(CloseFrame {
1343                        code: CloseCode::Protocol,
1344                        reason: "Protocol error".into(),
1345                    })));
1346                    break;
1347                }
1348            }
1349            Message::Close(frame) => {
1350                let _ = tx.send(Message::Close(frame.clone()));
1351                break;
1352            }
1353            Message::Ping(p) => {
1354                let _ = tx.send(Message::Pong(p));
1355                let _ = tx.send(Message::Text("pong".into()));
1356            }
1357            _ => {}
1358        }
1359    }
1360
1361    // cleanup
1362    {
1363        let mut h = hub.lock().await;
1364        h.leave_all(conn_id);
1365    }
1366    // drop tx to stop writer
1367    drop(tx);
1368    let _ = sink_task.await;
1369    debug!(conn_id, "connection closed and cleaned up");
1370    Ok(())
1371}
1372
1373fn parse_docupdate_payload(buf: &[u8]) -> Result<Vec<Vec<u8>>, String> {
1374    let mut r = BytesReader::new(buf);
1375    let n = usize::try_from(r.read_uleb128()?).map_err(|_| "length too large".to_string())?;
1376    let mut out: Vec<Vec<u8>> = Vec::with_capacity(n);
1377    for _ in 0..n {
1378        let b = r.read_var_bytes()?.to_vec();
1379        out.push(b);
1380    }
1381    if r.remaining() != 0 {
1382        return Err("trailing bytes".into());
1383    }
1384    Ok(out)
1385}