indymilter/message/
mod.rs1pub mod command;
22pub mod reply;
23
24use bytes::Bytes;
25use std::{
26 ascii,
27 error::Error,
28 ffi::{CStr, CString},
29 fmt::{self, Debug, Display, Formatter, Write},
30 io::{self, ErrorKind},
31};
32use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
33
34pub type Version = u32;
36
37pub const PROTOCOL_VERSION: Version = 6;
39
40#[derive(Clone, Eq, Hash, PartialEq)]
42pub struct Message {
43 pub kind: u8,
45 pub buffer: Bytes,
47}
48
49impl Message {
50 const MAX_BUFFER_LEN: usize = 1024 * 1024 - 1;
57
58 pub fn new(kind: impl Into<u8>, buffer: impl Into<Bytes>) -> Self {
60 Self {
61 kind: kind.into(),
62 buffer: buffer.into(),
63 }
64 }
65}
66
67impl Debug for Message {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 f.debug_struct("Message")
70 .field("kind", &Byte(self.kind))
71 .field("buffer", &self.buffer)
72 .finish()
73 }
74}
75
76pub(crate) struct Byte(pub u8);
77
78impl Debug for Byte {
79 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
80 f.write_str("b'")?;
81
82 match self.0 {
83 b'\0' => f.write_str("\\0")?,
84 b'"' => f.write_str("\"")?,
85 byte => {
86 for c in ascii::escape_default(byte) {
87 f.write_char(c.into())?;
88 }
89 }
90 }
91
92 f.write_str("'")?;
93
94 Ok(())
95 }
96}
97
98#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
100pub struct TryFromByteError(u8);
101
102impl TryFromByteError {
103 pub(crate) fn new(byte: u8) -> Self {
104 Self(byte)
105 }
106
107 pub fn byte(&self) -> u8 {
109 self.0
110 }
111}
112
113impl Display for TryFromByteError {
114 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115 write!(f, "failed to convert byte {:?}", Byte(self.0))
116 }
117}
118
119impl Error for TryFromByteError {}
120
121#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
123pub struct TryFromIndexError(i32);
124
125impl TryFromIndexError {
126 pub(crate) fn new(index: i32) -> Self {
127 Self(index)
128 }
129
130 pub fn index(&self) -> i32 {
132 self.0
133 }
134}
135
136impl Display for TryFromIndexError {
137 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
138 write!(f, "failed to convert index {}", self.0)
139 }
140}
141
142impl Error for TryFromIndexError {}
143
144pub async fn read<S>(stream: &mut S) -> io::Result<Message>
146where
147 S: AsyncRead + Unpin,
148{
149 let len = stream.read_u32().await?;
150 let kind = stream.read_u8().await?;
151
152 let len = usize::try_from(len)
153 .expect("unsupported pointer size")
154 .saturating_sub(1);
155
156 if len > Message::MAX_BUFFER_LEN {
157 return Err(ErrorKind::InvalidData.into());
158 }
159
160 let mut buffer = vec![0; len];
161 stream.read_exact(&mut buffer).await?;
162
163 Ok(Message::new(kind, buffer))
164}
165
166pub async fn write<S>(stream: &mut S, msg: Message) -> io::Result<()>
168where
169 S: AsyncWrite + Unpin,
170{
171 let len = msg.buffer.len();
172
173 if len > Message::MAX_BUFFER_LEN {
174 return Err(ErrorKind::InvalidData.into());
175 }
176
177 let len = u32::try_from(len).unwrap().checked_add(1).unwrap();
178
179 stream.write_u32(len).await?;
180 stream.write_u8(msg.kind).await?;
181 stream.write_all(msg.buffer.as_ref()).await?;
182 stream.flush().await?;
183
184 Ok(())
185}
186
187struct NoCStringFoundError;
188
189fn get_c_string(buf: &mut Bytes) -> Result<CString, NoCStringFoundError> {
190 if let Some(i) = buf.iter().position(|&x| x == 0) {
191 let b = buf.split_to(i + 1);
192 return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
193 }
194 Err(NoCStringFoundError)
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn byte_debug_ok() {
203 assert_eq!(format!("{:?}", Byte(b'\0')), r"b'\0'");
204 assert_eq!(format!("{:?}", Byte(b'\n')), r"b'\n'");
205 assert_eq!(format!("{:?}", Byte(b'\'')), r"b'\''");
206 assert_eq!(format!("{:?}", Byte(b'"')), r#"b'"'"#);
207 assert_eq!(format!("{:?}", Byte(b'a')), r"b'a'");
208 assert_eq!(format!("{:?}", Byte(b' ')), r"b' '");
209 assert_eq!(format!("{:?}", Byte(b'\x07')), r"b'\x07'");
210 assert_eq!(format!("{:?}", Byte(b'\xef')), r"b'\xef'");
211 }
212
213 #[test]
214 fn message_debug_ok() {
215 assert_eq!(
216 format!("{:?}", Message::new(b'A', "abc\0")),
217 r#"Message { kind: b'A', buffer: b"abc\0" }"#
218 );
219 assert_eq!(
220 format!("{:?}", Message::new(b'\0', "")),
221 r#"Message { kind: b'\0', buffer: b"" }"#
222 );
223 }
224
225 #[test]
226 fn try_from_byte_error_debug() {
227 assert_eq!(
228 TryFromByteError(b'x').to_string(),
229 "failed to convert byte b'x'"
230 );
231 }
232
233 #[tokio::test]
234 async fn read_message() {
235 let mut stream = {
236 let (mut client, stream) = tokio::io::duplex(100);
237
238 client.write_all(b"\0\0\0\x02xyz").await.unwrap();
239
240 stream
241 };
242
243 let msg = read(&mut stream).await.unwrap();
244 assert_eq!(msg, Message::new(b'x', "y"));
245
246 let error = read(&mut stream).await.unwrap_err();
247 assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
248 }
249
250 #[tokio::test]
251 async fn write_message() {
252 let mut client = {
253 let (client, mut stream) = tokio::io::duplex(100);
254
255 let msg = Message::new(b'x', "abc");
256 write(&mut stream, msg).await.unwrap();
257 stream.write_u8(1).await.unwrap();
258
259 client
260 };
261
262 let mut buffer = vec![0; 8];
263 client.read_exact(&mut buffer).await.unwrap();
264 assert_eq!(buffer, b"\0\0\0\x04xabc");
265
266 let byte = client.read_u8().await.unwrap();
267 assert_eq!(byte, 1);
268
269 let error = client.read_u8().await.unwrap_err();
270 assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
271 }
272}