Skip to main content

moq_transport/data/
extension_headers.rs

1use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair};
2use bytes::Buf;
3use std::fmt;
4
5/// A collection of KeyValuePair entries, where the length in bytes of key-value-pairs are encoded/decoded first.
6/// This structure is appropriate for Data plane extension headers.
7/// Since duplicate parameters are allowed for unknown extension headers, we don't do duplicate checking here.
8#[derive(Default, Clone, Eq, PartialEq)]
9pub struct ExtensionHeaders(pub Vec<KeyValuePair>);
10
11// TODO: These set/get API's all assume no duplicate keys. We can add API's to support duplicates if needed.
12impl ExtensionHeaders {
13    pub fn new() -> Self {
14        Self::default()
15    }
16
17    /// Insert or replace a KeyValuePair with the same key.
18    pub fn set(&mut self, kvp: KeyValuePair) {
19        if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) {
20            *existing = kvp;
21        } else {
22            self.0.push(kvp);
23        }
24    }
25
26    pub fn set_intvalue(&mut self, key: u64, value: u64) {
27        self.set(KeyValuePair::new_int(key, value));
28    }
29
30    pub fn set_bytesvalue(&mut self, key: u64, value: Vec<u8>) {
31        self.set(KeyValuePair::new_bytes(key, value));
32    }
33
34    pub fn has(&self, key: u64) -> bool {
35        self.0.iter().any(|k| k.key == key)
36    }
37
38    pub fn get(&self, key: u64) -> Option<&KeyValuePair> {
39        self.0.iter().find(|k| k.key == key)
40    }
41
42    pub fn is_empty(&self) -> bool {
43        self.0.is_empty()
44    }
45}
46
47impl Decode for ExtensionHeaders {
48    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
49        // Read total byte length of the encoded kvps
50        // Note: this is the difference between KeyValuePairs and ExtensionHeaders.
51        // KeyValuePairs encodes the count of kvps, whereas ExtensionHeaders encodes the total byte length.
52        let length = usize::decode(r)?;
53
54        // Ensure we have that many bytes available in the input
55        Self::decode_remaining(r, length)?;
56
57        // If zero length, return empty map
58        if length == 0 {
59            return Ok(ExtensionHeaders::new());
60        }
61
62        // Copy the exact slice that contains the encoded kvps and decode from it
63        let mut buf = vec![0u8; length];
64        r.copy_to_slice(&mut buf);
65        let mut kvps_bytes = bytes::Bytes::from(buf);
66
67        let mut kvps = Vec::new();
68        while kvps_bytes.has_remaining() {
69            let kvp = KeyValuePair::decode(&mut kvps_bytes)?;
70            kvps.push(kvp);
71        }
72
73        Ok(ExtensionHeaders(kvps))
74    }
75}
76
77impl Encode for ExtensionHeaders {
78    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
79        // Encode all KeyValuePair entries into a temporary buffer to compute total byte length
80        let mut tmp = bytes::BytesMut::new();
81        for kvp in &self.0 {
82            kvp.encode(&mut tmp)?;
83        }
84
85        // Write total byte length (u64) followed by the encoded bytes
86        (tmp.len() as u64).encode(w)?;
87        w.put_slice(&tmp);
88
89        Ok(())
90    }
91}
92
93impl fmt::Debug for ExtensionHeaders {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        write!(f, "{{ ")?;
96        for (i, kv) in self.0.iter().enumerate() {
97            if i > 0 {
98                write!(f, ", ")?;
99            }
100            write!(f, "{:?}", kv)?;
101        }
102        write!(f, " }}")
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use bytes::BytesMut;
110
111    #[test]
112    fn encode_decode_extension_headers() {
113        let mut buf = BytesMut::new();
114
115        let mut ext_hdrs = ExtensionHeaders::new();
116        ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]);
117        ext_hdrs.encode(&mut buf).unwrap();
118        assert_eq!(
119            buf.to_vec(),
120            vec![
121                0x07, // 7 bytes total length
122                0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5]
123            ]
124        );
125        let decoded = ExtensionHeaders::decode(&mut buf).unwrap();
126        assert_eq!(decoded, ext_hdrs);
127
128        let mut ext_hdrs = ExtensionHeaders::new();
129        ext_hdrs.set_intvalue(0, 0); // 2 bytes
130        ext_hdrs.set_intvalue(100, 100); // 4 bytes
131        ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); // 1 byte key, 1 byte length, 5 bytes data = 7 bytes
132        ext_hdrs.encode(&mut buf).unwrap();
133        let buf_vec = buf.to_vec();
134        // Validate the encoded length and the KeyValuePair's length.
135        assert_eq!(14, buf_vec.len()); // 14 bytes total (length + 3 kvps)
136        assert_eq!(13, buf_vec[0]); // 13 bytes for the 3 KeyValuePairs data
137        let decoded = ExtensionHeaders::decode(&mut buf).unwrap();
138        assert_eq!(decoded, ext_hdrs);
139    }
140}