livekit-datatrack 0.1.4

Data track core for LiveKit
Documentation
// Copyright 2025 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::{
    consts::*, E2eeExt, ExtensionTag, Extensions, FrameMarker, Handle, HandleError, Header, Packet,
    Timestamp, UserTimestampExt,
};
use bytes::{Buf, Bytes};
use thiserror::Error;

#[derive(Error, Debug)]
pub enum DeserializeError {
    #[error("too short to contain a valid header")]
    TooShort,

    #[error("header exceeds total packet length")]
    HeaderOverrun,

    #[error("extension word indicator is missing")]
    MissingExtWords,

    #[error("unsupported version {0}")]
    UnsupportedVersion(u8),

    #[error("invalid track handle: {0}")]
    InvalidHandle(#[from] HandleError),

    #[error("extension with tag {0} is malformed")]
    MalformedExt(ExtensionTag),
}

impl Packet {
    pub fn deserialize(mut raw: Bytes) -> Result<Self, DeserializeError> {
        let header = Header::deserialize(&mut raw)?;
        let payload_len = raw.remaining();
        let payload = raw.copy_to_bytes(payload_len);
        Ok(Self { header, payload })
    }
}

impl Header {
    fn deserialize(raw: &mut impl Buf) -> Result<Self, DeserializeError> {
        if raw.remaining() < BASE_HEADER_LEN {
            Err(DeserializeError::TooShort)?
        }
        let initial = raw.get_u8();

        let version = initial >> VERSION_SHIFT & VERSION_MASK;
        if version > SUPPORTED_VERSION {
            Err(DeserializeError::UnsupportedVersion(version))?
        }
        let marker = match initial >> FRAME_MARKER_SHIFT & FRAME_MARKER_MASK {
            FRAME_MARKER_START => FrameMarker::Start,
            FRAME_MARKER_FINAL => FrameMarker::Final,
            FRAME_MARKER_SINGLE => FrameMarker::Single,
            _ => FrameMarker::Inter,
        };
        let ext_flag = (initial >> EXT_FLAG_SHIFT & EXT_FLAG_MASK) > 0;
        raw.advance(1); // Reserved

        let track_handle: Handle = raw.get_u16().try_into()?;
        let sequence = raw.get_u16();
        let frame_number = raw.get_u16();
        let timestamp = Timestamp::from_ticks(raw.get_u32());

        let mut extensions = Extensions::default();
        if ext_flag {
            if raw.remaining() < 2 {
                Err(DeserializeError::MissingExtWords)?;
            }
            let ext_words = raw.get_u16();

            let ext_len = 4 * (ext_words as usize + 1) - EXT_WORDS_INDICATOR_SIZE;
            if ext_len > raw.remaining() {
                Err(DeserializeError::HeaderOverrun)?
            }
            let ext_block = raw.copy_to_bytes(ext_len);
            extensions = Extensions::deserialize(ext_block)?;
        }

        let header = Header { marker, track_handle, sequence, frame_number, timestamp, extensions };
        Ok(header)
    }
}

macro_rules! deserialize_ext {
    ($ext_type:ty, $raw:expr, $len:expr) => {{
        if $raw.remaining() < $len {
            Err(DeserializeError::MalformedExt(<$ext_type>::TAG))?
        }
        let mut buf = [0u8; <$ext_type>::LEN];
        $raw.copy_to_slice(&mut buf);

        let extra_bytes = $len - <$ext_type>::LEN;
        if extra_bytes > 0 {
            // Extra bytes, possibly from future extension version (skip)
            $raw.advance(extra_bytes);
        }
        Some(<$ext_type>::deserialize(buf))
    }};
}

impl Extensions {
    fn deserialize(mut raw: impl Buf) -> Result<Self, DeserializeError> {
        let mut extensions = Self::default();
        while raw.remaining() >= 2 * size_of::<u8>() {
            let tag = raw.get_u8();
            let len = raw.get_u8() as usize;
            match tag {
                EXT_TAG_PADDING => {} // Skip padding
                E2eeExt::TAG if len >= E2eeExt::LEN => {
                    extensions.e2ee = deserialize_ext!(E2eeExt, raw, len);
                }
                UserTimestampExt::TAG if len >= UserTimestampExt::LEN => {
                    extensions.user_timestamp = deserialize_ext!(UserTimestampExt, raw, len);
                }
                _ => {
                    // Skip over unknown or length-mismatched extensions (forward compatible).
                    if raw.remaining() < len {
                        Err(DeserializeError::MalformedExt(tag))?
                    }
                    raw.advance(len);
                    continue;
                }
            }
        }
        Ok(extensions)
    }
}

impl UserTimestampExt {
    fn deserialize(raw: [u8; Self::LEN]) -> Self {
        let timestamp = u64::from_be_bytes(raw);
        Self(timestamp)
    }
}

impl E2eeExt {
    fn deserialize(raw: [u8; Self::LEN]) -> Self {
        let key_index = raw[0];
        let mut iv = [0u8; 12];
        iv.copy_from_slice(&raw[1..13]);
        Self { key_index, iv }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::{BufMut, BytesMut};
    use test_case::test_matrix;

    /// Returns the simplest valid packet to use in test.
    fn valid_packet() -> BytesMut {
        let mut raw = BytesMut::zeroed(12); // Base header
        raw[3] = 1; // Non-zero track handle
        raw
    }

    #[test]
    fn test_short_buffer() {
        let mut raw = valid_packet();
        raw.truncate(11);

        let packet = Packet::deserialize(raw.freeze());
        assert!(matches!(packet, Err(DeserializeError::TooShort)));
    }

    #[test]
    fn test_missing_ext_words() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
                                       // Should have ext word indicator here

        let packet = Packet::deserialize(raw.freeze());
        assert!(matches!(packet, Err(DeserializeError::MissingExtWords)));
    }

