zendo-protocol 0.1.3

Wire-protocol constants and binary frame decoders for the Zendo motion-tracking WebSocket stream.
Documentation
//! Pure decoding of binary frames into typed [`Message`] values.

use crate::constants::{
    BODY_JOINT_COUNT, BODY_LANDMARK_COUNT, F64_BYTES, HAND_JOINT_COUNT, HAND_LANDMARK_COUNT,
    MSG_BODY_LANDMARK, MSG_BODY_QUATERNION, MSG_HAND_LANDMARK, MSG_HAND_QUATERNION, MSG_HELLO,
    TUPLE_BYTES,
};
use crate::error::ProtocolError;
use crate::frames::{
    BodyLandmarkFrame, BodyQuaternionFrame, HandLandmarkFrame, HandQuaternionFrame,
};
use crate::types::{HandSide, Landmark, Quaternion};

/// One decoded message from the Zendo stream.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Message {
    /// Body-joint orientations (`0x02`).
    BodyQuaternions(BodyQuaternionFrame),
    /// Body-landmark positions (`0x03`).
    BodyLandmarks(BodyLandmarkFrame),
    /// Hand-joint orientations for one hand (`0x04`).
    HandQuaternions {
        side: HandSide,
        frame: HandQuaternionFrame,
    },
    /// Hand-landmark positions for one hand (`0x05`).
    HandLandmarks {
        side: HandSide,
        frame: HandLandmarkFrame,
    },
}

/// Decodes one binary WebSocket frame.
///
/// `frame` is the whole message: byte 0 is the type tag, the rest is payload.
pub fn decode(frame: &[u8]) -> Result<Message, ProtocolError> {
    let (&tag, payload) = frame.split_first().ok_or(ProtocolError::EmptyFrame)?;
    match tag {
        MSG_BODY_QUATERNION => decode_body_quaternions(payload),
        MSG_BODY_LANDMARK => decode_body_landmarks(payload),
        MSG_HAND_QUATERNION => decode_hand_quaternions(payload),
        MSG_HAND_LANDMARK => decode_hand_landmarks(payload),
        other => Err(ProtocolError::UnknownMessageType(other)),
    }
}

/// Decodes a hello frame, returning the server's protocol version.
///
/// `frame` is the whole message, including the `0x01` type tag. The hello frame
/// is separate from [`decode`] because it is protocol metadata, not a
/// [`Message`] a consumer streams.
pub fn decode_hello(frame: &[u8]) -> Result<u16, ProtocolError> {
    let (&tag, payload) = frame.split_first().ok_or(ProtocolError::EmptyFrame)?;
    if tag != MSG_HELLO {
        return Err(ProtocolError::UnknownMessageType(tag));
    }
    if payload.len() != 2 {
        return Err(ProtocolError::InvalidLength {
            message_type: MSG_HELLO,
            expected: 2,
            actual: payload.len(),
        });
    }
    Ok(u16::from_le_bytes([payload[0], payload[1]]))
}

fn decode_body_quaternions(payload: &[u8]) -> Result<Message, ProtocolError> {
    const EXPECTED: usize = BODY_JOINT_COUNT * TUPLE_BYTES;
    check_len(MSG_BODY_QUATERNION, payload, EXPECTED)?;

    let mut quats = [Quaternion::default(); BODY_JOINT_COUNT];
    for (i, q) in quats.iter_mut().enumerate() {
        *q = read_quaternion(payload, i * TUPLE_BYTES);
    }
    Ok(Message::BodyQuaternions(BodyQuaternionFrame::from_array(
        quats,
    )))
}

fn decode_body_landmarks(payload: &[u8]) -> Result<Message, ProtocolError> {
    const EXPECTED: usize = BODY_LANDMARK_COUNT * TUPLE_BYTES;
    check_len(MSG_BODY_LANDMARK, payload, EXPECTED)?;

    let mut landmarks = [Landmark::default(); BODY_LANDMARK_COUNT];
    for (i, lm) in landmarks.iter_mut().enumerate() {
        *lm = read_landmark(payload, i * TUPLE_BYTES);
    }
    Ok(Message::BodyLandmarks(BodyLandmarkFrame::from_array(
        landmarks,
    )))
}

