miltr_common/commands/
connect.rs

1use std::borrow::Cow;
2
3use bytes::{BufMut, BytesMut};
4use num_enum::{IntoPrimitive, TryFromPrimitive};
5
6use crate::decoding::Parsable;
7use crate::encoding::Writable;
8use crate::ProtocolError;
9use crate::{error::STAGE_DECODING, InvalidData, NotEnoughData};
10use miltr_utils::ByteParsing;
11
12/// A marker for the connection family
13#[allow(missing_docs)]
14#[derive(Copy, Clone, PartialEq, Debug, IntoPrimitive, TryFromPrimitive)]
15#[repr(u8)]
16pub enum Family {
17    Unknown = b'U',
18    Unix = b'L',
19    Inet = b'4',
20    Inet6 = b'6',
21}
22
23impl Family {
24    fn parse(buffer: &[u8]) -> Result<Self, ProtocolError> {
25        match Family::try_from(buffer[0]) {
26            Ok(f) => Ok(f),
27            Err(_) => Err(InvalidData {
28                msg: "Received unknown protocol family for connection info",
29                offending_bytes: BytesMut::from_iter(&[buffer[0]]),
30            }
31            .into()),
32        }
33    }
34}
35
36/// Connect information about the smtp client
37#[derive(Clone, PartialEq, Debug)]
38pub struct Connect {
39    hostname: BytesMut,
40    /// The connection type connected to the milter client
41    pub family: Family,
42    /// On an IP connection, the port of the connection
43    pub port: Option<u16>,
44    address: BytesMut,
45}
46
47impl Connect {
48    const CODE: u8 = b'C';
49    /// Create a new connect package
50    #[must_use]
51    pub fn new(hostname: &[u8], family: Family, port: Option<u16>, address: &[u8]) -> Self {
52        Self {
53            hostname: BytesMut::from_iter(hostname),
54            family,
55            port,
56            address: BytesMut::from_iter(address),
57        }
58    }
59    /// Get the received hostname as as string-like type.
60    #[must_use]
61    pub fn hostname(&self) -> Cow<str> {
62        String::from_utf8_lossy(&self.hostname)
63    }
64
65    /// Get the received address as a string-like type.
66    ///
67    /// Remember, this can contain an IP-Address or a unix socket.
68    #[must_use]
69    pub fn address(&self) -> Cow<str> {
70        String::from_utf8_lossy(&self.address)
71    }
72}
73
74impl Parsable for Connect {
75    const CODE: u8 = Self::CODE;
76
77    fn parse(mut buffer: BytesMut) -> Result<Self, ProtocolError> {
78        let Some(hostname) = buffer.delimited(0) else {
79            return Err(InvalidData::new(
80                "Null-byte missing in connection package to delimit hostname",
81                buffer,
82            )
83            .into());
84        };
85
86        let Some(family) = buffer.safe_split_to(1) else {
87            return Err(NotEnoughData::new(
88                STAGE_DECODING,
89                "Connect",
90                "Family missing",
91                1,
92                2,
93                buffer,
94            )
95            .into());
96        };
97        let family = Family::parse(&family)?;
98
99        let port = {
100            match family {
101                Family::Inet | Family::Inet6 => {
102                    let Some(buf) = buffer.safe_split_to(2) else {
103                        return Err(NotEnoughData::new(
104                            STAGE_DECODING,
105                            "Connect",
106                            "Port missing",
107                            2,
108                            buffer.len(),
109                            buffer,
110                        )
111                        .into());
112                    };
113                    let mut raw: [u8; 2] = [0; 2];
114                    raw.copy_from_slice(&buf);
115
116                    Some(u16::from_be_bytes(raw))
117                }
118                _ => None,
119            }
120        };
121
122        let address;
123        if let Some(b'\0') = buffer.last() {
124            address = buffer.split_to(buffer.len() - 1);
125        } else {
126            address = buffer;
127        }
128
129        let connect = Connect {
130            hostname,
131            family,
132            port,
133            address,
134        };
135
136        Ok(connect)
137    }
138}
139
140impl Writable for Connect {
141    fn write(&self, buffer: &mut BytesMut) {
142        buffer.extend_from_slice(&self.hostname);
143        buffer.put_u8(0);
144
145        buffer.put_u8(self.family.into());
146
147        buffer.put_u16(self.port.unwrap_or_default());
148
149        buffer.extend_from_slice(&self.address);
150        buffer.put_u8(0);
151    }
152
153    fn len(&self) -> usize {
154        self.hostname.len() + 1 + 1 + 2 + self.address.len() + 1
155    }
156
157    fn code(&self) -> u8 {
158        Self::CODE
159    }
160
161    fn is_empty(&self) -> bool {
162        false
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::Family;
169    use crate::{commands::Connect, decoding::Parsable};
170    use bytes::BytesMut;
171    use pretty_assertions::assert_eq;
172
173    fn initialize() -> BytesMut {
174        let hostname = b"localhost";
175        let family = b'4';
176        let port = 1234u16.to_be_bytes();
177        let address = b"127.0.0.1";
178
179        let mut read_buffer = Vec::new();
180        read_buffer.extend(hostname);
181        read_buffer.push(0);
182        read_buffer.push(family);
183        read_buffer.extend(port);
184        read_buffer.extend(address);
185        read_buffer.push(0);
186
187        BytesMut::from_iter(read_buffer)
188    }
189
190    #[tokio::test]
191    async fn test_create_connect() {
192        let connect = Connect::parse(initialize()).expect("Failed parsing connect");
193
194        assert_eq!(b"localhost", connect.hostname.to_vec().as_slice());
195        assert_eq!(Family::Inet, connect.family);
196        assert_eq!(Some(1234), connect.port);
197        assert_eq!(b"127.0.0.1", connect.address.to_vec().as_slice());
198    }
199
200    #[cfg(feature = "count-allocations")]
201    #[test]
202    fn test_parse_connect() {
203        let buffer = initialize();
204
205        let info = allocation_counter::measure(|| {
206            let res = Connect::parse(buffer);
207            allocation_counter::opt_out(|| {
208                println!("{res:?}");
209                assert!(res.is_ok());
210            });
211        });
212
213        println!("{}", &info.count_total);
214        //4 allocations
215        assert_eq!(info.count_total, 1);
216    }
217}