takproto 0.4.2

Rust library for TAK (Team Awareness Kit) Protocol - send CoT messages to TAK servers with mTLS support
Documentation
use crate::error::{Result, TakError};
use bytes::{Buf, BufMut, BytesMut};
use prost::Message;

/// Magic byte used in TAK protocol headers (0xbf)
const TAK_MAGIC_BYTE: u8 = 0xbf;

/// Encodes a varint according to Google Protocol Buffers unsigned varint rules
pub fn encode_varint(value: u64, buf: &mut BytesMut) {
    let mut val = value;
    loop {
        let mut byte = (val & 0x7F) as u8;
        val >>= 7;
        if val != 0 {
            byte |= 0x80;
        }
        buf.put_u8(byte);
        if val == 0 {
            break;
        }
    }
}

/// Decodes a varint from a buffer
pub fn decode_varint(buf: &mut impl Buf) -> Result<u64> {
    let mut result: u64 = 0;
    let mut shift = 0;

    for _ in 0..10 {
        // Max 10 bytes for 64-bit varint
        if !buf.has_remaining() {
            return Err(TakError::InvalidVarint);
        }

        let byte = buf.get_u8();
        result |= ((byte & 0x7F) as u64) << shift;

        if byte & 0x80 == 0 {
            return Ok(result);
        }

        shift += 7;
    }

    Err(TakError::InvalidVarint)
}

/// Encodes a TAK message for streaming connection
///
/// Format: <magic byte 0xbf> <message length varint> <protobuf payload>
pub fn encode_tak_message<M: Message>(message: &M) -> Result<BytesMut> {
    // Encode the protobuf message
    let mut payload = BytesMut::new();
    message.encode(&mut payload)?;

    let payload_len = payload.len() as u64;

    // Create the frame with header
    let mut frame = BytesMut::with_capacity(1 + 10 + payload.len());

    // Magic byte
    frame.put_u8(TAK_MAGIC_BYTE);

    // Message length as varint
    encode_varint(payload_len, &mut frame);

    // Payload
    frame.put(payload);

    Ok(frame)
}

/// Decodes a TAK streaming message header and returns the payload length
///
/// Returns Ok(Some(length)) if a complete header is available
/// Returns Ok(None) if more data is needed
pub fn decode_tak_header(buf: &mut BytesMut) -> Result<Option<usize>> {
    if buf.is_empty() {
        return Ok(None);
    }

    // Need at least the magic byte
    if buf.len() < 1 {
        return Ok(None);
    }

    // Check magic byte
    if buf[0] != TAK_MAGIC_BYTE {
        return Err(TakError::InvalidMessage(format!(
            "Expected magic byte 0xbf, got 0x{:02x}",
            buf[0]
        )));
    }

    // Try to decode the varint length
    // We need to peek at the buffer to see if we have the complete varint
    let mut temp_buf = &buf[1..];
    let start_remaining = temp_buf.remaining();

    match decode_varint(&mut temp_buf) {
        Ok(length) => {
            let varint_len = start_remaining - temp_buf.remaining();
            let header_len = 1 + varint_len;

            // Advance the buffer past the header
            buf.advance(header_len);

            Ok(Some(length as usize))
        }
        Err(TakError::InvalidVarint) => {
            // Need more data for the varint
            Ok(None)
        }
        Err(e) => Err(e),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_varint_encoding() {
        let test_cases = vec![
            (0u64, vec![0x00]),
            (1u64, vec![0x01]),
            (127u64, vec![0x7F]),
            (128u64, vec![0x80, 0x01]),
            (300u64, vec![0xAC, 0x02]),
        ];

        for (value, expected) in test_cases {
            let mut buf = BytesMut::new();
            encode_varint(value, &mut buf);
            assert_eq!(buf.to_vec(), expected, "Failed for value {}", value);

            // Test decoding
            let mut decode_buf = buf.clone();
            let decoded = decode_varint(&mut decode_buf).unwrap();
            assert_eq!(decoded, value, "Failed decoding value {}", value);
        }
    }

    #[test]
    fn test_tak_header() {
        let mut buf = BytesMut::new();
        buf.put_u8(TAK_MAGIC_BYTE);
        encode_varint(42, &mut buf);

        let length = decode_tak_header(&mut buf).unwrap().unwrap();
        assert_eq!(length, 42);
    }
}