fn decode_hand_quaternions(payload: &[u8]) -> Result<Message, ProtocolError> {
    const EXPECTED: usize = 1 + HAND_JOINT_COUNT * TUPLE_BYTES;
    check_len(MSG_HAND_QUATERNION, payload, EXPECTED)?;

    let side = read_hand_side(payload)?;
    let body = &payload[1..];
    let mut quats = [Quaternion::default(); HAND_JOINT_COUNT];
    for (i, q) in quats.iter_mut().enumerate() {
        *q = read_quaternion(body, i * TUPLE_BYTES);
    }
    Ok(Message::HandQuaternions {
        side,
        frame: HandQuaternionFrame::from_array(quats),
    })
}

fn decode_hand_landmarks(payload: &[u8]) -> Result<Message, ProtocolError> {
    const EXPECTED: usize = 1 + HAND_LANDMARK_COUNT * TUPLE_BYTES;
    check_len(MSG_HAND_LANDMARK, payload, EXPECTED)?;

    let side = read_hand_side(payload)?;
    let body = &payload[1..];
    let mut landmarks = [Landmark::default(); HAND_LANDMARK_COUNT];
    for (i, lm) in landmarks.iter_mut().enumerate() {
        *lm = read_landmark(body, i * TUPLE_BYTES);
    }
    Ok(Message::HandLandmarks {
        side,
        frame: HandLandmarkFrame::from_array(landmarks),
    })
}

fn check_len(message_type: u8, payload: &[u8], expected: usize) -> Result<(), ProtocolError> {
    if payload.len() == expected {
        Ok(())
    } else {
        Err(ProtocolError::InvalidLength {
            message_type,
            expected,
            actual: payload.len(),
        })
    }
}

/// Reads the side byte. The caller must have validated the payload length, so
/// `payload[0]` is in bounds.
fn read_hand_side(payload: &[u8]) -> Result<HandSide, ProtocolError> {
    let byte = payload[0];
    HandSide::from_byte(byte).ok_or(ProtocolError::InvalidHandSide(byte))
}

/// Reads four consecutive little-endian `f64`s as a quaternion. The caller must
/// guarantee `buf[offset..offset + TUPLE_BYTES]` is in bounds.
fn read_quaternion(buf: &[u8], offset: usize) -> Quaternion {
    Quaternion {
        w: read_f64(buf, offset),
        x: read_f64(buf, offset + F64_BYTES),
        y: read_f64(buf, offset + 2 * F64_BYTES),
        z: read_f64(buf, offset + 3 * F64_BYTES),
    }
}

/// Reads four consecutive little-endian `f64`s as a landmark. The caller must
/// guarantee `buf[offset..offset + TUPLE_BYTES]` is in bounds.
fn read_landmark(buf: &[u8], offset: usize) -> Landmark {
    Landmark {
        x: read_f64(buf, offset),
        y: read_f64(buf, offset + F64_BYTES),
        z: read_f64(buf, offset + 2 * F64_BYTES),
        confidence: read_f64(buf, offset + 3 * F64_BYTES),
    }
}