    #[test]
    fn test_header_overrun() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(1); // One extension word

        let packet = Packet::deserialize(raw.freeze());
        assert!(matches!(packet, Err(DeserializeError::HeaderOverrun)));
    }

    #[test]
    fn test_unsupported_version() {
        let mut raw = valid_packet();
        raw[0] = 0x20; // Version 1 (not supported yet)

        let packet = Packet::deserialize(raw.freeze());
        assert!(matches!(packet, Err(DeserializeError::UnsupportedVersion(1))));
    }

    #[test]
    fn test_base_header() {
        let mut raw = BytesMut::new();
        raw.put_u8(0x8); // Version 0, final flag set, no extensions
        raw.put_u8(0x0); // Reserved
        raw.put_slice(&[0x88, 0x11]); // Track ID
        raw.put_slice(&[0x44, 0x22]); // Sequence
        raw.put_slice(&[0x44, 0x11]); // Frame number
        raw.put_slice(&[0x44, 0x22, 0x11, 0x88]); // Timestamp

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        assert_eq!(packet.header.marker, FrameMarker::Final);
        assert_eq!(packet.header.track_handle, 0x8811u32.try_into().unwrap());
        assert_eq!(packet.header.sequence, 0x4422);
        assert_eq!(packet.header.frame_number, 0x4411);
        assert_eq!(packet.header.timestamp, Timestamp::from_ticks(0x44221188));
        assert_eq!(packet.header.extensions.user_timestamp, None);
        assert_eq!(packet.header.extensions.e2ee, None);
    }

    #[test_matrix([0, 1, 24])]
    fn test_ext_skips_padding(ext_words: usize) {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag

        raw.put_u16(ext_words as u16); // Extension words
        let data_len = (ext_words + 1) * 4 - EXT_WORDS_INDICATOR_SIZE;
        raw.put_bytes(0, data_len); // Padding

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        assert_eq!(packet.payload.len(), 0);
    }

    #[test]
    fn test_ext_e2ee() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(4); // Extension words (includes 2-byte indicator in word count)

        raw.put_u8(1); // ID 1
        raw.put_u8(13); // Length
        raw.put_u8(0xFA); // Key index
        raw.put_bytes(0x3C, 12); // IV
        raw.put_bytes(0, 3); // Padding

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        let e2ee = packet.header.extensions.e2ee.unwrap();
        assert_eq!(e2ee.key_index, 0xFA);
        assert_eq!(e2ee.iv, [0x3C; 12]);
    }

    #[test]
    fn test_ext_user_timestamp() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(2); // Extension words

        raw.put_u8(2);
        raw.put_u8(8); // Length
        raw.put_slice(&[0x44, 0x11, 0x22, 0x11, 0x11, 0x11, 0x88, 0x11]); // User timestamp

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        assert_eq!(
            packet.header.extensions.user_timestamp,
            UserTimestampExt(0x4411221111118811).into()
        );
    }

    #[test]
    fn test_ext_forward_compat_longer_length() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(3); // Extension words

        raw.put_u8(2); // User timestamp
        raw.put_u8(12); // Longer than known length (8), extra bytes are skipped
        raw.put_slice(&[0x44, 0x11, 0x22, 0x11, 0x11, 0x11, 0x88, 0x11]); // Known 8 bytes
        raw.put_bytes(0xFF, 4); // 4 extra bytes from a future version

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        assert_eq!(
            packet.header.extensions.user_timestamp,
            UserTimestampExt(0x4411221111118811).into()
        );
    }

    #[test]
    fn test_ext_shorter_than_known_length_skipped() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(1); // Extension words

        raw.put_u8(2); // User timestamp tag
        raw.put_u8(4); // Shorter than known length (8), treated as unknown
        raw.put_bytes(0x3C, 4);

        let packet = Packet::deserialize(raw.freeze()).unwrap();
        assert!(packet.header.extensions.user_timestamp.is_none());
    }

    #[test]
    fn test_ext_unknown() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(1); // Extension words

        raw.put_u8(8); // ID 8 (unknown)
        raw.put_u8(0); // Length 0
        raw.put_bytes(0, 4); // Remaining padding

        Packet::deserialize(raw.freeze()).expect("Should skip unknown extension");
    }

    #[test]
    fn test_ext_required_word_alignment() {
        let mut raw = valid_packet();
        raw[0] |= 1 << EXT_FLAG_SHIFT; // Extension flag
        raw.put_u16(0); // Extension words (data_budget = 2)
        raw.put_bytes(0, 1); // Only 1 byte, but 2 needed

        assert!(Packet::deserialize(raw.freeze()).is_err());
    }
}