daaki_imap/connection/lifecycle.rs
1#![allow(clippy::wildcard_imports)]
2use super::*;
3
4impl ImapConnection {
5 // -----------------------------------------------------------------------
6 // Connection lifecycle
7 // -----------------------------------------------------------------------
8
9 /// Connect to an IMAP server.
10 ///
11 /// Performs the TCP connection (with optional TLS), reads the server greeting,
12 /// and parses initial capabilities per RFC 3501 Section 6.1.1.
13 pub async fn connect(
14 host: &str,
15 port: u16,
16 tls_mode: TlsMode,
17 timeout: Duration,
18 ) -> Result<Self, Error> {
19 let tls_config = build_default_tls_config();
20 Self::connect_with_tls_config(host, port, tls_mode, tls_config, timeout).await
21 }
22
23 /// Connect with a custom TLS configuration (RFC 3501 Section 6.1.1 /
24 /// RFC 9051 Section 6.1.1).
25 ///
26 /// Accepts an `Arc<rustls::ClientConfig>` for use cases like self-signed
27 /// certificates in integration tests.
28 ///
29 /// Establishes the TCP/TLS connection, reads the server greeting,
30 /// fetches initial capabilities, then spawns the driver task and
31 /// returns a handle-based `ImapConnection`.
32 #[allow(clippy::too_many_lines)]
33 pub async fn connect_with_tls_config(
34 host: &str,
35 port: u16,
36 tls_mode: TlsMode,
37 tls_config: Arc<rustls::ClientConfig>,
38 timeout: Duration,
39 ) -> Result<Self, Error> {
40 use super::dispatch::CapabilityConsumer;
41 use super::driver;
42
43 debug!(host, port, ?tls_mode, "connecting to IMAP server");
44
45 let tcp = tokio::time::timeout(timeout, TcpStream::connect((host, port)))
46 .await
47 .map_err(|_| Error::Timeout)?
48 .map_err(|e| Error::Io(std::sync::Arc::new(e)))?;
49
50 let stream = if tls_mode == TlsMode::Implicit {
51 let connector = TlsConnector::from(tls_config.clone());
52 let server_name = rustls_pki_types::ServerName::try_from(host.to_owned())
53 .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
54 let tls_stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp))
55 .await
56 .map_err(|_| Error::Timeout)?
57 .map_err(|e| Error::Io(std::sync::Arc::new(std::io::Error::other(e))))?;
58 ImapStream::Tls(tls_stream)
59 } else {
60 ImapStream::Plain(tcp)
61 };
62
63 // --- Pre-driver phase: read greeting and initial capabilities ---
64 //
65 // The WireReader and ProtocolState live on the stack here. After
66 // the greeting and capability fetch they are moved into the
67 // driver task, which owns them for the rest of the connection.
68 let mut wire_reader = wire::WireReader::new(stream);
69 let mut proto_state = state::ProtocolState::new();
70 let mut tag_gen = tag::TagGenerator::new();
71
72 // Set up the event channel.
73 let (events_tx, events_rx) = tokio::sync::mpsc::channel::<typed_event::TypedEvent>(256);
74
75 // Read and parse the server greeting (RFC 3501 Section 7.1).
76 // This MUST happen before wrapping events_tx in DriverEventSink
77 // so we can send a greeting ALERT on the raw channel (emit() is
78 // pub(in crate::connection::driver) and inaccessible from here).
79 let greeting = tokio::time::timeout(timeout, wire_reader.read_greeting())
80 .await
81 .map_err(|_| Error::Timeout)??;
82
83 let Response::Greeting(g) = &greeting else {
84 return Err(Error::Protocol(
85 "expected greeting from server (RFC 3501 Section 7.1)".into(),
86 ));
87 };
88
89 // Apply greeting to protocol state (sets session state, caches
90 // capabilities, returns Err on BYE).
91 let greeting_alert = proto_state.apply_greeting(g)?;
92
93 // RFC 3501 §7.1: if the greeting carried an [ALERT], deliver it
94 // on the raw channel before wrapping in DriverEventSink.
95 if let Some(alert_text) = greeting_alert {
96 // Best-effort: if the channel is full we drop the alert.
97 let _ = events_tx.try_send(typed_event::TypedEvent::Alert(alert_text));
98 }
99
100 // Wrap the channel in DriverEventSink — events_tx moves here.
101 let mut event_sink = driver::event_sink::DriverEventSink::new(events_tx, None);
102
103 // If we didn't get capabilities from the greeting, request them
104 // explicitly using the driver's run_one_command (RFC 3501 §6.1.1).
105 if proto_state.capabilities().is_empty() {
106 let consumer = driver::DriverConsumer::Regular(
107 Box::new(CapabilityConsumer::default()) as Box<dyn driver::ConsumerErased>
108 );
109 let result = tokio::time::timeout(
110 timeout,
111 driver::run_one_command(
112 &mut wire_reader,
113 &mut proto_state,
114 &mut tag_gen,
115 &mut event_sink,
116 Command::Capability,
117 consumer,
118 ),
119 )
120 .await
121 .map_err(|_| Error::Timeout)??;
122
123 // Downcast the erased output back to Vec<Capability>.
124 // CapabilityConsumer::Output is Vec<Capability>, so the downcast
125 // is provably correct. An Internal error here would indicate a
126 // library bug in the consumer/erased-output machinery.
127 let caps = result
128 .downcast::<Vec<Capability>>()
129 .map_err(|_| Error::Internal("CapabilityConsumer output downcast failed".into()))?;
130 proto_state.apply_capability_fetch(*caps);
131 }
132
133 // --- STARTTLS upgrade (RFC 3501 §6.2.1) ---
134 //
135 // When the caller requests `TlsMode::StartTls`, negotiate the
136 // upgrade before spawning the driver task. All required state
137 // (`wire_reader`, `proto_state`, `tag_gen`, `event_sink`) is
138 // still on the stack. `run_starttls_upgrade` sends the STARTTLS
139 // command, verifies the buffer is empty, performs the TLS
140 // handshake, installs a fresh `WireReader`, and re-fetches
141 // capabilities — so `proto_state` reflects post-TLS caps by the
142 // time we snapshot it into `state_tx` below.
143 if tls_mode == TlsMode::StartTls {
144 // The server must advertise STARTTLS (RFC 3501 §6.2.1).
145 if !proto_state
146 .capabilities()
147 .iter()
148 .any(|c| matches!(c, Capability::StartTls))
149 {
150 return Err(Error::StartTlsUnavailable);
151 }
152
153 let server_name = rustls_pki_types::ServerName::try_from(host.to_owned())
154 .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
155
156 tokio::time::timeout(
157 timeout,
158 driver::run_starttls_upgrade(
159 &mut wire_reader,
160 &mut proto_state,
161 &mut tag_gen,
162 &mut event_sink,
163 tls_config,
164 server_name,
165 ),
166 )
167 .await
168 .map_err(|_| Error::Timeout)??;
169 }
170
171 // --- Spawn the driver task ---
172 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(16);
173 let (state_tx, state_rx) = tokio::sync::watch::channel(proto_state.snapshot());
174
175 let handle = tokio::spawn(driver::driver_task(
176 wire_reader,
177 proto_state,
178 tag_gen,
179 cmd_rx,
180 state_tx,
181 event_sink,
182 ));
183
184 Ok(Self {
185 cmd_tx,
186 state_rx,
187 events_rx: tokio::sync::Mutex::new(events_rx),
188 driver_handle: tokio::sync::Mutex::new(Some(handle)),
189 prebuilt_tag_counter: std::sync::atomic::AtomicU32::new(0),
190 host: host.to_owned(),
191 })
192 }
193
194 /// Set TCP keepalive on the underlying socket (RFC 1122 Section 4.2.3.6).
195 ///
196 /// Configures the operating system's TCP keepalive probes via the
197 /// driver task. This does not send any data on the IMAP wire — it
198 /// sets socket options on the underlying file descriptor via
199 /// `setsockopt(2)`.
200 ///
201 /// # Errors
202 ///
203 /// Returns [`Error::Io`] if the `setsockopt` call fails.
204 /// Returns [`Error::DriverGone`] if the driver task has exited.
205 pub async fn set_keepalive(&self, keepalive: TcpKeepalive) -> Result<(), Error> {
206 let (result_tx, result_rx) = tokio::sync::oneshot::channel();
207 let dcmd = super::driver::DriverCommand::SetKeepalive {
208 keepalive,
209 result_tx,
210 };
211 if self.cmd_tx.send(dcmd).await.is_err() {
212 return Err(self.observe_driver_panic().await);
213 }
214 match result_rx.await {
215 Ok(result) => result,
216 Err(_) => Err(self.observe_driver_panic().await),
217 }
218 }
219
220 /// If the driver task has terminated, return an error describing
221 /// why (panic message if panicked, or `DriverGone` if exited
222 /// cleanly). Called from `submit` when `cmd_tx.send` fails so the
223 /// caller sees the real reason instead of a generic `Disconnected`.
224 pub(super) async fn observe_driver_panic(&self) -> Error {
225 let mut guard = self.driver_handle.lock().await;
226 let Some(handle) = guard.take() else {
227 return Error::DriverGone;
228 };
229 // `handle.is_finished()` avoids blocking if still running.
230 if !handle.is_finished() {
231 // Put it back — still running. The cmd_tx failure may be
232 // a TOCTOU race with the driver exiting.
233 *guard = Some(handle);
234 drop(guard);
235 return Error::DriverGone;
236 }
237 match handle.await {
238 Err(join_err) if join_err.is_panic() => {
239 let panic_msg = join_err
240 .into_panic()
241 .downcast::<String>()
242 .map(|s| *s)
243 .or_else(|p| p.downcast::<&'static str>().map(|s| s.to_string()))
244 .unwrap_or_else(|_| "driver panicked (payload not a String)".to_string());
245 Error::DriverPanicked(panic_msg)
246 }
247 Ok(()) | Err(_) => Error::DriverGone,
248 }
249 }
250
251 /// Upgrade to TLS via STARTTLS (RFC 3501 Section 6.2.1).
252 ///
253 /// Only valid if `TlsMode::StartTls` was used and `connect` did not already
254 /// perform the upgrade. Errors if the server doesn't advertise STARTTLS.
255 ///
256 /// The upgrade is atomic via the `Poisoned` sentinel pattern (I9, I10)
257 /// — handled entirely by the driver task.
258 pub async fn starttls(&self, timeout: Duration) -> Result<(), Error> {
259 self.starttls_with_config(build_default_tls_config(), timeout)
260 .await
261 }
262
263 /// STARTTLS with a custom TLS configuration (RFC 3501 Section 6.2.1 /
264 /// RFC 9051 Section 6.2.1).
265 ///
266 /// The upgrade is atomic via the `Poisoned` sentinel pattern (I9, I10):
267 ///
268 /// 1. Send STARTTLS, await tagged OK.
269 /// 2. Verify the wire buffer is empty (no injected bytes — B10 fix).
270 /// 3. `mem::replace` the reader with a `Poisoned`-stream reader.
271 /// 4. TLS handshake (may suspend). If it fails or the future is
272 /// cancelled, the reader stays wrapping `Poisoned` forever.
273 /// 5. Install a fresh `WireReader` on the new TLS stream.
274 /// 6. Re-fetch capabilities (RFC 3501 §6.2.1).
275 ///
276 /// All steps are executed by the driver task. The caller submits
277 /// the upgrade command and awaits the result.
278 pub async fn starttls_with_config(
279 &self,
280 tls_config: Arc<rustls::ClientConfig>,
281 timeout: Duration,
282 ) -> Result<(), Error> {
283 self.require_state(&[SessionState::NotAuthenticated])?;
284
285 // Check STARTTLS capability from the snapshot.
286 {
287 let snap = self.state_rx.borrow();
288 if !snap.capabilities.is_empty()
289 && !snap
290 .capabilities
291 .iter()
292 .any(|c| matches!(c, Capability::StartTls))
293 {
294 return Err(Error::StartTlsUnavailable);
295 }
296 }
297
298 // Validate server name before submitting — fail early without
299 // involving the driver (RFC 3501 §6.2.1).
300 let server_name = rustls_pki_types::ServerName::try_from(self.host.clone())
301 .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
302
303 debug!("upgrading to TLS via STARTTLS");
304 tokio::time::timeout(
305 timeout,
306 self.submit_upgrade(driver::UpgradePayload::StartTls {
307 tls_config,
308 server_name,
309 }),
310 )
311 .await
312 .map_err(|_| Error::Timeout)?
313 }
314}