Skip to main content

zendo_protocol/
decode.rs

1//! Pure decoding of binary frames into typed [`Message`] values.
2
3use crate::constants::{
4    BODY_ISB_ANGLE_COUNT, BODY_JOINT_COUNT, BODY_LANDMARK_COUNT, F64_BYTES, HAND_JOINT_COUNT,
5    HAND_LANDMARK_COUNT, MSG_BODY_ISB_ANGLES, MSG_BODY_LANDMARK, MSG_BODY_QUATERNION,
6    MSG_HAND_LANDMARK, MSG_HAND_QUATERNION, MSG_HELLO, TUPLE_BYTES,
7};
8use crate::error::ProtocolError;
9use crate::frames::{
10    BodyIsbAnglesFrame, BodyLandmarkFrame, BodyQuaternionFrame, HandLandmarkFrame,
11    HandQuaternionFrame,
12};
13use crate::types::{HandSide, Landmark, Quaternion};
14
15/// One decoded message from the Zendo stream.
16#[derive(Clone, Copy, Debug, PartialEq)]
17pub enum Message {
18    /// Body-joint orientations (`0x02`).
19    BodyQuaternions(BodyQuaternionFrame),
20    /// Body-landmark positions (`0x03`).
21    BodyLandmarks(BodyLandmarkFrame),
22    /// Hand-joint orientations for one hand (`0x04`).
23    HandQuaternions {
24        side: HandSide,
25        frame: HandQuaternionFrame,
26    },
27    /// Hand-landmark positions for one hand (`0x05`).
28    HandLandmarks {
29        side: HandSide,
30        frame: HandLandmarkFrame,
31    },
32    /// Body ISB joint angles in radians (`0x06`).
33    BodyIsbAngles(BodyIsbAnglesFrame),
34}
35
36/// Decodes one binary WebSocket frame.
37///
38/// `frame` is the whole message: byte 0 is the type tag, the rest is payload.
39pub fn decode(frame: &[u8]) -> Result<Message, ProtocolError> {
40    let (&tag, payload) = frame.split_first().ok_or(ProtocolError::EmptyFrame)?;
41    match tag {
42        MSG_BODY_QUATERNION => decode_body_quaternions(payload),
43        MSG_BODY_LANDMARK => decode_body_landmarks(payload),
44        MSG_HAND_QUATERNION => decode_hand_quaternions(payload),
45        MSG_HAND_LANDMARK => decode_hand_landmarks(payload),
46        MSG_BODY_ISB_ANGLES => decode_body_isb_angles(payload),
47        other => Err(ProtocolError::UnknownMessageType(other)),
48    }
49}
50
51/// Decodes a hello frame, returning the server's protocol version.
52///
53/// `frame` is the whole message, including the `0x01` type tag. The hello frame
54/// is separate from [`decode`] because it is protocol metadata, not a
55/// [`Message`] a consumer streams.
56pub fn decode_hello(frame: &[u8]) -> Result<u16, ProtocolError> {
57    let (&tag, payload) = frame.split_first().ok_or(ProtocolError::EmptyFrame)?;
58    if tag != MSG_HELLO {
59        return Err(ProtocolError::UnknownMessageType(tag));
60    }
61    if payload.len() != 2 {
62        return Err(ProtocolError::InvalidLength {
63            message_type: MSG_HELLO,
64            expected: 2,
65            actual: payload.len(),
66        });
67    }
68    Ok(u16::from_le_bytes([payload[0], payload[1]]))
69}
70
71fn decode_body_quaternions(payload: &[u8]) -> Result<Message, ProtocolError> {
72    const EXPECTED: usize = BODY_JOINT_COUNT * TUPLE_BYTES;
73    check_len(MSG_BODY_QUATERNION, payload, EXPECTED)?;
74
75    let mut quats = [Quaternion::default(); BODY_JOINT_COUNT];
76    for (i, q) in quats.iter_mut().enumerate() {
77        *q = read_quaternion(payload, i * TUPLE_BYTES);
78    }
79    Ok(Message::BodyQuaternions(BodyQuaternionFrame::from_array(
80        quats,
81    )))
82}
83
84fn decode_body_landmarks(payload: &[u8]) -> Result<Message, ProtocolError> {
85    const EXPECTED: usize = BODY_LANDMARK_COUNT * TUPLE_BYTES;
86    check_len(MSG_BODY_LANDMARK, payload, EXPECTED)?;
87
88    let mut landmarks = [Landmark::default(); BODY_LANDMARK_COUNT];
89    for (i, lm) in landmarks.iter_mut().enumerate() {
90        *lm = read_landmark(payload, i * TUPLE_BYTES);
91    }
92    Ok(Message::BodyLandmarks(BodyLandmarkFrame::from_array(
93        landmarks,
94    )))
95}
96
97fn decode_hand_quaternions(payload: &[u8]) -> Result<Message, ProtocolError> {
98    const EXPECTED: usize = 1 + HAND_JOINT_COUNT * TUPLE_BYTES;
99    check_len(MSG_HAND_QUATERNION, payload, EXPECTED)?;
100
101    let side = read_hand_side(payload)?;
102    let body = &payload[1..];
103    let mut quats = [Quaternion::default(); HAND_JOINT_COUNT];
104    for (i, q) in quats.iter_mut().enumerate() {
105        *q = read_quaternion(body, i * TUPLE_BYTES);
106    }
107    Ok(Message::HandQuaternions {
108        side,
109        frame: HandQuaternionFrame::from_array(quats),
110    })
111}
112
113fn decode_hand_landmarks(payload: &[u8]) -> Result<Message, ProtocolError> {
114    const EXPECTED: usize = 1 + HAND_LANDMARK_COUNT * TUPLE_BYTES;
115    check_len(MSG_HAND_LANDMARK, payload, EXPECTED)?;
116
117    let side = read_hand_side(payload)?;
118    let body = &payload[1..];
119    let mut landmarks = [Landmark::default(); HAND_LANDMARK_COUNT];
120    for (i, lm) in landmarks.iter_mut().enumerate() {
121        *lm = read_landmark(body, i * TUPLE_BYTES);
122    }
123    Ok(Message::HandLandmarks {
124        side,
125        frame: HandLandmarkFrame::from_array(landmarks),
126    })
127}
128
129fn decode_body_isb_angles(payload: &[u8]) -> Result<Message, ProtocolError> {
130    const EXPECTED: usize = BODY_ISB_ANGLE_COUNT * F64_BYTES;
131    check_len(MSG_BODY_ISB_ANGLES, payload, EXPECTED)?;
132
133    let mut values = [0.0f64; BODY_ISB_ANGLE_COUNT];
134    for (i, v) in values.iter_mut().enumerate() {
135        *v = read_f64(payload, i * F64_BYTES);
136    }
137    Ok(Message::BodyIsbAngles(BodyIsbAnglesFrame::from_array(
138        values,
139    )))
140}
141
142fn check_len(message_type: u8, payload: &[u8], expected: usize) -> Result<(), ProtocolError> {
143    if payload.len() == expected {
144        Ok(())
145    } else {
146        Err(ProtocolError::InvalidLength {
147            message_type,
148            expected,
149            actual: payload.len(),
150        })
151    }
152}
153
154/// Reads the side byte. The caller must have validated the payload length, so
155/// `payload[0]` is in bounds.
156fn read_hand_side(payload: &[u8]) -> Result<HandSide, ProtocolError> {
157    let byte = payload[0];
158    HandSide::from_byte(byte).ok_or(ProtocolError::InvalidHandSide(byte))
159}
160
161/// Reads four consecutive little-endian `f64`s as a quaternion. The caller must
162/// guarantee `buf[offset..offset + TUPLE_BYTES]` is in bounds.
163fn read_quaternion(buf: &[u8], offset: usize) -> Quaternion {
164    Quaternion {
165        w: read_f64(buf, offset),
166        x: read_f64(buf, offset + F64_BYTES),
167        y: read_f64(buf, offset + 2 * F64_BYTES),
168        z: read_f64(buf, offset + 3 * F64_BYTES),
169    }
170}
171
172/// Reads four consecutive little-endian `f64`s as a landmark. The caller must
173/// guarantee `buf[offset..offset + TUPLE_BYTES]` is in bounds.
174fn read_landmark(buf: &[u8], offset: usize) -> Landmark {
175    Landmark {
176        x: read_f64(buf, offset),
177        y: read_f64(buf, offset + F64_BYTES),
178        z: read_f64(buf, offset + 2 * F64_BYTES),
179        confidence: read_f64(buf, offset + 3 * F64_BYTES),
180    }
181}
182
183/// Reads one little-endian `f64`. The caller must guarantee
184/// `buf[offset..offset + F64_BYTES]` is in bounds.
185fn read_f64(buf: &[u8], offset: usize) -> f64 {
186    let mut bytes = [0u8; F64_BYTES];
187    bytes.copy_from_slice(&buf[offset..offset + F64_BYTES]);
188    f64::from_le_bytes(bytes)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::constants::{HAND_SIDE_LEFT, HAND_SIDE_RIGHT};
195
196    fn body_quaternion_frame() -> [u8; 1 + BODY_JOINT_COUNT * TUPLE_BYTES] {
197        let mut frame = [0u8; 1 + BODY_JOINT_COUNT * TUPLE_BYTES];
198        frame[0] = MSG_BODY_QUATERNION;
199        // First joint: w = 1.0
200        frame[1..9].copy_from_slice(&1.0f64.to_le_bytes());
201        frame
202    }
203
204    #[test]
205    fn decode_rejects_empty_frame() {
206        // Arrange / Act / Assert
207        assert_eq!(decode(&[]), Err(ProtocolError::EmptyFrame));
208    }
209
210    #[test]
211    fn decode_rejects_unknown_tag() {
212        // Arrange / Act / Assert
213        assert_eq!(
214            decode(&[0xFF]),
215            Err(ProtocolError::UnknownMessageType(0xFF))
216        );
217    }
218
219    #[test]
220    fn decode_rejects_wrong_length() {
221        // Arrange
222        let frame = [MSG_BODY_QUATERNION, 0, 0, 0];
223
224        // Act
225        let result = decode(&frame);
226
227        // Assert
228        assert_eq!(
229            result,
230            Err(ProtocolError::InvalidLength {
231                message_type: MSG_BODY_QUATERNION,
232                expected: BODY_JOINT_COUNT * TUPLE_BYTES,
233                actual: 3,
234            })
235        );
236    }
237
238    #[test]
239    fn decode_body_quaternions_reads_first_joint() {
240        // Arrange
241        let frame = body_quaternion_frame();
242
243        // Act
244        let message = decode(&frame).expect("valid frame");
245
246        // Assert
247        match message {
248            Message::BodyQuaternions(q) => {
249                assert_eq!(q.hips.w, 1.0);
250                assert_eq!(q.hips.x, 0.0);
251                assert_eq!(q.left_foot, Quaternion::default());
252            }
253            other => panic!("unexpected message: {other:?}"),
254        }
255    }
256
257    #[test]
258    fn decode_hand_quaternions_reads_side() {
259        // Arrange
260        let mut frame = [0u8; 2 + HAND_JOINT_COUNT * TUPLE_BYTES];
261        frame[0] = MSG_HAND_QUATERNION;
262        frame[1] = HAND_SIDE_LEFT;
263
264        // Act
265        let message = decode(&frame).expect("valid frame");
266
267        // Assert
268        match message {
269            Message::HandQuaternions { side, .. } => assert_eq!(side, HandSide::Left),
270            other => panic!("unexpected message: {other:?}"),
271        }
272    }
273
274    #[test]
275    fn decode_hand_landmarks_rejects_bad_side_byte() {
276        // Arrange
277        let mut frame = [0u8; 2 + HAND_LANDMARK_COUNT * TUPLE_BYTES];
278        frame[0] = MSG_HAND_LANDMARK;
279        frame[1] = 9; // neither 0 nor 1
280
281        // Act / Assert
282        assert_eq!(decode(&frame), Err(ProtocolError::InvalidHandSide(9)));
283    }
284
285    #[test]
286    fn decode_hand_landmarks_accepts_right_side() {
287        // Arrange
288        let mut frame = [0u8; 2 + HAND_LANDMARK_COUNT * TUPLE_BYTES];
289        frame[0] = MSG_HAND_LANDMARK;
290        frame[1] = HAND_SIDE_RIGHT;
291
292        // Act
293        let message = decode(&frame).expect("valid frame");
294
295        // Assert
296        match message {
297            Message::HandLandmarks { side, .. } => assert_eq!(side, HandSide::Right),
298            other => panic!("unexpected message: {other:?}"),
299        }
300    }
301
302    #[test]
303    fn decode_hello_reads_version() {
304        // Arrange — tag 0x01 followed by a u16 LE version.
305        let frame = [MSG_HELLO, 1, 0];
306
307        // Act / Assert
308        assert_eq!(decode_hello(&frame), Ok(1));
309    }
310
311    #[test]
312    fn decode_hello_rejects_wrong_tag() {
313        // Arrange / Act / Assert
314        assert_eq!(
315            decode_hello(&[MSG_BODY_QUATERNION, 1, 0]),
316            Err(ProtocolError::UnknownMessageType(MSG_BODY_QUATERNION))
317        );
318    }
319
320    #[test]
321    fn decode_hello_rejects_wrong_length() {
322        // Arrange / Act / Assert
323        assert_eq!(
324            decode_hello(&[MSG_HELLO, 1]),
325            Err(ProtocolError::InvalidLength {
326                message_type: MSG_HELLO,
327                expected: 2,
328                actual: 1,
329            })
330        );
331    }
332
333    #[test]
334    fn decode_body_isb_angles_reads_first_value() {
335        // Arrange
336        let mut frame = [0u8; 1 + BODY_ISB_ANGLE_COUNT * F64_BYTES];
337        frame[0] = MSG_BODY_ISB_ANGLES;
338        frame[1..9].copy_from_slice(&1.5f64.to_le_bytes());
339
340        // Act
341        let message = decode(&frame).expect("valid frame");
342
343        // Assert
344        match message {
345            Message::BodyIsbAngles(f) => assert_eq!(f.to_array()[0], 1.5),
346            other => panic!("unexpected message: {other:?}"),
347        }
348    }
349
350    #[test]
351    fn decode_body_isb_angles_rejects_short_frame() {
352        // Arrange
353        let frame = [MSG_BODY_ISB_ANGLES, 0, 0, 0];
354
355        // Act
356        let result = decode(&frame);
357
358        // Assert
359        assert_eq!(
360            result,
361            Err(ProtocolError::InvalidLength {
362                message_type: MSG_BODY_ISB_ANGLES,
363                expected: BODY_ISB_ANGLE_COUNT * F64_BYTES,
364                actual: 3,
365            })
366        );
367    }
368}