miltr_common/commands/
connect.rs1use std::borrow::Cow;
2
3use bytes::{BufMut, BytesMut};
4use num_enum::{IntoPrimitive, TryFromPrimitive};
5
6use crate::decoding::Parsable;
7use crate::encoding::Writable;
8use crate::ProtocolError;
9use crate::{error::STAGE_DECODING, InvalidData, NotEnoughData};
10use miltr_utils::ByteParsing;
11
12#[allow(missing_docs)]
14#[derive(Copy, Clone, PartialEq, Debug, IntoPrimitive, TryFromPrimitive)]
15#[repr(u8)]
16pub enum Family {
17 Unknown = b'U',
18 Unix = b'L',
19 Inet = b'4',
20 Inet6 = b'6',
21}
22
23impl Family {
24 fn parse(buffer: &[u8]) -> Result<Self, ProtocolError> {
25 match Family::try_from(buffer[0]) {
26 Ok(f) => Ok(f),
27 Err(_) => Err(InvalidData {
28 msg: "Received unknown protocol family for connection info",
29 offending_bytes: BytesMut::from_iter(&[buffer[0]]),
30 }
31 .into()),
32 }
33 }
34}
35
36#[derive(Clone, PartialEq, Debug)]
38pub struct Connect {
39 hostname: BytesMut,
40 pub family: Family,
42 pub port: Option<u16>,
44 address: BytesMut,
45}
46
47impl Connect {
48 const CODE: u8 = b'C';
49 #[must_use]
51 pub fn new(hostname: &[u8], family: Family, port: Option<u16>, address: &[u8]) -> Self {
52 Self {
53 hostname: BytesMut::from_iter(hostname),
54 family,
55 port,
56 address: BytesMut::from_iter(address),
57 }
58 }
59 #[must_use]
61 pub fn hostname(&self) -> Cow<str> {
62 String::from_utf8_lossy(&self.hostname)
63 }
64
65 #[must_use]
69 pub fn address(&self) -> Cow<str> {
70 String::from_utf8_lossy(&self.address)
71 }
72}
73
74impl Parsable for Connect {
75 const CODE: u8 = Self::CODE;
76
77 fn parse(mut buffer: BytesMut) -> Result<Self, ProtocolError> {
78 let Some(hostname) = buffer.delimited(0) else {
79 return Err(InvalidData::new(
80 "Null-byte missing in connection package to delimit hostname",
81 buffer,
82 )
83 .into());
84 };
85
86 let Some(family) = buffer.safe_split_to(1) else {
87 return Err(NotEnoughData::new(
88 STAGE_DECODING,
89 "Connect",
90 "Family missing",
91 1,
92 2,
93 buffer,
94 )
95 .into());
96 };
97 let family = Family::parse(&family)?;
98
99 let port = {
100 match family {
101 Family::Inet | Family::Inet6 => {
102 let Some(buf) = buffer.safe_split_to(2) else {
103 return Err(NotEnoughData::new(
104 STAGE_DECODING,
105 "Connect",
106 "Port missing",
107 2,
108 buffer.len(),
109 buffer,
110 )
111 .into());
112 };
113 let mut raw: [u8; 2] = [0; 2];
114 raw.copy_from_slice(&buf);
115
116 Some(u16::from_be_bytes(raw))
117 }
118 _ => None,
119 }
120 };
121
122 let address;
123 if let Some(b'\0') = buffer.last() {
124 address = buffer.split_to(buffer.len() - 1);
125 } else {
126 address = buffer;
127 }
128
129 let connect = Connect {
130 hostname,
131 family,
132 port,
133 address,
134 };
135
136 Ok(connect)
137 }
138}
139
140impl Writable for Connect {
141 fn write(&self, buffer: &mut BytesMut) {
142 buffer.extend_from_slice(&self.hostname);
143 buffer.put_u8(0);
144
145 buffer.put_u8(self.family.into());
146
147 buffer.put_u16(self.port.unwrap_or_default());
148
149 buffer.extend_from_slice(&self.address);
150 buffer.put_u8(0);
151 }
152
153 fn len(&self) -> usize {
154 self.hostname.len() + 1 + 1 + 2 + self.address.len() + 1
155 }
156
157 fn code(&self) -> u8 {
158 Self::CODE
159 }
160
161 fn is_empty(&self) -> bool {
162 false
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::Family;
169 use crate::{commands::Connect, decoding::Parsable};
170 use bytes::BytesMut;
171 use pretty_assertions::assert_eq;
172
173 fn initialize() -> BytesMut {
174 let hostname = b"localhost";
175 let family = b'4';
176 let port = 1234u16.to_be_bytes();
177 let address = b"127.0.0.1";
178
179 let mut read_buffer = Vec::new();
180 read_buffer.extend(hostname);
181 read_buffer.push(0);
182 read_buffer.push(family);
183 read_buffer.extend(port);
184 read_buffer.extend(address);
185 read_buffer.push(0);
186
187 BytesMut::from_iter(read_buffer)
188 }
189
190 #[tokio::test]
191 async fn test_create_connect() {
192 let connect = Connect::parse(initialize()).expect("Failed parsing connect");
193
194 assert_eq!(b"localhost", connect.hostname.to_vec().as_slice());
195 assert_eq!(Family::Inet, connect.family);
196 assert_eq!(Some(1234), connect.port);
197 assert_eq!(b"127.0.0.1", connect.address.to_vec().as_slice());
198 }
199
200 #[cfg(feature = "count-allocations")]
201 #[test]
202 fn test_parse_connect() {
203 let buffer = initialize();
204
205 let info = allocation_counter::measure(|| {
206 let res = Connect::parse(buffer);
207 allocation_counter::opt_out(|| {
208 println!("{res:?}");
209 assert!(res.is_ok());
210 });
211 });
212
213 println!("{}", &info.count_total);
214 assert_eq!(info.count_total, 1);
216 }
217}