Skip to main content

rift_wasm/
lib.rs

1//! WebAssembly bindings for the Rift protocol.
2//!
3//! This module exposes a minimal API for:
4//! - invite creation/inspection
5//! - session bootstrap
6//! - encrypted text encode/decode using protocol framing
7//! - voice frame encode/decode for browser audio integration
8//! - audio utilities (level metering, VAD)
9
10use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
11use aes_gcm::aead::Aead;
12use js_sys::{Date, Uint8Array};
13use rift_core::{
14    invite::{decode_invite, encode_invite, generate_invite, Invite},
15    Identity,
16};
17use rift_protocol::{
18    decode_frame, encode_frame, ChatMessage, CodecId, EncryptedPayload, ProtocolVersion,
19    RiftFrameHeader, RiftPayload, SessionId, StreamKind, VoicePacket,
20};
21use serde::Serialize;
22use thiserror::Error;
23use wasm_bindgen::prelude::*;
24
25#[derive(Debug, Error)]
26enum WasmError {
27    #[error("invalid invite: {0}")]
28    InvalidInvite(String),
29    #[error("frame decode failed: {0}")]
30    FrameDecode(String),
31    #[error("cipher error")]
32    Cipher,
33    #[error("payload decode failed: {0}")]
34    PayloadDecode(String),
35}
36
37impl From<WasmError> for JsValue {
38    fn from(err: WasmError) -> Self {
39        JsValue::from_str(&err.to_string())
40    }
41}
42
43#[wasm_bindgen]
44pub struct WasmClient {
45    /// Ephemeral identity for the session.
46    identity: Identity,
47    /// Session identifier derived from the invite.
48    session: SessionId,
49    /// Symmetric channel key for AES-GCM.
50    channel_key: [u8; 32],
51    /// Local sequence counter for frames.
52    seq: u32,
53}
54
55#[wasm_bindgen]
56pub struct InviteInfo {
57    /// Channel name embedded in the invite.
58    channel_name: String,
59    /// Whether a password was set.
60    has_password: bool,
61    /// Protocol version.
62    version: u8,
63    /// Invite creation timestamp.
64    created_at: u64,
65}
66
67#[wasm_bindgen]
68impl InviteInfo {
69    #[wasm_bindgen(getter)]
70    pub fn channel_name(&self) -> String {
71        self.channel_name.clone()
72    }
73
74    #[wasm_bindgen(getter)]
75    pub fn has_password(&self) -> bool {
76        self.has_password
77    }
78
79    #[wasm_bindgen(getter)]
80    pub fn version(&self) -> u8 {
81        self.version
82    }
83
84    #[wasm_bindgen(getter)]
85    pub fn created_at(&self) -> u64 {
86        self.created_at
87    }
88}
89
90#[derive(Serialize)]
91struct DecodedTextMessage {
92    from: String,
93    timestamp: u64,
94    text: String,
95}
96
97#[derive(Serialize)]
98struct DecodedVoiceFrame {
99    from: String,
100    timestamp: u64,
101    seq: u32,
102    codec: String,
103    payload: Vec<u8>,
104}
105
106/// Audio configuration for browser integration.
107#[wasm_bindgen]
108pub struct AudioConfig {
109    /// Sample rate in Hz (typically 48000 for Opus).
110    sample_rate: u32,
111    /// Number of channels (1 for mono, 2 for stereo).
112    channels: u8,
113    /// Frame size in samples per channel.
114    frame_size: u32,
115}
116
117#[wasm_bindgen]
118impl AudioConfig {
119    /// Create a new audio configuration.
120    #[wasm_bindgen(constructor)]
121    pub fn new(sample_rate: u32, channels: u8, frame_size: u32) -> Self {
122        Self {
123            sample_rate,
124            channels,
125            frame_size,
126        }
127    }
128
129    /// Create default config for Opus (48kHz, mono, 20ms frame).
130    #[wasm_bindgen]
131    pub fn opus_default() -> Self {
132        Self {
133            sample_rate: 48000,
134            channels: 1,
135            frame_size: 960, // 20ms at 48kHz
136        }
137    }
138
139    #[wasm_bindgen(getter)]
140    pub fn sample_rate(&self) -> u32 {
141        self.sample_rate
142    }
143
144    #[wasm_bindgen(getter)]
145    pub fn channels(&self) -> u8 {
146        self.channels
147    }
148
149    #[wasm_bindgen(getter)]
150    pub fn frame_size(&self) -> u32 {
151        self.frame_size
152    }
153
154    /// Calculate frame duration in milliseconds.
155    #[wasm_bindgen]
156    pub fn frame_duration_ms(&self) -> f64 {
157        (self.frame_size as f64 / self.sample_rate as f64) * 1000.0
158    }
159}
160
161#[wasm_bindgen]
162pub fn create_invite(channel_name: String, password: Option<String>) -> Result<String, JsValue> {
163    let invite = generate_invite(
164        &channel_name,
165        password.as_deref(),
166        Vec::new(),
167        Vec::new(),
168    );
169    Ok(encode_invite(&invite))
170}
171
172#[wasm_bindgen]
173pub fn inspect_invite(invite_url: String) -> Result<InviteInfo, JsValue> {
174    let invite = decode_invite(&invite_url)
175        .map_err(|err| WasmError::InvalidInvite(err.to_string()))?;
176    Ok(InviteInfo {
177        channel_name: invite.channel_name,
178        has_password: invite.password.is_some(),
179        version: invite.version,
180        created_at: invite.created_at,
181    })
182}
183
184#[wasm_bindgen]
185pub fn join_invite(invite_url: String) -> Result<WasmClient, JsValue> {
186    let invite = decode_invite(&invite_url)
187        .map_err(|err| WasmError::InvalidInvite(err.to_string()))?;
188    Ok(WasmClient::from_invite(invite))
189}
190
191#[wasm_bindgen]
192impl WasmClient {
193    /// Construct a client from an invite.
194    fn from_invite(invite: Invite) -> Self {
195        let identity = Identity::generate();
196        let session = SessionId::from_channel(&invite.channel_name, invite.password.as_deref());
197        Self {
198            identity,
199            session,
200            channel_key: invite.channel_key,
201            seq: 0,
202        }
203    }
204
205    /// Return this client's peer id as hex.
206    #[wasm_bindgen(getter)]
207    pub fn peer_id(&self) -> String {
208        self.identity.peer_id.to_hex()
209    }
210
211    /// Return the session id as hex.
212    #[wasm_bindgen(getter)]
213    pub fn session_id(&self) -> String {
214        self.session.to_hex()
215    }
216
217    /// Encode a text message into an encrypted Rift frame.
218    #[wasm_bindgen]
219    pub fn encode_text(&mut self, text: String) -> Result<Uint8Array, JsValue> {
220        let timestamp = now_ms();
221        let message = ChatMessage::new(self.identity.peer_id, timestamp, text);
222        let payload = RiftPayload::Text(message);
223        let encrypted = self.encrypt_payload(&payload)?;
224        let header = RiftFrameHeader {
225            version: ProtocolVersion::V2,
226            stream: StreamKind::Text,
227            flags: 0,
228            seq: self.seq,
229            timestamp,
230            source: self.identity.peer_id,
231            session: self.session,
232        };
233        self.seq = self.seq.wrapping_add(1);
234        let frame = encode_frame(&header, &encrypted);
235        Ok(Uint8Array::from(frame.as_slice()))
236    }
237
238    /// Decode an encrypted Rift frame into a JSON-compatible JS object.
239    #[wasm_bindgen]
240    pub fn decode_text(&self, bytes: Uint8Array) -> Result<JsValue, JsValue> {
241        let data = bytes.to_vec();
242        let (_, payload) =
243            decode_frame(&data).map_err(|err| WasmError::FrameDecode(err.to_string()))?;
244        let decrypted = self.decrypt_payload(&payload)?;
245        let RiftPayload::Text(message) = decrypted else {
246            return Err(WasmError::PayloadDecode("not a text payload".to_string()).into());
247        };
248        let decoded = DecodedTextMessage {
249            from: message.from.to_hex(),
250            timestamp: message.timestamp,
251            text: message.text,
252        };
253        serde_wasm_bindgen::to_value(&decoded).map_err(|err| err.into())
254    }
255
256    /// Encrypt a payload using the channel key.
257    fn encrypt_payload(&self, payload: &RiftPayload) -> Result<RiftPayload, JsValue> {
258        let serialized = bincode::serialize(payload)
259            .map_err(|err| WasmError::PayloadDecode(err.to_string()))?;
260        let cipher = Aes256Gcm::new_from_slice(&self.channel_key)
261            .map_err(|_| WasmError::Cipher)?;
262        let nonce_bytes = random_nonce();
263        let nonce = Nonce::from_slice(&nonce_bytes);
264        let ciphertext = cipher
265            .encrypt(nonce, serialized.as_ref())
266            .map_err(|_| WasmError::Cipher)?;
267        Ok(RiftPayload::Encrypted(EncryptedPayload {
268            nonce: nonce_bytes,
269            ciphertext,
270        }))
271    }
272
273    /// Decrypt a payload using the channel key.
274    fn decrypt_payload(&self, payload: &RiftPayload) -> Result<RiftPayload, JsValue> {
275        let RiftPayload::Encrypted(encrypted) = payload else {
276            return Err(WasmError::PayloadDecode("missing encrypted payload".to_string()).into());
277        };
278        let cipher = Aes256Gcm::new_from_slice(&self.channel_key)
279            .map_err(|_| WasmError::Cipher)?;
280        let nonce = Nonce::from_slice(&encrypted.nonce);
281        let plaintext = cipher
282            .decrypt(nonce, encrypted.ciphertext.as_ref())
283            .map_err(|_| WasmError::Cipher)?;
284        let decoded: RiftPayload = bincode::deserialize(&plaintext)
285            .map_err(|err| WasmError::PayloadDecode(err.to_string()))?;
286        Ok(decoded)
287    }
288
289    /// Encode a voice frame into an encrypted Rift frame.
290    ///
291    /// The `opus_payload` should be pre-encoded Opus data from a browser
292    /// Opus encoder (e.g., via AudioWorklet + opus-wasm).
293    #[wasm_bindgen]
294    pub fn encode_voice(&mut self, opus_payload: Uint8Array) -> Result<Uint8Array, JsValue> {
295        let timestamp = now_ms();
296        let voice = VoicePacket {
297            codec_id: CodecId::Opus,
298            payload: opus_payload.to_vec(),
299        };
300        let payload = RiftPayload::Voice(voice);
301        let encrypted = self.encrypt_payload(&payload)?;
302        let header = RiftFrameHeader {
303            version: ProtocolVersion::V2,
304            stream: StreamKind::Voice,
305            flags: 0,
306            seq: self.seq,
307            timestamp,
308            source: self.identity.peer_id,
309            session: self.session,
310        };
311        self.seq = self.seq.wrapping_add(1);
312        let frame = encode_frame(&header, &encrypted);
313        Ok(Uint8Array::from(frame.as_slice()))
314    }
315
316    /// Encode a voice frame with raw PCM16 samples.
317    ///
318    /// Use this when you have raw Int16Array samples from Web Audio API.
319    /// The samples will be wrapped in a PCM16 codec frame.
320    #[wasm_bindgen]
321    pub fn encode_voice_pcm(&mut self, pcm_samples: Uint8Array) -> Result<Uint8Array, JsValue> {
322        let timestamp = now_ms();
323        let voice = VoicePacket {
324            codec_id: CodecId::PCM16,
325            payload: pcm_samples.to_vec(),
326        };
327        let payload = RiftPayload::Voice(voice);
328        let encrypted = self.encrypt_payload(&payload)?;
329        let header = RiftFrameHeader {
330            version: ProtocolVersion::V2,
331            stream: StreamKind::Voice,
332            flags: 0,
333            seq: self.seq,
334            timestamp,
335            source: self.identity.peer_id,
336            session: self.session,
337        };
338        self.seq = self.seq.wrapping_add(1);
339        let frame = encode_frame(&header, &encrypted);
340        Ok(Uint8Array::from(frame.as_slice()))
341    }
342
343    /// Decode an encrypted Rift voice frame.
344    ///
345    /// Returns a JS object with: from, timestamp, seq, codec, payload.
346    /// The payload is the encoded audio data (Opus or PCM16).
347    #[wasm_bindgen]
348    pub fn decode_voice(&self, bytes: Uint8Array) -> Result<JsValue, JsValue> {
349        let data = bytes.to_vec();
350        let (header, payload) =
351            decode_frame(&data).map_err(|err| WasmError::FrameDecode(err.to_string()))?;
352        let decrypted = self.decrypt_payload(&payload)?;
353        let RiftPayload::Voice(voice) = decrypted else {
354            return Err(WasmError::PayloadDecode("not a voice payload".to_string()).into());
355        };
356        let codec = match voice.codec_id {
357            CodecId::Opus => "opus".to_string(),
358            CodecId::PCM16 => "pcm16".to_string(),
359            CodecId::Experimental(id) => format!("experimental-{}", id),
360        };
361        let decoded = DecodedVoiceFrame {
362            from: header.source.to_hex(),
363            timestamp: header.timestamp,
364            seq: header.seq,
365            codec,
366            payload: voice.payload,
367        };
368        serde_wasm_bindgen::to_value(&decoded).map_err(|err| err.into())
369    }
370
371    /// Get the voice payload bytes from a decoded frame.
372    ///
373    /// This is a convenience method for extracting just the audio payload
374    /// without the metadata, useful for feeding directly to a decoder.
375    #[wasm_bindgen]
376    pub fn extract_voice_payload(&self, bytes: Uint8Array) -> Result<Uint8Array, JsValue> {
377        let data = bytes.to_vec();
378        let (_, payload) =
379            decode_frame(&data).map_err(|err| WasmError::FrameDecode(err.to_string()))?;
380        let decrypted = self.decrypt_payload(&payload)?;
381        let RiftPayload::Voice(voice) = decrypted else {
382            return Err(WasmError::PayloadDecode("not a voice payload".to_string()).into());
383        };
384        Ok(Uint8Array::from(voice.payload.as_slice()))
385    }
386
387    /// Get the current sequence number.
388    #[wasm_bindgen(getter)]
389    pub fn seq(&self) -> u32 {
390        self.seq
391    }
392}
393
394/// Current time in milliseconds (JS Date).
395fn now_ms() -> u64 {
396    Date::now() as u64
397}
398
399/// Generate a random AES-GCM nonce.
400fn random_nonce() -> [u8; 12] {
401    let mut nonce = [0u8; 12];
402    getrandom::getrandom(&mut nonce).expect("random nonce");
403    nonce
404}
405
406// ============================================================
407// Audio Utility Functions for Browser Integration
408// ============================================================
409
410/// Calculate the RMS audio level from PCM16 samples.
411///
412/// Takes a Uint8Array of little-endian i16 samples (as from Web Audio).
413/// Returns a normalized level from 0.0 (silent) to 1.0 (max).
414#[wasm_bindgen]
415pub fn audio_level(samples: &[i16]) -> f32 {
416    if samples.is_empty() {
417        return 0.0;
418    }
419    let mut sum = 0f64;
420    for s in samples {
421        let v = *s as f64;
422        sum += v * v;
423    }
424    let rms = (sum / samples.len() as f64).sqrt();
425    (rms / i16::MAX as f64) as f32
426}
427
428/// Calculate the RMS audio level from a Uint8Array of PCM16 bytes.
429///
430/// The bytes should be little-endian i16 samples.
431#[wasm_bindgen]
432pub fn audio_level_bytes(bytes: Uint8Array) -> f32 {
433    let data = bytes.to_vec();
434    if data.len() < 2 {
435        return 0.0;
436    }
437    let samples: Vec<i16> = data
438        .chunks_exact(2)
439        .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
440        .collect();
441    audio_level(&samples)
442}
443
444/// Check if an audio frame is "active" (contains speech).
445///
446/// Simple energy-based VAD: returns true if average amplitude exceeds threshold.
447/// Threshold is tuned for typical voice activity.
448#[wasm_bindgen]
449pub fn is_voice_active(samples: &[i16]) -> bool {
450    if samples.is_empty() {
451        return false;
452    }
453    let mut sum = 0i64;
454    for s in samples {
455        sum += (*s as i64).abs();
456    }
457    let avg = sum / samples.len() as i64;
458    avg > 250
459}
460
461/// Check if an audio frame is "active" from PCM16 bytes.
462#[wasm_bindgen]
463pub fn is_voice_active_bytes(bytes: Uint8Array) -> bool {
464    let data = bytes.to_vec();
465    if data.len() < 2 {
466        return false;
467    }
468    let samples: Vec<i16> = data
469        .chunks_exact(2)
470        .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
471        .collect();
472    is_voice_active(&samples)
473}
474
475/// Convert Float32Array audio samples to PCM16 bytes.
476///
477/// Useful for converting Web Audio API float samples to PCM16 format.
478/// Input should be normalized floats in range [-1.0, 1.0].
479#[wasm_bindgen]
480pub fn float32_to_pcm16(samples: &[f32]) -> Uint8Array {
481    let mut bytes = Vec::with_capacity(samples.len() * 2);
482    for s in samples {
483        let clamped = s.clamp(-1.0, 1.0);
484        let pcm = (clamped * i16::MAX as f32) as i16;
485        bytes.extend_from_slice(&pcm.to_le_bytes());
486    }
487    Uint8Array::from(bytes.as_slice())
488}
489
490/// Convert PCM16 bytes to Float32 samples.
491///
492/// Useful for feeding decoded audio to Web Audio API.
493/// Returns normalized floats in range [-1.0, 1.0].
494#[wasm_bindgen]
495pub fn pcm16_to_float32(bytes: Uint8Array) -> js_sys::Float32Array {
496    let data = bytes.to_vec();
497    let samples: Vec<f32> = data
498        .chunks_exact(2)
499        .map(|chunk| {
500            let pcm = i16::from_le_bytes([chunk[0], chunk[1]]);
501            pcm as f32 / i16::MAX as f32
502        })
503        .collect();
504    js_sys::Float32Array::from(samples.as_slice())
505}
506
507/// Compute audio level in decibels (dB) from RMS level.
508///
509/// Returns dB relative to full scale (0 dB = max amplitude).
510/// Silent audio returns -100.0 dB.
511#[wasm_bindgen]
512pub fn level_to_db(level: f32) -> f32 {
513    if level <= 0.0 {
514        return -100.0;
515    }
516    20.0 * level.log10()
517}
518
519/// Apply a simple gain to PCM16 samples.
520///
521/// Gain of 1.0 = no change, 2.0 = double amplitude, 0.5 = half amplitude.
522/// Values are clamped to prevent clipping.
523#[wasm_bindgen]
524pub fn apply_gain(bytes: Uint8Array, gain: f32) -> Uint8Array {
525    let data = bytes.to_vec();
526    let mut out = Vec::with_capacity(data.len());
527    for chunk in data.chunks_exact(2) {
528        let pcm = i16::from_le_bytes([chunk[0], chunk[1]]);
529        let amplified = (pcm as f32 * gain).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
530        out.extend_from_slice(&amplified.to_le_bytes());
531    }
532    Uint8Array::from(out.as_slice())
533}
534
535/// Mix two audio frames together.
536///
537/// Both frames must be the same length. Result is averaged to prevent clipping.
538#[wasm_bindgen]
539pub fn mix_frames(frame_a: Uint8Array, frame_b: Uint8Array) -> Result<Uint8Array, JsValue> {
540    let a = frame_a.to_vec();
541    let b = frame_b.to_vec();
542    if a.len() != b.len() {
543        return Err(JsValue::from_str("frames must be same length"));
544    }
545    let mut out = Vec::with_capacity(a.len());
546    for (chunk_a, chunk_b) in a.chunks_exact(2).zip(b.chunks_exact(2)) {
547        let pcm_a = i16::from_le_bytes([chunk_a[0], chunk_a[1]]) as i32;
548        let pcm_b = i16::from_le_bytes([chunk_b[0], chunk_b[1]]) as i32;
549        let mixed = ((pcm_a + pcm_b) / 2).clamp(i16::MIN as i32, i16::MAX as i32) as i16;
550        out.extend_from_slice(&mixed.to_le_bytes());
551    }
552    Ok(Uint8Array::from(out.as_slice()))
553}