Skip to main content

rift_web_chat/
lib.rs

1//! WebAssembly browser chat client for Rift over WebSocket relay.
2//!
3//! This crate provides a simple API for text-only chat in the browser,
4//! using encrypted Rift protocol frames over a WebSocket relay.
5//!
6//! # Example (JavaScript)
7//!
8//! ```js
9//! import init, { WebChat, create_invite } from 'rift-web-chat';
10//!
11//! await init();
12//!
13//! const invite = create_invite("my-room", null);
14//! const chat = new WebChat("ws://localhost:8787/ws", invite);
15//!
16//! chat.on_message((msg) => {
17//!     console.log(`${msg.from}: ${msg.text}`);
18//! });
19//!
20//! chat.on_connect(() => {
21//!     chat.send("Hello, world!");
22//! });
23//! ```
24
25mod callbacks;
26mod relay;
27
28use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit, Nonce};
29use callbacks::{invoke_callback_event, ChatMessageEvent, ConnectionEvent, PeerEvent};
30use js_sys::Date;
31use relay::RelayEnvelope;
32use rift_core::{
33    invite::{decode_invite, encode_invite, generate_invite, Invite},
34    Identity,
35};
36use rift_protocol::{
37    decode_frame, encode_frame, ChatMessage, EncryptedPayload, ProtocolVersion, RiftFrameHeader,
38    RiftPayload, SessionId, StreamKind,
39};
40use std::cell::RefCell;
41use std::rc::Rc;
42use thiserror::Error;
43use wasm_bindgen::prelude::*;
44use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
45
46#[derive(Debug, Error)]
47#[allow(dead_code)]
48enum ChatError {
49    #[error("invalid invite: {0}")]
50    InvalidInvite(String),
51    #[error("websocket error: {0}")]
52    WebSocket(String),
53    #[error("not connected")]
54    NotConnected,
55    #[error("frame decode error: {0}")]
56    FrameDecode(String),
57    #[error("encryption error")]
58    Cipher,
59    #[error("payload decode error: {0}")]
60    PayloadDecode(String),
61}
62
63impl From<ChatError> for JsValue {
64    fn from(err: ChatError) -> Self {
65        JsValue::from_str(&err.to_string())
66    }
67}
68
69/// Internal state for the WebChat client.
70struct ChatState {
71    identity: Identity,
72    session: SessionId,
73    channel_key: [u8; 32],
74    room: String,
75    seq: u32,
76    on_message: Option<js_sys::Function>,
77    on_peer_event: Option<js_sys::Function>,
78    on_connect: Option<js_sys::Function>,
79    on_disconnect: Option<js_sys::Function>,
80    on_error: Option<js_sys::Function>,
81}
82
83impl ChatState {
84    fn from_invite(invite: Invite) -> Self {
85        let identity = Identity::generate();
86        let session = SessionId::from_channel(&invite.channel_name, invite.password.as_deref());
87        Self {
88            identity,
89            session,
90            channel_key: invite.channel_key,
91            room: invite.channel_name,
92            seq: 0,
93        on_message: None,
94            on_peer_event: None,
95            on_connect: None,
96            on_disconnect: None,
97            on_error: None,
98        }
99    }
100
101    fn encrypt_payload(&self, payload: &RiftPayload) -> Result<RiftPayload, ChatError> {
102        let serialized = bincode::serialize(payload).map_err(|e| ChatError::PayloadDecode(e.to_string()))?;
103        let cipher = Aes256Gcm::new_from_slice(&self.channel_key).map_err(|_| ChatError::Cipher)?;
104        let nonce_bytes = random_nonce();
105        let nonce = Nonce::from_slice(&nonce_bytes);
106        let ciphertext = cipher
107            .encrypt(nonce, serialized.as_ref())
108            .map_err(|_| ChatError::Cipher)?;
109        Ok(RiftPayload::Encrypted(EncryptedPayload {
110            nonce: nonce_bytes,
111            ciphertext,
112        }))
113    }
114
115    fn decrypt_payload(&self, payload: &RiftPayload) -> Result<RiftPayload, ChatError> {
116        let RiftPayload::Encrypted(encrypted) = payload else {
117            return Err(ChatError::PayloadDecode("expected encrypted payload".into()));
118        };
119        let cipher = Aes256Gcm::new_from_slice(&self.channel_key).map_err(|_| ChatError::Cipher)?;
120        let nonce = Nonce::from_slice(&encrypted.nonce);
121        let plaintext = cipher
122            .decrypt(nonce, encrypted.ciphertext.as_ref())
123            .map_err(|_| ChatError::Cipher)?;
124        bincode::deserialize(&plaintext).map_err(|e| ChatError::PayloadDecode(e.to_string()))
125    }
126
127    fn encode_text(&mut self, text: &str) -> Result<Vec<u8>, ChatError> {
128        let timestamp = now_ms();
129        let message = ChatMessage::new(self.identity.peer_id, timestamp, text.to_string());
130        let payload = RiftPayload::Text(message);
131        let encrypted = self.encrypt_payload(&payload)?;
132        let header = RiftFrameHeader {
133            version: ProtocolVersion::V2,
134            stream: StreamKind::Text,
135            flags: 0,
136            seq: self.seq,
137            timestamp,
138            source: self.identity.peer_id,
139            session: self.session,
140        };
141        self.seq = self.seq.wrapping_add(1);
142        Ok(encode_frame(&header, &encrypted))
143    }
144
145    fn decode_text(&self, data: &[u8]) -> Result<ChatMessageEvent, ChatError> {
146        let (_, payload) = decode_frame(data).map_err(|e| ChatError::FrameDecode(e.to_string()))?;
147        let decrypted = self.decrypt_payload(&payload)?;
148        let RiftPayload::Text(message) = decrypted else {
149            return Err(ChatError::PayloadDecode("not a text message".into()));
150        };
151        Ok(ChatMessageEvent {
152            from: message.from.to_hex(),
153            timestamp: message.timestamp,
154            text: message.text,
155        })
156    }
157}
158
159/// WebSocket-based chat client for the browser.
160///
161/// Connects to a `rift-ws-relay` server and provides encrypted text chat.
162#[wasm_bindgen]
163pub struct WebChat {
164    socket: WebSocket,
165    state: Rc<RefCell<ChatState>>,
166    // Keep closures alive
167    _on_open: Closure<dyn FnMut()>,
168    _on_message: Closure<dyn FnMut(MessageEvent)>,
169    _on_error: Closure<dyn FnMut(ErrorEvent)>,
170    _on_close: Closure<dyn FnMut(CloseEvent)>,
171}
172
173#[wasm_bindgen]
174impl WebChat {
175    /// Create a new WebChat client and connect to the relay.
176    ///
177    /// # Arguments
178    /// * `relay_url` - WebSocket URL of the relay (e.g., "ws://localhost:8787/ws")
179    /// * `invite_url` - Rift invite URL (e.g., "rift://z/...")
180    #[wasm_bindgen(constructor)]
181    pub fn new(relay_url: &str, invite_url: &str) -> Result<WebChat, JsValue> {
182        let invite =
183            decode_invite(invite_url).map_err(|e| ChatError::InvalidInvite(e.to_string()))?;
184
185        let state = Rc::new(RefCell::new(ChatState::from_invite(invite)));
186        let socket = WebSocket::new(relay_url).map_err(|e| ChatError::WebSocket(format!("{:?}", e)))?;
187        socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
188
189        // Clone state for closures
190        let state_open = state.clone();
191        let socket_open = socket.clone();
192        let on_open = Closure::new(move || {
193            let s = state_open.borrow();
194            // Send join message
195            let join = RelayEnvelope::join(&s.session.to_hex(), &s.identity.peer_id.to_hex());
196            if let Ok(json) = join.to_json() {
197                let _ = socket_open.send_with_str(&json);
198            }
199            // Notify connect callback
200            if let Some(cb) = &s.on_connect {
201                invoke_callback_event(cb, &ConnectionEvent {
202                    state: "connected".into(),
203                    error: None,
204                });
205            }
206        });
207
208        let state_msg = state.clone();
209        let on_message = Closure::new(move |event: MessageEvent| {
210            if let Some(text) = event.data().as_string() {
211                handle_relay_message(&state_msg, &text);
212            }
213        });
214
215        let state_err = state.clone();
216        let on_error = Closure::new(move |event: ErrorEvent| {
217            let s = state_err.borrow();
218            if let Some(cb) = &s.on_error {
219                invoke_callback_event(cb, &ConnectionEvent {
220                    state: "error".into(),
221                    error: Some(event.message()),
222                });
223            }
224        });
225
226        let state_close = state.clone();
227        let on_close = Closure::new(move |_event: CloseEvent| {
228            let s = state_close.borrow();
229            if let Some(cb) = &s.on_disconnect {
230                invoke_callback_event(cb, &ConnectionEvent {
231                    state: "disconnected".into(),
232                    error: None,
233                });
234            }
235        });
236
237        socket.set_onopen(Some(on_open.as_ref().unchecked_ref()));
238        socket.set_onmessage(Some(on_message.as_ref().unchecked_ref()));
239        socket.set_onerror(Some(on_error.as_ref().unchecked_ref()));
240        socket.set_onclose(Some(on_close.as_ref().unchecked_ref()));
241
242        Ok(WebChat {
243            socket,
244            state,
245            _on_open: on_open,
246            _on_message: on_message,
247            _on_error: on_error,
248            _on_close: on_close,
249        })
250    }
251
252    /// Send a text message to the chat room.
253    pub fn send(&mut self, text: &str) -> Result<(), JsValue> {
254        let frame = {
255            let mut state = self.state.borrow_mut();
256            state.encode_text(text)?
257        };
258
259        let state = self.state.borrow();
260        let envelope = RelayEnvelope::data(
261            &state.session.to_hex(),
262            &state.identity.peer_id.to_hex(),
263            &frame,
264        );
265        let json = envelope.to_json().map_err(|e| JsValue::from_str(&e.to_string()))?;
266
267        self.socket
268            .send_with_str(&json)
269            .map_err(|e| ChatError::WebSocket(format!("{:?}", e)))?;
270
271        Ok(())
272    }
273
274    /// Set callback for incoming chat messages.
275    ///
276    /// Callback receives: `{ from: string, timestamp: number, text: string }`
277    pub fn on_message(&mut self, callback: js_sys::Function) {
278        self.state.borrow_mut().on_message = Some(callback);
279    }
280
281    /// Set callback for peer join/leave events.
282    ///
283    /// Callback receives: `{ peer_id: string, event: "join" | "leave" }`
284    pub fn on_peer_event(&mut self, callback: js_sys::Function) {
285        self.state.borrow_mut().on_peer_event = Some(callback);
286    }
287
288    /// Set callback for successful connection.
289    ///
290    /// Callback receives: `{ state: "connected", error: null }`
291    pub fn on_connect(&mut self, callback: js_sys::Function) {
292        self.state.borrow_mut().on_connect = Some(callback);
293    }
294
295    /// Set callback for disconnection.
296    ///
297    /// Callback receives: `{ state: "disconnected", error: null }`
298    pub fn on_disconnect(&mut self, callback: js_sys::Function) {
299        self.state.borrow_mut().on_disconnect = Some(callback);
300    }
301
302    /// Set callback for errors.
303    ///
304    /// Callback receives: `{ state: "error", error: string }`
305    pub fn on_error(&mut self, callback: js_sys::Function) {
306        self.state.borrow_mut().on_error = Some(callback);
307    }
308
309    /// Get this client's peer ID (hex-encoded).
310    #[wasm_bindgen(getter)]
311    pub fn peer_id(&self) -> String {
312        self.state.borrow().identity.peer_id.to_hex()
313    }
314
315    /// Get the room/channel name.
316    #[wasm_bindgen(getter)]
317    pub fn room(&self) -> String {
318        self.state.borrow().room.clone()
319    }
320
321    /// Get the session ID (hex-encoded).
322    #[wasm_bindgen(getter)]
323    pub fn session_id(&self) -> String {
324        self.state.borrow().session.to_hex()
325    }
326
327    /// Disconnect from the relay.
328    pub fn disconnect(&self) {
329        let _ = self.socket.close();
330    }
331
332    /// Check if the WebSocket is open.
333    #[wasm_bindgen(getter)]
334    pub fn is_connected(&self) -> bool {
335        self.socket.ready_state() == WebSocket::OPEN
336    }
337}
338
339fn handle_relay_message(state: &Rc<RefCell<ChatState>>, text: &str) {
340    let Ok(envelope) = RelayEnvelope::from_json(text) else {
341        return;
342    };
343
344    let s = state.borrow();
345    let my_peer_id = s.identity.peer_id.to_hex();
346
347    match &envelope {
348        RelayEnvelope::Data { peer_id, .. } => {
349            // Skip our own messages
350            if peer_id == &my_peer_id {
351                return;
352            }
353
354            if let Some(data) = envelope.decode_data() {
355                if let Ok(msg) = s.decode_text(&data) {
356                    if let Some(cb) = &s.on_message {
357                        invoke_callback_event(cb, &msg);
358                    }
359                }
360            }
361        }
362        RelayEnvelope::Status { peer_id, status, .. } => {
363            if let Some(cb) = &s.on_peer_event {
364                invoke_callback_event(cb, &PeerEvent {
365                    peer_id: peer_id.clone(),
366                    event: status.clone(),
367                });
368            }
369        }
370        RelayEnvelope::Join { peer_id, .. } => {
371            // Another peer joined
372            if peer_id != &my_peer_id {
373                if let Some(cb) = &s.on_peer_event {
374                    invoke_callback_event(cb, &PeerEvent {
375                        peer_id: peer_id.clone(),
376                        event: "join".into(),
377                    });
378                }
379            }
380        }
381    }
382}
383
384/// Generate a random 12-byte nonce for AES-GCM.
385fn random_nonce() -> [u8; 12] {
386    let mut nonce = [0u8; 12];
387    getrandom::getrandom(&mut nonce).expect("random nonce");
388    nonce
389}
390
391/// Current time in milliseconds.
392fn now_ms() -> u64 {
393    Date::now() as u64
394}
395
396// ============================================================
397// Standalone utility functions
398// ============================================================
399
400/// Create a new invite URL for a chat room.
401#[wasm_bindgen]
402pub fn create_invite(channel_name: &str, password: Option<String>) -> String {
403    let invite = generate_invite(channel_name, password.as_deref(), Vec::new(), Vec::new());
404    encode_invite(&invite)
405}
406
407/// Inspect an invite URL without joining.
408///
409/// Returns a JS object with: `{ channel_name, has_password, version, created_at }`
410#[wasm_bindgen]
411pub fn inspect_invite(invite_url: &str) -> Result<JsValue, JsValue> {
412    let invite =
413        decode_invite(invite_url).map_err(|e| JsValue::from_str(&format!("Invalid invite: {}", e)))?;
414
415    #[derive(serde::Serialize)]
416    struct InviteInfo {
417        channel_name: String,
418        has_password: bool,
419        version: u8,
420        created_at: u64,
421    }
422
423    let info = InviteInfo {
424        channel_name: invite.channel_name,
425        has_password: invite.password.is_some(),
426        version: invite.version,
427        created_at: invite.created_at,
428    };
429
430    serde_wasm_bindgen::to_value(&info).map_err(|e| JsValue::from_str(&e.to_string()))
431}
432
433/// Generate a new random identity and return the peer ID (hex-encoded).
434#[wasm_bindgen]
435pub fn generate_peer_id() -> String {
436    Identity::generate().peer_id.to_hex()
437}