indymilter/message/
mod.rs

1// indymilter – asynchronous milter library
2// Copyright © 2021–2024 David Bürgin <dbuergin@gluet.ch>
3//
4// This program is free software: you can redistribute it and/or modify it under
5// the terms of the GNU General Public License as published by the Free Software
6// Foundation, either version 3 of the License, or (at your option) any later
7// version.
8//
9// This program is distributed in the hope that it will be useful, but WITHOUT
10// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
12// details.
13//
14// You should have received a copy of the GNU General Public License along with
15// this program. If not, see <https://www.gnu.org/licenses/>.
16
17//! Milter protocol messages.
18//!
19//! This module contains low-level protocol helpers.
20
21pub mod command;
22pub mod reply;
23
24use bytes::Bytes;
25use std::{
26    ascii,
27    error::Error,
28    ffi::{CStr, CString},
29    fmt::{self, Debug, Display, Formatter, Write},
30    io::{self, ErrorKind},
31};
32use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
33
34/// The type of the milter protocol version.
35pub type Version = u32;
36
37/// Milter protocol version.
38pub const PROTOCOL_VERSION: Version = 6;
39
40/// A milter protocol message.
41#[derive(Clone, Eq, Hash, PartialEq)]
42pub struct Message {
43    /// The message kind.
44    pub kind: u8,
45    /// The message payload buffer.
46    pub buffer: Bytes,
47}
48
49impl Message {
50    // Limit length of message buffer somewhat arbitrarily to 1 MB. Else if
51    // someone erroneously sends a jumbo message we would spend a lot of time
52    // and memory reading/writing it.
53    //
54    // Support for the experimental, negotiable custom `Maxdatasize` feature of
55    // libmilter is not implemented.
56    const MAX_BUFFER_LEN: usize = 1024 * 1024 - 1;
57
58    /// Creates a new message with the given kind and payload buffer.
59    pub fn new(kind: impl Into<u8>, buffer: impl Into<Bytes>) -> Self {
60        Self {
61            kind: kind.into(),
62            buffer: buffer.into(),
63        }
64    }
65}
66
67impl Debug for Message {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        f.debug_struct("Message")
70            .field("kind", &Byte(self.kind))
71            .field("buffer", &self.buffer)
72            .finish()
73    }
74}
75
76pub(crate) struct Byte(pub u8);
77
78impl Debug for Byte {
79    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
80        f.write_str("b'")?;
81
82        match self.0 {
83            b'\0' => f.write_str("\\0")?,
84            b'"' => f.write_str("\"")?,
85            byte => {
86                for c in ascii::escape_default(byte) {
87                    f.write_char(c.into())?;
88                }
89            }
90        }
91
92        f.write_str("'")?;
93
94        Ok(())
95    }
96}
97
98/// An error that occurs when conversion from a wire format byte fails.
99#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
100pub struct TryFromByteError(u8);
101
102impl TryFromByteError {
103    pub(crate) fn new(byte: u8) -> Self {
104        Self(byte)
105    }
106
107    /// Returns the byte that caused the conversion error.
108    pub fn byte(&self) -> u8 {
109        self.0
110    }
111}
112
113impl Display for TryFromByteError {
114    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115        write!(f, "failed to convert byte {:?}", Byte(self.0))
116    }
117}
118
119impl Error for TryFromByteError {}
120
121/// An error that occurs when conversion from a wire format index fails.
122#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
123pub struct TryFromIndexError(i32);
124
125impl TryFromIndexError {
126    pub(crate) fn new(index: i32) -> Self {
127        Self(index)
128    }
129
130    /// Returns the index that caused the conversion error.
131    pub fn index(&self) -> i32 {
132        self.0
133    }
134}
135
136impl Display for TryFromIndexError {
137    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
138        write!(f, "failed to convert index {}", self.0)
139    }
140}
141
142impl Error for TryFromIndexError {}
143
144/// Reads a milter protocol message from a stream.
145pub async fn read<S>(stream: &mut S) -> io::Result<Message>
146where
147    S: AsyncRead + Unpin,
148{
149    let len = stream.read_u32().await?;
150    let kind = stream.read_u8().await?;
151
152    let len = usize::try_from(len)
153        .expect("unsupported pointer size")
154        .saturating_sub(1);
155
156    if len > Message::MAX_BUFFER_LEN {
157        return Err(ErrorKind::InvalidData.into());
158    }
159
160    let mut buffer = vec![0; len];
161    stream.read_exact(&mut buffer).await?;
162
163    Ok(Message::new(kind, buffer))
164}
165
166/// Writes a milter protocol message to a stream.
167pub async fn write<S>(stream: &mut S, msg: Message) -> io::Result<()>
168where
169    S: AsyncWrite + Unpin,
170{
171    let len = msg.buffer.len();
172
173    if len > Message::MAX_BUFFER_LEN {
174        return Err(ErrorKind::InvalidData.into());
175    }
176
177    let len = u32::try_from(len).unwrap().checked_add(1).unwrap();
178
179    stream.write_u32(len).await?;
180    stream.write_u8(msg.kind).await?;
181    stream.write_all(msg.buffer.as_ref()).await?;
182    stream.flush().await?;
183
184    Ok(())
185}
186
187struct NoCStringFoundError;
188
189fn get_c_string(buf: &mut Bytes) -> Result<CString, NoCStringFoundError> {
190    if let Some(i) = buf.iter().position(|&x| x == 0) {
191        let b = buf.split_to(i + 1);
192        return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
193    }
194    Err(NoCStringFoundError)
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn byte_debug_ok() {
203        assert_eq!(format!("{:?}", Byte(b'\0')), r"b'\0'");
204        assert_eq!(format!("{:?}", Byte(b'\n')), r"b'\n'");
205        assert_eq!(format!("{:?}", Byte(b'\'')), r"b'\''");
206        assert_eq!(format!("{:?}", Byte(b'"')), r#"b'"'"#);
207        assert_eq!(format!("{:?}", Byte(b'a')), r"b'a'");
208        assert_eq!(format!("{:?}", Byte(b' ')), r"b' '");
209        assert_eq!(format!("{:?}", Byte(b'\x07')), r"b'\x07'");
210        assert_eq!(format!("{:?}", Byte(b'\xef')), r"b'\xef'");
211    }
212
213    #[test]
214    fn message_debug_ok() {
215        assert_eq!(
216            format!("{:?}", Message::new(b'A', "abc\0")),
217            r#"Message { kind: b'A', buffer: b"abc\0" }"#
218        );
219        assert_eq!(
220            format!("{:?}", Message::new(b'\0', "")),
221            r#"Message { kind: b'\0', buffer: b"" }"#
222        );
223    }
224
225    #[test]
226    fn try_from_byte_error_debug() {
227        assert_eq!(
228            TryFromByteError(b'x').to_string(),
229            "failed to convert byte b'x'"
230        );
231    }
232
233    #[tokio::test]
234    async fn read_message() {
235        let mut stream = {
236            let (mut client, stream) = tokio::io::duplex(100);
237
238            client.write_all(b"\0\0\0\x02xyz").await.unwrap();
239
240            stream
241        };
242
243        let msg = read(&mut stream).await.unwrap();
244        assert_eq!(msg, Message::new(b'x', "y"));
245
246        let error = read(&mut stream).await.unwrap_err();
247        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
248    }
249
250    #[tokio::test]
251    async fn write_message() {
252        let mut client = {
253            let (client, mut stream) = tokio::io::duplex(100);
254
255            let msg = Message::new(b'x', "abc");
256            write(&mut stream, msg).await.unwrap();
257            stream.write_u8(1).await.unwrap();
258
259            client
260        };
261
262        let mut buffer = vec![0; 8];
263        client.read_exact(&mut buffer).await.unwrap();
264        assert_eq!(buffer, b"\0\0\0\x04xabc");
265
266        let byte = client.read_u8().await.unwrap();
267        assert_eq!(byte, 1);
268
269        let error = client.read_u8().await.unwrap_err();
270        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
271    }
272}