commucat_proto/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3use std::convert::TryFrom;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6
7pub mod call;
8
9#[cfg(feature = "obfuscation")]
10mod obfuscation;
11
12#[cfg(feature = "obfuscation")]
13pub use obfuscation::{
14    AdaptiveMimicPolicy, AdaptiveObfuscator, AmnesiaSignature, CensorshipSignal, DaitaProfile,
15    DnsPacketSnapshot, ObfuscatedPacket, ObfuscationError, ObfuscationKey, ProtocolFlavor,
16    ProtocolMimicry, ProtocolSnapshot, QuicHandshakeSnapshot, RealityTicket, SipMessageSnapshot,
17    SipMethod, TlsHandshakeSnapshot, WebRtcDataChannelSnapshot,
18};
19
20pub const PROTOCOL_VERSION: u16 = 1;
21pub const SUPPORTED_PROTOCOL_VERSIONS: &[u16] = &[PROTOCOL_VERSION];
22pub const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
23pub const MAX_CONTROL_JSON_LEN: usize = 256 * 1024;
24pub const MAX_CHANNEL_ID: u64 = u32::MAX as u64;
25pub const MAX_SEQUENCE: u64 = u32::MAX as u64;
26
27/// Returns true when the provided protocol version is supported by the current codec.
28pub fn is_supported_protocol_version(version: u16) -> bool {
29    SUPPORTED_PROTOCOL_VERSIONS
30        .iter()
31        .copied()
32        .any(|v| v == version)
33}
34
35/// Picks the highest mutually supported protocol version between peers.
36pub fn negotiate_protocol_version(peer_versions: &[u16]) -> Option<u16> {
37    let mut negotiated = None;
38    for version in peer_versions.iter().copied() {
39        if is_supported_protocol_version(version) {
40            negotiated = match negotiated {
41                Some(current) if current >= version => Some(current),
42                _ => Some(version),
43            };
44        }
45    }
46    negotiated
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50#[repr(u8)]
51pub enum FrameType {
52    Hello = 0x01,
53    Auth = 0x02,
54    Join = 0x03,
55    Leave = 0x04,
56    Msg = 0x05,
57    Ack = 0x06,
58    Typing = 0x07,
59    Presence = 0x08,
60    KeyUpdate = 0x09,
61    GroupCreate = 0x0a,
62    GroupInvite = 0x0b,
63    GroupEvent = 0x0c,
64    Error = 0x0d,
65    CallOffer = 0x0e,
66    CallAnswer = 0x0f,
67    CallEnd = 0x10,
68    VoiceFrame = 0x11,
69    VideoFrame = 0x12,
70    CallStats = 0x13,
71    TransportUpdate = 0x14,
72}
73
74impl FrameType {
75    fn from_u8(value: u8) -> Option<Self> {
76        match value {
77            0x01 => Some(Self::Hello),
78            0x02 => Some(Self::Auth),
79            0x03 => Some(Self::Join),
80            0x04 => Some(Self::Leave),
81            0x05 => Some(Self::Msg),
82            0x06 => Some(Self::Ack),
83            0x07 => Some(Self::Typing),
84            0x08 => Some(Self::Presence),
85            0x09 => Some(Self::KeyUpdate),
86            0x0a => Some(Self::GroupCreate),
87            0x0b => Some(Self::GroupInvite),
88            0x0c => Some(Self::GroupEvent),
89            0x0d => Some(Self::Error),
90            0x0e => Some(Self::CallOffer),
91            0x0f => Some(Self::CallAnswer),
92            0x10 => Some(Self::CallEnd),
93            0x11 => Some(Self::VoiceFrame),
94            0x12 => Some(Self::VideoFrame),
95            0x13 => Some(Self::CallStats),
96            0x14 => Some(Self::TransportUpdate),
97            _ => None,
98        }
99    }
100}
101
102#[derive(Debug)]
103pub enum CodecError {
104    InvalidFrameType,
105    InvalidControlJson,
106    UnexpectedEof,
107    VarintOverflow,
108    PayloadTooLarge,
109    FrameTooLarge,
110    ControlTooLarge,
111    ChannelIdTooLarge,
112    SequenceTooLarge,
113}
114
115impl Display for CodecError {
116    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Self::InvalidFrameType => write!(f, "invalid frame type"),
119            Self::InvalidControlJson => write!(f, "invalid control payload"),
120            Self::UnexpectedEof => write!(f, "unexpected end of frame"),
121            Self::VarintOverflow => write!(f, "varint overflow"),
122            Self::PayloadTooLarge => write!(f, "payload exceeds limits"),
123            Self::FrameTooLarge => write!(f, "frame exceeds limits"),
124            Self::ControlTooLarge => write!(f, "control payload exceeds limits"),
125            Self::ChannelIdTooLarge => write!(f, "channel id exceeds limits"),
126            Self::SequenceTooLarge => write!(f, "sequence exceeds limits"),
127        }
128    }
129}
130
131impl Error for CodecError {}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
134#[serde(transparent)]
135pub struct ControlEnvelope {
136    pub properties: serde_json::Value,
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
140pub enum FramePayload {
141    Control(ControlEnvelope),
142    Opaque(Vec<u8>),
143}
144
145impl FramePayload {
146    fn bytes(&self) -> Result<Cow<'_, [u8]>, CodecError> {
147        match self {
148            FramePayload::Control(ctrl) => {
149                let encoded =
150                    serde_json::to_vec(ctrl).map_err(|_| CodecError::InvalidControlJson)?;
151                if encoded.len() > MAX_CONTROL_JSON_LEN {
152                    return Err(CodecError::ControlTooLarge);
153                }
154                Ok(Cow::Owned(encoded))
155            }
156            FramePayload::Opaque(data) => Ok(Cow::Borrowed(data)),
157        }
158    }
159
160    fn from_bytes(frame_type: FrameType, data: &[u8]) -> Result<Self, CodecError> {
161        match frame_type {
162            FrameType::Msg | FrameType::VoiceFrame | FrameType::VideoFrame => {
163                Ok(FramePayload::Opaque(data.to_vec()))
164            }
165            FrameType::Hello
166            | FrameType::Auth
167            | FrameType::Join
168            | FrameType::Leave
169            | FrameType::Ack
170            | FrameType::Typing
171            | FrameType::Presence
172            | FrameType::GroupCreate
173            | FrameType::GroupInvite
174            | FrameType::GroupEvent
175            | FrameType::Error
176            | FrameType::CallOffer
177            | FrameType::CallAnswer
178            | FrameType::CallEnd
179            | FrameType::CallStats
180            | FrameType::TransportUpdate
181            | FrameType::KeyUpdate => {
182                if data.len() > MAX_CONTROL_JSON_LEN {
183                    return Err(CodecError::ControlTooLarge);
184                }
185                serde_json::from_slice::<ControlEnvelope>(data)
186                    .map(FramePayload::Control)
187                    .map_err(|_| CodecError::InvalidControlJson)
188            }
189        }
190    }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub struct Frame {
195    pub channel_id: u64,
196    pub sequence: u64,
197    pub frame_type: FrameType,
198    pub payload: FramePayload,
199}
200
201impl Frame {
202    /// Serializes a frame into a length prefixed binary representation.
203    pub fn encode(&self) -> Result<Vec<u8>, CodecError> {
204        if self.channel_id > MAX_CHANNEL_ID {
205            return Err(CodecError::ChannelIdTooLarge);
206        }
207        if self.sequence > MAX_SEQUENCE {
208            return Err(CodecError::SequenceTooLarge);
209        }
210        let payload = self.payload.bytes()?;
211        if payload.len() > MAX_FRAME_LEN {
212            return Err(CodecError::PayloadTooLarge);
213        }
214        let mut body = Vec::new();
215        body.push(self.frame_type as u8);
216        encode_varint(self.channel_id, &mut body);
217        encode_varint(self.sequence, &mut body);
218        encode_varint(payload.len() as u64, &mut body);
219        body.extend_from_slice(payload.as_ref());
220        if body.len() > MAX_FRAME_LEN {
221            return Err(CodecError::FrameTooLarge);
222        }
223        let mut encoded = Vec::new();
224        encode_varint(body.len() as u64, &mut encoded);
225        encoded.extend_from_slice(&body);
226        Ok(encoded)
227    }
228
229    /// Attempts to decode a frame from a contiguous buffer.
230    pub fn decode(buffer: &[u8]) -> Result<(Self, usize), CodecError> {
231        let (frame_len_raw, header_len) = decode_varint(buffer)?;
232        let frame_len = usize::try_from(frame_len_raw).map_err(|_| CodecError::FrameTooLarge)?;
233        if frame_len > MAX_FRAME_LEN {
234            return Err(CodecError::FrameTooLarge);
235        }
236        if buffer.len() < header_len + frame_len {
237            return Err(CodecError::UnexpectedEof);
238        }
239        let frame_slice = &buffer[header_len..header_len + frame_len];
240        if frame_slice.is_empty() {
241            return Err(CodecError::UnexpectedEof);
242        }
243        let frame_type = FrameType::from_u8(frame_slice[0]).ok_or(CodecError::InvalidFrameType)?;
244        let mut cursor = 1;
245        let (channel_id, read) = decode_varint(&frame_slice[cursor..])?;
246        cursor += read;
247        if channel_id > MAX_CHANNEL_ID {
248            return Err(CodecError::ChannelIdTooLarge);
249        }
250        let (sequence, read) = decode_varint(&frame_slice[cursor..])?;
251        cursor += read;
252        if sequence > MAX_SEQUENCE {
253            return Err(CodecError::SequenceTooLarge);
254        }
255        let (payload_len_raw, read) = decode_varint(&frame_slice[cursor..])?;
256        cursor += read;
257        let payload_len =
258            usize::try_from(payload_len_raw).map_err(|_| CodecError::PayloadTooLarge)?;
259        if payload_len > MAX_FRAME_LEN {
260            return Err(CodecError::PayloadTooLarge);
261        }
262        if frame_slice.len() < cursor + payload_len {
263            return Err(CodecError::UnexpectedEof);
264        }
265        let payload_slice = &frame_slice[cursor..cursor + payload_len];
266        let payload = FramePayload::from_bytes(frame_type, payload_slice)?;
267        let total = header_len + frame_len;
268        Ok((
269            Frame {
270                channel_id,
271                sequence,
272                frame_type,
273                payload,
274            },
275            total,
276        ))
277    }
278}
279
280fn encode_varint(mut value: u64, buffer: &mut Vec<u8>) {
281    while value >= 0x80 {
282        buffer.push(((value as u8) & 0x7f) | 0x80);
283        value >>= 7;
284    }
285    buffer.push(value as u8);
286}
287
288fn decode_varint(buffer: &[u8]) -> Result<(u64, usize), CodecError> {
289    let mut value = 0u64;
290    let mut shift = 0u32;
291    for (index, byte) in buffer.iter().enumerate() {
292        let part = (byte & 0x7f) as u64;
293        value |= part << shift;
294        if byte & 0x80 == 0 {
295            return Ok((value, index + 1));
296        }
297        shift += 7;
298        if shift > 63 {
299            return Err(CodecError::VarintOverflow);
300        }
301    }
302    Err(CodecError::UnexpectedEof)
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn negotiate_version_prefers_highest_supported() {
311        let negotiated = negotiate_protocol_version(&[0, PROTOCOL_VERSION, PROTOCOL_VERSION + 1]);
312        assert_eq!(negotiated, Some(PROTOCOL_VERSION));
313    }
314
315    #[test]
316    fn negotiate_version_none_when_disjoint() {
317        let negotiated = negotiate_protocol_version(&[42, 43]);
318        assert_eq!(negotiated, None);
319    }
320
321    #[test]
322    fn supported_version_predicate_matches() {
323        assert!(is_supported_protocol_version(PROTOCOL_VERSION));
324        assert!(!is_supported_protocol_version(PROTOCOL_VERSION + 5));
325    }
326
327    #[test]
328    fn encode_roundtrip_control_frame() {
329        let frame = Frame {
330            channel_id: 12,
331            sequence: 34,
332            frame_type: FrameType::Hello,
333            payload: FramePayload::Control(ControlEnvelope {
334                properties: serde_json::json!({
335                    "protocol_version": PROTOCOL_VERSION,
336                    "capabilities": ["noise", "zstd"],
337                }),
338            }),
339        };
340        let encoded = frame.encode().unwrap();
341        let (decoded, read) = Frame::decode(&encoded).unwrap();
342        assert_eq!(read, encoded.len());
343        assert_eq!(decoded.channel_id, 12);
344        assert_eq!(decoded.sequence, 34);
345        assert_eq!(decoded.frame_type, FrameType::Hello);
346        match decoded.payload {
347            FramePayload::Control(ctrl) => {
348                let version = ctrl.properties.get("protocol_version").unwrap();
349                assert_eq!(version.as_u64(), Some(PROTOCOL_VERSION as u64));
350            }
351            _ => panic!("unexpected payload"),
352        }
353    }
354
355    #[test]
356    fn key_update_roundtrip_returns_control() {
357        let frame = Frame {
358            channel_id: 5,
359            sequence: 2,
360            frame_type: FrameType::KeyUpdate,
361            payload: FramePayload::Control(ControlEnvelope {
362                properties: serde_json::json!({
363                    "type": "device-key-rotated",
364                    "rotation_id": "rot-123",
365                }),
366            }),
367        };
368        let encoded = frame.encode().expect("encode");
369        let (decoded, _) = Frame::decode(&encoded).expect("decode");
370        assert_eq!(decoded.frame_type, FrameType::KeyUpdate);
371        match decoded.payload {
372            FramePayload::Control(ctrl) => {
373                assert_eq!(
374                    ctrl.properties["type"],
375                    serde_json::json!("device-key-rotated")
376                );
377            }
378            other => panic!("expected control payload, got {:?}", other),
379        }
380    }
381
382    #[test]
383    fn encode_roundtrip_opaque_frame() {
384        let frame = Frame {
385            channel_id: 9,
386            sequence: 1,
387            frame_type: FrameType::Msg,
388            payload: FramePayload::Opaque(vec![1, 2, 3, 4]),
389        };
390        let encoded = frame.encode().unwrap();
391        let (decoded, _read) = Frame::decode(&encoded).unwrap();
392        assert_eq!(decoded.payload, FramePayload::Opaque(vec![1, 2, 3, 4]));
393    }
394
395    #[test]
396    fn encode_roundtrip_voice_frame() {
397        let frame = Frame {
398            channel_id: 42,
399            sequence: 5,
400            frame_type: FrameType::VoiceFrame,
401            payload: FramePayload::Opaque(vec![0xaa, 0xbb, 0xcc, 0xdd]),
402        };
403        let encoded = frame.encode().unwrap();
404        let (decoded, _read) = Frame::decode(&encoded).unwrap();
405        assert_eq!(decoded.frame_type, FrameType::VoiceFrame);
406        assert_eq!(
407            decoded.payload,
408            FramePayload::Opaque(vec![0xaa, 0xbb, 0xcc, 0xdd])
409        );
410    }
411
412    #[test]
413    fn encode_roundtrip_call_offer_frame() {
414        use crate::call::{AudioParameters, VideoParameters};
415        use crate::call::{
416            CallMediaProfile, IceCandidateType, IceCredentials, TransportCandidate,
417            TransportProtocol,
418        };
419        use crate::call::{CallMode, CallOffer, CallTransport};
420        use commucat_media_types::{VideoCodec, VideoResolution};
421        use std::convert::TryInto;
422
423        let video = VideoParameters {
424            codec: VideoCodec::Vp8,
425            max_bitrate: 500_000,
426            max_resolution: VideoResolution::default(),
427            frame_rate: 24,
428            adaptive: true,
429            ..VideoParameters::default()
430        };
431
432        let offer = CallOffer {
433            call_id: "call-xyz".to_string(),
434            from: "alice:device".to_string(),
435            to: vec!["bob:device".to_string()],
436            media: CallMediaProfile {
437                audio: AudioParameters::default(),
438                video: Some(video),
439                mode: CallMode::FullDuplex,
440                capabilities: None,
441            },
442            metadata: serde_json::json!({"mode": "voice"}),
443            transport: Some(CallTransport {
444                prefer_relay: false,
445                candidates: vec![TransportCandidate {
446                    address: "198.51.100.12".to_string(),
447                    port: 60_000,
448                    protocol: TransportProtocol::Udp,
449                    foundation: Some("f1".to_string()),
450                    component: Some(1),
451                    priority: Some(12_345_678),
452                    candidate_type: Some(IceCandidateType::Srflx),
453                    related_address: Some("10.0.0.5".to_string()),
454                    related_port: Some(52_333),
455                    tcp_type: None,
456                    sdp_mid: Some("0".to_string()),
457                    sdp_mline_index: Some(0),
458                    url: None,
459                }],
460                fingerprints: vec!["deadbeef".to_string()],
461                ice_credentials: Some(IceCredentials {
462                    username_fragment: "ufrag".to_string(),
463                    password: "pwd".to_string(),
464                    expires_at: None,
465                }),
466                trickle: false,
467                consent_interval_secs: None,
468            }),
469            expires_at: Some(1_690_000_000),
470            ephemeral_key: None,
471        };
472        let envelope: ControlEnvelope = (&offer).try_into().unwrap();
473        let frame = Frame {
474            channel_id: 7,
475            sequence: 9,
476            frame_type: FrameType::CallOffer,
477            payload: FramePayload::Control(envelope.clone()),
478        };
479        let encoded = frame.encode().unwrap();
480        let (decoded, _) = Frame::decode(&encoded).unwrap();
481        assert_eq!(decoded.frame_type, FrameType::CallOffer);
482        match decoded.payload {
483            FramePayload::Control(ctrl) => {
484                let decoded_offer = CallOffer::try_from(&ctrl).unwrap();
485                assert_eq!(decoded_offer.call_id, offer.call_id);
486                assert_eq!(decoded_offer.media.mode, CallMode::FullDuplex);
487            }
488            _ => panic!("expected control payload"),
489        }
490    }
491
492    #[test]
493    fn decode_multiple_frames_in_sequence() {
494        let frame1 = Frame {
495            channel_id: 7,
496            sequence: 11,
497            frame_type: FrameType::Hello,
498            payload: FramePayload::Control(ControlEnvelope {
499                properties: serde_json::json!({
500                    "protocol_version": PROTOCOL_VERSION,
501                }),
502            }),
503        };
504        let frame2 = Frame {
505            channel_id: 7,
506            sequence: 12,
507            frame_type: FrameType::Msg,
508            payload: FramePayload::Opaque(vec![9, 8, 7]),
509        };
510        let mut concatenated = frame1.encode().unwrap();
511        let second = frame2.encode().unwrap();
512        let first_len = concatenated.len();
513        concatenated.extend_from_slice(&second);
514        let (decoded1, read1) = Frame::decode(&concatenated).unwrap();
515        assert_eq!(read1, first_len);
516        assert_eq!(decoded1.sequence, 11);
517        let (decoded2, read2) = Frame::decode(&concatenated[read1..]).unwrap();
518        assert_eq!(read1 + read2, concatenated.len());
519        assert_eq!(decoded2.payload, FramePayload::Opaque(vec![9, 8, 7]));
520    }
521
522    #[test]
523    fn decode_rejects_payload_length_mismatch() {
524        let frame = Frame {
525            channel_id: 1,
526            sequence: 2,
527            frame_type: FrameType::Msg,
528            payload: FramePayload::Opaque(vec![0xaa, 0xbb, 0xcc]),
529        };
530        let mut encoded = frame.encode().unwrap();
531        let (_, header_len) = decode_varint(&encoded).unwrap();
532        let mut cursor = header_len + 1;
533        let (_, read) = decode_varint(&encoded[cursor..]).unwrap();
534        cursor += read;
535        let (_, read) = decode_varint(&encoded[cursor..]).unwrap();
536        cursor += read;
537        let (payload_len, read) = decode_varint(&encoded[cursor..]).unwrap();
538        let mut new_len = Vec::new();
539        encode_varint(payload_len + 1, &mut new_len);
540        let start = cursor;
541        let end = cursor + read;
542        encoded.splice(start..end, new_len);
543        assert!(matches!(
544            Frame::decode(&encoded),
545            Err(CodecError::UnexpectedEof)
546        ));
547    }
548
549    #[test]
550    fn decode_rejects_varint_overflow() {
551        let buffer = vec![0xff; 10];
552        assert!(matches!(
553            Frame::decode(&buffer),
554            Err(CodecError::VarintOverflow)
555        ));
556    }
557
558    #[test]
559    fn decode_rejects_unknown_frame_type() {
560        let frame = Frame {
561            channel_id: 3,
562            sequence: 4,
563            frame_type: FrameType::Ack,
564            payload: FramePayload::Control(ControlEnvelope {
565                properties: serde_json::json!({"status": "ok"}),
566            }),
567        };
568        let mut encoded = frame.encode().unwrap();
569        let (_, header_len) = decode_varint(&encoded).unwrap();
570        encoded[header_len] = 0xff;
571        assert!(matches!(
572            Frame::decode(&encoded),
573            Err(CodecError::InvalidFrameType)
574        ));
575    }
576
577    #[test]
578    fn decode_rejects_oversized_frame() {
579        let mut buffer = Vec::new();
580        encode_varint((MAX_FRAME_LEN + 1) as u64, &mut buffer);
581        assert!(matches!(
582            Frame::decode(&buffer),
583            Err(CodecError::FrameTooLarge)
584        ));
585    }
586
587    #[test]
588    fn encode_group_control_frame() {
589        let frame = Frame {
590            channel_id: 99,
591            sequence: 1,
592            frame_type: FrameType::GroupInvite,
593            payload: FramePayload::Control(ControlEnvelope {
594                properties: serde_json::json!({
595                    "group_id": "grp-1",
596                    "device": "dev-1",
597                }),
598            }),
599        };
600        let encoded = frame.encode().unwrap();
601        let (decoded, _) = Frame::decode(&encoded).unwrap();
602        assert_eq!(decoded.frame_type, FrameType::GroupInvite);
603    }
604
605    #[test]
606    fn encode_large_batch() {
607        let mut buffer = Vec::new();
608        for index in 0..512u64 {
609            let frame = Frame {
610                channel_id: index,
611                sequence: index,
612                frame_type: FrameType::Msg,
613                payload: FramePayload::Opaque(vec![0u8; 16]),
614            };
615            let encoded = frame.encode().unwrap();
616            buffer.extend_from_slice(&encoded);
617        }
618        let mut cursor = buffer.as_slice();
619        let mut decoded = 0;
620        while !cursor.is_empty() {
621            let (frame, read) = Frame::decode(cursor).unwrap();
622            assert_eq!(frame.frame_type, FrameType::Msg);
623            cursor = &cursor[read..];
624            decoded += 1;
625        }
626        assert_eq!(decoded, 512);
627    }
628}