librqbit_utp/
raw.rs

1pub mod ext_close_reason;
2pub mod selective_ack;
3
4use tracing::trace;
5
6use crate::{Error, constants::UTP_HEADER, seq_nr::SeqNr};
7
8const NO_NEXT_EXT: u8 = 0;
9const EXT_SELECTIVE_ACK: u8 = 1;
10const EXT_CLOSE_REASON: u8 = 3;
11
12#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
13#[allow(non_camel_case_types)]
14pub enum Type {
15    ST_DATA = 0,
16    ST_FIN = 1,
17    #[default]
18    ST_STATE = 2,
19    ST_RESET = 3,
20    ST_SYN = 4,
21}
22
23impl Type {
24    fn from_number(num: u8) -> Option<Type> {
25        match num {
26            0 => Some(Type::ST_DATA),
27            1 => Some(Type::ST_FIN),
28            2 => Some(Type::ST_STATE),
29            3 => Some(Type::ST_RESET),
30            4 => Some(Type::ST_SYN),
31            _ => None,
32        }
33    }
34
35    fn to_number(self) -> u8 {
36        match self {
37            Type::ST_DATA => 0,
38            Type::ST_FIN => 1,
39            Type::ST_STATE => 2,
40            Type::ST_RESET => 3,
41            Type::ST_SYN => 4,
42        }
43    }
44}
45
46#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
47pub struct Extensions {
48    pub selective_ack: Option<selective_ack::SelectiveAck>,
49    pub close_reason: Option<ext_close_reason::LibTorrentCloseReason>,
50}
51
52#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub struct UtpHeader {
54    pub htype: Type,                            // 4 bits type and 4 bits version
55    pub connection_id: SeqNr,                   // Connection ID
56    pub timestamp_microseconds: u32,            // Timestamp in microseconds
57    pub timestamp_difference_microseconds: u32, // Timestamp difference in microseconds
58    pub wnd_size: u32,                          // Window size
59    pub seq_nr: SeqNr,                          // Sequence number
60    pub ack_nr: SeqNr,                          // Acknowledgment number
61    pub extensions: Extensions,
62}
63
64impl UtpHeader {
65    pub fn set_type(&mut self, packet_type: Type) {
66        // let packet_type = packet_type.to_number();
67        // self.type_ver = (self.type_ver & 0xF0) | (packet_type & 0x0F);
68        self.htype = packet_type;
69    }
70
71    pub fn get_type(&self) -> Type {
72        // Type::from_number(self.type_ver & 0x0F)
73        self.htype
74    }
75
76    pub fn short_repr(&self) -> impl std::fmt::Display + '_ {
77        struct D<'a>(&'a UtpHeader);
78        impl std::fmt::Display for D<'_> {
79            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80                write!(
81                    f,
82                    "{:?}:seq_nr={}:ack_nr={}:wnd_size={}",
83                    self.0.get_type(),
84                    self.0.seq_nr,
85                    self.0.ack_nr,
86                    self.0.wnd_size,
87                )
88            }
89        }
90        D(self)
91    }
92
93    pub fn serialize(&self, buffer: &mut [u8]) -> crate::Result<usize> {
94        if buffer.len() < UTP_HEADER as usize {
95            return Err(Error::SerializeTooSmallBuffer);
96        }
97        const VERSION: u8 = 1;
98        const NEXT_EXT_IDX: usize = 1;
99        let typever = (self.htype.to_number() << 4) | VERSION;
100        buffer[0] = typever;
101        buffer[NEXT_EXT_IDX] = NO_NEXT_EXT; // this will be overwritten down below if extensions are present.
102        buffer[2..4].copy_from_slice(&self.connection_id.to_be_bytes());
103        buffer[4..8].copy_from_slice(&self.timestamp_microseconds.to_be_bytes());
104        buffer[8..12].copy_from_slice(&self.timestamp_difference_microseconds.to_be_bytes());
105        buffer[12..16].copy_from_slice(&self.wnd_size.to_be_bytes());
106        buffer[16..18].copy_from_slice(&self.seq_nr.to_be_bytes());
107        buffer[18..20].copy_from_slice(&self.ack_nr.to_be_bytes());
108
109        let mut next_ext_pos = NEXT_EXT_IDX;
110        let mut offset = 20;
111
112        macro_rules! add_ext {
113            ($id:expr, $payload:expr) => {
114                let payload = $payload;
115                if buffer.len() >= offset + 2 + payload.len() {
116                    buffer[next_ext_pos] = $id;
117                    buffer[offset] = NO_NEXT_EXT;
118                    buffer[offset + 1] = payload.len() as u8;
119                    buffer[offset + 2..offset + 2 + payload.len()].copy_from_slice(payload);
120
121                    #[allow(unused)]
122                    {
123                        next_ext_pos = offset + 1;
124                    }
125                    offset += 2 + payload.len();
126                }
127            };
128        }
129
130        if let Some(sack) = self.extensions.selective_ack {
131            add_ext!(EXT_SELECTIVE_ACK, sack.as_bytes());
132        }
133        if let Some(close_reason) = self.extensions.close_reason {
134            add_ext!(EXT_CLOSE_REASON, &close_reason.as_bytes());
135        }
136
137        Ok(offset)
138    }
139
140    pub fn serialize_with_payload(
141        &self,
142        out_buf: &mut [u8],
143        payload_serialize: impl FnOnce(&mut [u8]) -> crate::Result<usize>,
144    ) -> crate::Result<usize> {
145        let sz = self.serialize(out_buf)?;
146        let payload_sz = payload_serialize(
147            out_buf
148                .get_mut(sz..)
149                .ok_or(Error::SerializeTooSmallBuffer)?,
150        )?;
151        Ok(sz + payload_sz)
152    }
153
154    pub fn deserialize(orig_buffer: &[u8]) -> Option<(Self, usize)> {
155        let mut buffer = orig_buffer;
156        if buffer.len() < UTP_HEADER as usize {
157            return None;
158        }
159        let mut header = UtpHeader::default();
160
161        let typenum = buffer[0] >> 4;
162        let version = buffer[0] & 0xf;
163        if version != 1 {
164            trace!(version, "wrong version");
165            return None;
166        }
167        header.htype = Type::from_number(typenum)?;
168        let mut next_ext = buffer[1];
169        header.connection_id = u16::from_be_bytes(buffer[2..4].try_into().unwrap()).into();
170        header.timestamp_microseconds = u32::from_be_bytes(buffer[4..8].try_into().unwrap());
171        header.timestamp_difference_microseconds =
172            u32::from_be_bytes(buffer[8..12].try_into().unwrap());
173        header.wnd_size = u32::from_be_bytes(buffer[12..16].try_into().unwrap());
174        header.seq_nr = u16::from_be_bytes(buffer[16..18].try_into().unwrap()).into();
175        header.ack_nr = u16::from_be_bytes(buffer[18..20].try_into().unwrap()).into();
176
177        buffer = &buffer[20..];
178
179        let mut total_ext_size = 0usize;
180
181        while next_ext > 0 {
182            total_ext_size += 2;
183            let ext = next_ext;
184            next_ext = *buffer.first()?;
185            let ext_len = *buffer.get(1)? as usize;
186
187            let ext_data = buffer.get(2..2 + ext_len)?;
188            match (ext, ext_len) {
189                (EXT_SELECTIVE_ACK, _) => {
190                    header.extensions.selective_ack =
191                        Some(selective_ack::SelectiveAck::deserialize(ext_data));
192                }
193                (EXT_CLOSE_REASON, 4) => {
194                    header.extensions.close_reason =
195                        Some(ext_close_reason::LibTorrentCloseReason::parse(
196                            ext_data.try_into().unwrap(),
197                        ));
198                }
199                _ => {
200                    trace!(
201                        ext,
202                        next_ext, ext_len, "unsupported extension for deserializing, skipping"
203                    );
204                }
205            }
206
207            total_ext_size += ext_len;
208            buffer = buffer.get(2 + ext_len..)?;
209        }
210
211        Some((header, 20 + total_ext_size))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use crate::{raw::Type, test_util::setup_test_logging};
218
219    use super::UtpHeader;
220
221    #[test]
222    fn test_parse_fin_with_extension() {
223        setup_test_logging();
224        let packet = include_bytes!("../test/resources/packet_fin_with_extension.bin");
225        let (header, len) = UtpHeader::deserialize(packet).unwrap();
226        assert_eq!(
227            header,
228            UtpHeader {
229                htype: Type::ST_FIN,
230                connection_id: 30796.into(),
231                timestamp_microseconds: 2293274188,
232                timestamp_difference_microseconds: 1967430273,
233                wnd_size: 1048576,
234                seq_nr: 54661.into(),
235                ack_nr: 54397.into(),
236                extensions: crate::raw::Extensions {
237                    close_reason: Some(crate::raw::ext_close_reason::LibTorrentCloseReason(15)),
238                    selective_ack: None
239                }
240            }
241        );
242        assert_eq!(len, packet.len());
243
244        let mut buf = [0u8; 1024];
245        let len = header.serialize(&mut buf).unwrap();
246        assert_eq!(&buf[..len], packet);
247    }
248}