drafftink_core/
sync.rs

1//! WebSocket client for collaboration.
2//!
3//! Provides a platform-agnostic WebSocket client interface for connecting
4//! to the relay server.
5
6use serde::{Deserialize, Serialize};
7
8/// Messages sent to the server
9#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum ClientMessage {
12    /// Join a room
13    Join { room: String },
14    /// Leave current room
15    Leave,
16    /// Sync CRDT data (base64 encoded Loro bytes)
17    Sync { data: String },
18    /// Awareness update (cursor position, selection, etc.)
19    Awareness {
20        peer_id: u64,
21        #[serde(flatten)]
22        state: AwarenessState,
23    },
24}
25
26/// Messages received from the server
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum ServerMessage {
30    /// Confirm room join with current state
31    Joined {
32        room: String,
33        peer_count: usize,
34        /// Initial sync data (if room has history)
35        #[serde(skip_serializing_if = "Option::is_none")]
36        initial_sync: Option<String>,
37    },
38    /// Peer joined the room
39    PeerJoined { peer_id: String },
40    /// Peer left the room
41    PeerLeft { peer_id: String },
42    /// Sync data from another peer
43    Sync { from: String, data: String },
44    /// Awareness update from another peer
45    Awareness {
46        from: String,
47        peer_id: u64,
48        #[serde(flatten)]
49        state: AwarenessState,
50    },
51    /// Error message
52    Error { message: String },
53}
54
55/// Awareness state for a peer
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
57pub struct AwarenessState {
58    /// Cursor position (if any)
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub cursor: Option<CursorPosition>,
61    /// User name/color
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub user: Option<UserInfo>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CursorPosition {
68    pub x: f64,
69    pub y: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct UserInfo {
74    pub name: String,
75    pub color: String,
76}
77
78/// Connection state
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum ConnectionState {
81    Disconnected,
82    Connecting,
83    Connected,
84    Error,
85}
86
87/// Events from the WebSocket client
88#[derive(Debug, Clone)]
89pub enum SyncEvent {
90    /// Connected to server
91    Connected,
92    /// Disconnected from server
93    Disconnected,
94    /// Joined a room
95    JoinedRoom { room: String, peer_count: usize, initial_sync: Option<Vec<u8>> },
96    /// A peer joined the room
97    PeerJoined { peer_id: String },
98    /// A peer left the room
99    PeerLeft { peer_id: String },
100    /// Received sync data from a peer
101    SyncReceived { from: String, data: Vec<u8> },
102    /// Received awareness update from a peer
103    AwarenessReceived { from: String, peer_id: u64, state: AwarenessState },
104    /// Error occurred
105    Error { message: String },
106}
107
108/// Base64 decoding
109pub fn base64_decode(input: &str) -> Option<Vec<u8>> {
110    const DECODE_TABLE: [i8; 128] = [
111        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
112        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
113        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
114        52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
115        -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
116        15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
117        -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
118        41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
119    ];
120
121    let input = input.trim_end_matches('=');
122    let mut result = Vec::with_capacity(input.len() * 3 / 4);
123    let mut buf = 0u32;
124    let mut bits = 0;
125
126    for c in input.bytes() {
127        if c >= 128 {
128            return None;
129        }
130        let val = DECODE_TABLE[c as usize];
131        if val < 0 {
132            return None;
133        }
134        buf = (buf << 6) | (val as u32);
135        bits += 6;
136        if bits >= 8 {
137            bits -= 8;
138            result.push((buf >> bits) as u8);
139            buf &= (1 << bits) - 1;
140        }
141    }
142
143    Some(result)
144}
145
146/// Base64 encoding
147pub fn base64_encode(data: &[u8]) -> String {
148    const B64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
149    
150    let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
151    
152    for chunk in data.chunks(3) {
153        let b0 = chunk[0];
154        let b1 = chunk.get(1).copied().unwrap_or(0);
155        let b2 = chunk.get(2).copied().unwrap_or(0);
156        
157        result.push(B64_CHARS[(b0 >> 2) as usize] as char);
158        result.push(B64_CHARS[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
159        
160        if chunk.len() > 1 {
161            result.push(B64_CHARS[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char);
162        } else {
163            result.push('=');
164        }
165        
166        if chunk.len() > 2 {
167            result.push(B64_CHARS[(b2 & 0x3f) as usize] as char);
168        } else {
169            result.push('=');
170        }
171    }
172    
173    result
174}
175
176// ============================================================================
177// WASM WebSocket Client
178// ============================================================================
179
180#[cfg(target_arch = "wasm32")]
181mod wasm_client {
182    use super::*;
183    use std::cell::RefCell;
184    use std::rc::Rc;
185    use wasm_bindgen::prelude::*;
186    use wasm_bindgen::JsCast;
187    use web_sys::{MessageEvent, WebSocket, CloseEvent, ErrorEvent};
188
189    /// WebSocket client for WASM.
190    /// 
191    /// Events are collected and must be polled via `poll_events()`.
192    pub struct WasmWebSocket {
193        ws: Option<WebSocket>,
194        state: ConnectionState,
195        events: Rc<RefCell<Vec<SyncEvent>>>,
196        // Store closures to prevent them from being dropped
197        _on_open: Option<Closure<dyn Fn()>>,
198        _on_message: Option<Closure<dyn Fn(MessageEvent)>>,
199        _on_close: Option<Closure<dyn Fn(CloseEvent)>>,
200        _on_error: Option<Closure<dyn Fn(ErrorEvent)>>,
201    }
202
203    impl WasmWebSocket {
204        /// Create a new disconnected WebSocket client.
205        pub fn new() -> Self {
206            Self {
207                ws: None,
208                state: ConnectionState::Disconnected,
209                events: Rc::new(RefCell::new(Vec::new())),
210                _on_open: None,
211                _on_message: None,
212                _on_close: None,
213                _on_error: None,
214            }
215        }
216
217        /// Connect to a WebSocket server.
218        pub fn connect(&mut self, url: &str) -> Result<(), String> {
219            if self.ws.is_some() {
220                return Err("Already connected".to_string());
221            }
222
223            let ws = WebSocket::new(url).map_err(|e| format!("Failed to create WebSocket: {:?}", e))?;
224            ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
225
226            self.state = ConnectionState::Connecting;
227            let events = self.events.clone();
228
229            // onopen
230            let events_open = events.clone();
231            let on_open = Closure::wrap(Box::new(move || {
232                events_open.borrow_mut().push(SyncEvent::Connected);
233            }) as Box<dyn Fn()>);
234            ws.set_onopen(Some(on_open.as_ref().unchecked_ref()));
235
236            // onmessage
237            let events_msg = events.clone();
238            let on_message = Closure::wrap(Box::new(move |e: MessageEvent| {
239                if let Ok(txt) = e.data().dyn_into::<js_sys::JsString>() {
240                    let s: String = txt.into();
241                    // Parse and convert to SyncEvent
242                    if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&s) {
243                        let event = match server_msg {
244                            ServerMessage::Joined { room, peer_count, initial_sync } => {
245                                let data = initial_sync.and_then(|s| super::base64_decode(&s));
246                                SyncEvent::JoinedRoom { room, peer_count, initial_sync: data }
247                            }
248                            ServerMessage::PeerJoined { peer_id } => SyncEvent::PeerJoined { peer_id },
249                            ServerMessage::PeerLeft { peer_id } => SyncEvent::PeerLeft { peer_id },
250                            ServerMessage::Sync { from, data } => {
251                                if let Some(bytes) = super::base64_decode(&data) {
252                                    SyncEvent::SyncReceived { from, data: bytes }
253                                } else {
254                                    return;
255                                }
256                            }
257                            ServerMessage::Awareness { from, peer_id, state } => {
258                                SyncEvent::AwarenessReceived { from, peer_id, state }
259                            }
260                            ServerMessage::Error { message } => SyncEvent::Error { message },
261                        };
262                        events_msg.borrow_mut().push(event);
263                    }
264                }
265            }) as Box<dyn Fn(MessageEvent)>);
266            ws.set_onmessage(Some(on_message.as_ref().unchecked_ref()));
267
268            // onclose
269            let events_close = events.clone();
270            let on_close = Closure::wrap(Box::new(move |_e: CloseEvent| {
271                events_close.borrow_mut().push(SyncEvent::Disconnected);
272            }) as Box<dyn Fn(CloseEvent)>);
273            ws.set_onclose(Some(on_close.as_ref().unchecked_ref()));
274
275            // onerror
276            let events_err = events;
277            let on_error = Closure::wrap(Box::new(move |_e: ErrorEvent| {
278                events_err.borrow_mut().push(SyncEvent::Error {
279                    message: "WebSocket error".to_string(),
280                });
281            }) as Box<dyn Fn(ErrorEvent)>);
282            ws.set_onerror(Some(on_error.as_ref().unchecked_ref()));
283
284            self.ws = Some(ws);
285            self._on_open = Some(on_open);
286            self._on_message = Some(on_message);
287            self._on_close = Some(on_close);
288            self._on_error = Some(on_error);
289
290            Ok(())
291        }
292
293        /// Disconnect from the server.
294        pub fn disconnect(&mut self) {
295            if let Some(ws) = self.ws.take() {
296                let _ = ws.close();
297            }
298            self.state = ConnectionState::Disconnected;
299            self._on_open = None;
300            self._on_message = None;
301            self._on_close = None;
302            self._on_error = None;
303        }
304
305        /// Send a text message.
306        pub fn send(&self, msg: &str) -> Result<(), String> {
307            if let Some(ref ws) = self.ws {
308                ws.send_with_str(msg)
309                    .map_err(|e| format!("Send failed: {:?}", e))
310            } else {
311                Err("Not connected".to_string())
312            }
313        }
314
315        /// Poll for pending events (non-blocking).
316        pub fn poll_events(&mut self) -> Vec<SyncEvent> {
317            let mut events = self.events.borrow_mut();
318            
319            // Update state based on events
320            for event in events.iter() {
321                match event {
322                    SyncEvent::Connected => self.state = ConnectionState::Connected,
323                    SyncEvent::Disconnected => self.state = ConnectionState::Disconnected,
324                    SyncEvent::Error { .. } => self.state = ConnectionState::Error,
325                    _ => {}
326                }
327            }
328            
329            std::mem::take(&mut *events)
330        }
331
332        /// Get current connection state.
333        pub fn state(&self) -> ConnectionState {
334            self.state
335        }
336
337        /// Check if connected.
338        pub fn is_connected(&self) -> bool {
339            self.state == ConnectionState::Connected
340        }
341    }
342
343    impl Default for WasmWebSocket {
344        fn default() -> Self {
345            Self::new()
346        }
347    }
348}
349
350#[cfg(target_arch = "wasm32")]
351pub use wasm_client::WasmWebSocket;
352
353// ============================================================================
354// Native WebSocket Client
355// ============================================================================
356
357#[cfg(not(target_arch = "wasm32"))]
358mod native_client {
359    use super::*;
360    use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
361    use std::thread::{self, JoinHandle};
362    use std::time::Duration;
363    use tungstenite::{connect, Message};
364    use url::Url;
365
366    /// Commands sent to the WebSocket thread.
367    enum WsCommand {
368        Send(String),
369        Close,
370    }
371
372    /// WebSocket client for native platforms.
373    /// 
374    /// Uses a background thread for non-blocking operation.
375    pub struct NativeWebSocket {
376        state: ConnectionState,
377        events: Vec<SyncEvent>,
378        /// Channel to send commands to the WebSocket thread.
379        cmd_tx: Option<Sender<WsCommand>>,
380        /// Channel to receive events from the WebSocket thread.
381        event_rx: Option<Receiver<SyncEvent>>,
382        /// Handle to the WebSocket thread.
383        _thread: Option<JoinHandle<()>>,
384    }
385
386    impl NativeWebSocket {
387        /// Create a new disconnected WebSocket client.
388        pub fn new() -> Self {
389            Self {
390                state: ConnectionState::Disconnected,
391                events: Vec::new(),
392                cmd_tx: None,
393                event_rx: None,
394                _thread: None,
395            }
396        }
397
398        /// Connect to a WebSocket server.
399        pub fn connect(&mut self, url: &str) -> Result<(), String> {
400            if self.cmd_tx.is_some() {
401                return Err("Already connected".to_string());
402            }
403
404            // Validate URL
405            let parsed_url = Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
406            if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" {
407                return Err(format!("Invalid WebSocket URL scheme: {}", parsed_url.scheme()));
408            }
409
410            self.state = ConnectionState::Connecting;
411            
412            let (cmd_tx, cmd_rx) = channel::<WsCommand>();
413            let (event_tx, event_rx) = channel::<SyncEvent>();
414            
415            let url = url.to_string();
416            
417            let handle = thread::spawn(move || {
418                log::info!("WebSocket thread: connecting to {}", url);
419                
420                // Connect to WebSocket with timeout
421                let ws_result = connect(&url);
422                
423                match ws_result {
424                    Ok((mut socket, response)) => {
425                        log::info!("WebSocket connected, status: {}", response.status());
426                        let _ = event_tx.send(SyncEvent::Connected);
427                        
428                        // Set read timeout on the underlying TCP stream for non-blocking behavior
429                        // This is more reliable for tunneled/forwarded connections
430                        {
431                            let stream = socket.get_mut();
432                            match stream {
433                                tungstenite::stream::MaybeTlsStream::Plain(tcp) => {
434                                    let _ = tcp.set_read_timeout(Some(Duration::from_millis(50)));
435                                    let _ = tcp.set_write_timeout(Some(Duration::from_secs(5)));
436                                }
437                                #[allow(unreachable_patterns)]
438                                _ => {
439                                    // For TLS streams, we'll rely on WouldBlock/TimedOut errors
440                                    log::debug!("TLS or other stream - using default timeout handling");
441                                }
442                            }
443                        }
444                        
445                        loop {
446                            // Check for commands (non-blocking)
447                            match cmd_rx.try_recv() {
448                                Ok(WsCommand::Send(msg)) => {
449                                    log::debug!("WebSocket sending: {}", &msg[..msg.len().min(100)]);
450                                    if let Err(e) = socket.send(Message::Text(msg)) {
451                                        log::error!("WebSocket send error: {}", e);
452                                        break;
453                                    }
454                                }
455                                Ok(WsCommand::Close) => {
456                                    log::info!("WebSocket close requested");
457                                    let _ = socket.close(None);
458                                    break;
459                                }
460                                Err(TryRecvError::Disconnected) => {
461                                    log::info!("WebSocket command channel disconnected");
462                                    break;
463                                }
464                                Err(TryRecvError::Empty) => {}
465                            }
466                            
467                            // Check for incoming messages (with timeout)
468                            match socket.read() {
469                                Ok(Message::Text(txt)) => {
470                                    log::debug!("WebSocket received: {}", &txt[..txt.len().min(100)]);
471                                    if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&txt) {
472                                        let event = match server_msg {
473                                            ServerMessage::Joined { room, peer_count, initial_sync } => {
474                                                let data = initial_sync.and_then(|s| super::base64_decode(&s));
475                                                SyncEvent::JoinedRoom { room, peer_count, initial_sync: data }
476                                            }
477                                            ServerMessage::PeerJoined { peer_id } => SyncEvent::PeerJoined { peer_id },
478                                            ServerMessage::PeerLeft { peer_id } => SyncEvent::PeerLeft { peer_id },
479                                            ServerMessage::Sync { from, data } => {
480                                                if let Some(bytes) = super::base64_decode(&data) {
481                                                    SyncEvent::SyncReceived { from, data: bytes }
482                                                } else {
483                                                    continue;
484                                                }
485                                            }
486                                            ServerMessage::Awareness { from, peer_id, state } => {
487                                                SyncEvent::AwarenessReceived { from, peer_id, state }
488                                            }
489                                            ServerMessage::Error { message } => SyncEvent::Error { message },
490                                        };
491                                        let _ = event_tx.send(event);
492                                    } else {
493                                        log::warn!("Failed to parse server message: {}", txt);
494                                    }
495                                }
496                                Ok(Message::Ping(data)) => {
497                                    // Respond to ping with pong
498                                    let _ = socket.send(Message::Pong(data));
499                                }
500                                Ok(Message::Close(_)) => {
501                                    log::info!("WebSocket received close frame");
502                                    break;
503                                }
504                                Ok(_) => {} // Ignore binary, pong
505                                Err(tungstenite::Error::Io(ref e)) 
506                                    if e.kind() == std::io::ErrorKind::WouldBlock 
507                                    || e.kind() == std::io::ErrorKind::TimedOut => {
508                                    // Timeout on read, continue loop
509                                    continue;
510                                }
511                                Err(e) => {
512                                    log::error!("WebSocket read error: {}", e);
513                                    break;
514                                }
515                            }
516                        }
517                        
518                        log::info!("WebSocket thread exiting");
519                        let _ = event_tx.send(SyncEvent::Disconnected);
520                    }
521                    Err(e) => {
522                        log::error!("WebSocket connection failed: {}", e);
523                        let _ = event_tx.send(SyncEvent::Error {
524                            message: format!("Connection failed: {}", e),
525                        });
526                    }
527                }
528            });
529            
530            self.cmd_tx = Some(cmd_tx);
531            self.event_rx = Some(event_rx);
532            self._thread = Some(handle);
533            
534            Ok(())
535        }
536
537        /// Disconnect from the server.
538        pub fn disconnect(&mut self) {
539            if let Some(tx) = self.cmd_tx.take() {
540                let _ = tx.send(WsCommand::Close);
541            }
542            self.event_rx = None;
543            self._thread = None;
544            self.state = ConnectionState::Disconnected;
545        }
546
547        /// Send a text message.
548        pub fn send(&self, msg: &str) -> Result<(), String> {
549            if let Some(ref tx) = self.cmd_tx {
550                tx.send(WsCommand::Send(msg.to_string()))
551                    .map_err(|e| format!("Send failed: {}", e))
552            } else {
553                Err("Not connected".to_string())
554            }
555        }
556
557        /// Poll for pending events (non-blocking).
558        pub fn poll_events(&mut self) -> Vec<SyncEvent> {
559            // Drain events from channel
560            if let Some(ref rx) = self.event_rx {
561                while let Ok(event) = rx.try_recv() {
562                    // Update state based on event
563                    match &event {
564                        SyncEvent::Connected => self.state = ConnectionState::Connected,
565                        SyncEvent::Disconnected => self.state = ConnectionState::Disconnected,
566                        SyncEvent::Error { .. } => self.state = ConnectionState::Error,
567                        _ => {}
568                    }
569                    self.events.push(event);
570                }
571            }
572            
573            std::mem::take(&mut self.events)
574        }
575
576        /// Get current connection state.
577        pub fn state(&self) -> ConnectionState {
578            self.state
579        }
580
581        /// Check if connected.
582        pub fn is_connected(&self) -> bool {
583            self.state == ConnectionState::Connected
584        }
585    }
586
587    impl Default for NativeWebSocket {
588        fn default() -> Self {
589            Self::new()
590        }
591    }
592    
593    impl Drop for NativeWebSocket {
594        fn drop(&mut self) {
595            self.disconnect();
596        }
597    }
598}
599
600#[cfg(not(target_arch = "wasm32"))]
601pub use native_client::NativeWebSocket;
602
603// ============================================================================
604// Platform type alias
605// ============================================================================
606
607/// Platform-specific WebSocket client type.
608#[cfg(target_arch = "wasm32")]
609pub type PlatformWebSocket = WasmWebSocket;
610
611#[cfg(not(target_arch = "wasm32"))]
612pub type PlatformWebSocket = NativeWebSocket;
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_base64_roundtrip() {
620        let data = b"Hello, World!";
621        let encoded = base64_encode(data);
622        let decoded = base64_decode(&encoded).unwrap();
623        assert_eq!(data.to_vec(), decoded);
624    }
625
626    #[test]
627    fn test_base64_empty() {
628        let data = b"";
629        let encoded = base64_encode(data);
630        let decoded = base64_decode(&encoded).unwrap();
631        assert_eq!(data.to_vec(), decoded);
632    }
633
634    #[test]
635    fn test_base64_padding() {
636        // 1 byte -> 2 chars + 2 padding
637        assert_eq!(base64_encode(b"a"), "YQ==");
638        // 2 bytes -> 3 chars + 1 padding
639        assert_eq!(base64_encode(b"ab"), "YWI=");
640        // 3 bytes -> 4 chars, no padding
641        assert_eq!(base64_encode(b"abc"), "YWJj");
642    }
643
644    #[test]
645    fn test_client_message_serialize() {
646        let msg = ClientMessage::Join { room: "test-room".to_string() };
647        let json = serde_json::to_string(&msg).unwrap();
648        assert!(json.contains("join"));
649        assert!(json.contains("test-room"));
650    }
651
652    #[test]
653    fn test_server_message_deserialize() {
654        let json = r#"{"type":"joined","room":"test","peer_count":2}"#;
655        let msg: ServerMessage = serde_json::from_str(json).unwrap();
656        match msg {
657            ServerMessage::Joined { room, peer_count, .. } => {
658                assert_eq!(room, "test");
659                assert_eq!(peer_count, 2);
660            }
661            _ => panic!("Wrong message type"),
662        }
663    }
664}