1use std::io::{BufRead, BufReader, Read, Write};
6
7#[derive(Debug)]
9pub enum PostfixPolicyError<ErrorType> {
10 IoError(std::io::Error),
12 ProtocolError(Vec<u8>),
14 HandlerError(ErrorType),
18}
19
20impl<ErrorType> std::convert::From<std::io::Error> for PostfixPolicyError<ErrorType> {
21 fn from(e: std::io::Error) -> Self {
22 PostfixPolicyError::IoError(e)
23 }
24}
25
26#[derive(Debug, PartialEq)]
30pub enum PolicyResponse {
31 Ok,
32 Reject(Vec<u8>),
33 Defer(Vec<u8>),
34 DeferIfReject(Vec<u8>),
35 DeferIfPermit(Vec<u8>),
36 Bcc(Vec<u8>),
37 Discard(Vec<u8>),
38 Dunno,
39 Hold(Vec<u8>),
40 Redirect(Vec<u8>),
41 Info(Vec<u8>),
42 Warn(Vec<u8>),
43}
44
45pub trait PolicyRequestHandler<'l, ContextType, ErrorType> {
52 fn new(ctx: &'l ContextType) -> Self;
54 fn attribute(&mut self, name: &[u8], value: &[u8]) -> Option<ErrorType>;
60 fn response(self) -> Result<PolicyResponse, ErrorType>;
66}
67
68fn serialize_response(resp: PolicyResponse) -> Vec<u8> {
69 let mut message = Vec::new();
70 let action: &[u8] = match resp {
71 PolicyResponse::Ok => b"OK",
72 PolicyResponse::Reject(msg) => {
73 message = msg;
74 b"REJECT"
75 }
76 PolicyResponse::Defer(msg) => {
77 message = msg;
78 b"DEFER"
79 }
80 PolicyResponse::DeferIfReject(msg) => {
81 message = msg;
82 b"DEFER_IF_REJECT"
83 }
84 PolicyResponse::DeferIfPermit(msg) => {
85 message = msg;
86 b"DEFER_IF_PERMIT"
87 }
88 PolicyResponse::Bcc(email) => {
89 message = email;
90 b"BCC"
91 }
92 PolicyResponse::Discard(msg) => {
93 message = msg;
94 b"DISCARD"
95 }
96 PolicyResponse::Dunno => b"DUNNO",
97 PolicyResponse::Hold(msg) => {
98 message = msg;
99 b"HOLD"
100 }
101 PolicyResponse::Redirect(dst) => {
102 message = dst;
103 b"REDIRECT"
104 }
105 PolicyResponse::Info(msg) => {
106 message = msg;
107 b"INFO"
108 }
109 PolicyResponse::Warn(msg) => {
110 message = msg;
111 b"WARN"
112 }
113 };
114 let mut resp = Vec::from(action);
115 if !message.is_empty() {
116 resp.push(b' ');
117 resp.extend_from_slice(&message);
118 }
119 resp
120}
121
122#[test]
123fn test_serialize_response() {
124 assert_eq!(b"OK"[..], serialize_response(PolicyResponse::Ok)[..]);
125 assert_eq!(
126 b"REJECT"[..],
127 serialize_response(PolicyResponse::Reject(Vec::new()))[..]
128 );
129 assert_eq!(
130 b"REJECT asdf"[..],
131 serialize_response(PolicyResponse::Reject(b"asdf".to_vec()))[..]
132 );
133 assert_eq!(b"DEFER"[..], serialize_response(PolicyResponse::Defer(Vec::new()))[..]);
134 assert_eq!(
135 b"DEFER fdas"[..],
136 serialize_response(PolicyResponse::Defer(b"fdas".to_vec()))[..]
137 );
138 assert_eq!(
139 b"DEFER_IF_REJECT"[..],
140 serialize_response(PolicyResponse::DeferIfReject(Vec::new()))[..]
141 );
142 assert_eq!(
143 b"DEFER_IF_REJECT blblblbl"[..],
144 serialize_response(PolicyResponse::DeferIfReject(b"blblblbl".to_vec()))[..]
145 );
146 assert_eq!(
147 b"DEFER_IF_PERMIT"[..],
148 serialize_response(PolicyResponse::DeferIfPermit(Vec::new()))[..]
149 );
150 assert_eq!(
151 b"DEFER_IF_PERMIT gsdk jf"[..],
152 serialize_response(PolicyResponse::DeferIfPermit(b"gsdk jf".to_vec()))[..]
153 );
154 assert_eq!(
155 b"BCC a@b.c"[..],
156 serialize_response(PolicyResponse::Bcc(b"a@b.c".to_vec()))[..]
157 );
158 assert_eq!(
159 b"DISCARD"[..],
160 serialize_response(PolicyResponse::Discard(Vec::new()))[..]
161 );
162 assert_eq!(
163 b"DISCARD asdffdas"[..],
164 serialize_response(PolicyResponse::Discard(b"asdffdas".to_vec()))[..]
165 );
166 assert_eq!(b"DUNNO"[..], serialize_response(PolicyResponse::Dunno)[..]);
167 assert_eq!(b"HOLD"[..], serialize_response(PolicyResponse::Hold(Vec::new()))[..]);
168 assert_eq!(
169 b"HOLD cmn,sd"[..],
170 serialize_response(PolicyResponse::Hold(b"cmn,sd".to_vec()))[..]
171 );
172 assert_eq!(
173 b"REDIRECT a@b.c"[..],
174 serialize_response(PolicyResponse::Redirect(b"a@b.c".to_vec()))[..]
175 );
176 assert_eq!(
177 b"INFO some message trololol"[..],
178 serialize_response(PolicyResponse::Info(b"some message trololol".to_vec()))[..]
179 );
180 assert_eq!(
181 b"WARN writing something to logs because logging is great and everyone should log everything"[..],
182 serialize_response(PolicyResponse::Warn(
183 b"writing something to logs because logging is great and everyone should log everything".to_vec()
184 ))[..]
185 );
186}
187
188pub fn handle_connection<'socket, 'ctx, HandlerType, ContextType, ErrorType, SocketType>(
210 mut socket: &'socket SocketType,
211 ctx: &'ctx ContextType,
212) -> Result<(), PostfixPolicyError<ErrorType>>
213where
214 HandlerType: PolicyRequestHandler<'ctx, ContextType, ErrorType>,
215 &'socket SocketType: Read + Write,
216{
217 let mut handler: HandlerType = HandlerType::new(ctx);
218 let mut reader = BufReader::new(socket);
219
220 loop {
221 let mut buf: Vec<u8> = vec![];
222 if reader.read_until(b'\n', &mut buf)? == 0 {
223 return Ok(());
224 }
225
226 if buf == b"\n" {
227 let result = match handler.response() {
228 Ok(result) => result,
229 Err(e) => return Err(PostfixPolicyError::HandlerError(e)),
230 };
231 socket.write_all(b"action=")?;
232 socket.write_all(&serialize_response(result))?;
233 socket.write_all(b"\n\n")?;
234 socket.flush()?;
235 handler = HandlerType::new(ctx);
236 continue;
237 }
238
239 match buf.iter().position(|&c| c == b'=') {
240 None => return Err(PostfixPolicyError::ProtocolError(buf)),
241 Some(pos) => {
242 let (left, mut right) = buf.split_at(pos);
243 if left.is_empty() || right.len() < 2 {
244 return Err(PostfixPolicyError::ProtocolError(buf));
245 }
246 right = &right[1..right.len() - 1];
247 if let Some(error) = handler.attribute(left, right) {
248 return Err(PostfixPolicyError::HandlerError(error));
249 }
250 }
251 }
252 }
253}
254
255pub mod test_helper {
257 use super::{handle_connection, PolicyRequestHandler, PostfixPolicyError};
258 use std::cell::RefCell;
259 use std::io::Cursor;
260 use std::io::{Read, Write};
261
262 pub struct DummySocket<'lt> {
264 input: RefCell<Cursor<&'lt [u8]>>,
265 output: RefCell<Vec<u8>>,
266 }
267
268 impl<'lt> DummySocket<'lt> {
269 pub fn new(input: &'lt [u8]) -> Self {
271 DummySocket {
272 input: RefCell::new(Cursor::new(input)),
273 output: RefCell::new(vec![]),
274 }
275 }
276
277 pub fn get_output(self) -> Vec<u8> {
279 self.output.into_inner()
280 }
281 }
282
283 impl<'lt> Read for &DummySocket<'lt> {
284 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
285 self.input.borrow_mut().read(buf)
286 }
287 }
288
289 impl<'lt> Write for &DummySocket<'lt> {
290 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
291 self.output.borrow_mut().write(buf)
292 }
293 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
294 self.output.borrow_mut().flush()
295 }
296 }
297
298 pub fn handle_connection_response<'l, HandlerType, ContextType, ErrorType>(
314 input: &'l [u8],
315 ctx: &'l ContextType,
316 ) -> Result<Vec<u8>, PostfixPolicyError<ErrorType>>
317 where
318 HandlerType: PolicyRequestHandler<'l, ContextType, ErrorType>,
319 {
320 let socket = DummySocket::new(input);
321 handle_connection::<HandlerType, ContextType, ErrorType, _>(&socket, ctx)?;
322 Ok(socket.get_output())
323 }
324}
325
326#[cfg(test)]
327mod tests {
328
329 use super::test_helper::handle_connection_response;
330 use super::{PolicyRequestHandler, PolicyResponse, PostfixPolicyError};
331
332 struct DummyRequestHandler {
333 found_request: bool,
334 client_address: Vec<u8>,
335 }
336 impl<'l> PolicyRequestHandler<'l, (), ()> for DummyRequestHandler {
337 fn new(_: &()) -> Self {
338 Self {
339 found_request: false,
340 client_address: vec![],
341 }
342 }
343 fn attribute(&mut self, name: &[u8], value: &[u8]) -> Option<()> {
344 match name {
345 b"request" => self.found_request = true,
346 b"client_address" => self.client_address = value.to_vec(),
347 _ => {}
348 }
349 None
350 }
351
352 fn response(self) -> Result<PolicyResponse, ()> {
353 if !self.found_request {
354 return Ok(PolicyResponse::Reject(Vec::new()));
355 }
356 Ok(PolicyResponse::Defer(self.client_address.clone()))
357 }
358 }
359
360 #[test]
361 fn test_handle_connection_valid() {
362 let input =
363 b"request=smtpd_access_policy\nprotocol_state=RCPT\nprotocol_name=ESMTP\nclient_address=131.234.189.14\n\n";
364 assert_eq!(
365 handle_connection_response::<DummyRequestHandler, _, _>(input, &()).unwrap(),
366 b"action=DEFER 131.234.189.14\n\n"
367 );
368 }
369
370 #[test]
371 fn test_handle_connection_empty() {
372 let input = b"\n";
373 assert_eq!(
374 handle_connection_response::<DummyRequestHandler, _, _>(input, &()).unwrap(),
375 b"action=REJECT\n\n"
376 );
377 }
378
379 #[test]
380 fn test_handle_connection_line_without_eq() {
381 let input = b"asdf\n\n";
382
383 assert!(
384 match handle_connection_response::<DummyRequestHandler, _, _>(input, &()) {
385 Err(PostfixPolicyError::ProtocolError(l)) => {
386 assert_eq!(&l, b"asdf\n");
387 true
388 }
389 _ => false,
390 }
391 );
392 }
393
394 #[test]
395 fn test_handle_connection_line_empty_name() {
396 let input = b"=a\n\n";
397
398 assert!(
399 match handle_connection_response::<DummyRequestHandler, _, _>(input, &()) {
400 Err(PostfixPolicyError::ProtocolError(l)) => {
401 assert_eq!(&l, b"=a\n");
402 true
403 }
404 _ => false,
405 }
406 );
407 }
408}