/// Reads one little-endian `f64`. The caller must guarantee
/// `buf[offset..offset + F64_BYTES]` is in bounds.
fn read_f64(buf: &[u8], offset: usize) -> f64 {
    let mut bytes = [0u8; F64_BYTES];
    bytes.copy_from_slice(&buf[offset..offset + F64_BYTES]);
    f64::from_le_bytes(bytes)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::constants::{HAND_SIDE_LEFT, HAND_SIDE_RIGHT};

    fn body_quaternion_frame() -> [u8; 1 + BODY_JOINT_COUNT * TUPLE_BYTES] {
        let mut frame = [0u8; 1 + BODY_JOINT_COUNT * TUPLE_BYTES];
        frame[0] = MSG_BODY_QUATERNION;
        // First joint: w = 1.0
        frame[1..9].copy_from_slice(&1.0f64.to_le_bytes());
        frame
    }

    #[test]
    fn decode_rejects_empty_frame() {
        // Arrange / Act / Assert
        assert_eq!(decode(&[]), Err(ProtocolError::EmptyFrame));
    }

    #[test]
    fn decode_rejects_unknown_tag() {
        // Arrange / Act / Assert
        assert_eq!(
            decode(&[0xFF]),
            Err(ProtocolError::UnknownMessageType(0xFF))
        );
    }

    #[test]
    fn decode_rejects_wrong_length() {
        // Arrange
        let frame = [MSG_BODY_QUATERNION, 0, 0, 0];

        // Act
        let result = decode(&frame);

        // Assert
        assert_eq!(
            result,
            Err(ProtocolError::InvalidLength {
                message_type: MSG_BODY_QUATERNION,
                expected: BODY_JOINT_COUNT * TUPLE_BYTES,
                actual: 3,
            })
        );
    }

    #[test]
    fn decode_body_quaternions_reads_first_joint() {
        // Arrange
        let frame = body_quaternion_frame();

        // Act
        let message = decode(&frame).expect("valid frame");

        // Assert
        match message {
            Message::BodyQuaternions(q) => {
                assert_eq!(q.hips.w, 1.0);
                assert_eq!(q.hips.x, 0.0);
                assert_eq!(q.left_foot, Quaternion::default());
            }
            other => panic!("unexpected message: {other:?}"),
        }
    }

    #[test]
    fn decode_hand_quaternions_reads_side() {
        // Arrange
        let mut frame = [0u8; 2 + HAND_JOINT_COUNT * TUPLE_BYTES];
        frame[0] = MSG_HAND_QUATERNION;
        frame[1] = HAND_SIDE_LEFT;

        // Act
        let message = decode(&frame).expect("valid frame");

        // Assert
        match message {
            Message::HandQuaternions { side, .. } => assert_eq!(side, HandSide::Left),
            other => panic!("unexpected message: {other:?}"),
        }
    }

    #[test]
    fn decode_hand_landmarks_rejects_bad_side_byte() {
        // Arrange
        let mut frame = [0u8; 2 + HAND_LANDMARK_COUNT * TUPLE_BYTES];
        frame[0] = MSG_HAND_LANDMARK;
        frame[1] = 9; // neither 0 nor 1

        // Act / Assert
        assert_eq!(decode(&frame), Err(ProtocolError::InvalidHandSide(9)));
    }

    #[test]
    fn decode_hand_landmarks_accepts_right_side() {
        // Arrange
        let mut frame = [0u8; 2 + HAND_LANDMARK_COUNT * TUPLE_BYTES];
        frame[0] = MSG_HAND_LANDMARK;
        frame[1] = HAND_SIDE_RIGHT;

        // Act
        let message = decode(&frame).expect("valid frame");

        // Assert
        match message {
            Message::HandLandmarks { side, .. } => assert_eq!(side, HandSide::Right),
            other => panic!("unexpected message: {other:?}"),
        }
    }

    #[test]
    fn decode_hello_reads_version() {
        // Arrange — tag 0x01 followed by a u16 LE version.
        let frame = [MSG_HELLO, 1, 0];

        // Act / Assert
        assert_eq!(decode_hello(&frame), Ok(1));
    }

    #[test]
    fn decode_hello_rejects_wrong_tag() {
        // Arrange / Act / Assert
        assert_eq!(
            decode_hello(&[MSG_BODY_QUATERNION, 1, 0]),
            Err(ProtocolError::UnknownMessageType(MSG_BODY_QUATERNION))
        );
    }

    #[test]
    fn decode_hello_rejects_wrong_length() {
        // Arrange / Act / Assert
        assert_eq!(
            decode_hello(&[MSG_HELLO, 1]),
            Err(ProtocolError::InvalidLength {
                message_type: MSG_HELLO,
                expected: 2,
                actual: 1,
            })
        );
    }
}