1use futures_util::{SinkExt, StreamExt};
23use std::{
24 collections::{HashMap, HashSet},
25 future::Future,
26 hash::{Hash, Hasher},
27 pin::Pin,
28 sync::{
29 atomic::{AtomicU64, Ordering},
30 Arc,
31 },
32 time::Duration,
33};
34use tokio::{
35 net::{TcpListener, TcpStream},
36 sync::mpsc,
37};
38use tokio_tungstenite::accept_hdr_async;
39use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
40use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
41use tokio_tungstenite::tungstenite::{self, Message};
42
43use loro::awareness::EphemeralStore;
44use loro::{ExportMode, LoroDoc};
45pub use loro_protocol as protocol;
46use protocol::{try_decode, BytesReader, CrdtType, JoinErrorCode, Permission, ProtocolMessage};
47use tracing::{debug, error, info, warn};
48
49const MAX_FRAGMENTS: u64 = 4096; const MAX_BATCH_BYTES: u64 = 64 * 1024 * 1024; #[derive(Clone, Debug, PartialEq, Eq)]
54struct RoomKey {
55 crdt: CrdtType,
56 room: String,
57}
58impl Hash for RoomKey {
59 fn hash<H: Hasher>(&self, state: &mut H) {
60 let tag = match self.crdt {
62 CrdtType::Loro => 0u8,
63 CrdtType::LoroEphemeralStore => 1,
64 CrdtType::LoroEphemeralStorePersisted => 2,
65 CrdtType::Yjs => 3,
66 CrdtType::YjsAwareness => 4,
67 CrdtType::Elo => 5,
68 };
69 tag.hash(state);
70 self.room.hash(state);
71 }
72}
73
74type Sender = mpsc::UnboundedSender<Message>;
75
76pub struct LoadedDoc<DocCtx> {
80 pub snapshot: Option<Vec<u8>>,
81 pub ctx: Option<DocCtx>,
82}
83
84pub struct LoadDocArgs {
86 pub workspace: String,
87 pub room: String,
88 pub crdt: CrdtType,
89}
90
91pub struct SaveDocArgs<DocCtx> {
93 pub workspace: String,
94 pub room: String,
95 pub crdt: CrdtType,
96 pub data: Vec<u8>,
97 pub ctx: Option<DocCtx>,
98}
99
100type LoadFuture<DocCtx> =
101 Pin<Box<dyn Future<Output = Result<LoadedDoc<DocCtx>, String>> + Send + 'static>>;
102type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
103type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
104type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;
105type AuthFuture =
106 Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
107type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
108
109type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync;
110
111#[derive(Clone)]
112pub struct ServerConfig<DocCtx = ()> {
113 pub on_load_document: Option<LoadFn<DocCtx>>,
114 pub on_save_document: Option<SaveFn<DocCtx>>,
115 pub save_interval_ms: Option<u64>,
116 pub default_permission: Permission,
117 pub authenticate: Option<AuthFn>,
118 pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
126}
127
128trait CrdtDoc: Send {
130 fn get_version(&self) -> Vec<u8> {
131 Vec::new()
132 }
133 fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
134 Vec::new()
135 }
136 fn apply_updates(&mut self, _updates: &[Vec<u8>]) -> Result<(), String> {
137 Ok(())
138 }
139 fn should_persist(&self) -> bool {
140 false
141 }
142 fn export_snapshot(&self) -> Option<Vec<u8>> {
143 None
144 }
145 fn import_snapshot(&mut self, _data: &[u8]) {}
146 fn allow_backfill_when_no_other_clients(&self) -> bool {
147 false
148 }
149 fn remove_when_last_subscriber_leaves(&self) -> bool {
150 false
151 }
152}
153
154struct LoroRoomDoc {
155 doc: LoroDoc,
156}
157impl LoroRoomDoc {
158 fn new() -> Self {
159 Self {
160 doc: LoroDoc::new(),
161 }
162 }
163}
164impl CrdtDoc for LoroRoomDoc {
165 fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
166 for u in updates {
167 let _ = self.doc.import(u);
168 }
169 Ok(())
170 }
171 fn should_persist(&self) -> bool {
172 true
173 }
174 fn export_snapshot(&self) -> Option<Vec<u8>> {
175 self.doc.export(ExportMode::Snapshot).ok()
176 }
177 fn import_snapshot(&mut self, data: &[u8]) {
178 let _ = self.doc.import(data);
179 }
180}
181
182struct EphemeralRoomDoc {
183 store: EphemeralStore,
184}
185impl EphemeralRoomDoc {
186 fn new(timeout_ms: i64) -> Self {
187 Self {
188 store: EphemeralStore::new(timeout_ms),
189 }
190 }
191}
192impl CrdtDoc for EphemeralRoomDoc {
193 fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
194 let data = self.store.encode_all();
195 if data.is_empty() {
196 Vec::new()
197 } else {
198 vec![data]
199 }
200 }
201 fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
202 for u in updates {
203 if !u.is_empty() {
204 self.store.apply(u);
205 }
206 }
207 Ok(())
208 }
209 fn remove_when_last_subscriber_leaves(&self) -> bool {
210 true
211 }
212}
213
214struct PersistentEphemeralRoomDoc {
215 store: EphemeralStore,
216 timeout_ms: i64,
217}
218impl PersistentEphemeralRoomDoc {
219 fn new(timeout_ms: i64) -> Self {
220 Self {
221 store: EphemeralStore::new(timeout_ms),
222 timeout_ms,
223 }
224 }
225}
226impl CrdtDoc for PersistentEphemeralRoomDoc {
227 fn compute_backfill(&self, _client_version: &[u8]) -> Vec<Vec<u8>> {
228 let data = self.store.encode_all();
229 if data.is_empty() {
230 Vec::new()
231 } else {
232 vec![data]
233 }
234 }
235 fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
236 for u in updates {
237 if !u.is_empty() {
238 self.store.apply(u);
239 }
240 }
241 Ok(())
242 }
243 fn should_persist(&self) -> bool {
244 true
245 }
246 fn export_snapshot(&self) -> Option<Vec<u8>> {
247 Some(self.store.encode_all())
248 }
249 fn import_snapshot(&mut self, data: &[u8]) {
250 self.store = EphemeralStore::new(self.timeout_ms);
251 if !data.is_empty() {
252 self.store.apply(data);
253 }
254 }
255 fn allow_backfill_when_no_other_clients(&self) -> bool {
256 true
257 }
258}
259
260struct EloDeltaSpanIndexEntry {
262 start: u64,
263 end: u64,
264 key_id: String,
265 record: Vec<u8>,
266}
267
268struct EloRoomDoc {
269 spans_by_peer: std::collections::HashMap<String, Vec<EloDeltaSpanIndexEntry>>,
270}
271impl EloRoomDoc {
272 fn new() -> Self {
273 Self {
274 spans_by_peer: std::collections::HashMap::new(),
275 }
276 }
277
278 fn peer_key_from_bytes(bytes: &[u8]) -> String {
279 match std::str::from_utf8(bytes) {
281 Ok(s) => s.to_string(),
282 Err(_) => {
283 let mut out = String::with_capacity(bytes.len() * 2);
284 for b in bytes {
285 use std::fmt::Write as _;
286 let _ = write!(&mut out, "{:02x}", b);
287 }
288 out
289 }
290 }
291 }
292
293 fn decode_version_vector(&self, buf: &[u8]) -> Option<std::collections::HashMap<String, u64>> {
294 use loro_protocol::bytes::BytesReader;
295 let mut r = BytesReader::new(buf);
296 let count = usize::try_from(r.read_uleb128().ok()?).ok()?;
297 let mut map: std::collections::HashMap<String, u64> =
298 std::collections::HashMap::with_capacity(count);
299 for _ in 0..count {
300 let peer_bytes = r.read_var_bytes().ok()?;
301 let ctr = r.read_uleb128().ok()?;
302 map.insert(Self::peer_key_from_bytes(peer_bytes), ctr);
303 }
304 Some(map)
305 }
306
307 fn encode_current_vv(&self) -> Vec<u8> {
308 use loro_protocol::bytes::BytesWriter;
309 let mut entries: Vec<(String, u64)> = Vec::new();
310 for (peer, spans) in self.spans_by_peer.iter() {
311 if !peer.as_bytes().iter().all(|b| b.is_ascii_digit()) {
312 continue;
313 }
314 let mut max_end = 0u64;
315 for s in spans.iter() {
316 if s.end > max_end {
317 max_end = s.end;
318 }
319 }
320 if max_end > 0 {
321 entries.push((peer.clone(), max_end));
322 }
323 }
324 let mut w = BytesWriter::new();
325 w.push_uleb128(entries.len() as u64);
326 for (peer, ctr) in entries.iter() {
327 w.push_var_bytes(peer.as_bytes());
328 w.push_uleb128(*ctr);
329 }
330 w.finalize()
331 }
332}
333impl CrdtDoc for EloRoomDoc {
334 fn get_version(&self) -> Vec<u8> {
335 if self.spans_by_peer.is_empty() {
338 return Vec::new();
339 }
340 self.encode_current_vv()
341 }
342 fn compute_backfill(&self, client_version: &[u8]) -> Vec<Vec<u8>> {
343 let known = self
344 .decode_version_vector(client_version)
345 .unwrap_or_default();
346 let mut records: Vec<Vec<u8>> = Vec::new();
347 for (peer, spans) in self.spans_by_peer.iter() {
348 let k = known.get(peer).copied().unwrap_or(0);
349 for e in spans {
350 if e.end > k {
351 records.push(e.record.clone());
352 }
353 }
354 }
355 if records.is_empty() {
356 return Vec::new();
357 }
358 let mut w = loro_protocol::bytes::BytesWriter::new();
359 w.push_uleb128(records.len() as u64);
360 for rec in records.iter() {
361 w.push_var_bytes(rec);
362 }
363 vec![w.finalize()]
364 }
365 fn apply_updates(&mut self, updates: &[Vec<u8>]) -> Result<(), String> {
366 use loro_protocol::elo::{
367 decode_elo_container, parse_elo_record_header, EloHeader, EloRecordKind,
368 };
369 for u in updates {
370 let records = decode_elo_container(u.as_slice())?;
371 for rec in records {
372 let parsed = parse_elo_record_header(rec)?;
373 match parsed.kind {
374 EloRecordKind::DeltaSpan => {
375 if let EloHeader::Delta(h) = parsed.header {
376 if !(h.end > h.start) {
377 return Err("invalid ELO delta span: end must be > start".into());
378 }
379 if h.iv.len() != 12 {
380 return Err("invalid ELO delta span: IV must be 12 bytes".into());
381 }
382 let peer = Self::peer_key_from_bytes(&h.peer_id);
383 let list = self.spans_by_peer.entry(peer).or_default();
384 let mut kept: Vec<EloDeltaSpanIndexEntry> =
386 Vec::with_capacity(list.len() + 1);
387 let mut inserted = false;
388 for e in list.iter() {
389 if !inserted && e.start >= h.start {
390 kept.push(EloDeltaSpanIndexEntry {
391 start: h.start,
392 end: h.end,
393 key_id: h.key_id.clone(),
394 record: rec.to_vec(),
395 });
396 inserted = true;
397 }
398 let covered = e.start >= h.start && e.end <= h.end;
400 if !covered {
401 kept.push(EloDeltaSpanIndexEntry {
402 start: e.start,
403 end: e.end,
404 key_id: e.key_id.clone(),
405 record: e.record.clone(),
406 });
407 }
408 }
409 if !inserted {
410 kept.push(EloDeltaSpanIndexEntry {
411 start: h.start,
412 end: h.end,
413 key_id: h.key_id.clone(),
414 record: rec.to_vec(),
415 });
416 }
417 *list = kept;
418 }
419 }
420 EloRecordKind::Snapshot => {
421 }
423 }
424 }
425 }
426 Ok(())
427 }
428 fn allow_backfill_when_no_other_clients(&self) -> bool {
429 true
430 }
431}
432impl<DocCtx> Default for ServerConfig<DocCtx> {
433 fn default() -> Self {
434 Self {
435 on_load_document: None,
436 on_save_document: None,
437 save_interval_ms: None,
438 default_permission: Permission::Write,
439 authenticate: None,
440 handshake_auth: None,
441 }
442 }
443}
444
445struct RoomDocState<DocCtx> {
446 doc: Box<dyn CrdtDoc>,
447 dirty: bool,
448 ctx: Option<DocCtx>,
449}
450
451struct Hub<DocCtx> {
452 subs: HashMap<RoomKey, Vec<(u64, Sender)>>,
454 docs: HashMap<RoomKey, RoomDocState<DocCtx>>,
456 config: ServerConfig<DocCtx>,
457 perms: HashMap<(u64, RoomKey), Permission>,
459 workspace: String,
460 fragments: HashMap<(RoomKey, protocol::BatchId), FragmentBatch>,
462}
463
464impl<DocCtx> Hub<DocCtx>
465where
466 DocCtx: Clone + Send + Sync + 'static,
467{
468 fn new(config: ServerConfig<DocCtx>, workspace: String) -> Self {
469 Self {
470 subs: HashMap::new(),
471 docs: HashMap::new(),
472 config,
473 perms: HashMap::new(),
474 workspace,
475 fragments: HashMap::new(),
476 }
477 }
478
479 const EPHEMERAL_TIMEOUT_MS: i64 = 60_000;
480
481 fn join(&mut self, conn_id: u64, room: RoomKey, tx: &Sender) {
482 let entry = self.subs.entry(room).or_default();
483 if !entry.iter().any(|(id, _)| *id == conn_id) {
484 entry.push((conn_id, tx.clone()));
485 }
486 }
487
488 fn leave_all(&mut self, conn_id: u64) {
489 let mut emptied: Vec<RoomKey> = Vec::new();
490 for (k, vec) in self.subs.iter_mut() {
491 vec.retain(|(id, _)| *id != conn_id);
492 if vec.is_empty() {
493 emptied.push(k.clone());
494 }
495 }
496 for k in &emptied {
498 let _ = self.subs.remove(k);
499 }
500
501 self.perms.retain(|(id, _), _| *id != conn_id);
503
504 for k in emptied.clone() {
506 if let Some(state) = self.docs.get(&k) {
507 if state.doc.remove_when_last_subscriber_leaves() {
508 self.docs.remove(&k);
509 debug!(room=?k.room, "cleaned up ephemeral doc after last subscriber left");
510 }
511 }
512 }
513
514 if !self.fragments.is_empty() {
516 use std::collections::HashSet;
517 let emptied_set: HashSet<RoomKey> = emptied.into_iter().collect();
518 self.fragments
519 .retain(|(rk, _), b| b.from_conn != conn_id && !emptied_set.contains(rk));
520 }
521 }
522
523 fn broadcast(&mut self, room: &RoomKey, from: u64, msg: Message) {
524 if let Some(list) = self.subs.get_mut(room) {
525 let mut dead: HashSet<u64> = HashSet::new();
527 for (id, tx) in list.iter() {
528 if *id == from {
529 continue;
530 }
531 if tx.send(msg.clone()).is_err() {
532 dead.insert(*id);
533 }
534 }
535 if !dead.is_empty() {
536 list.retain(|(id, _)| !dead.contains(id));
537 debug!(room=?room.room, removed=%dead.len(), "removed dead subscribers");
538 }
539 }
540 }
541
542 async fn ensure_room_loaded(&mut self, room: &RoomKey) {
543 if self.docs.contains_key(room) {
544 return;
545 }
546 match room.crdt {
547 CrdtType::Loro => {
548 let mut d = LoroRoomDoc::new();
549 let mut ctx = None;
550 if let Some(loader) = &self.config.on_load_document {
551 let args = LoadDocArgs {
552 workspace: self.workspace.clone(),
553 room: room.room.clone(),
554 crdt: room.crdt,
555 };
556 match (loader)(args).await {
557 Ok(loaded) => {
558 if let Some(bytes) = loaded.snapshot {
559 d.import_snapshot(&bytes);
560 }
561 ctx = loaded.ctx;
562 }
563 Err(e) => {
564 warn!(room=?room.room, %e, "load document failed");
565 }
566 }
567 }
568 self.docs.insert(
569 room.clone(),
570 RoomDocState {
571 doc: Box::new(d),
572 dirty: false,
573 ctx,
574 },
575 );
576 }
577 CrdtType::LoroEphemeralStore => {
578 let d = EphemeralRoomDoc::new(Self::EPHEMERAL_TIMEOUT_MS);
579 self.docs.insert(
580 room.clone(),
581 RoomDocState {
582 doc: Box::new(d),
583 dirty: false,
584 ctx: None,
585 },
586 );
587 }
588 CrdtType::LoroEphemeralStorePersisted => {
589 let mut d = PersistentEphemeralRoomDoc::new(Self::EPHEMERAL_TIMEOUT_MS);
590 let mut ctx = None;
591 if let Some(loader) = &self.config.on_load_document {
592 let args = LoadDocArgs {
593 workspace: self.workspace.clone(),
594 room: room.room.clone(),
595 crdt: room.crdt,
596 };
597 match (loader)(args).await {
598 Ok(loaded) => {
599 if let Some(bytes) = loaded.snapshot {
600 d.import_snapshot(&bytes);
601 }
602 ctx = loaded.ctx;
603 }
604 Err(e) => {
605 warn!(room=?room.room, %e, "load persisted ephemeral store failed");
606 }
607 }
608 }
609 self.docs.insert(
610 room.clone(),
611 RoomDocState {
612 doc: Box::new(d),
613 dirty: false,
614 ctx,
615 },
616 );
617 }
618 CrdtType::Elo => {
619 let d = EloRoomDoc::new();
620 self.docs.insert(
621 room.clone(),
622 RoomDocState {
623 doc: Box::new(d),
624 dirty: false,
625 ctx: None,
626 },
627 );
628 }
629 _ => {}
630 }
631 }
632
633 fn current_version_bytes(&self, room: &RoomKey) -> Vec<u8> {
634 match self.docs.get(room) {
635 Some(state) => state.doc.get_version(),
636 None => Vec::new(),
637 }
638 }
639
640 fn apply_updates(&mut self, room: &RoomKey, updates: &[Vec<u8>]) {
641 if let Some(state) = self.docs.get_mut(room) {
642 if let Err(e) = state.doc.apply_updates(updates) {
643 warn!(room=?room.room, %e, "apply_updates failed");
644 } else if state.doc.should_persist() {
645 state.dirty = true;
646 }
647 }
648 }
649
650 fn snapshot_bytes(&self, room: &RoomKey) -> Option<Vec<u8>> {
651 let Some(data) = self.docs.get(room).and_then(|s| s.doc.export_snapshot()) else {
652 return None;
653 };
654 if data.is_empty() {
655 None
656 } else {
657 Some(data)
658 }
659 }
660}
661
662struct FragmentBatch {
663 from_conn: u64,
664 fragment_count: u64,
665 total_size: u64,
666 received: u64,
667 chunks: Vec<Option<Vec<u8>>>,
668}
669
670impl<DocCtx> Hub<DocCtx>
671where
672 DocCtx: Clone + Send + Sync + 'static,
673{
674 fn start_fragment_batch(
675 &mut self,
676 room: &RoomKey,
677 from_conn: u64,
678 batch_id: protocol::BatchId,
679 fragment_count: u64,
680 total_size: u64,
681 ) {
682 let key = (room.clone(), batch_id);
683 let chunks_len = usize::try_from(fragment_count).unwrap_or(0);
684 let batch = FragmentBatch {
685 from_conn,
686 fragment_count,
687 total_size,
688 received: 0,
689 chunks: vec![None; chunks_len],
690 };
691 self.fragments.insert(key, batch);
692 }
693
694 fn add_fragment_and_maybe_finish(
696 &mut self,
697 room: &RoomKey,
698 batch_id: protocol::BatchId,
699 index: u64,
700 fragment: Vec<u8>,
701 ) -> Option<Vec<u8>> {
702 let key = (room.clone(), batch_id);
703 if let Some(b) = self.fragments.get_mut(&key) {
704 let idx = match usize::try_from(index) {
705 Ok(i) => i,
706 Err(_) => return None,
707 };
708 if idx >= b.chunks.len() {
709 return None;
710 }
711 if b.chunks[idx].is_none() {
712 b.chunks[idx] = Some(fragment);
713 b.received += 1;
714 }
715 if b.received == b.fragment_count {
716 let mut out = Vec::with_capacity(b.total_size as usize);
717 for ch in b.chunks.iter() {
718 if let Some(bytes) = ch.as_ref() {
719 out.extend_from_slice(bytes);
720 }
721 }
722 self.fragments.remove(&key);
723 return Some(out);
724 }
725 }
726 None
727 }
728}
729
730static NEXT_ID: AtomicU64 = AtomicU64::new(1);
731
732struct HubRegistry<DocCtx> {
733 config: ServerConfig<DocCtx>,
734 hubs: tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<Hub<DocCtx>>>>>,
735}
736
737impl<DocCtx> HubRegistry<DocCtx>
738where
739 DocCtx: Clone + Send + Sync + 'static,
740{
741 fn new(config: ServerConfig<DocCtx>) -> Self {
742 Self {
743 config,
744 hubs: tokio::sync::Mutex::new(HashMap::new()),
745 }
746 }
747
748 async fn get_or_create(&self, workspace: &str) -> Arc<tokio::sync::Mutex<Hub<DocCtx>>> {
749 let mut map = self.hubs.lock().await;
750 if let Some(h) = map.get(workspace) {
751 return h.clone();
752 }
753 let hub = Arc::new(tokio::sync::Mutex::new(Hub::new(
754 self.config.clone(),
755 workspace.to_string(),
756 )));
757 if let (Some(ms), Some(saver)) = (
759 self.config.save_interval_ms,
760 self.config.on_save_document.clone(),
761 ) {
762 let hub_clone = hub.clone();
763 tokio::spawn(async move {
764 let mut interval = tokio::time::interval(Duration::from_millis(ms));
765 loop {
766 interval.tick().await;
767 let mut guard = hub_clone.lock().await;
768 let ws = guard.workspace.clone();
769 let rooms: Vec<RoomKey> = guard.docs.keys().cloned().collect();
770 for room in rooms {
771 if let Some(state) = guard.docs.get_mut(&room) {
772 if state.dirty && state.doc.should_persist() {
773 let start = std::time::Instant::now();
774 if let Some(snapshot) = state.doc.export_snapshot() {
775 let room_str = room.room.clone();
776 let ctx = state.ctx.clone();
777 let args = SaveDocArgs {
778 workspace: ws.clone(),
779 room: room_str.clone(),
780 crdt: room.crdt,
781 data: snapshot,
782 ctx,
783 };
784 match (saver)(args).await {
785 Ok(()) => {
786 state.dirty = false;
787 let elapsed = start.elapsed();
788 debug!(workspace=%ws, room=%room_str, ms=%elapsed.as_millis(), "snapshot saved");
789 }
790 Err(e) => {
791 warn!(workspace=%ws, room=%room_str, %e, "snapshot save failed");
792 }
793 }
794 }
795 }
796 }
797 }
798 }
799 });
800 }
801 map.insert(workspace.to_string(), hub.clone());
802 hub
803 }
804}
805
806pub async fn serve(addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
808 info!(%addr, "binding TCP listener");
809 let listener = TcpListener::bind(addr).await?;
810 serve_incoming_with_config::<()>(listener, ServerConfig::default()).await
811}
812
813pub async fn serve_incoming(
815 listener: TcpListener,
816) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
817 serve_incoming_with_config::<()>(listener, ServerConfig::default()).await
818}
819
820pub async fn serve_incoming_with_config<DocCtx>(
821 listener: TcpListener,
822 config: ServerConfig<DocCtx>,
823) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
824where
825 DocCtx: Clone + Send + Sync + 'static,
826{
827 let registry = Arc::new(HubRegistry::new(config.clone()));
828
829 loop {
830 match listener.accept().await {
831 Ok((stream, peer)) => {
832 debug!(remote=%peer, "accepted TCP connection");
833 let registry = registry.clone();
834 tokio::spawn(async move {
835 if let Err(e) = handle_conn(stream, registry).await {
836 warn!(%e, "connection task ended with error");
837 }
838 });
839 }
840 Err(e) => {
841 error!(%e, "accept failed; continuing");
842 continue;
843 }
844 }
845 }
846}
847
848async fn handle_conn<DocCtx>(
849 stream: TcpStream,
850 registry: Arc<HubRegistry<DocCtx>>,
851) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
852where
853 DocCtx: Clone + Send + Sync + 'static,
854{
855 let handshake_auth = registry.config.handshake_auth.clone();
857 let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
858 Arc::new(std::sync::Mutex::new(None));
859 let workspace_holder_c = workspace_holder.clone();
860
861 let ws = accept_hdr_async(
862 stream,
863 move |req: &tungstenite::handshake::server::Request,
864 resp: tungstenite::handshake::server::Response| {
865 if let Some(check) = &handshake_auth {
866 let uri = req.uri();
868 let path = uri.path();
869 let mut workspace_id = "";
870 if let Some(rest) = path.strip_prefix('/') {
871 if !rest.is_empty() {
872 workspace_id = rest.split('/').next().unwrap_or("");
874 }
875 }
876 {
878 if let Ok(mut guard) = workspace_holder_c.lock() {
879 *guard = Some(workspace_id.to_string());
880 }
881 }
882
883 let token = uri.query().and_then(|q| {
885 for pair in q.split('&') {
886 let mut it = pair.splitn(2, '=');
887 let k = it.next().unwrap_or("");
888 let v = it.next();
889 if k == "token" {
890 return Some(v.unwrap_or(""));
891 }
892 }
893 None
894 });
895
896 let allowed = (check)(workspace_id, token);
897 if !allowed {
898 warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
899 let builder = tungstenite::http::Response::builder()
901 .status(tungstenite::http::StatusCode::UNAUTHORIZED);
902 let response = builder
904 .body(Some("Unauthorized".to_string()))
905 .unwrap_or_else(|e| {
906 warn!(?e, "failed to build unauthorized response");
907 let mut fallback =
908 tungstenite::http::Response::new(Some("Unauthorized".to_string()));
909 *fallback.status_mut() = tungstenite::http::StatusCode::UNAUTHORIZED;
910 fallback
911 });
912 return Err(response);
913 }
914 debug!(workspace=%workspace_id, token=?token, "handshake auth accepted");
915 }
916 Ok(resp)
917 },
918 )
919 .await?;
920
921 let workspace_id = workspace_holder
923 .lock()
924 .ok()
925 .and_then(|g| g.clone())
926 .unwrap_or_default();
927 let hub = registry.get_or_create(&workspace_id).await;
928
929 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
931 let (mut sink, mut stream) = ws.split();
932 let sink_task = tokio::spawn(async move {
934 while let Some(msg) = rx.recv().await {
935 if sink.send(msg).await.is_err() {
936 debug!("sink send error; writer task exiting");
937 break;
938 }
939 }
940 });
941
942 let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
943 let mut joined_rooms: HashSet<RoomKey> = HashSet::new();
944
945 while let Some(msg) = stream.next().await {
946 match msg? {
947 Message::Text(txt) => {
948 if txt == "ping" {
949 let _ = tx.send(Message::Text("pong".into()));
950 }
951 }
952 Message::Binary(data) => {
953 if let Some(proto) = try_decode(data.as_ref()) {
954 match proto {
955 ProtocolMessage::JoinRequest {
956 crdt,
957 room_id,
958 auth,
959 version,
960 } => {
961 let room = RoomKey {
962 crdt,
963 room: room_id.clone(),
964 };
965 let mut h = hub.lock().await;
966 h.ensure_room_loaded(&room).await;
968 let mut permission = h.config.default_permission;
970 if let Some(auth_fn) = &h.config.authenticate {
971 let room_str = room.room.clone();
972 match (auth_fn)(room_str, room.crdt, auth.clone()).await {
973 Ok(Some(p)) => {
974 permission = p;
975 }
976 Ok(None) => {
977 let err = ProtocolMessage::JoinError {
978 crdt,
979 room_id: room.room.clone(),
980 code: JoinErrorCode::AuthFailed,
981 message: "Authentication failed".into(),
982 receiver_version: None,
983 app_code: None,
984 };
985 if let Ok(bytes) = loro_protocol::encode(&err) {
986 let _ = tx.send(Message::Binary(bytes.into()));
987 }
988 warn!(room=?room.room, "join denied by authenticate() returning None");
989 continue;
990 }
991 Err(e) => {
992 let err = ProtocolMessage::JoinError {
993 crdt,
994 room_id: room.room.clone(),
995 code: JoinErrorCode::Unknown,
996 message: e,
997 receiver_version: None,
998 app_code: None,
999 };
1000 if let Ok(bytes) = loro_protocol::encode(&err) {
1001 let _ = tx.send(Message::Binary(bytes.into()));
1002 }
1003 warn!(room=?room.room, "join denied due to authenticate() error");
1004 continue;
1005 }
1006 }
1007 }
1008 h.join(conn_id, room.clone(), &tx);
1010 h.perms.insert((conn_id, room.clone()), permission);
1011 joined_rooms.insert(room.clone());
1012 info!(workspace=%h.workspace, room=?room.room, ?permission, "join ok");
1013 let current_version = h.current_version_bytes(&room);
1015 let ok = ProtocolMessage::JoinResponseOk {
1016 crdt,
1017 room_id: room.room.clone(),
1018 permission,
1019 version: current_version,
1020 extra: Some(Vec::new()),
1021 };
1022 if let Ok(bytes) = loro_protocol::encode(&ok) {
1023 let _ = tx.send(Message::Binary(bytes.into()));
1024 }
1025 if let Some(snap) = h.snapshot_bytes(&room) {
1028 let du = ProtocolMessage::DocUpdate {
1029 crdt,
1030 room_id: room.room.clone(),
1031 updates: vec![snap],
1032 };
1033 if let Ok(bytes) = loro_protocol::encode(&du) {
1034 let _ = tx.send(Message::Binary(bytes.into()));
1035 debug!(room=?room.room, "sent initial snapshot after join");
1036 }
1037 } else {
1038 let others_in_room =
1040 h.subs.get(&room).map(|v| v.len()).unwrap_or(0) > 1;
1041 let allow_when_empty = h
1042 .docs
1043 .get(&room)
1044 .map(|s| s.doc.allow_backfill_when_no_other_clients())
1045 .unwrap_or(false);
1046 if others_in_room || allow_when_empty {
1047 let backfill = h
1048 .docs
1049 .get(&room)
1050 .map(|s| s.doc.compute_backfill(&version))
1051 .unwrap_or_default();
1052 let backfill_cnt = backfill.len();
1053 for u in backfill {
1054 let du = ProtocolMessage::DocUpdate {
1055 crdt,
1056 room_id: room.room.clone(),
1057 updates: vec![u],
1058 };
1059 if let Ok(bytes) = loro_protocol::encode(&du) {
1060 let _ = tx.send(Message::Binary(bytes.into()));
1061 }
1062 }
1063 if backfill_cnt > 0 {
1064 debug!(room=?room.room, cnt=%backfill_cnt, "sent backfill after join");
1065 }
1066 }
1067 }
1068 }
1069 ProtocolMessage::DocUpdateFragmentHeader {
1070 crdt,
1071 room_id,
1072 batch_id,
1073 fragment_count,
1074 total_size_bytes,
1075 } => {
1076 let room = RoomKey {
1077 crdt,
1078 room: room_id.clone(),
1079 };
1080 if !joined_rooms.contains(&room) {
1081 let err = ProtocolMessage::UpdateError {
1082 crdt,
1083 room_id: room.room.clone(),
1084 code: protocol::UpdateErrorCode::PermissionDenied,
1085 message: "Must join room before sending updates".into(),
1086 batch_id: Some(batch_id),
1087 app_code: None,
1088 };
1089 if let Ok(bytes) = loro_protocol::encode(&err) {
1090 let _ = tx.send(Message::Binary(bytes.into()));
1091 }
1092 continue;
1093 }
1094 let perm = hub
1096 .lock()
1097 .await
1098 .perms
1099 .get(&(conn_id, room.clone()))
1100 .copied();
1101 if !matches!(perm, Some(Permission::Write)) {
1102 let err = ProtocolMessage::UpdateError {
1103 crdt,
1104 room_id: room.room.clone(),
1105 code: protocol::UpdateErrorCode::PermissionDenied,
1106 message: "Write permission required to update document".into(),
1107 batch_id: Some(batch_id),
1108 app_code: None,
1109 };
1110 if let Ok(bytes) = loro_protocol::encode(&err) {
1111 let _ = tx.send(Message::Binary(bytes.into()));
1112 }
1113 continue;
1114 }
1115 if fragment_count == 0
1117 || fragment_count > MAX_FRAGMENTS
1118 || total_size_bytes > MAX_BATCH_BYTES
1119 {
1120 let err = ProtocolMessage::UpdateError {
1121 crdt,
1122 room_id: room.room.clone(),
1123 code: protocol::UpdateErrorCode::PayloadTooLarge,
1124 message: "Fragmented batch exceeds server limits".into(),
1125 batch_id: Some(batch_id),
1126 app_code: None,
1127 };
1128 if let Ok(bytes) = loro_protocol::encode(&err) {
1129 let _ = tx.send(Message::Binary(bytes.into()));
1130 }
1131 continue;
1132 }
1133 let mut h = hub.lock().await;
1135 let key = (room.clone(), batch_id);
1136 if let Some(existing) = h.fragments.get(&key) {
1137 if existing.from_conn != conn_id {
1138 let err = ProtocolMessage::UpdateError {
1139 crdt,
1140 room_id: room.room.clone(),
1141 code: protocol::UpdateErrorCode::InvalidUpdate,
1142 message: "Batch ID already in use by another sender".into(),
1143 batch_id: Some(batch_id),
1144 app_code: None,
1145 };
1146 if let Ok(bytes) = loro_protocol::encode(&err) {
1147 let _ = tx.send(Message::Binary(bytes.into()));
1148 }
1149 continue;
1150 }
1151 } else {
1153 h.start_fragment_batch(
1154 &room,
1155 conn_id,
1156 batch_id,
1157 fragment_count,
1158 total_size_bytes,
1159 );
1160 }
1161 h.broadcast(&room, conn_id, Message::Binary(data));
1163 }
1164 ProtocolMessage::DocUpdateFragment {
1165 crdt,
1166 room_id,
1167 batch_id,
1168 index,
1169 fragment,
1170 } => {
1171 let room = RoomKey {
1172 crdt,
1173 room: room_id.clone(),
1174 };
1175 if !joined_rooms.contains(&room) {
1176 let err = ProtocolMessage::UpdateError {
1177 crdt,
1178 room_id: room.room.clone(),
1179 code: protocol::UpdateErrorCode::PermissionDenied,
1180 message: "Must join room before sending updates".into(),
1181 batch_id: Some(batch_id),
1182 app_code: None,
1183 };
1184 if let Ok(bytes) = loro_protocol::encode(&err) {
1185 let _ = tx.send(Message::Binary(bytes.into()));
1186 }
1187 continue;
1188 }
1189 let mut h = hub.lock().await;
1191 let key = (room.clone(), batch_id);
1192 if let Some(b) = h.fragments.get(&key) {
1193 if b.from_conn != conn_id {
1194 let err = ProtocolMessage::UpdateError {
1195 crdt,
1196 room_id: room.room.clone(),
1197 code: protocol::UpdateErrorCode::InvalidUpdate,
1198 message: "Fragment from wrong sender for batch".into(),
1199 batch_id: Some(batch_id),
1200 app_code: None,
1201 };
1202 if let Ok(bytes) = loro_protocol::encode(&err) {
1203 let _ = tx.send(Message::Binary(bytes.into()));
1204 }
1205 continue;
1207 }
1208 if !usize::try_from(index)
1209 .ok()
1210 .map(|i| i < b.chunks.len())
1211 .unwrap_or(false)
1212 {
1213 let err = ProtocolMessage::UpdateError {
1214 crdt,
1215 room_id: room.room.clone(),
1216 code: protocol::UpdateErrorCode::InvalidUpdate,
1217 message: "Fragment index out of range".into(),
1218 batch_id: Some(batch_id),
1219 app_code: None,
1220 };
1221 if let Ok(bytes) = loro_protocol::encode(&err) {
1222 let _ = tx.send(Message::Binary(bytes.into()));
1223 }
1224 continue;
1225 }
1226 } else {
1227 let err = ProtocolMessage::UpdateError {
1228 crdt,
1229 room_id: room.room.clone(),
1230 code: protocol::UpdateErrorCode::InvalidUpdate,
1231 message: "Unknown fragment batch".into(),
1232 batch_id: Some(batch_id),
1233 app_code: None,
1234 };
1235 if let Ok(bytes) = loro_protocol::encode(&err) {
1236 let _ = tx.send(Message::Binary(bytes.into()));
1237 }
1238 continue;
1239 }
1240 h.broadcast(&room, conn_id, Message::Binary(data.clone()));
1242 if let Some(buf) =
1244 h.add_fragment_and_maybe_finish(&room, batch_id, index, fragment)
1245 {
1246 match crdt {
1248 CrdtType::Loro
1249 | CrdtType::LoroEphemeralStore
1250 | CrdtType::LoroEphemeralStorePersisted => {
1251 if let Ok(updates) = parse_docupdate_payload(&buf) {
1252 let start = std::time::Instant::now();
1253 h.apply_updates(&room, &updates);
1254 let elapsed_ms = start.elapsed().as_millis();
1255 debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied reassembled updates");
1256 }
1257 }
1258 CrdtType::Elo => {
1259 h.apply_updates(&room, &[buf]);
1261 }
1262 _ => {}
1263 }
1264 }
1265 }
1266 ProtocolMessage::DocUpdate {
1267 crdt,
1268 room_id,
1269 updates,
1270 } => {
1271 let room = RoomKey {
1272 crdt,
1273 room: room_id.clone(),
1274 };
1275 if !joined_rooms.contains(&room) {
1276 let err = ProtocolMessage::UpdateError {
1278 crdt,
1279 room_id: room.room.clone(),
1280 code: protocol::UpdateErrorCode::PermissionDenied,
1281 message: "Must join room before sending updates".into(),
1282 batch_id: None,
1283 app_code: None,
1284 };
1285 if let Ok(bytes) = loro_protocol::encode(&err) {
1286 let _ = tx.send(Message::Binary(bytes.into()));
1287 }
1288 warn!(room=?room.room, "update rejected: not joined");
1289 } else {
1290 let perm = hub
1292 .lock()
1293 .await
1294 .perms
1295 .get(&(conn_id, room.clone()))
1296 .copied();
1297 if !matches!(perm, Some(Permission::Write)) {
1298 let err = ProtocolMessage::UpdateError {
1299 crdt,
1300 room_id: room.room.clone(),
1301 code: protocol::UpdateErrorCode::PermissionDenied,
1302 message: "Write permission required to update document"
1303 .into(),
1304 batch_id: None,
1305 app_code: None,
1306 };
1307 if let Ok(bytes) = loro_protocol::encode(&err) {
1308 let _ = tx.send(Message::Binary(bytes.into()));
1309 }
1310 continue;
1311 }
1312 let mut h = hub.lock().await;
1313 match crdt {
1314 CrdtType::Loro
1315 | CrdtType::LoroEphemeralStore
1316 | CrdtType::LoroEphemeralStorePersisted => {
1317 let start = std::time::Instant::now();
1318 h.apply_updates(&room, &updates);
1319 let elapsed_ms = start.elapsed().as_millis();
1320 h.broadcast(&room, conn_id, Message::Binary(data));
1321 debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied and broadcast updates");
1322 }
1323 CrdtType::Elo => {
1324 h.apply_updates(&room, &updates);
1326 h.broadcast(&room, conn_id, Message::Binary(data));
1327 debug!(room=?room.room, updates=%updates.len(), "indexed and broadcast ELO updates");
1328 }
1329 _ => {
1330 h.broadcast(&room, conn_id, Message::Binary(data));
1331 }
1332 }
1333 }
1334 }
1335 _ => {
1336 }
1338 }
1339 } else {
1340 warn!("invalid protocol frame; closing connection");
1342 let _ = tx.send(Message::Close(Some(CloseFrame {
1343 code: CloseCode::Protocol,
1344 reason: "Protocol error".into(),
1345 })));
1346 break;
1347 }
1348 }
1349 Message::Close(frame) => {
1350 let _ = tx.send(Message::Close(frame.clone()));
1351 break;
1352 }
1353 Message::Ping(p) => {
1354 let _ = tx.send(Message::Pong(p));
1355 let _ = tx.send(Message::Text("pong".into()));
1356 }
1357 _ => {}
1358 }
1359 }
1360
1361 {
1363 let mut h = hub.lock().await;
1364 h.leave_all(conn_id);
1365 }
1366 drop(tx);
1368 let _ = sink_task.await;
1369 debug!(conn_id, "connection closed and cleaned up");
1370 Ok(())
1371}
1372
1373fn parse_docupdate_payload(buf: &[u8]) -> Result<Vec<Vec<u8>>, String> {
1374 let mut r = BytesReader::new(buf);
1375 let n = usize::try_from(r.read_uleb128()?).map_err(|_| "length too large".to_string())?;
1376 let mut out: Vec<Vec<u8>> = Vec::with_capacity(n);
1377 for _ in 0..n {
1378 let b = r.read_var_bytes()?.to_vec();
1379 out.push(b);
1380 }
1381 if r.remaining() != 0 {
1382 return Err("trailing bytes".into());
1383 }
1384 Ok(out)
1385}