Skip to main content

io_imap/rfc3501/
starttls.rs

1//! IMAP STARTTLS coroutine; returns any bytes received past the tagged
2//! response. RFC 3501 ยง6.2.1 forbids trailing bytes, so a non-empty return
3//! value is a STARTTLS-injection signal: refuse the upgrade.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use std::{
9//!     io::{Read, Write},
10//!     net::TcpStream,
11//! };
12//!
13//! use io_imap::{
14//!     codec::fragmentizer::Fragmentizer,
15//!     coroutine::{ImapCoroutine, ImapCoroutineState, ImapYield},
16//!     rfc3501::starttls::ImapStartTls,
17//! };
18//!
19//! // Ready stream needed (TCP-connected, plain IMAP)
20//! let mut stream = TcpStream::connect("localhost:143").unwrap();
21//!
22//! let mut fragmentizer = Fragmentizer::new(50 * 1024 * 1024);
23//! let mut buf = [0u8; 4096];
24//!
25//! let mut coroutine = ImapStartTls::new();
26//! let mut arg = None;
27//!
28//! let leftover = loop {
29//!     match coroutine.resume(&mut fragmentizer, arg.take()) {
30//!         ImapCoroutineState::Yielded(ImapYield::WantsWrite(bytes)) => {
31//!             stream.write_all(&bytes).unwrap();
32//!         }
33//!         ImapCoroutineState::Yielded(ImapYield::WantsRead) => {
34//!             let n = stream.read(&mut buf).unwrap();
35//!             arg = Some(&buf[..n]);
36//!         }
37//!         ImapCoroutineState::Complete(Ok(leftover)) => break leftover,
38//!         ImapCoroutineState::Complete(Err(err)) => panic!("{err}"),
39//!     }
40//! };
41//!
42//! assert!(leftover.is_empty(), "STARTTLS-injection: refuse the upgrade");
43//! // Now upgrade `stream` to TLS before sending further IMAP commands.
44//! ```
45
46use core::{fmt, mem};
47
48use alloc::vec::Vec;
49
50use imap_codec::{
51    CommandCodec,
52    encode::{Encoder, Fragment},
53    fragmentizer::Fragmentizer,
54    imap_types::{
55        command::{Command, CommandBody},
56        core::{Tag, TagGenerator},
57        utils::escape_byte_string,
58    },
59};
60use log::trace;
61use thiserror::Error;
62
63use crate::coroutine::*;
64
65/// Failure causes during the IMAP STARTTLS handshake.
66#[derive(Clone, Debug, Error)]
67pub enum ImapStartTlsError {
68    #[error("IMAP STARTTLS failed: reached unexpected EOF on stream")]
69    Eof,
70}
71
72/// I/O-free IMAP STARTTLS coroutine.
73pub struct ImapStartTls {
74    tag_bytes: Vec<u8>,
75    state: State,
76    wants_read: bool,
77    wants_write: Option<Vec<u8>>,
78    buf: Vec<u8>,
79}
80
81impl ImapStartTls {
82    pub fn new() -> Self {
83        let tag_bytes = TagGenerator::new().generate().as_ref().as_bytes().to_vec();
84
85        Self {
86            tag_bytes,
87            state: State::DiscardGreeting,
88            wants_read: false,
89            wants_write: None,
90            buf: Vec::new(),
91        }
92    }
93}
94
95impl Default for ImapStartTls {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl ImapCoroutine for ImapStartTls {
102    type Yield = ImapYield;
103    type Return = Result<Vec<u8>, ImapStartTlsError>;
104
105    fn resume(
106        &mut self,
107        _fragmentizer: &mut Fragmentizer,
108        mut arg: Option<&[u8]>,
109    ) -> ImapCoroutineState<Self::Yield, Self::Return> {
110        loop {
111            trace!("starttls: {}", self.state);
112
113            if let Some(bytes) = self.wants_write.take() {
114                return ImapCoroutineState::Yielded(ImapYield::WantsWrite(bytes));
115            }
116
117            if mem::take(&mut self.wants_read) {
118                return ImapCoroutineState::Yielded(ImapYield::WantsRead);
119            }
120
121            match self.state {
122                State::DiscardGreeting => match arg.take() {
123                    Some(&[]) => {
124                        return ImapCoroutineState::Complete(Err(ImapStartTlsError::Eof));
125                    }
126                    Some(data) => {
127                        self.buf.extend_from_slice(data);
128
129                        let Some(pos) = self.buf.iter().position(|&b| b == b'\n') else {
130                            self.wants_read = true;
131                            continue;
132                        };
133
134                        let line = self.buf.drain(..=pos).collect::<Vec<_>>();
135                        trace!("discard greeting line: {}", escape_byte_string(&line));
136
137                        let encoder = CommandCodec::new();
138                        // SAFETY: tag is always valid.
139                        let tag: Tag = self.tag_bytes.as_slice().try_into().unwrap();
140                        let starttls = Command {
141                            tag,
142                            body: CommandBody::StartTLS,
143                        };
144
145                        let Some(Fragment::Line { data }) = encoder.encode(&starttls).next() else {
146                            // SAFETY: STARTTLS is one simple line.
147                            unreachable!();
148                        };
149
150                        trace!("write starttls command: {}", escape_byte_string(&data));
151                        self.wants_write = Some(data);
152                        self.state = State::WriteStartTls;
153                    }
154                    None => {
155                        self.wants_read = true;
156                    }
157                },
158                State::WriteStartTls => {
159                    self.state = State::DiscardStartTls;
160                }
161                State::DiscardStartTls => match arg.take() {
162                    Some(&[]) => {
163                        return ImapCoroutineState::Complete(Err(ImapStartTlsError::Eof));
164                    }
165                    Some(data) => {
166                        self.buf.extend_from_slice(data);
167
168                        let mut tag_with_space = Vec::with_capacity(self.tag_bytes.len() + 1);
169                        tag_with_space.extend(&self.tag_bytes);
170                        tag_with_space.push(b' ');
171
172                        let Some(tag_pos) = self
173                            .buf
174                            .windows(tag_with_space.len())
175                            .position(|w| w == tag_with_space.as_slice())
176                        else {
177                            self.wants_read = true;
178                            continue;
179                        };
180
181                        let Some(rel) = self.buf[tag_pos..].iter().position(|&b| b == b'\n') else {
182                            self.wants_read = true;
183                            continue;
184                        };
185
186                        let end = tag_pos + rel + 1;
187                        let line = &self.buf[tag_pos..end];
188                        trace!(
189                            "discard STARTTLS response line: {}",
190                            escape_byte_string(line)
191                        );
192
193                        let remaining = self.buf.split_off(end);
194                        return ImapCoroutineState::Complete(Ok(remaining));
195                    }
196                    None => {
197                        self.wants_read = true;
198                    }
199                },
200            }
201        }
202    }
203}
204
205enum State {
206    DiscardGreeting,
207    WriteStartTls,
208    DiscardStartTls,
209}
210
211impl fmt::Display for State {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            Self::DiscardGreeting => f.write_str("discard greeting"),
215            Self::WriteStartTls => f.write_str("write starttls"),
216            Self::DiscardStartTls => f.write_str("discard starttls response"),
217        }
218    }
219}