Skip to main content

wscall_protocol/
lib.rs

1//! Shared protocol definitions for WSCALL.
2//!
3//! This crate contains the transport envelope, frame codec, encryption modes,
4//! and inline attachment model used by both the server and client crates.
5
6use aes_gcm::{Aes256Gcm, KeyInit as AesKeyInit, Nonce as AesNonce, aead::Aead as AesAead};
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as BASE64;
9use chacha20poly1305::{ChaCha20Poly1305, Nonce};
10use getrandom::getrandom;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13use thiserror::Error;
14
15const AES256_NONCE_LEN: usize = 12;
16const CHACHA20_NONCE_LEN: usize = 12;
17const MAX_FRAME_BYTES: usize = 10 * 1024 * 1024;
18const MAX_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - 6;
19
20/// Distinguishes API messages from event messages inside a WSCALL frame.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[repr(u8)]
23pub enum MessageType {
24    /// API request or API response.
25    Api = 0x00,
26    /// Event emit or event acknowledgement.
27    Event = 0x01,
28}
29
30impl TryFrom<u8> for MessageType {
31    type Error = ProtocolError;
32
33    fn try_from(value: u8) -> Result<Self, Self::Error> {
34        match value {
35            0x00 => Ok(Self::Api),
36            0x01 => Ok(Self::Event),
37            _ => Err(ProtocolError::UnknownMessageType(value)),
38        }
39    }
40}
41
42/// Selects how the payload section of a frame is encoded.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[repr(u8)]
45pub enum EncryptionKind {
46    /// Unencrypted JSON payload.
47    None = 0x00,
48    /// ChaCha20-Poly1305 encrypted payload.
49    ChaCha20 = 0x01,
50    /// AES256-GCM encrypted payload.
51    Aes256 = 0x02,
52}
53
54impl TryFrom<u8> for EncryptionKind {
55    type Error = ProtocolError;
56
57    fn try_from(value: u8) -> Result<Self, Self::Error> {
58        match value {
59            0x00 => Ok(Self::None),
60            0x01 => Ok(Self::ChaCha20),
61            0x02 => Ok(Self::Aes256),
62            _ => Err(ProtocolError::UnknownEncryption(value)),
63        }
64    }
65}
66
67/// Inline attachment carried alongside JSON params or event data.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FileAttachment {
70    /// Attachment identifier referenced from JSON using `{ "$file": "..." }`.
71    pub id: String,
72    /// Original file name.
73    pub name: String,
74    /// MIME type supplied by the sender.
75    pub content_type: String,
76    /// Content transfer encoding. Current implementation uses Base64.
77    pub encoding: String,
78    /// Encoded attachment payload.
79    pub data: String,
80    /// Original byte length before encoding.
81    pub size: usize,
82}
83
84impl FileAttachment {
85    /// Builds an inline text attachment and encodes it as Base64.
86    pub fn inline_text(
87        id: impl Into<String>,
88        name: impl Into<String>,
89        content_type: impl Into<String>,
90        text: impl AsRef<str>,
91    ) -> Self {
92        Self::inline_bytes(id, name, content_type, text.as_ref().as_bytes().to_vec())
93    }
94
95    /// Builds an inline binary attachment and encodes it as Base64.
96    pub fn inline_bytes(
97        id: impl Into<String>,
98        name: impl Into<String>,
99        content_type: impl Into<String>,
100        bytes: Vec<u8>,
101    ) -> Self {
102        let size = bytes.len();
103        Self {
104            id: id.into(),
105            name: name.into(),
106            content_type: content_type.into(),
107            encoding: "base64".to_string(),
108            data: BASE64.encode(bytes),
109            size,
110        }
111    }
112
113    /// Decodes the attachment payload back into raw bytes.
114    pub fn decode_bytes(&self) -> Result<Vec<u8>, ProtocolError> {
115        BASE64
116            .decode(self.data.as_bytes())
117            .map_err(|source| ProtocolError::InvalidAttachmentEncoding(source.to_string()))
118    }
119
120    /// Returns a JSON reference object that points to an attachment by id.
121    pub fn param_ref(id: impl Into<String>) -> Value {
122        json!({ "$file": id.into() })
123    }
124}
125
126/// Standard error payload embedded in API responses and event acknowledgements.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ErrorPayload {
129    pub code: String,
130    pub message: String,
131    pub status: u16,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub details: Option<Value>,
134}
135
136/// JSON-level message body transported inside a WSCALL frame.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(tag = "kind", rename_all = "snake_case")]
139pub enum PacketBody {
140    /// Client-to-server API request.
141    ApiRequest {
142        request_id: String,
143        route: String,
144        params: Value,
145        attachments: Vec<FileAttachment>,
146        metadata: Value,
147    },
148    /// Server-to-client API response.
149    ApiResponse {
150        request_id: String,
151        ok: bool,
152        status: u16,
153        data: Value,
154        #[serde(skip_serializing_if = "Option::is_none")]
155        error: Option<ErrorPayload>,
156        metadata: Value,
157    },
158    /// Event emission in either direction.
159    EventEmit {
160        event_id: String,
161        name: String,
162        data: Value,
163        attachments: Vec<FileAttachment>,
164        metadata: Value,
165        expect_ack: bool,
166    },
167    /// Acknowledgement for an emitted event.
168    EventAck {
169        event_id: String,
170        ok: bool,
171        receipt: Value,
172        #[serde(skip_serializing_if = "Option::is_none")]
173        error: Option<ErrorPayload>,
174    },
175}
176
177impl PacketBody {
178    pub fn message_type(&self) -> MessageType {
179        match self {
180            Self::ApiRequest { .. } | Self::ApiResponse { .. } => MessageType::Api,
181            Self::EventEmit { .. } | Self::EventAck { .. } => MessageType::Event,
182        }
183    }
184}
185
186/// Full transport envelope before frame encoding.
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct PacketEnvelope {
189    /// Message category declared in the frame header.
190    pub message_type: MessageType,
191    /// Encryption mode declared in the frame header.
192    pub encryption: EncryptionKind,
193    /// JSON body payload.
194    pub body: PacketBody,
195}
196
197impl PacketEnvelope {
198    /// Builds a plaintext envelope from a body.
199    pub fn new(body: PacketBody) -> Self {
200        Self {
201            message_type: body.message_type(),
202            encryption: EncryptionKind::None,
203            body,
204        }
205    }
206
207    /// Builds an envelope with an explicit encryption mode.
208    pub fn with_encryption(body: PacketBody, encryption: EncryptionKind) -> Self {
209        Self {
210            message_type: body.message_type(),
211            encryption,
212            body,
213        }
214    }
215}
216
217/// Encodes and decodes WSCALL binary frames.
218#[derive(Debug, Clone, Default)]
219pub struct FrameCodec {
220    aes256_key: Option<[u8; 32]>,
221    chacha20_key: Option<[u8; 32]>,
222}
223
224impl FrameCodec {
225    /// Builds a codec configured for plaintext transport.
226    pub fn plaintext() -> Self {
227        Self::default()
228    }
229
230    /// Configures a ChaCha20-Poly1305 key.
231    pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
232        self.chacha20_key = Some(key);
233        self
234    }
235
236    /// Configures an AES256-GCM key.
237    pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
238        self.aes256_key = Some(key);
239        self
240    }
241
242    /// Encodes an envelope into a binary WSCALL frame.
243    pub fn encode(&self, packet: &PacketEnvelope) -> Result<Vec<u8>, ProtocolError> {
244        let payload = serde_json::to_vec(&packet.body)?;
245        let payload = match packet.encryption {
246            EncryptionKind::None => payload,
247            EncryptionKind::ChaCha20 => self.encrypt_chacha20(&payload)?,
248            EncryptionKind::Aes256 => self.encrypt_aes256(&payload)?,
249        };
250
251        if payload.len() > MAX_PAYLOAD_BYTES {
252            return Err(ProtocolError::PayloadTooLarge {
253                actual: payload.len(),
254                max: MAX_PAYLOAD_BYTES,
255            });
256        }
257
258        let frame_len = 2 + payload.len();
259        let mut frame = Vec::with_capacity(4 + frame_len);
260        frame.extend_from_slice(&(frame_len as u32).to_be_bytes());
261        frame.push(packet.message_type as u8);
262        frame.push(packet.encryption as u8);
263        frame.extend_from_slice(&payload);
264        Ok(frame)
265    }
266
267    /// Decodes a binary WSCALL frame back into an envelope.
268    pub fn decode(&self, frame: &[u8]) -> Result<PacketEnvelope, ProtocolError> {
269        if frame.len() < 6 {
270            return Err(ProtocolError::FrameTooShort);
271        }
272
273        let declared = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
274        let actual = frame.len() - 4;
275        if declared != actual {
276            return Err(ProtocolError::FrameLengthMismatch { declared, actual });
277        }
278
279        let payload_len = actual - 2;
280        if payload_len > MAX_PAYLOAD_BYTES {
281            return Err(ProtocolError::PayloadTooLarge {
282                actual: payload_len,
283                max: MAX_PAYLOAD_BYTES,
284            });
285        }
286
287        let message_type = MessageType::try_from(frame[4])?;
288        let encryption = EncryptionKind::try_from(frame[5])?;
289        let payload = match encryption {
290            EncryptionKind::None => frame[6..].to_vec(),
291            EncryptionKind::ChaCha20 => self.decrypt_chacha20(&frame[6..])?,
292            EncryptionKind::Aes256 => self.decrypt_aes256(&frame[6..])?,
293        };
294
295        let body: PacketBody = serde_json::from_slice(&payload)?;
296        if body.message_type() != message_type {
297            return Err(ProtocolError::MessageTypeMismatch);
298        }
299
300        Ok(PacketEnvelope {
301            message_type,
302            encryption,
303            body,
304        })
305    }
306
307    fn encrypt_chacha20(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
308        let key = self
309            .chacha20_key
310            .ok_or(ProtocolError::MissingEncryptionKey("chacha20"))?;
311        let cipher = ChaCha20Poly1305::new_from_slice(&key)
312            .map_err(|_| ProtocolError::InvalidEncryptionKey("chacha20"))?;
313        let mut nonce_bytes = [0_u8; CHACHA20_NONCE_LEN];
314        getrandom(&mut nonce_bytes).map_err(|source| ProtocolError::Random(source.to_string()))?;
315        let ciphertext = cipher
316            .encrypt(Nonce::from_slice(&nonce_bytes), payload)
317            .map_err(|_| ProtocolError::EncryptionFailed("chacha20"))?;
318
319        let mut encoded = Vec::with_capacity(CHACHA20_NONCE_LEN + ciphertext.len());
320        encoded.extend_from_slice(&nonce_bytes);
321        encoded.extend_from_slice(&ciphertext);
322        Ok(encoded)
323    }
324
325    fn decrypt_chacha20(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
326        if payload.len() < CHACHA20_NONCE_LEN {
327            return Err(ProtocolError::EncryptedPayloadTooShort {
328                algorithm: "chacha20",
329                expected_min: CHACHA20_NONCE_LEN,
330                actual: payload.len(),
331            });
332        }
333
334        let key = self
335            .chacha20_key
336            .ok_or(ProtocolError::MissingEncryptionKey("chacha20"))?;
337        let cipher = ChaCha20Poly1305::new_from_slice(&key)
338            .map_err(|_| ProtocolError::InvalidEncryptionKey("chacha20"))?;
339        let (nonce_bytes, ciphertext) = payload.split_at(CHACHA20_NONCE_LEN);
340        cipher
341            .decrypt(Nonce::from_slice(nonce_bytes), ciphertext)
342            .map_err(|_| ProtocolError::DecryptionFailed("chacha20"))
343    }
344
345    fn encrypt_aes256(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
346        let key = self
347            .aes256_key
348            .ok_or(ProtocolError::MissingEncryptionKey("aes256"))?;
349        let cipher = Aes256Gcm::new_from_slice(&key)
350            .map_err(|_| ProtocolError::InvalidEncryptionKey("aes256"))?;
351        let mut nonce_bytes = [0_u8; AES256_NONCE_LEN];
352        getrandom(&mut nonce_bytes).map_err(|source| ProtocolError::Random(source.to_string()))?;
353        let ciphertext = cipher
354            .encrypt(AesNonce::from_slice(&nonce_bytes), payload)
355            .map_err(|_| ProtocolError::EncryptionFailed("aes256"))?;
356
357        let mut encoded = Vec::with_capacity(AES256_NONCE_LEN + ciphertext.len());
358        encoded.extend_from_slice(&nonce_bytes);
359        encoded.extend_from_slice(&ciphertext);
360        Ok(encoded)
361    }
362
363    fn decrypt_aes256(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
364        if payload.len() < AES256_NONCE_LEN {
365            return Err(ProtocolError::EncryptedPayloadTooShort {
366                algorithm: "aes256",
367                expected_min: AES256_NONCE_LEN,
368                actual: payload.len(),
369            });
370        }
371
372        let key = self
373            .aes256_key
374            .ok_or(ProtocolError::MissingEncryptionKey("aes256"))?;
375        let cipher = Aes256Gcm::new_from_slice(&key)
376            .map_err(|_| ProtocolError::InvalidEncryptionKey("aes256"))?;
377        let (nonce_bytes, ciphertext) = payload.split_at(AES256_NONCE_LEN);
378        cipher
379            .decrypt(AesNonce::from_slice(nonce_bytes), ciphertext)
380            .map_err(|_| ProtocolError::DecryptionFailed("aes256"))
381    }
382}
383
384/// Helper for encoding a plaintext frame without constructing a custom codec.
385pub fn encode_frame(packet: &PacketEnvelope) -> Result<Vec<u8>, ProtocolError> {
386    FrameCodec::plaintext().encode(packet)
387}
388
389/// Helper for decoding a plaintext frame without constructing a custom codec.
390pub fn decode_frame(frame: &[u8]) -> Result<PacketEnvelope, ProtocolError> {
391    FrameCodec::plaintext().decode(frame)
392}
393
394/// Errors returned while encoding or decoding WSCALL frames.
395#[derive(Debug, Error)]
396pub enum ProtocolError {
397    #[error("frame too short")]
398    FrameTooShort,
399    #[error("frame length mismatch: declared={declared}, actual={actual}")]
400    FrameLengthMismatch { declared: usize, actual: usize },
401    #[error("payload too large: actual={actual}, max={max}")]
402    PayloadTooLarge { actual: usize, max: usize },
403    #[error("unknown message type: {0:#x}")]
404    UnknownMessageType(u8),
405    #[error("unknown encryption kind: {0:#x}")]
406    UnknownEncryption(u8),
407    #[error("unsupported encryption kind: {0:#x}")]
408    UnsupportedEncryption(u8),
409    #[error("missing encryption key for {0}")]
410    MissingEncryptionKey(&'static str),
411    #[error("invalid encryption key for {0}")]
412    InvalidEncryptionKey(&'static str),
413    #[error("secure random generation failed: {0}")]
414    Random(String),
415    #[error(
416        "encrypted payload too short for {algorithm}: expected at least {expected_min}, actual={actual}"
417    )]
418    EncryptedPayloadTooShort {
419        algorithm: &'static str,
420        expected_min: usize,
421        actual: usize,
422    },
423    #[error("encryption failed for {0}")]
424    EncryptionFailed(&'static str),
425    #[error("decryption failed for {0}")]
426    DecryptionFailed(&'static str),
427    #[error("message type does not match packet body")]
428    MessageTypeMismatch,
429    #[error("invalid attachment encoding: {0}")]
430    InvalidAttachmentEncoding(String),
431    #[error("json error: {0}")]
432    Json(#[from] serde_json::Error),
433}
434
435#[cfg(test)]
436mod tests {
437    use super::{
438        EncryptionKind, FrameCodec, MAX_PAYLOAD_BYTES, MessageType, PacketBody, PacketEnvelope,
439        ProtocolError, decode_frame, encode_frame,
440    };
441    use serde_json::json;
442
443    const TEST_KEY: [u8; 32] = [0x11; 32];
444
445    #[test]
446    fn plaintext_helpers_still_work() {
447        let packet = PacketEnvelope::new(PacketBody::EventAck {
448            event_id: "evt-1".to_string(),
449            ok: true,
450            receipt: json!({ "ok": true }),
451            error: None,
452        });
453
454        let encoded = encode_frame(&packet).expect("encode plaintext");
455        let decoded = decode_frame(&encoded).expect("decode plaintext");
456        assert!(matches!(decoded.encryption, EncryptionKind::None));
457    }
458
459    #[test]
460    fn aes256_roundtrip_works() {
461        let codec = FrameCodec::plaintext().with_aes256_key(TEST_KEY);
462        let packet = PacketEnvelope::with_encryption(
463            PacketBody::ApiResponse {
464                request_id: "req-1".to_string(),
465                ok: true,
466                status: 200,
467                data: json!({ "message": "encrypted" }),
468                error: None,
469                metadata: json!({}),
470            },
471            EncryptionKind::Aes256,
472        );
473
474        let encoded = codec.encode(&packet).expect("encode aes256");
475        let decoded = codec.decode(&encoded).expect("decode aes256");
476        assert!(matches!(decoded.encryption, EncryptionKind::Aes256));
477    }
478
479    #[test]
480    fn encode_rejects_payloads_over_limit() {
481        let codec = FrameCodec::plaintext();
482        let packet = PacketEnvelope::new(PacketBody::ApiResponse {
483            request_id: "req-oversize".to_string(),
484            ok: true,
485            status: 200,
486            data: json!({ "blob": "a".repeat(10 * 1024 * 1024) }),
487            error: None,
488            metadata: json!({}),
489        });
490
491        let error = codec
492            .encode(&packet)
493            .expect_err("oversized payload should fail");
494        assert!(matches!(error, ProtocolError::PayloadTooLarge { .. }));
495    }
496
497    #[test]
498    fn decode_rejects_payloads_over_limit() {
499        let payload = vec![0_u8; MAX_PAYLOAD_BYTES + 1];
500        let frame_len = 2 + payload.len();
501        let mut frame = Vec::with_capacity(4 + frame_len);
502        frame.extend_from_slice(&(frame_len as u32).to_be_bytes());
503        frame.push(MessageType::Api as u8);
504        frame.push(EncryptionKind::None as u8);
505        frame.extend_from_slice(&payload);
506
507        let error = FrameCodec::plaintext()
508            .decode(&frame)
509            .expect_err("oversized payload should fail");
510        assert!(matches!(error, ProtocolError::PayloadTooLarge { .. }));
511    }
512}