1use bytes::{Buf, BufMut, Bytes, BytesMut};
20use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
21
22use crate::VarInt;
23use crate::coding::{self, Codec};
24
25pub const CAPSULE_COMPRESSION_ASSIGN: u64 = 0x11;
27
28pub const CAPSULE_COMPRESSION_ACK: u64 = 0x12;
30
31pub const CAPSULE_COMPRESSION_CLOSE: u64 = 0x13;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct CompressionAssign {
45 pub context_id: VarInt,
47 pub ip_version: u8,
49 pub ip_address: Option<IpAddr>,
51 pub udp_port: Option<u16>,
53}
54
55impl CompressionAssign {
56 pub fn uncompressed(context_id: VarInt) -> Self {
62 Self {
63 context_id,
64 ip_version: 0,
65 ip_address: None,
66 udp_port: None,
67 }
68 }
69
70 pub fn compressed_v4(context_id: VarInt, addr: Ipv4Addr, port: u16) -> Self {
75 Self {
76 context_id,
77 ip_version: 4,
78 ip_address: Some(IpAddr::V4(addr)),
79 udp_port: Some(port),
80 }
81 }
82
83 pub fn compressed_v6(context_id: VarInt, addr: Ipv6Addr, port: u16) -> Self {
88 Self {
89 context_id,
90 ip_version: 6,
91 ip_address: Some(IpAddr::V6(addr)),
92 udp_port: Some(port),
93 }
94 }
95
96 pub fn is_uncompressed(&self) -> bool {
98 self.ip_version == 0
99 }
100
101 pub fn target(&self) -> Option<std::net::SocketAddr> {
103 match (self.ip_address, self.udp_port) {
104 (Some(ip), Some(port)) => Some(std::net::SocketAddr::new(ip, port)),
105 _ => None,
106 }
107 }
108}
109
110impl Codec for CompressionAssign {
111 fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
112 let context_id = VarInt::decode(buf)?;
113
114 if buf.remaining() < 1 {
115 return Err(coding::UnexpectedEnd);
116 }
117 let ip_version = buf.get_u8();
118
119 let (ip_address, udp_port) = if ip_version == 0 {
120 (None, None)
121 } else {
122 let ip = match ip_version {
123 4 => {
124 if buf.remaining() < 4 {
125 return Err(coding::UnexpectedEnd);
126 }
127 let mut octets = [0u8; 4];
128 buf.copy_to_slice(&mut octets);
129 IpAddr::V4(Ipv4Addr::from(octets))
130 }
131 6 => {
132 if buf.remaining() < 16 {
133 return Err(coding::UnexpectedEnd);
134 }
135 let mut octets = [0u8; 16];
136 buf.copy_to_slice(&mut octets);
137 IpAddr::V6(Ipv6Addr::from(octets))
138 }
139 _ => return Err(coding::UnexpectedEnd),
140 };
141
142 if buf.remaining() < 2 {
143 return Err(coding::UnexpectedEnd);
144 }
145 let port = buf.get_u16();
146
147 (Some(ip), Some(port))
148 };
149
150 Ok(Self {
151 context_id,
152 ip_version,
153 ip_address,
154 udp_port,
155 })
156 }
157
158 fn encode<B: BufMut>(&self, buf: &mut B) {
159 self.context_id.encode(buf);
160 buf.put_u8(self.ip_version);
161
162 if let (Some(ip), Some(port)) = (&self.ip_address, self.udp_port) {
163 match ip {
164 IpAddr::V4(v4) => buf.put_slice(&v4.octets()),
165 IpAddr::V6(v6) => buf.put_slice(&v6.octets()),
166 }
167 buf.put_u16(port);
168 }
169 }
170}
171
172#[derive(Debug, Clone, PartialEq, Eq)]
177pub struct CompressionAck {
178 pub context_id: VarInt,
180}
181
182impl CompressionAck {
183 pub fn new(context_id: VarInt) -> Self {
185 Self { context_id }
186 }
187}
188
189impl Codec for CompressionAck {
190 fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
191 let context_id = VarInt::decode(buf)?;
192 Ok(Self { context_id })
193 }
194
195 fn encode<B: BufMut>(&self, buf: &mut B) {
196 self.context_id.encode(buf);
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq)]
206pub struct CompressionClose {
207 pub context_id: VarInt,
209}
210
211impl CompressionClose {
212 pub fn new(context_id: VarInt) -> Self {
214 Self { context_id }
215 }
216}
217
218impl Codec for CompressionClose {
219 fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
220 let context_id = VarInt::decode(buf)?;
221 Ok(Self { context_id })
222 }
223
224 fn encode<B: BufMut>(&self, buf: &mut B) {
225 self.context_id.encode(buf);
226 }
227}
228
229#[derive(Debug, Clone)]
234pub enum Capsule {
235 CompressionAssign(CompressionAssign),
237 CompressionAck(CompressionAck),
239 CompressionClose(CompressionClose),
241 Unknown {
243 capsule_type: VarInt,
245 data: Vec<u8>,
247 },
248}
249
250impl Capsule {
251 pub fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
256 let capsule_type = VarInt::decode(buf)?;
257 let length = VarInt::decode(buf)?;
258 let length_usize = length.into_inner() as usize;
259
260 if buf.remaining() < length_usize {
261 return Err(coding::UnexpectedEnd);
262 }
263
264 match capsule_type.into_inner() {
265 CAPSULE_COMPRESSION_ASSIGN => {
266 let capsule = CompressionAssign::decode(buf)?;
267 Ok(Capsule::CompressionAssign(capsule))
268 }
269 CAPSULE_COMPRESSION_ACK => {
270 let capsule = CompressionAck::decode(buf)?;
271 Ok(Capsule::CompressionAck(capsule))
272 }
273 CAPSULE_COMPRESSION_CLOSE => {
274 let capsule = CompressionClose::decode(buf)?;
275 Ok(Capsule::CompressionClose(capsule))
276 }
277 _ => {
278 let mut data = vec![0u8; length_usize];
279 buf.copy_to_slice(&mut data);
280 Ok(Capsule::Unknown { capsule_type, data })
281 }
282 }
283 }
284
285 pub fn encode(&self) -> Bytes {
289 let mut buf = BytesMut::new();
290 let mut payload = BytesMut::new();
291
292 let capsule_type = match self {
293 Capsule::CompressionAssign(c) => {
294 c.encode(&mut payload);
295 CAPSULE_COMPRESSION_ASSIGN
296 }
297 Capsule::CompressionAck(c) => {
298 c.encode(&mut payload);
299 CAPSULE_COMPRESSION_ACK
300 }
301 Capsule::CompressionClose(c) => {
302 c.encode(&mut payload);
303 CAPSULE_COMPRESSION_CLOSE
304 }
305 Capsule::Unknown { capsule_type, data } => {
306 payload.put_slice(data);
307 capsule_type.into_inner()
308 }
309 };
310
311 if let Ok(ct) = VarInt::from_u64(capsule_type) {
313 ct.encode(&mut buf);
314 }
315
316 if let Ok(len) = VarInt::from_u64(payload.len() as u64) {
318 len.encode(&mut buf);
319 }
320
321 buf.put(payload);
323
324 buf.freeze()
325 }
326
327 pub fn capsule_type(&self) -> u64 {
329 match self {
330 Capsule::CompressionAssign(_) => CAPSULE_COMPRESSION_ASSIGN,
331 Capsule::CompressionAck(_) => CAPSULE_COMPRESSION_ACK,
332 Capsule::CompressionClose(_) => CAPSULE_COMPRESSION_CLOSE,
333 Capsule::Unknown { capsule_type, .. } => capsule_type.into_inner(),
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_compression_assign_uncompressed_roundtrip() {
344 let original = CompressionAssign::uncompressed(VarInt::from_u32(2));
345 let mut buf = BytesMut::new();
346 original.encode(&mut buf);
347
348 let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
349 assert_eq!(original, decoded);
350 assert!(decoded.is_uncompressed());
351 assert!(decoded.target().is_none());
352 }
353
354 #[test]
355 fn test_compression_assign_ipv4_roundtrip() {
356 let addr = Ipv4Addr::new(192, 168, 1, 100);
357 let original = CompressionAssign::compressed_v4(VarInt::from_u32(4), addr, 8080);
358 let mut buf = BytesMut::new();
359 original.encode(&mut buf);
360
361 let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
362 assert_eq!(original, decoded);
363 assert!(!decoded.is_uncompressed());
364 assert_eq!(
365 decoded.target(),
366 Some(std::net::SocketAddr::new(IpAddr::V4(addr), 8080))
367 );
368 }
369
370 #[test]
371 fn test_compression_assign_ipv6_roundtrip() {
372 let addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
373 let original = CompressionAssign::compressed_v6(VarInt::from_u32(6), addr, 443);
374 let mut buf = BytesMut::new();
375 original.encode(&mut buf);
376
377 let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
378 assert_eq!(original, decoded);
379 assert_eq!(decoded.ip_version, 6);
380 }
381
382 #[test]
383 fn test_compression_ack_roundtrip() {
384 let original = CompressionAck::new(VarInt::from_u32(42));
385 let mut buf = BytesMut::new();
386 original.encode(&mut buf);
387
388 let decoded = CompressionAck::decode(&mut buf.freeze()).unwrap();
389 assert_eq!(original, decoded);
390 }
391
392 #[test]
393 fn test_compression_close_roundtrip() {
394 let original = CompressionClose::new(VarInt::from_u32(99));
395 let mut buf = BytesMut::new();
396 original.encode(&mut buf);
397
398 let decoded = CompressionClose::decode(&mut buf.freeze()).unwrap();
399 assert_eq!(original, decoded);
400 }
401
402 #[test]
403 fn test_capsule_wrapper_encoding() {
404 let assign =
405 CompressionAssign::compressed_v4(VarInt::from_u32(2), Ipv4Addr::new(10, 0, 0, 1), 9000);
406 let capsule = Capsule::CompressionAssign(assign.clone());
407
408 let encoded = capsule.encode();
409 let mut buf = encoded;
410 let decoded = Capsule::decode(&mut buf).unwrap();
411
412 match decoded {
413 Capsule::CompressionAssign(c) => assert_eq!(c, assign),
414 _ => panic!("Expected CompressionAssign capsule"),
415 }
416 }
417
418 #[test]
419 fn test_capsule_type_identifiers() {
420 assert_eq!(
421 Capsule::CompressionAssign(CompressionAssign::uncompressed(VarInt::from_u32(1)))
422 .capsule_type(),
423 CAPSULE_COMPRESSION_ASSIGN
424 );
425 assert_eq!(
426 Capsule::CompressionAck(CompressionAck::new(VarInt::from_u32(1))).capsule_type(),
427 CAPSULE_COMPRESSION_ACK
428 );
429 assert_eq!(
430 Capsule::CompressionClose(CompressionClose::new(VarInt::from_u32(1))).capsule_type(),
431 CAPSULE_COMPRESSION_CLOSE
432 );
433 }
434}