1use futures_util::{SinkExt, StreamExt};
44use std::{
45 collections::HashMap,
46 hash::{Hash, Hasher},
47 sync::{
48 atomic::{AtomicU64, Ordering},
49 Arc,
50 },
51};
52use tokio::{
53 net::TcpStream,
54 sync::{mpsc, oneshot, Mutex},
55};
56use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
57
58use aes_gcm::aead::{Aead, KeyInit};
59use loro::LoroDoc;
60pub use loro_protocol as protocol;
61use protocol::{encode, try_decode, CrdtType, ProtocolMessage};
62
63#[derive(Debug)]
65pub enum ClientError {
66 Unauthorized,
68 Ws(Box<tokio_tungstenite::tungstenite::Error>),
70 Protocol(String),
72}
73
74impl std::fmt::Display for ClientError {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 ClientError::Unauthorized => write!(f, "unauthorized"),
78 ClientError::Ws(e) => write!(f, "websocket error: {}", e),
79 ClientError::Protocol(e) => write!(f, "protocol error: {}", e),
80 }
81 }
82}
83impl std::error::Error for ClientError {}
84impl From<tokio_tungstenite::tungstenite::Error> for ClientError {
85 fn from(e: tokio_tungstenite::tungstenite::Error) -> Self {
86 ClientError::Ws(Box::new(e))
87 }
88}
89
90type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
91
92#[derive(Debug, Clone)]
94pub struct ClientConfig {
95 pub fragment_reassembly_timeout: std::time::Duration,
97 pub fragment_limit_headroom: usize,
99 pub fragment_limit_soft_max: usize,
101}
102
103impl Default for ClientConfig {
104 fn default() -> Self {
105 Self {
106 fragment_reassembly_timeout: std::time::Duration::from_secs(10),
107 fragment_limit_headroom: 4096,
108 fragment_limit_soft_max: 240 * 1024,
109 }
110 }
111}
112
113pub struct Client {
115 ws: Ws,
116}
117
118impl Client {
119 pub async fn connect(url: &str) -> Result<Self, ClientError> {
121 match connect_async(url).await {
122 Ok((ws, _resp)) => Ok(Self { ws }),
123 Err(e) => {
124 if let tokio_tungstenite::tungstenite::Error::Http(resp) = &e {
126 if resp.status()
127 == tokio_tungstenite::tungstenite::http::StatusCode::UNAUTHORIZED
128 {
129 return Err(ClientError::Unauthorized);
130 }
131 }
132 let s = e.to_string().to_lowercase();
133 if s.contains("401") || s.contains("unauthorized") {
134 return Err(ClientError::Unauthorized);
135 }
136 Err(ClientError::Ws(Box::new(e)))
137 }
138 }
139 }
140
141 pub async fn send(&mut self, msg: &ProtocolMessage) -> Result<(), ClientError> {
143 let data = encode(msg).map_err(ClientError::Protocol)?;
144 self.ws.send(Message::Binary(data.into())).await?;
145 Ok(())
146 }
147
148 pub async fn ping(&mut self) -> Result<(), ClientError> {
150 self.ws.send(Message::Text("ping".into())).await?;
151 Ok(())
152 }
153
154 pub async fn next(&mut self) -> Result<Option<ProtocolMessage>, ClientError> {
159 loop {
160 match self.ws.next().await {
161 Some(Ok(Message::Binary(data))) => {
162 if let Some(msg) = try_decode(data.as_ref()) {
163 return Ok(Some(msg));
164 }
165 }
167 Some(Ok(Message::Text(txt))) => {
168 if txt == "ping" {
169 self.ws.send(Message::Text("pong".into())).await?;
170 }
171 }
173 Some(Ok(Message::Ping(_))) => {
174 self.ws.send(Message::Text("pong".into())).await?;
177 }
178 Some(Ok(Message::Close(_))) => return Ok(None),
179 Some(Ok(_)) => { }
180 Some(Err(e)) => return Err(ClientError::Ws(Box::new(e))),
181 None => return Ok(None),
182 }
183 }
184 }
185
186 pub async fn close(mut self) -> Result<(), ClientError> {
188 self.ws.close(None).await?;
189 Ok(())
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn client_error_display() {
199 let e = ClientError::Protocol("bad".into());
200 assert!(format!("{}", e).contains("protocol error"));
201 }
202
203 #[derive(Default)]
204 struct RecordingAdaptor {
205 updates: Arc<Mutex<Vec<Vec<u8>>>>,
206 }
207
208 #[async_trait::async_trait]
209 impl CrdtDocAdaptor for RecordingAdaptor {
210 fn crdt_type(&self) -> CrdtType {
211 CrdtType::Loro
212 }
213
214 async fn version(&self) -> Vec<u8> {
215 Vec::new()
216 }
217
218 async fn set_ctx(&mut self, _ctx: CrdtAdaptorContext) {}
219
220 async fn handle_join_ok(
221 &mut self,
222 _permission: protocol::Permission,
223 _version: Vec<u8>,
224 ) {
225 }
226
227 async fn apply_update(&mut self, updates: Vec<Vec<u8>>) {
228 self.updates.lock().await.extend(updates);
229 }
230 }
231
232 #[tokio::test(flavor = "current_thread")]
233 async fn fragment_reassembly_delivers_updates_in_order() {
234 let (tx, _rx) = mpsc::unbounded_channel::<Message>();
235 let rooms = Arc::new(Mutex::new(HashMap::new()));
236 let pending = Arc::new(Mutex::new(HashMap::new()));
237 let adaptors = Arc::new(Mutex::new(HashMap::new()));
238 let pre_join_buf = Arc::new(Mutex::new(HashMap::new()));
239 let frag_batches = Arc::new(Mutex::new(HashMap::new()));
240 let config = Arc::new(ClientConfig::default());
241
242 let worker = ConnectionWorker::new(
243 tx,
244 rooms,
245 pending,
246 adaptors.clone(),
247 pre_join_buf,
248 frag_batches,
249 config,
250 );
251
252 let room_id = "room-frag".to_string();
253 let key = RoomKey {
254 crdt: CrdtType::Loro,
255 room: room_id.clone(),
256 };
257 let collected = Arc::new(Mutex::new(Vec::<Vec<u8>>::new()));
258 adaptors.lock().await.insert(
259 key.clone(),
260 Box::new(RecordingAdaptor {
261 updates: collected.clone(),
262 }),
263 );
264
265 let batch_id = protocol::BatchId([1, 2, 3, 4, 5, 6, 7, 8]);
266 worker
267 .handle_message(ProtocolMessage::DocUpdateFragmentHeader {
268 crdt: CrdtType::Loro,
269 room_id: room_id.clone(),
270 batch_id,
271 fragment_count: 2,
272 total_size_bytes: 10,
273 })
274 .await;
275 worker
277 .handle_message(ProtocolMessage::DocUpdateFragment {
278 crdt: CrdtType::Loro,
279 room_id: room_id.clone(),
280 batch_id,
281 index: 1,
282 fragment: b"world".to_vec(),
283 })
284 .await;
285 worker
286 .handle_message(ProtocolMessage::DocUpdateFragment {
287 crdt: CrdtType::Loro,
288 room_id,
289 batch_id,
290 index: 0,
291 fragment: b"hello".to_vec(),
292 })
293 .await;
294
295 let updates = collected.lock().await;
296 assert_eq!(updates.as_slice(), &[b"helloworld".to_vec()]);
297 }
298
299 #[tokio::test(flavor = "current_thread")]
300 async fn elo_snapshot_container_roundtrips_plaintext() {
301 let doc = Arc::new(Mutex::new(LoroDoc::new()));
302 let key = [7u8; 32];
303 let adaptor = EloDocAdaptor::new(doc, "kid", key)
304 .with_iv_factory(Arc::new(|| [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]));
305 let plaintext = b"hello-elo".to_vec();
306
307 let container = adaptor.encode_elo_snapshot_container(&plaintext);
308 let records =
309 protocol::elo::decode_elo_container(&container).expect("container should decode");
310 assert_eq!(records.len(), 1);
311 let parsed =
312 protocol::elo::parse_elo_record_header(records[0]).expect("header should parse");
313 match parsed.header {
314 protocol::elo::EloHeader::Snapshot(hdr) => {
315 assert_eq!(hdr.key_id, "kid");
316 assert_eq!(hdr.iv, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
317 let cipher = aes_gcm::Aes256Gcm::new_from_slice(&key).unwrap();
318 let decrypted = cipher
319 .decrypt(
320 aes_gcm::Nonce::from_slice(&hdr.iv),
321 aes_gcm::aead::Payload {
322 msg: parsed.ct,
323 aad: parsed.aad,
324 },
325 )
326 .unwrap();
327 assert_eq!(decrypted, plaintext);
328 }
329 _ => panic!("expected snapshot header"),
330 }
331 assert!(matches!(
332 parsed.kind,
333 protocol::elo::EloRecordKind::Snapshot
334 ));
335 }
336}
337
338#[derive(Clone)]
339struct ConnectionWorker {
340 tx: mpsc::UnboundedSender<Message>,
341 rooms: Arc<Mutex<HashMap<RoomKey, RoomState>>>,
342 pending: Arc<Mutex<PendingMap>>,
343 adaptors: Arc<Mutex<HashMap<RoomKey, Box<dyn CrdtDocAdaptor + Send + Sync>>>>,
344 pre_join_buf: Arc<Mutex<HashMap<RoomKey, Vec<Vec<u8>>>>>,
345 frag_batches: Arc<Mutex<HashMap<(RoomKey, protocol::BatchId), FragmentBatch>>>,
346 config: Arc<ClientConfig>,
347}
348
349impl ConnectionWorker {
350 fn new(
351 tx: mpsc::UnboundedSender<Message>,
352 rooms: Arc<Mutex<HashMap<RoomKey, RoomState>>>,
353 pending: Arc<Mutex<PendingMap>>,
354 adaptors: Arc<Mutex<HashMap<RoomKey, Box<dyn CrdtDocAdaptor + Send + Sync>>>>,
355 pre_join_buf: Arc<Mutex<HashMap<RoomKey, Vec<Vec<u8>>>>>,
356 frag_batches: Arc<Mutex<HashMap<(RoomKey, protocol::BatchId), FragmentBatch>>>,
357 config: Arc<ClientConfig>,
358 ) -> Self {
359 Self {
360 tx,
361 rooms,
362 pending,
363 adaptors,
364 pre_join_buf,
365 frag_batches,
366 config,
367 }
368 }
369
370 fn spawn(self, mut stream: futures_util::stream::SplitStream<Ws>) {
371 tokio::spawn(async move {
372 while let Some(frame) = stream.next().await {
373 match frame {
374 Ok(Message::Text(txt)) => {
375 self.handle_text(txt.to_string()).await;
376 }
377 Ok(Message::Binary(data)) => {
378 self.handle_binary(data.to_vec()).await;
379 }
380 Ok(Message::Ping(p)) => {
381 let _ = self.tx.send(Message::Pong(p));
382 let _ = self.tx.send(Message::Text("pong".into()));
383 }
384 Ok(Message::Close(_)) => break,
385 Ok(_) => {}
386 Err(e) => {
387 eprintln!("ws read error: {}", e);
388 break;
389 }
390 }
391 }
392 });
393 }
394
395 async fn handle_text(&self, txt: String) {
396 if txt == "ping" {
397 let _ = self.tx.send(Message::Text("pong".into()));
398 }
399 }
401
402 async fn handle_binary(&self, data: Vec<u8>) {
403 if let Some(msg) = try_decode(data.as_ref()) {
404 self.handle_message(msg).await;
405 }
406 }
407
408 async fn handle_message(&self, msg: ProtocolMessage) {
409 let key = RoomKey {
410 crdt: msg_crdt(&msg),
411 room: msg_room_id(&msg),
412 };
413 match msg {
414 ProtocolMessage::JoinResponseOk {
415 permission,
416 version,
417 ..
418 } => {
419 if let Some(ch) = self.pending.lock().await.remove(&key) {
420 let _ = ch.send(JoinOutcome::Ok {
421 permission,
422 version,
423 });
424 }
425 }
426 ProtocolMessage::JoinError {
427 code,
428 message,
429 receiver_version,
430 ..
431 } => {
432 if let Some(ch) = self.pending.lock().await.remove(&key) {
433 let _ = ch.send(JoinOutcome::Err {
434 code,
435 message: message.clone(),
436 receiver_version,
437 });
438 }
439 eprintln!("join error: {:?} - {}", code, message);
440 }
441 ProtocolMessage::DocUpdate { updates, .. } => {
442 if let Some(adaptor) = self.adaptors.lock().await.get_mut(&key) {
443 adaptor.apply_update(updates).await;
444 } else if let Some(state) = self.rooms.lock().await.get(&key) {
445 let doc = state.doc.lock().await;
446 for u in updates {
447 let _ = doc.import(&u);
448 }
449 } else {
450 let mut buf = self.pre_join_buf.lock().await;
451 buf.entry(key).or_default().extend(updates);
452 }
453 }
454 ProtocolMessage::DocUpdateFragmentHeader {
455 batch_id,
456 fragment_count,
457 total_size_bytes,
458 ..
459 } => {
460 let mut map = self.frag_batches.lock().await;
462 map.insert(
463 (key.clone(), batch_id),
464 FragmentBatch {
465 fragment_count: fragment_count as usize,
466 total_size_bytes: total_size_bytes as usize,
467 slots: vec![Vec::new(); fragment_count as usize],
468 received: 0,
469 },
470 );
471 drop(map);
472
473 let batches = self.frag_batches.clone();
475 let key_clone = key.clone();
476 let tx_timeout = self.tx.clone();
477 let timeout = self.config.fragment_reassembly_timeout;
478 tokio::spawn(async move {
479 use tokio::time::sleep;
480 sleep(timeout).await;
481 let mut m = batches.lock().await;
482 if m.remove(&(key_clone.clone(), batch_id)).is_some() {
483 let err = ProtocolMessage::UpdateError {
484 crdt: key_clone.crdt,
485 room_id: key_clone.room.clone(),
486 code: protocol::UpdateErrorCode::FragmentTimeout,
487 message: format!(
488 "Fragment reassembly timeout for batch {}",
489 batch_id.to_hex()
490 ),
491 batch_id: Some(batch_id),
492 app_code: None,
493 };
494 if let Ok(data) = encode(&err) {
495 let _ = tx_timeout.send(Message::Binary(data.into()));
496 }
497 }
498 });
499 }
500 ProtocolMessage::DocUpdateFragment {
501 batch_id,
502 index,
503 fragment,
504 ..
505 } => {
506 let mut map = self.frag_batches.lock().await;
507 if let Some(batch) = map.get_mut(&(key.clone(), batch_id)) {
508 let i = index as usize;
509 if i < batch.slots.len() && batch.slots[i].is_empty() {
510 batch.slots[i] = fragment;
511 batch.received += 1;
512 }
513 if batch.received == batch.fragment_count {
514 let mut reassembled = Vec::with_capacity(batch.total_size_bytes);
515 for s in batch.slots.iter() {
516 reassembled.extend_from_slice(s);
517 }
518 map.remove(&(key.clone(), batch_id));
519 drop(map);
520 if let Some(adaptor) = self.adaptors.lock().await.get_mut(&key) {
521 adaptor.apply_update(vec![reassembled]).await;
522 } else if let Some(state) = self.rooms.lock().await.get(&key) {
523 let doc = state.doc.lock().await;
524 let _ = doc.import(&reassembled);
525 } else {
526 let mut buf = self.pre_join_buf.lock().await;
527 buf.entry(key).or_default().push(reassembled);
528 }
529 }
530 } else {
531 eprintln!(
532 "Received fragment for unknown batch {:?} in room {:?}",
533 batch_id, key.room
534 );
535 }
536 }
537 ProtocolMessage::UpdateError { code, message, .. } => {
538 if let Some(adaptor) = self.adaptors.lock().await.get_mut(&key) {
539 adaptor.handle_update_error(code, &message).await;
540 } else {
541 eprintln!("update error (no adaptor): {:?} - {}", code, message);
542 }
543 }
544 ProtocolMessage::Leave { .. } | ProtocolMessage::JoinRequest { .. } => {}
545 }
546 }
547}
548
549#[derive(Clone, Debug, PartialEq, Eq)]
550struct RoomKey {
551 crdt: CrdtType,
552 room: String,
553}
554impl Hash for RoomKey {
555 fn hash<H: Hasher>(&self, state: &mut H) {
556 let tag = match self.crdt {
557 CrdtType::Loro => 0u8,
558 CrdtType::LoroEphemeralStore => 1,
559 CrdtType::LoroEphemeralStorePersisted => 2,
560 CrdtType::Yjs => 3,
561 CrdtType::YjsAwareness => 4,
562 CrdtType::Elo => 5,
563 };
564 tag.hash(state);
565 self.room.hash(state);
566 }
567}
568
569struct RoomState {
570 doc: Arc<Mutex<LoroDoc>>,
571 sub: Option<loro::Subscription>,
572}
573
574enum JoinOutcome {
575 Ok {
576 permission: protocol::Permission,
577 version: Vec<u8>,
578 },
579 Err {
580 code: protocol::JoinErrorCode,
581 message: String,
582 receiver_version: Option<Vec<u8>>,
583 },
584}
585
586type PendingMap = HashMap<RoomKey, oneshot::Sender<JoinOutcome>>;
587
588#[derive(Clone)]
590pub struct LoroWebsocketClient {
591 tx: mpsc::UnboundedSender<Message>,
592 rooms: Arc<Mutex<HashMap<RoomKey, RoomState>>>,
593 pending: Arc<Mutex<PendingMap>>,
595 adaptors: Arc<Mutex<HashMap<RoomKey, Box<dyn CrdtDocAdaptor + Send + Sync>>>>,
597 pre_join_buf: Arc<Mutex<HashMap<RoomKey, Vec<Vec<u8>>>>>,
599 next_batch_id: Arc<AtomicU64>,
601 config: Arc<ClientConfig>,
603}
604
605impl LoroWebsocketClient {
606 pub async fn connect(url: &str) -> Result<Self, ClientError> {
608 Self::connect_with_config(url, ClientConfig::default()).await
609 }
610
611 pub async fn connect_with_config(url: &str, config: ClientConfig) -> Result<Self, ClientError> {
613 let (ws, _resp) = match connect_async(url).await {
614 Ok(ok) => ok,
615 Err(e) => {
616 if let tokio_tungstenite::tungstenite::Error::Http(resp) = &e {
617 if resp.status()
618 == tokio_tungstenite::tungstenite::http::StatusCode::UNAUTHORIZED
619 {
620 return Err(ClientError::Unauthorized);
621 }
622 }
623 let s = e.to_string().to_lowercase();
624 if s.contains("401") || s.contains("unauthorized") {
625 return Err(ClientError::Unauthorized);
626 }
627 return Err(ClientError::Ws(Box::new(e)));
628 }
629 };
630 let (mut sink, stream) = ws.split();
631 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
632
633 tokio::spawn(async move {
635 while let Some(msg) = rx.recv().await {
636 if sink.send(msg).await.is_err() {
637 break;
638 }
639 }
640 });
641
642 let rooms: Arc<Mutex<HashMap<RoomKey, RoomState>>> = Arc::new(Mutex::new(HashMap::new()));
643 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
644 let adaptors_reader: Arc<Mutex<HashMap<RoomKey, Box<dyn CrdtDocAdaptor + Send + Sync>>>> =
645 Arc::new(Mutex::new(HashMap::new()));
646 let pre_join_buf_reader: Arc<Mutex<HashMap<RoomKey, Vec<Vec<u8>>>>> =
647 Arc::new(Mutex::new(HashMap::new()));
648 let frag_batches_reader: Arc<Mutex<HashMap<(RoomKey, protocol::BatchId), FragmentBatch>>> =
649 Arc::new(Mutex::new(HashMap::new()));
650
651 let cfg = Arc::new(config);
653 ConnectionWorker::new(
654 tx.clone(),
655 rooms.clone(),
656 pending.clone(),
657 adaptors_reader.clone(),
658 pre_join_buf_reader.clone(),
659 frag_batches_reader,
660 cfg.clone(),
661 )
662 .spawn(stream);
663
664 Ok(Self {
665 tx,
666 rooms,
667 pending,
668 adaptors: adaptors_reader,
669 pre_join_buf: pre_join_buf_reader,
670 next_batch_id: Arc::new(AtomicU64::new(1)),
671 config: cfg,
672 })
673 }
674
675 pub async fn join_loro(
677 &self,
678 room_id: &str,
679 doc: Arc<Mutex<LoroDoc>>,
680 ) -> Result<LoroWebsocketClientRoom, ClientError> {
681 let key = RoomKey {
682 crdt: CrdtType::Loro,
683 room: room_id.to_string(),
684 };
685 self.rooms.lock().await.insert(
687 key.clone(),
688 RoomState {
689 doc: doc.clone(),
690 sub: None,
691 },
692 );
693
694 let (tx_done, rx_done) = oneshot::channel::<JoinOutcome>();
695 self.pending.lock().await.insert(key.clone(), tx_done);
696
697 let local_version = doc.lock().await.oplog_vv().encode();
699 let msg = ProtocolMessage::JoinRequest {
700 crdt: CrdtType::Loro,
701 room_id: key.room.clone(),
702 auth: Vec::new(),
703 version: local_version,
704 };
705 let data = encode(&msg).map_err(ClientError::Protocol)?;
706 self.tx
707 .send(Message::Binary(data.into()))
708 .map_err(|_| ClientError::Protocol("send failed".into()))?;
709
710 match rx_done.await {
712 Ok(JoinOutcome::Ok { permission, .. }) => {
713 if matches!(permission, protocol::Permission::Write) {
715 let tx2 = self.tx.clone();
716 let key2 = key.clone();
717 let sub = {
718 let guard = doc.lock().await;
719 guard.subscribe_local_update(Box::new(move |bytes| {
720 let msg = ProtocolMessage::DocUpdate {
721 crdt: key2.crdt,
722 room_id: key2.room.clone(),
723 updates: vec![bytes.clone()],
724 };
725 if let Ok(data) = encode(&msg) {
726 let _ = tx2.send(Message::Binary(data.into()));
727 }
728 true
729 }))
730 };
731 self.rooms.lock().await.insert(
733 key.clone(),
734 RoomState {
735 doc: doc.clone(),
736 sub: Some(sub),
737 },
738 );
739 } else {
740 self.rooms.lock().await.insert(
742 key.clone(),
743 RoomState {
744 doc: doc.clone(),
745 sub: None,
746 },
747 );
748 }
749 }
750 Ok(JoinOutcome::Err { code, message, .. }) => {
751 self.rooms.lock().await.remove(&key);
753 return Err(ClientError::Protocol(format!(
754 "join error: {:?} - {}",
755 code, message
756 )));
757 }
758 Err(_) => {
759 self.rooms.lock().await.remove(&key);
760 return Err(ClientError::Protocol("join canceled".into()));
761 }
762 }
763
764 Ok(LoroWebsocketClientRoom {
765 inner: self.clone(),
766 key,
767 })
768 }
769
770 pub fn ping(&self) -> Result<(), ClientError> {
772 self.tx
773 .send(Message::Text("ping".into()))
774 .map_err(|_| ClientError::Protocol("send failed".into()))
775 }
776
777 pub async fn join_with_adaptor(
779 &self,
780 room_id: &str,
781 mut adaptor: Box<dyn CrdtDocAdaptor + Send + Sync>,
782 ) -> Result<LoroWebsocketClientRoom, ClientError> {
783 let key = RoomKey {
784 crdt: adaptor.crdt_type(),
785 room: room_id.to_string(),
786 };
787
788 let tx2 = self.tx.clone();
791 let room_vec = key.room.clone();
792 let crdt = key.crdt;
793 let batch_counter = self.next_batch_id.clone();
794 let cfg = self.config.clone();
795 let send_update = move |upd: Vec<u8>| {
796 let frag_limit = std::cmp::max(
798 1usize,
799 std::cmp::min(
800 cfg.fragment_limit_soft_max,
801 protocol::MAX_MESSAGE_SIZE.saturating_sub(cfg.fragment_limit_headroom),
802 ),
803 );
804
805 if upd.len() <= frag_limit {
806 let msg = ProtocolMessage::DocUpdate {
807 crdt,
808 room_id: room_vec.clone(),
809 updates: vec![upd],
810 };
811 if let Ok(data) = encode(&msg) {
812 let _ = tx2.send(Message::Binary(data.into()));
813 }
814 } else {
815 let total = upd.len();
816 let n = (total + frag_limit - 1) / frag_limit;
817 let batch_id =
818 protocol::BatchId(batch_counter.fetch_add(1, Ordering::Relaxed).to_be_bytes());
819 let header = ProtocolMessage::DocUpdateFragmentHeader {
821 crdt,
822 room_id: room_vec.clone(),
823 batch_id,
824 fragment_count: n as u64,
825 total_size_bytes: total as u64,
826 };
827 if let Ok(data) = encode(&header) {
828 let _ = tx2.send(Message::Binary(data.into()));
829 }
830 for i in 0..n {
832 let start = i * frag_limit;
833 let end = ((i + 1) * frag_limit).min(total);
834 let frag = upd[start..end].to_vec();
835 let msg = ProtocolMessage::DocUpdateFragment {
836 crdt,
837 room_id: room_vec.clone(),
838 batch_id,
839 index: i as u64,
840 fragment: frag,
841 };
842 if let Ok(data) = encode(&msg) {
843 let _ = tx2.send(Message::Binary(data.into()));
844 }
845 }
846 }
847 };
848
849 let tx_err = self.tx.clone();
850 let room_vec2 = key.room.clone();
851 let crdt2 = key.crdt;
852 let on_join_failed = move |reason: String| {
853 let msg = ProtocolMessage::JoinError {
854 crdt: crdt2,
855 room_id: room_vec2.clone(),
856 code: protocol::JoinErrorCode::AppError,
857 message: reason,
858 receiver_version: None,
859 app_code: None,
860 };
861 if let Ok(data) = encode(&msg) {
862 let _ = tx_err.send(Message::Binary(data.into()));
863 }
864 };
865 let tx_imp = self.tx.clone();
866 let room_vec3 = key.room.clone();
867 let crdt3 = key.crdt;
868 let on_import_error = move |err: String, _data: Vec<Vec<u8>>| {
869 let msg = ProtocolMessage::UpdateError {
870 crdt: crdt3,
871 room_id: room_vec3.clone(),
872 code: protocol::UpdateErrorCode::AppError,
873 message: err,
874 batch_id: None,
875 app_code: None,
876 };
877 if let Ok(data) = encode(&msg) {
878 let _ = tx_imp.send(Message::Binary(data.into()));
879 }
880 };
881
882 adaptor
883 .set_ctx(CrdtAdaptorContext {
884 send_update: Arc::new(send_update),
885 on_join_failed: Arc::new(on_join_failed),
886 on_import_error: Arc::new(on_import_error),
887 })
888 .await;
889
890 self.adaptors.lock().await.insert(key.clone(), adaptor);
892
893 let mut current_version = if let Some(ad) = self.adaptors.lock().await.get(&key) {
895 ad.version().await
896 } else {
897 Vec::new()
898 };
899 let mut tried_empty = false;
900 loop {
901 let (tx_done, rx_done) = oneshot::channel::<JoinOutcome>();
903 self.pending.lock().await.insert(key.clone(), tx_done);
904 let msg = ProtocolMessage::JoinRequest {
905 crdt: key.crdt,
906 room_id: key.room.clone(),
907 auth: Vec::new(),
908 version: current_version.clone(),
909 };
910 let data = encode(&msg).map_err(ClientError::Protocol)?;
911 self.tx
912 .send(Message::Binary(data.into()))
913 .map_err(|_| ClientError::Protocol("send failed".into()))?;
914
915 match rx_done.await {
916 Ok(JoinOutcome::Ok {
917 permission,
918 version: server_version,
919 }) => {
920 if let Some(adaptor) = self.adaptors.lock().await.get_mut(&key) {
921 if let Some(buf) = self.pre_join_buf.lock().await.remove(&key) {
922 adaptor.apply_update(buf).await;
923 }
924 adaptor.handle_join_ok(permission, server_version).await;
925 }
926 break;
927 }
928 Ok(JoinOutcome::Err {
929 code,
930 message,
931 receiver_version: _rv,
932 }) => {
933 if let Some(adaptor) = self.adaptors.lock().await.get_mut(&key) {
935 adaptor.handle_join_err(code, &message).await;
936 if code == protocol::JoinErrorCode::VersionUnknown {
937 if let Some(alt) =
939 adaptor.get_alternative_version(¤t_version).await
940 {
941 current_version = alt;
942 continue;
943 } else if !tried_empty {
944 current_version = Vec::new();
945 tried_empty = true;
946 continue;
947 }
948 }
949 }
950 self.adaptors.lock().await.remove(&key);
951 return Err(ClientError::Protocol(format!(
952 "join error: {:?} - {}",
953 code, message
954 )));
955 }
956 Err(_) => {
957 self.adaptors.lock().await.remove(&key);
958 return Err(ClientError::Protocol("join canceled".into()));
959 }
960 }
961 }
962
963 Ok(LoroWebsocketClientRoom {
964 inner: self.clone(),
965 key,
966 })
967 }
968
969 pub async fn join_loro_with_adaptor(
971 &self,
972 room_id: &str,
973 doc: Arc<Mutex<LoroDoc>>,
974 ) -> Result<LoroWebsocketClientRoom, ClientError> {
975 let adaptor: Box<dyn CrdtDocAdaptor + Send + Sync> = Box::new(LoroDocAdaptor::new(doc));
976 self.join_with_adaptor(room_id, adaptor).await
977 }
978
979 pub async fn join_elo_with_adaptor(
981 &self,
982 room_id: &str,
983 doc: Arc<Mutex<LoroDoc>>,
984 key_id: impl Into<String>,
985 key: [u8; 32],
986 ) -> Result<LoroWebsocketClientRoom, ClientError> {
987 let adaptor: Box<dyn CrdtDocAdaptor + Send + Sync> =
988 Box::new(EloDocAdaptor::new(doc, key_id, key));
989 self.join_with_adaptor(room_id, adaptor).await
990 }
991}
992
993#[derive(Clone)]
995pub struct LoroWebsocketClientRoom {
996 inner: LoroWebsocketClient,
997 key: RoomKey,
998}
999
1000impl LoroWebsocketClientRoom {
1001 pub async fn leave(&self) -> Result<(), ClientError> {
1003 let msg = ProtocolMessage::Leave {
1004 crdt: self.key.crdt,
1005 room_id: self.key.room.clone(),
1006 };
1007 let data = encode(&msg).map_err(ClientError::Protocol)?;
1008 self.inner
1009 .tx
1010 .send(Message::Binary(data.into()))
1011 .map_err(|_| ClientError::Protocol("send failed".into()))?;
1012 if let Some(state) = self.inner.rooms.lock().await.remove(&self.key) {
1014 if let Some(sub) = state.sub {
1015 sub.unsubscribe();
1016 }
1017 }
1018 self.inner.adaptors.lock().await.remove(&self.key);
1020 Ok(())
1021 }
1022}
1023
1024fn msg_crdt(msg: &ProtocolMessage) -> CrdtType {
1025 match msg {
1026 ProtocolMessage::JoinRequest { crdt, .. }
1027 | ProtocolMessage::JoinResponseOk { crdt, .. }
1028 | ProtocolMessage::JoinError { crdt, .. }
1029 | ProtocolMessage::DocUpdate { crdt, .. }
1030 | ProtocolMessage::DocUpdateFragmentHeader { crdt, .. }
1031 | ProtocolMessage::DocUpdateFragment { crdt, .. }
1032 | ProtocolMessage::UpdateError { crdt, .. }
1033 | ProtocolMessage::Leave { crdt, .. } => *crdt,
1034 }
1035}
1036
1037fn msg_room_id(msg: &ProtocolMessage) -> String {
1038 match msg {
1039 ProtocolMessage::JoinRequest { room_id, .. }
1040 | ProtocolMessage::JoinResponseOk { room_id, .. }
1041 | ProtocolMessage::JoinError { room_id, .. }
1042 | ProtocolMessage::DocUpdate { room_id, .. }
1043 | ProtocolMessage::DocUpdateFragmentHeader { room_id, .. }
1044 | ProtocolMessage::DocUpdateFragment { room_id, .. }
1045 | ProtocolMessage::UpdateError { room_id, .. }
1046 | ProtocolMessage::Leave { room_id, .. } => room_id.clone(),
1047 }
1048}
1049
1050struct FragmentBatch {
1052 fragment_count: usize,
1053 total_size_bytes: usize,
1054 slots: Vec<Vec<u8>>,
1055 received: usize,
1056}
1057
1058#[async_trait::async_trait]
1060pub trait CrdtDocAdaptor {
1061 fn crdt_type(&self) -> CrdtType;
1062 async fn version(&self) -> Vec<u8>;
1063 async fn set_ctx(&mut self, ctx: CrdtAdaptorContext);
1064 async fn handle_join_ok(&mut self, permission: protocol::Permission, version: Vec<u8>);
1065 async fn apply_update(&mut self, updates: Vec<Vec<u8>>);
1066 async fn handle_update_error(&mut self, _code: protocol::UpdateErrorCode, _message: &str) {}
1067 async fn handle_join_err(&mut self, _code: protocol::JoinErrorCode, _message: &str) {}
1068 async fn get_alternative_version(&mut self, _current: &[u8]) -> Option<Vec<u8>> {
1069 None
1070 }
1071}
1072
1073pub struct CrdtAdaptorContext {
1074 pub send_update: Arc<dyn Fn(Vec<u8>) + Send + Sync>,
1075 pub on_join_failed: Arc<dyn Fn(String) + Send + Sync>,
1076 pub on_import_error: Arc<dyn Fn(String, Vec<Vec<u8>>) + Send + Sync>,
1077}
1078
1079pub struct LoroDocAdaptor {
1081 doc: Arc<Mutex<LoroDoc>>,
1082 sub: Option<loro::Subscription>,
1083 ctx: Option<CrdtAdaptorContext>,
1084}
1085
1086impl LoroDocAdaptor {
1087 pub fn new(doc: Arc<Mutex<LoroDoc>>) -> Self {
1088 Self {
1089 doc,
1090 sub: None,
1091 ctx: None,
1092 }
1093 }
1094}
1095
1096#[async_trait::async_trait]
1097impl CrdtDocAdaptor for LoroDocAdaptor {
1098 fn crdt_type(&self) -> CrdtType {
1099 CrdtType::Loro
1100 }
1101
1102 async fn version(&self) -> Vec<u8> {
1103 self.doc.lock().await.oplog_vv().encode()
1104 }
1105
1106 async fn set_ctx(&mut self, ctx: CrdtAdaptorContext) {
1107 self.ctx = Some(CrdtAdaptorContext {
1108 send_update: ctx.send_update.clone(),
1109 on_join_failed: ctx.on_join_failed.clone(),
1110 on_import_error: ctx.on_import_error.clone(),
1111 });
1112 let doc = self.doc.clone();
1113 let send = ctx.send_update.clone();
1114 let sub = {
1116 let guard = doc.lock().await;
1117 guard.subscribe_local_update(Box::new(move |bytes| {
1118 (send)(bytes.clone());
1119 true
1120 }))
1121 };
1122 self.sub = Some(sub);
1123 }
1124
1125 async fn handle_join_ok(&mut self, _permission: protocol::Permission, version: Vec<u8>) {
1126 if version.is_empty() {
1128 if let Ok(pt) = self.doc.lock().await.export(loro::ExportMode::Snapshot) {
1129 if let Some(ctx) = &self.ctx {
1130 (ctx.send_update)(pt);
1131 }
1132 }
1133 }
1134 }
1135
1136 async fn apply_update(&mut self, updates: Vec<Vec<u8>>) {
1137 let guard = self.doc.lock().await;
1138 for u in updates {
1139 let _ = guard.import(&u);
1140 }
1141 }
1142
1143 async fn handle_update_error(&mut self, _code: protocol::UpdateErrorCode, _message: &str) {}
1144}
1145
1146impl Drop for LoroDocAdaptor {
1147 fn drop(&mut self) {
1148 if let Some(sub) = self.sub.take() {
1149 sub.unsubscribe();
1150 }
1151 }
1152}
1153
1154pub struct EloDocAdaptor {
1158 doc: Arc<Mutex<LoroDoc>>,
1159 ctx: Option<CrdtAdaptorContext>,
1160 key_id: String,
1161 key: [u8; 32],
1162 iv_factory: Option<Arc<dyn Fn() -> [u8; 12] + Send + Sync>>,
1163 sub: Option<loro::Subscription>,
1164}
1165
1166impl EloDocAdaptor {
1167 pub fn new(doc: Arc<Mutex<LoroDoc>>, key_id: impl Into<String>, key: [u8; 32]) -> Self {
1168 Self {
1169 doc,
1170 ctx: None,
1171 key_id: key_id.into(),
1172 key,
1173 iv_factory: None,
1174 sub: None,
1175 }
1176 }
1177
1178 pub fn with_iv_factory(mut self, f: Arc<dyn Fn() -> [u8; 12] + Send + Sync>) -> Self {
1179 self.iv_factory = Some(f);
1180 self
1181 }
1182
1183 fn encode_elo_snapshot_container(&self, plaintext: &[u8]) -> Vec<u8> {
1184 use protocol::bytes::BytesWriter;
1185 let iv: [u8; 12] = self.iv_factory.as_ref().map(|f| (f)()).unwrap_or([0u8; 12]);
1187 let mut hdr = BytesWriter::new();
1188 hdr.push_byte(protocol::elo::EloRecordKind::Snapshot as u8);
1189 hdr.push_uleb128(0); hdr.push_var_string(&self.key_id);
1191 hdr.push_var_bytes(&iv);
1192 let header_bytes = hdr.finalize();
1193
1194 let cipher = aes_gcm::Aes256Gcm::new_from_slice(&self.key).expect("key");
1196 let nonce = aes_gcm::Nonce::from_slice(&iv);
1197 let ct = cipher
1198 .encrypt(
1199 nonce,
1200 aes_gcm::aead::Payload {
1201 msg: plaintext,
1202 aad: &header_bytes,
1203 },
1204 )
1205 .expect("encrypt elo snapshot");
1206
1207 let mut rec = BytesWriter::new();
1209 rec.push_bytes(&header_bytes);
1210 rec.push_var_bytes(&ct);
1211 let record = rec.finalize();
1212
1213 let mut cont = BytesWriter::new();
1215 cont.push_uleb128(1);
1216 cont.push_var_bytes(&record);
1217 cont.finalize()
1218 }
1219}
1220
1221#[async_trait::async_trait]
1222impl CrdtDocAdaptor for EloDocAdaptor {
1223 fn crdt_type(&self) -> CrdtType {
1224 CrdtType::Elo
1225 }
1226
1227 async fn version(&self) -> Vec<u8> {
1228 self.doc.lock().await.oplog_vv().encode()
1229 }
1230
1231 async fn set_ctx(&mut self, ctx: CrdtAdaptorContext) {
1232 self.ctx = Some(CrdtAdaptorContext {
1234 send_update: ctx.send_update.clone(),
1235 on_join_failed: ctx.on_join_failed.clone(),
1236 on_import_error: ctx.on_import_error.clone(),
1237 });
1238
1239 let doc = self.doc.clone();
1240 let send = ctx.send_update.clone();
1241 let key_id = self.key_id.clone();
1242 let key = self.key;
1243 let iv_factory = self.iv_factory.clone();
1244 let sub = {
1247 let guard = doc.lock().await;
1248 guard.subscribe_local_update(Box::new(move |bytes| {
1249 use protocol::bytes::BytesWriter;
1250 let iv: [u8; 12] = iv_factory.as_ref().map(|f| (f)()).unwrap_or([0u8; 12]);
1251 let mut hdr = BytesWriter::new();
1252 hdr.push_byte(protocol::elo::EloRecordKind::Snapshot as u8);
1253 hdr.push_uleb128(0); hdr.push_var_string(&key_id);
1255 hdr.push_var_bytes(&iv);
1256 let header_bytes = hdr.finalize();
1257
1258 let cipher = aes_gcm::Aes256Gcm::new_from_slice(&key).expect("key");
1260 let nonce = aes_gcm::Nonce::from_slice(&iv);
1261 if let Ok(ct) = cipher.encrypt(
1262 nonce,
1263 aes_gcm::aead::Payload {
1264 msg: &bytes,
1265 aad: &header_bytes,
1266 },
1267 ) {
1268 let mut rec = BytesWriter::new();
1269 rec.push_bytes(&header_bytes);
1270 rec.push_var_bytes(&ct);
1271 let record = rec.finalize();
1272 let mut cont = BytesWriter::new();
1273 cont.push_uleb128(1);
1274 cont.push_var_bytes(&record);
1275 let container = cont.finalize();
1276 (send)(container);
1277 }
1278 true
1279 }))
1280 };
1281 self.sub = Some(sub);
1282 }
1283
1284 async fn handle_join_ok(&mut self, _permission: protocol::Permission, _version: Vec<u8>) {
1285 if let Ok(snap) = self.doc.lock().await.export(loro::ExportMode::Snapshot) {
1290 let ct = self.encode_elo_snapshot_container(&snap);
1291 if let Some(ctx) = &self.ctx {
1292 (ctx.send_update)(ct);
1293 }
1294 }
1295 }
1297
1298 async fn apply_update(&mut self, updates: Vec<Vec<u8>>) {
1299 for u in updates {
1300 if let Ok(records) = protocol::elo::decode_elo_container(&u) {
1301 for rec in records {
1302 if let Ok(parsed) = protocol::elo::parse_elo_record_header(rec) {
1303 let iv = match &parsed.header {
1304 protocol::elo::EloHeader::Delta(h) => h.iv,
1305 protocol::elo::EloHeader::Snapshot(h) => h.iv,
1306 };
1307 let aad = parsed.aad;
1308 let cipher = aes_gcm::Aes256Gcm::new_from_slice(&self.key).expect("key");
1309 if let Ok(pt) = cipher.decrypt(
1310 aes_gcm::Nonce::from_slice(&iv),
1311 aes_gcm::aead::Payload {
1312 msg: parsed.ct,
1313 aad,
1314 },
1315 ) {
1316 let _ = self.doc.lock().await.import(&pt);
1317 } else if let Some(ctx) = &self.ctx {
1318 (ctx.on_import_error)("decrypt failed".to_string(), vec![u.clone()]);
1319 }
1320 }
1321 }
1322 }
1323 }
1324 }
1325
1326 async fn handle_update_error(&mut self, _code: protocol::UpdateErrorCode, _message: &str) {}
1327}
1328
1329impl Drop for EloDocAdaptor {
1330 fn drop(&mut self) {
1331 if let Some(sub) = self.sub.take() {
1332 sub.unsubscribe();
1333 }
1334 }
1335}
1336
1337pub use CrdtDocAdaptor as DocAdaptor;
1339pub use EloDocAdaptor as EloAdaptor;
1340pub use LoroDocAdaptor as LoroAdaptor;