1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum ClientMessage {
12 Join { room: String },
14 Leave,
16 Sync { data: String },
18 Awareness {
20 peer_id: u64,
21 #[serde(flatten)]
22 state: AwarenessState,
23 },
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum ServerMessage {
30 Joined {
32 room: String,
33 peer_count: usize,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 initial_sync: Option<String>,
37 },
38 PeerJoined { peer_id: String },
40 PeerLeft { peer_id: String },
42 Sync { from: String, data: String },
44 Awareness {
46 from: String,
47 peer_id: u64,
48 #[serde(flatten)]
49 state: AwarenessState,
50 },
51 Error { message: String },
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, Default)]
57pub struct AwarenessState {
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub cursor: Option<CursorPosition>,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub user: Option<UserInfo>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CursorPosition {
68 pub x: f64,
69 pub y: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct UserInfo {
74 pub name: String,
75 pub color: String,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum ConnectionState {
81 Disconnected,
82 Connecting,
83 Connected,
84 Error,
85}
86
87#[derive(Debug, Clone)]
89pub enum SyncEvent {
90 Connected,
92 Disconnected,
94 JoinedRoom { room: String, peer_count: usize, initial_sync: Option<Vec<u8>> },
96 PeerJoined { peer_id: String },
98 PeerLeft { peer_id: String },
100 SyncReceived { from: String, data: Vec<u8> },
102 AwarenessReceived { from: String, peer_id: u64, state: AwarenessState },
104 Error { message: String },
106}
107
108pub fn base64_decode(input: &str) -> Option<Vec<u8>> {
110 const DECODE_TABLE: [i8; 128] = [
111 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
112 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
113 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
114 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
115 -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
116 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
117 -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
118 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
119 ];
120
121 let input = input.trim_end_matches('=');
122 let mut result = Vec::with_capacity(input.len() * 3 / 4);
123 let mut buf = 0u32;
124 let mut bits = 0;
125
126 for c in input.bytes() {
127 if c >= 128 {
128 return None;
129 }
130 let val = DECODE_TABLE[c as usize];
131 if val < 0 {
132 return None;
133 }
134 buf = (buf << 6) | (val as u32);
135 bits += 6;
136 if bits >= 8 {
137 bits -= 8;
138 result.push((buf >> bits) as u8);
139 buf &= (1 << bits) - 1;
140 }
141 }
142
143 Some(result)
144}
145
146pub fn base64_encode(data: &[u8]) -> String {
148 const B64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
149
150 let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
151
152 for chunk in data.chunks(3) {
153 let b0 = chunk[0];
154 let b1 = chunk.get(1).copied().unwrap_or(0);
155 let b2 = chunk.get(2).copied().unwrap_or(0);
156
157 result.push(B64_CHARS[(b0 >> 2) as usize] as char);
158 result.push(B64_CHARS[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
159
160 if chunk.len() > 1 {
161 result.push(B64_CHARS[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char);
162 } else {
163 result.push('=');
164 }
165
166 if chunk.len() > 2 {
167 result.push(B64_CHARS[(b2 & 0x3f) as usize] as char);
168 } else {
169 result.push('=');
170 }
171 }
172
173 result
174}
175
176#[cfg(target_arch = "wasm32")]
181mod wasm_client {
182 use super::*;
183 use std::cell::RefCell;
184 use std::rc::Rc;
185 use wasm_bindgen::prelude::*;
186 use wasm_bindgen::JsCast;
187 use web_sys::{MessageEvent, WebSocket, CloseEvent, ErrorEvent};
188
189 pub struct WasmWebSocket {
193 ws: Option<WebSocket>,
194 state: ConnectionState,
195 events: Rc<RefCell<Vec<SyncEvent>>>,
196 _on_open: Option<Closure<dyn Fn()>>,
198 _on_message: Option<Closure<dyn Fn(MessageEvent)>>,
199 _on_close: Option<Closure<dyn Fn(CloseEvent)>>,
200 _on_error: Option<Closure<dyn Fn(ErrorEvent)>>,
201 }
202
203 impl WasmWebSocket {
204 pub fn new() -> Self {
206 Self {
207 ws: None,
208 state: ConnectionState::Disconnected,
209 events: Rc::new(RefCell::new(Vec::new())),
210 _on_open: None,
211 _on_message: None,
212 _on_close: None,
213 _on_error: None,
214 }
215 }
216
217 pub fn connect(&mut self, url: &str) -> Result<(), String> {
219 if self.ws.is_some() {
220 return Err("Already connected".to_string());
221 }
222
223 let ws = WebSocket::new(url).map_err(|e| format!("Failed to create WebSocket: {:?}", e))?;
224 ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
225
226 self.state = ConnectionState::Connecting;
227 let events = self.events.clone();
228
229 let events_open = events.clone();
231 let on_open = Closure::wrap(Box::new(move || {
232 events_open.borrow_mut().push(SyncEvent::Connected);
233 }) as Box<dyn Fn()>);
234 ws.set_onopen(Some(on_open.as_ref().unchecked_ref()));
235
236 let events_msg = events.clone();
238 let on_message = Closure::wrap(Box::new(move |e: MessageEvent| {
239 if let Ok(txt) = e.data().dyn_into::<js_sys::JsString>() {
240 let s: String = txt.into();
241 if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&s) {
243 let event = match server_msg {
244 ServerMessage::Joined { room, peer_count, initial_sync } => {
245 let data = initial_sync.and_then(|s| super::base64_decode(&s));
246 SyncEvent::JoinedRoom { room, peer_count, initial_sync: data }
247 }
248 ServerMessage::PeerJoined { peer_id } => SyncEvent::PeerJoined { peer_id },
249 ServerMessage::PeerLeft { peer_id } => SyncEvent::PeerLeft { peer_id },
250 ServerMessage::Sync { from, data } => {
251 if let Some(bytes) = super::base64_decode(&data) {
252 SyncEvent::SyncReceived { from, data: bytes }
253 } else {
254 return;
255 }
256 }
257 ServerMessage::Awareness { from, peer_id, state } => {
258 SyncEvent::AwarenessReceived { from, peer_id, state }
259 }
260 ServerMessage::Error { message } => SyncEvent::Error { message },
261 };
262 events_msg.borrow_mut().push(event);
263 }
264 }
265 }) as Box<dyn Fn(MessageEvent)>);
266 ws.set_onmessage(Some(on_message.as_ref().unchecked_ref()));
267
268 let events_close = events.clone();
270 let on_close = Closure::wrap(Box::new(move |_e: CloseEvent| {
271 events_close.borrow_mut().push(SyncEvent::Disconnected);
272 }) as Box<dyn Fn(CloseEvent)>);
273 ws.set_onclose(Some(on_close.as_ref().unchecked_ref()));
274
275 let events_err = events;
277 let on_error = Closure::wrap(Box::new(move |_e: ErrorEvent| {
278 events_err.borrow_mut().push(SyncEvent::Error {
279 message: "WebSocket error".to_string(),
280 });
281 }) as Box<dyn Fn(ErrorEvent)>);
282 ws.set_onerror(Some(on_error.as_ref().unchecked_ref()));
283
284 self.ws = Some(ws);
285 self._on_open = Some(on_open);
286 self._on_message = Some(on_message);
287 self._on_close = Some(on_close);
288 self._on_error = Some(on_error);
289
290 Ok(())
291 }
292
293 pub fn disconnect(&mut self) {
295 if let Some(ws) = self.ws.take() {
296 let _ = ws.close();
297 }
298 self.state = ConnectionState::Disconnected;
299 self._on_open = None;
300 self._on_message = None;
301 self._on_close = None;
302 self._on_error = None;
303 }
304
305 pub fn send(&self, msg: &str) -> Result<(), String> {
307 if let Some(ref ws) = self.ws {
308 ws.send_with_str(msg)
309 .map_err(|e| format!("Send failed: {:?}", e))
310 } else {
311 Err("Not connected".to_string())
312 }
313 }
314
315 pub fn poll_events(&mut self) -> Vec<SyncEvent> {
317 let mut events = self.events.borrow_mut();
318
319 for event in events.iter() {
321 match event {
322 SyncEvent::Connected => self.state = ConnectionState::Connected,
323 SyncEvent::Disconnected => self.state = ConnectionState::Disconnected,
324 SyncEvent::Error { .. } => self.state = ConnectionState::Error,
325 _ => {}
326 }
327 }
328
329 std::mem::take(&mut *events)
330 }
331
332 pub fn state(&self) -> ConnectionState {
334 self.state
335 }
336
337 pub fn is_connected(&self) -> bool {
339 self.state == ConnectionState::Connected
340 }
341 }
342
343 impl Default for WasmWebSocket {
344 fn default() -> Self {
345 Self::new()
346 }
347 }
348}
349
350#[cfg(target_arch = "wasm32")]
351pub use wasm_client::WasmWebSocket;
352
353#[cfg(not(target_arch = "wasm32"))]
358mod native_client {
359 use super::*;
360 use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
361 use std::thread::{self, JoinHandle};
362 use std::time::Duration;
363 use tungstenite::{connect, Message};
364 use url::Url;
365
366 enum WsCommand {
368 Send(String),
369 Close,
370 }
371
372 pub struct NativeWebSocket {
376 state: ConnectionState,
377 events: Vec<SyncEvent>,
378 cmd_tx: Option<Sender<WsCommand>>,
380 event_rx: Option<Receiver<SyncEvent>>,
382 _thread: Option<JoinHandle<()>>,
384 }
385
386 impl NativeWebSocket {
387 pub fn new() -> Self {
389 Self {
390 state: ConnectionState::Disconnected,
391 events: Vec::new(),
392 cmd_tx: None,
393 event_rx: None,
394 _thread: None,
395 }
396 }
397
398 pub fn connect(&mut self, url: &str) -> Result<(), String> {
400 if self.cmd_tx.is_some() {
401 return Err("Already connected".to_string());
402 }
403
404 let parsed_url = Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
406 if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" {
407 return Err(format!("Invalid WebSocket URL scheme: {}", parsed_url.scheme()));
408 }
409
410 self.state = ConnectionState::Connecting;
411
412 let (cmd_tx, cmd_rx) = channel::<WsCommand>();
413 let (event_tx, event_rx) = channel::<SyncEvent>();
414
415 let url = url.to_string();
416
417 let handle = thread::spawn(move || {
418 log::info!("WebSocket thread: connecting to {}", url);
419
420 let ws_result = connect(&url);
422
423 match ws_result {
424 Ok((mut socket, response)) => {
425 log::info!("WebSocket connected, status: {}", response.status());
426 let _ = event_tx.send(SyncEvent::Connected);
427
428 {
431 let stream = socket.get_mut();
432 match stream {
433 tungstenite::stream::MaybeTlsStream::Plain(tcp) => {
434 let _ = tcp.set_read_timeout(Some(Duration::from_millis(50)));
435 let _ = tcp.set_write_timeout(Some(Duration::from_secs(5)));
436 }
437 #[allow(unreachable_patterns)]
438 _ => {
439 log::debug!("TLS or other stream - using default timeout handling");
441 }
442 }
443 }
444
445 loop {
446 match cmd_rx.try_recv() {
448 Ok(WsCommand::Send(msg)) => {
449 log::debug!("WebSocket sending: {}", &msg[..msg.len().min(100)]);
450 if let Err(e) = socket.send(Message::Text(msg)) {
451 log::error!("WebSocket send error: {}", e);
452 break;
453 }
454 }
455 Ok(WsCommand::Close) => {
456 log::info!("WebSocket close requested");
457 let _ = socket.close(None);
458 break;
459 }
460 Err(TryRecvError::Disconnected) => {
461 log::info!("WebSocket command channel disconnected");
462 break;
463 }
464 Err(TryRecvError::Empty) => {}
465 }
466
467 match socket.read() {
469 Ok(Message::Text(txt)) => {
470 log::debug!("WebSocket received: {}", &txt[..txt.len().min(100)]);
471 if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&txt) {
472 let event = match server_msg {
473 ServerMessage::Joined { room, peer_count, initial_sync } => {
474 let data = initial_sync.and_then(|s| super::base64_decode(&s));
475 SyncEvent::JoinedRoom { room, peer_count, initial_sync: data }
476 }
477 ServerMessage::PeerJoined { peer_id } => SyncEvent::PeerJoined { peer_id },
478 ServerMessage::PeerLeft { peer_id } => SyncEvent::PeerLeft { peer_id },
479 ServerMessage::Sync { from, data } => {
480 if let Some(bytes) = super::base64_decode(&data) {
481 SyncEvent::SyncReceived { from, data: bytes }
482 } else {
483 continue;
484 }
485 }
486 ServerMessage::Awareness { from, peer_id, state } => {
487 SyncEvent::AwarenessReceived { from, peer_id, state }
488 }
489 ServerMessage::Error { message } => SyncEvent::Error { message },
490 };
491 let _ = event_tx.send(event);
492 } else {
493 log::warn!("Failed to parse server message: {}", txt);
494 }
495 }
496 Ok(Message::Ping(data)) => {
497 let _ = socket.send(Message::Pong(data));
499 }
500 Ok(Message::Close(_)) => {
501 log::info!("WebSocket received close frame");
502 break;
503 }
504 Ok(_) => {} Err(tungstenite::Error::Io(ref e))
506 if e.kind() == std::io::ErrorKind::WouldBlock
507 || e.kind() == std::io::ErrorKind::TimedOut => {
508 continue;
510 }
511 Err(e) => {
512 log::error!("WebSocket read error: {}", e);
513 break;
514 }
515 }
516 }
517
518 log::info!("WebSocket thread exiting");
519 let _ = event_tx.send(SyncEvent::Disconnected);
520 }
521 Err(e) => {
522 log::error!("WebSocket connection failed: {}", e);
523 let _ = event_tx.send(SyncEvent::Error {
524 message: format!("Connection failed: {}", e),
525 });
526 }
527 }
528 });
529
530 self.cmd_tx = Some(cmd_tx);
531 self.event_rx = Some(event_rx);
532 self._thread = Some(handle);
533
534 Ok(())
535 }
536
537 pub fn disconnect(&mut self) {
539 if let Some(tx) = self.cmd_tx.take() {
540 let _ = tx.send(WsCommand::Close);
541 }
542 self.event_rx = None;
543 self._thread = None;
544 self.state = ConnectionState::Disconnected;
545 }
546
547 pub fn send(&self, msg: &str) -> Result<(), String> {
549 if let Some(ref tx) = self.cmd_tx {
550 tx.send(WsCommand::Send(msg.to_string()))
551 .map_err(|e| format!("Send failed: {}", e))
552 } else {
553 Err("Not connected".to_string())
554 }
555 }
556
557 pub fn poll_events(&mut self) -> Vec<SyncEvent> {
559 if let Some(ref rx) = self.event_rx {
561 while let Ok(event) = rx.try_recv() {
562 match &event {
564 SyncEvent::Connected => self.state = ConnectionState::Connected,
565 SyncEvent::Disconnected => self.state = ConnectionState::Disconnected,
566 SyncEvent::Error { .. } => self.state = ConnectionState::Error,
567 _ => {}
568 }
569 self.events.push(event);
570 }
571 }
572
573 std::mem::take(&mut self.events)
574 }
575
576 pub fn state(&self) -> ConnectionState {
578 self.state
579 }
580
581 pub fn is_connected(&self) -> bool {
583 self.state == ConnectionState::Connected
584 }
585 }
586
587 impl Default for NativeWebSocket {
588 fn default() -> Self {
589 Self::new()
590 }
591 }
592
593 impl Drop for NativeWebSocket {
594 fn drop(&mut self) {
595 self.disconnect();
596 }
597 }
598}
599
600#[cfg(not(target_arch = "wasm32"))]
601pub use native_client::NativeWebSocket;
602
603#[cfg(target_arch = "wasm32")]
609pub type PlatformWebSocket = WasmWebSocket;
610
611#[cfg(not(target_arch = "wasm32"))]
612pub type PlatformWebSocket = NativeWebSocket;
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_base64_roundtrip() {
620 let data = b"Hello, World!";
621 let encoded = base64_encode(data);
622 let decoded = base64_decode(&encoded).unwrap();
623 assert_eq!(data.to_vec(), decoded);
624 }
625
626 #[test]
627 fn test_base64_empty() {
628 let data = b"";
629 let encoded = base64_encode(data);
630 let decoded = base64_decode(&encoded).unwrap();
631 assert_eq!(data.to_vec(), decoded);
632 }
633
634 #[test]
635 fn test_base64_padding() {
636 assert_eq!(base64_encode(b"a"), "YQ==");
638 assert_eq!(base64_encode(b"ab"), "YWI=");
640 assert_eq!(base64_encode(b"abc"), "YWJj");
642 }
643
644 #[test]
645 fn test_client_message_serialize() {
646 let msg = ClientMessage::Join { room: "test-room".to_string() };
647 let json = serde_json::to_string(&msg).unwrap();
648 assert!(json.contains("join"));
649 assert!(json.contains("test-room"));
650 }
651
652 #[test]
653 fn test_server_message_deserialize() {
654 let json = r#"{"type":"joined","room":"test","peer_count":2}"#;
655 let msg: ServerMessage = serde_json::from_str(json).unwrap();
656 match msg {
657 ServerMessage::Joined { room, peer_count, .. } => {
658 assert_eq!(room, "test");
659 assert_eq!(peer_count, 2);
660 }
661 _ => panic!("Wrong message type"),
662 }
663 }
664}