postbus/
session.rs

1use std::{io::ErrorKind, net::SocketAddr, sync::Arc};
2use tokio::net::TcpStream;
3
4use crate::{
5    command::{Command, Domain, Mailbox},
6    Handler, Response,
7};
8
9/// Struct holding data about the session.
10pub struct SmtpSession {
11    stream: TcpStream,
12    server_name: String,
13    remaining: String,
14    addr: SocketAddr,
15    handler: Arc<dyn Handler>,
16    state: SmtpState,
17}
18
19/// Struct holding the current state of an transaction.
20#[derive(Debug)]
21pub struct SmtpState {
22    pub receiving_data: bool,
23    pub domain: Option<Domain>,
24    pub from: Option<Mailbox>,
25    pub recipients: Vec<Mailbox>,
26    pub data: String,
27}
28
29impl Default for SmtpState {
30    fn default() -> Self {
31        Self {
32            receiving_data: false,
33            domain: None,
34            from: None,
35            recipients: Vec::new(),
36            data: String::new(),
37        }
38    }
39}
40
41impl SmtpSession {
42    /// Create a new session.
43    pub(crate) fn new(
44        stream: TcpStream,
45        addr: SocketAddr,
46        server_name: String,
47        handler: Arc<dyn Handler>,
48    ) -> Self {
49        SmtpSession {
50            stream,
51            server_name,
52            handler,
53            addr,
54            remaining: String::with_capacity(128),
55            state: SmtpState::default(),
56        }
57    }
58
59    /// Handle the session, reading and writing.
60    /// Should only be called once, returns when the connection should be dropped.
61    pub(crate) async fn handle(mut self) -> () {
62        let mut buff = vec![0; 1024];
63
64        debug!("Accepted new client {}.", self.addr);
65
66        match self
67            .send(&Response::Greeting(self.server_name.clone()))
68            .await
69        {
70            Ok(_) => (),
71            Err(_) => return (),
72        };
73
74        loop {
75            match self.stream.readable().await {
76                Ok(_) => (),
77                Err(e) => {
78                    error!(
79                        "Encountered error while waiting for socket to get ready to read: {}.",
80                        e
81                    );
82                    break;
83                }
84            }
85
86            match self.stream.try_read(&mut buff) {
87                Ok(0) => break,
88                Ok(n) => {
89                    let msg = match std::str::from_utf8(&buff[..n]) {
90                        Ok(m) => m,
91                        Err(_) => {
92                            debug!("Received non-utf8 characters.");
93                            break;
94                        }
95                    };
96
97                    let should_quit = self.input(msg).await;
98                    if should_quit {
99                        debug!("Server indicated to quit.");
100                        break;
101                    }
102                }
103                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
104                    continue;
105                }
106                Err(e) => {
107                    warn!("Received error while reading socket: {}.", e);
108                    break;
109                }
110            }
111        }
112    }
113
114    /// Handle new incoming input.
115    async fn input(&mut self, input: &str) -> bool {
116        let full_input = format!("{}{}", self.remaining.as_str(), input);
117
118        let full_input = if self.state.receiving_data {
119            let (has_ended, res, rem) = super::parser::parse_data_lines(full_input.as_str());
120
121            self.state.data.push_str(res.as_str());
122
123            if has_ended {
124                let resp = match self.handler.save(&self.state).await {
125                    true => Response::Ok,
126                    false => Response::TransactionFailed,
127                };
128
129                match self.send(&resp).await {
130                    Ok(_) => (),
131                    Err(_) => return true,
132                }
133
134                self.state.receiving_data = false;
135                self.state.data = String::new();
136            }
137
138            rem
139        } else {
140            full_input
141        };
142
143        let (cmds, rem) = super::parser::parse(full_input.as_str());
144        self.remaining = rem.to_owned();
145
146        for (_, command) in cmds {
147            debug!("Processing command {:?}.", command);
148            match command {
149                Some(c) => match self.process_command(c).await {
150                    Ok(cmd) => {
151                        let resp = self.send(&cmd).await;
152
153                        if cmd == Response::Goodbye || resp.is_err() {
154                            return true;
155                        }
156                    }
157                    Err(_) => return true,
158                },
159                None => match self.send(&Response::SyntaxError).await {
160                    Ok(_) => (),
161                    Err(_) => return true,
162                },
163            }
164        }
165
166        false
167    }
168
169    async fn process_command(&mut self, command: Command) -> Result<Response, std::io::Error> {
170        Ok(match command {
171            Command::HELO(domain) => self.process_helo(domain),
172            Command::EHLO(domain) => self.process_ehlo(domain),
173            Command::FROM(sender) => self.process_from(sender),
174            Command::RCPT(recipient) => self.process_rcpt(recipient).await,
175            Command::DATA => self.process_data(),
176            Command::RSET => self.process_reset(),
177            Command::QUIT => Response::Goodbye,
178        })
179    }
180
181    fn process_helo(&mut self, domain: Domain) -> Response {
182        debug!("Processing HELO for {:?}.", domain);
183
184        self.state.domain = Some(domain.clone());
185        Response::Helo(self.server_name.clone())
186    }
187
188    fn process_ehlo(&mut self, domain: Domain) -> Response {
189        debug!("Processing EHLO for {:?}.", domain);
190
191        self.state.domain = Some(domain.clone());
192        Response::Ehlo(self.server_name.clone())
193    }
194
195    fn process_from(&mut self, sender: Mailbox) -> Response {
196        debug!("Processing FROM for {:?}.", sender);
197
198        if self.state.domain == None {
199            debug!("MAIL command was out of sequence.");
200            return Response::OutOfSequence;
201        }
202
203        debug!("Sender accepted.");
204        self.state.from = Some(sender.clone());
205        Response::Ok
206    }
207
208    async fn process_rcpt(&mut self, recipient: Mailbox) -> Response {
209        debug!("Processing recipient for {:?}.", recipient);
210
211        if self.state.domain == None {
212            debug!("RCPT command was send out of sequence.");
213            return Response::OutOfSequence;
214        }
215
216        if self.state.recipients.len() >= 100 {
217            debug!("Received 100 or more recipients.");
218            return Response::TooManyRecipients;
219        }
220
221        if !self.handler.recipient_local(&recipient).await {
222            debug!("Handler indicated the recipient was not local.");
223            return Response::RecipientNotLocal;
224        }
225
226        debug!("Recipient accepted.");
227        self.state.recipients.push(recipient);
228        Response::Ok
229    }
230
231    fn process_data(&mut self) -> Response {
232        if self.state.domain == None {
233            debug!("Received DATA without EHLO.");
234            return Response::OutOfSequence;
235        }
236
237        if self.state.from == None {
238            debug!("Received DATA without FROM.");
239            return Response::OutOfSequence;
240        }
241
242        if self.state.recipients.len() <= 0 {
243            debug!("Received DATA without RCPT.");
244            return Response::OutOfSequence;
245        }
246
247        self.state.receiving_data = true;
248        Response::StartData
249    }
250
251    fn process_reset(&mut self) -> Response {
252        self.state.from = None;
253        self.state.recipients = Vec::new();
254        self.state.data = String::new();
255
256        Response::Ok
257    }
258
259    /// Send a response to the client.
260    async fn send(&self, res: &Response) -> Result<(), std::io::Error> {
261        debug!("Sending `{:?}`.", res);
262
263        match self.stream.writable().await {
264            Ok(_) => (),
265            Err(e) => {
266                error!(
267                    "Encountered error while waiting for socket to get ready to write: {}.",
268                    e
269                );
270            }
271        }
272
273        self.stream.try_write(res.to_response().as_bytes())?;
274
275        Ok(())
276    }
277}