1pub mod ext_close_reason;
2pub mod selective_ack;
3
4use tracing::trace;
5
6use crate::{Error, constants::UTP_HEADER, seq_nr::SeqNr};
7
8const NO_NEXT_EXT: u8 = 0;
9const EXT_SELECTIVE_ACK: u8 = 1;
10const EXT_CLOSE_REASON: u8 = 3;
11
12#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
13#[allow(non_camel_case_types)]
14pub enum Type {
15 ST_DATA = 0,
16 ST_FIN = 1,
17 #[default]
18 ST_STATE = 2,
19 ST_RESET = 3,
20 ST_SYN = 4,
21}
22
23impl Type {
24 fn from_number(num: u8) -> Option<Type> {
25 match num {
26 0 => Some(Type::ST_DATA),
27 1 => Some(Type::ST_FIN),
28 2 => Some(Type::ST_STATE),
29 3 => Some(Type::ST_RESET),
30 4 => Some(Type::ST_SYN),
31 _ => None,
32 }
33 }
34
35 fn to_number(self) -> u8 {
36 match self {
37 Type::ST_DATA => 0,
38 Type::ST_FIN => 1,
39 Type::ST_STATE => 2,
40 Type::ST_RESET => 3,
41 Type::ST_SYN => 4,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
47pub struct Extensions {
48 pub selective_ack: Option<selective_ack::SelectiveAck>,
49 pub close_reason: Option<ext_close_reason::LibTorrentCloseReason>,
50}
51
52#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub struct UtpHeader {
54 pub htype: Type, pub connection_id: SeqNr, pub timestamp_microseconds: u32, pub timestamp_difference_microseconds: u32, pub wnd_size: u32, pub seq_nr: SeqNr, pub ack_nr: SeqNr, pub extensions: Extensions,
62}
63
64impl UtpHeader {
65 pub fn set_type(&mut self, packet_type: Type) {
66 self.htype = packet_type;
69 }
70
71 pub fn get_type(&self) -> Type {
72 self.htype
74 }
75
76 pub fn short_repr(&self) -> impl std::fmt::Display + '_ {
77 struct D<'a>(&'a UtpHeader);
78 impl std::fmt::Display for D<'_> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(
81 f,
82 "{:?}:seq_nr={}:ack_nr={}:wnd_size={}",
83 self.0.get_type(),
84 self.0.seq_nr,
85 self.0.ack_nr,
86 self.0.wnd_size,
87 )
88 }
89 }
90 D(self)
91 }
92
93 pub fn serialize(&self, buffer: &mut [u8]) -> crate::Result<usize> {
94 if buffer.len() < UTP_HEADER as usize {
95 return Err(Error::SerializeTooSmallBuffer);
96 }
97 const VERSION: u8 = 1;
98 const NEXT_EXT_IDX: usize = 1;
99 let typever = (self.htype.to_number() << 4) | VERSION;
100 buffer[0] = typever;
101 buffer[NEXT_EXT_IDX] = NO_NEXT_EXT; buffer[2..4].copy_from_slice(&self.connection_id.to_be_bytes());
103 buffer[4..8].copy_from_slice(&self.timestamp_microseconds.to_be_bytes());
104 buffer[8..12].copy_from_slice(&self.timestamp_difference_microseconds.to_be_bytes());
105 buffer[12..16].copy_from_slice(&self.wnd_size.to_be_bytes());
106 buffer[16..18].copy_from_slice(&self.seq_nr.to_be_bytes());
107 buffer[18..20].copy_from_slice(&self.ack_nr.to_be_bytes());
108
109 let mut next_ext_pos = NEXT_EXT_IDX;
110 let mut offset = 20;
111
112 macro_rules! add_ext {
113 ($id:expr, $payload:expr) => {
114 let payload = $payload;
115 if buffer.len() >= offset + 2 + payload.len() {
116 buffer[next_ext_pos] = $id;
117 buffer[offset] = NO_NEXT_EXT;
118 buffer[offset + 1] = payload.len() as u8;
119 buffer[offset + 2..offset + 2 + payload.len()].copy_from_slice(payload);
120
121 #[allow(unused)]
122 {
123 next_ext_pos = offset + 1;
124 }
125 offset += 2 + payload.len();
126 }
127 };
128 }
129
130 if let Some(sack) = self.extensions.selective_ack {
131 add_ext!(EXT_SELECTIVE_ACK, sack.as_bytes());
132 }
133 if let Some(close_reason) = self.extensions.close_reason {
134 add_ext!(EXT_CLOSE_REASON, &close_reason.as_bytes());
135 }
136
137 Ok(offset)
138 }
139
140 pub fn serialize_with_payload(
141 &self,
142 out_buf: &mut [u8],
143 payload_serialize: impl FnOnce(&mut [u8]) -> crate::Result<usize>,
144 ) -> crate::Result<usize> {
145 let sz = self.serialize(out_buf)?;
146 let payload_sz = payload_serialize(
147 out_buf
148 .get_mut(sz..)
149 .ok_or(Error::SerializeTooSmallBuffer)?,
150 )?;
151 Ok(sz + payload_sz)
152 }
153
154 pub fn deserialize(orig_buffer: &[u8]) -> Option<(Self, usize)> {
155 let mut buffer = orig_buffer;
156 if buffer.len() < UTP_HEADER as usize {
157 return None;
158 }
159 let mut header = UtpHeader::default();
160
161 let typenum = buffer[0] >> 4;
162 let version = buffer[0] & 0xf;
163 if version != 1 {
164 trace!(version, "wrong version");
165 return None;
166 }
167 header.htype = Type::from_number(typenum)?;
168 let mut next_ext = buffer[1];
169 header.connection_id = u16::from_be_bytes(buffer[2..4].try_into().unwrap()).into();
170 header.timestamp_microseconds = u32::from_be_bytes(buffer[4..8].try_into().unwrap());
171 header.timestamp_difference_microseconds =
172 u32::from_be_bytes(buffer[8..12].try_into().unwrap());
173 header.wnd_size = u32::from_be_bytes(buffer[12..16].try_into().unwrap());
174 header.seq_nr = u16::from_be_bytes(buffer[16..18].try_into().unwrap()).into();
175 header.ack_nr = u16::from_be_bytes(buffer[18..20].try_into().unwrap()).into();
176
177 buffer = &buffer[20..];
178
179 let mut total_ext_size = 0usize;
180
181 while next_ext > 0 {
182 total_ext_size += 2;
183 let ext = next_ext;
184 next_ext = *buffer.first()?;
185 let ext_len = *buffer.get(1)? as usize;
186
187 let ext_data = buffer.get(2..2 + ext_len)?;
188 match (ext, ext_len) {
189 (EXT_SELECTIVE_ACK, _) => {
190 header.extensions.selective_ack =
191 Some(selective_ack::SelectiveAck::deserialize(ext_data));
192 }
193 (EXT_CLOSE_REASON, 4) => {
194 header.extensions.close_reason =
195 Some(ext_close_reason::LibTorrentCloseReason::parse(
196 ext_data.try_into().unwrap(),
197 ));
198 }
199 _ => {
200 trace!(
201 ext,
202 next_ext, ext_len, "unsupported extension for deserializing, skipping"
203 );
204 }
205 }
206
207 total_ext_size += ext_len;
208 buffer = buffer.get(2 + ext_len..)?;
209 }
210
211 Some((header, 20 + total_ext_size))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use crate::{raw::Type, test_util::setup_test_logging};
218
219 use super::UtpHeader;
220
221 #[test]
222 fn test_parse_fin_with_extension() {
223 setup_test_logging();
224 let packet = include_bytes!("../test/resources/packet_fin_with_extension.bin");
225 let (header, len) = UtpHeader::deserialize(packet).unwrap();
226 assert_eq!(
227 header,
228 UtpHeader {
229 htype: Type::ST_FIN,
230 connection_id: 30796.into(),
231 timestamp_microseconds: 2293274188,
232 timestamp_difference_microseconds: 1967430273,
233 wnd_size: 1048576,
234 seq_nr: 54661.into(),
235 ack_nr: 54397.into(),
236 extensions: crate::raw::Extensions {
237 close_reason: Some(crate::raw::ext_close_reason::LibTorrentCloseReason(15)),
238 selective_ack: None
239 }
240 }
241 );
242 assert_eq!(len, packet.len());
243
244 let mut buf = [0u8; 1024];
245 let len = header.serialize(&mut buf).unwrap();
246 assert_eq!(&buf[..len], packet);
247 }
248}