Skip to main content

daaki_imap/connection/
lifecycle.rs

1#![allow(clippy::wildcard_imports)]
2use super::*;
3
4impl ImapConnection {
5    // -----------------------------------------------------------------------
6    // Connection lifecycle
7    // -----------------------------------------------------------------------
8
9    /// Connect to an IMAP server.
10    ///
11    /// Performs the TCP connection (with optional TLS), reads the server greeting,
12    /// and parses initial capabilities per RFC 3501 Section 6.1.1.
13    pub async fn connect(
14        host: &str,
15        port: u16,
16        tls_mode: TlsMode,
17        timeout: Duration,
18    ) -> Result<Self, Error> {
19        let tls_config = build_default_tls_config();
20        Self::connect_with_tls_config(host, port, tls_mode, tls_config, timeout).await
21    }
22
23    /// Connect with a custom TLS configuration (RFC 3501 Section 6.1.1 /
24    /// RFC 9051 Section 6.1.1).
25    ///
26    /// Accepts an `Arc<rustls::ClientConfig>` for use cases like self-signed
27    /// certificates in integration tests.
28    ///
29    /// Establishes the TCP/TLS connection, reads the server greeting,
30    /// fetches initial capabilities, then spawns the driver task and
31    /// returns a handle-based `ImapConnection`.
32    #[allow(clippy::too_many_lines)]
33    pub async fn connect_with_tls_config(
34        host: &str,
35        port: u16,
36        tls_mode: TlsMode,
37        tls_config: Arc<rustls::ClientConfig>,
38        timeout: Duration,
39    ) -> Result<Self, Error> {
40        use super::dispatch::CapabilityConsumer;
41        use super::driver;
42
43        debug!(host, port, ?tls_mode, "connecting to IMAP server");
44
45        let tcp = tokio::time::timeout(timeout, TcpStream::connect((host, port)))
46            .await
47            .map_err(|_| Error::Timeout)?
48            .map_err(|e| Error::Io(std::sync::Arc::new(e)))?;
49
50        let stream = if tls_mode == TlsMode::Implicit {
51            let connector = TlsConnector::from(tls_config.clone());
52            let server_name = rustls_pki_types::ServerName::try_from(host.to_owned())
53                .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
54            let tls_stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp))
55                .await
56                .map_err(|_| Error::Timeout)?
57                .map_err(|e| Error::Io(std::sync::Arc::new(std::io::Error::other(e))))?;
58            ImapStream::Tls(tls_stream)
59        } else {
60            ImapStream::Plain(tcp)
61        };
62
63        // --- Pre-driver phase: read greeting and initial capabilities ---
64        //
65        // The WireReader and ProtocolState live on the stack here. After
66        // the greeting and capability fetch they are moved into the
67        // driver task, which owns them for the rest of the connection.
68        let mut wire_reader = wire::WireReader::new(stream);
69        let mut proto_state = state::ProtocolState::new();
70        let mut tag_gen = tag::TagGenerator::new();
71
72        // Set up the event channel.
73        let (events_tx, events_rx) = tokio::sync::mpsc::channel::<typed_event::TypedEvent>(256);
74
75        // Read and parse the server greeting (RFC 3501 Section 7.1).
76        // This MUST happen before wrapping events_tx in DriverEventSink
77        // so we can send a greeting ALERT on the raw channel (emit() is
78        // pub(in crate::connection::driver) and inaccessible from here).
79        let greeting = tokio::time::timeout(timeout, wire_reader.read_greeting())
80            .await
81            .map_err(|_| Error::Timeout)??;
82
83        let Response::Greeting(g) = &greeting else {
84            return Err(Error::Protocol(
85                "expected greeting from server (RFC 3501 Section 7.1)".into(),
86            ));
87        };
88
89        // Apply greeting to protocol state (sets session state, caches
90        // capabilities, returns Err on BYE).
91        let greeting_alert = proto_state.apply_greeting(g)?;
92
93        // RFC 3501 §7.1: if the greeting carried an [ALERT], deliver it
94        // on the raw channel before wrapping in DriverEventSink.
95        if let Some(alert_text) = greeting_alert {
96            // Best-effort: if the channel is full we drop the alert.
97            let _ = events_tx.try_send(typed_event::TypedEvent::Alert(alert_text));
98        }
99
100        // Wrap the channel in DriverEventSink — events_tx moves here.
101        let mut event_sink = driver::event_sink::DriverEventSink::new(events_tx, None);
102
103        // If we didn't get capabilities from the greeting, request them
104        // explicitly using the driver's run_one_command (RFC 3501 §6.1.1).
105        if proto_state.capabilities().is_empty() {
106            let consumer = driver::DriverConsumer::Regular(
107                Box::new(CapabilityConsumer::default()) as Box<dyn driver::ConsumerErased>
108            );
109            let result = tokio::time::timeout(
110                timeout,
111                driver::run_one_command(
112                    &mut wire_reader,
113                    &mut proto_state,
114                    &mut tag_gen,
115                    &mut event_sink,
116                    Command::Capability,
117                    consumer,
118                ),
119            )
120            .await
121            .map_err(|_| Error::Timeout)??;
122
123            // Downcast the erased output back to Vec<Capability>.
124            // CapabilityConsumer::Output is Vec<Capability>, so the downcast
125            // is provably correct. An Internal error here would indicate a
126            // library bug in the consumer/erased-output machinery.
127            let caps = result
128                .downcast::<Vec<Capability>>()
129                .map_err(|_| Error::Internal("CapabilityConsumer output downcast failed".into()))?;
130            proto_state.apply_capability_fetch(*caps);
131        }
132
133        // --- STARTTLS upgrade (RFC 3501 §6.2.1) ---
134        //
135        // When the caller requests `TlsMode::StartTls`, negotiate the
136        // upgrade before spawning the driver task.  All required state
137        // (`wire_reader`, `proto_state`, `tag_gen`, `event_sink`) is
138        // still on the stack.  `run_starttls_upgrade` sends the STARTTLS
139        // command, verifies the buffer is empty, performs the TLS
140        // handshake, installs a fresh `WireReader`, and re-fetches
141        // capabilities — so `proto_state` reflects post-TLS caps by the
142        // time we snapshot it into `state_tx` below.
143        if tls_mode == TlsMode::StartTls {
144            // The server must advertise STARTTLS (RFC 3501 §6.2.1).
145            if !proto_state
146                .capabilities()
147                .iter()
148                .any(|c| matches!(c, Capability::StartTls))
149            {
150                return Err(Error::StartTlsUnavailable);
151            }
152
153            let server_name = rustls_pki_types::ServerName::try_from(host.to_owned())
154                .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
155
156            tokio::time::timeout(
157                timeout,
158                driver::run_starttls_upgrade(
159                    &mut wire_reader,
160                    &mut proto_state,
161                    &mut tag_gen,
162                    &mut event_sink,
163                    tls_config,
164                    server_name,
165                ),
166            )
167            .await
168            .map_err(|_| Error::Timeout)??;
169        }
170
171        // --- Spawn the driver task ---
172        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::channel(16);
173        let (state_tx, state_rx) = tokio::sync::watch::channel(proto_state.snapshot());
174
175        let handle = tokio::spawn(driver::driver_task(
176            wire_reader,
177            proto_state,
178            tag_gen,
179            cmd_rx,
180            state_tx,
181            event_sink,
182        ));
183
184        Ok(Self {
185            cmd_tx,
186            state_rx,
187            events_rx: tokio::sync::Mutex::new(events_rx),
188            driver_handle: tokio::sync::Mutex::new(Some(handle)),
189            prebuilt_tag_counter: std::sync::atomic::AtomicU32::new(0),
190            host: host.to_owned(),
191        })
192    }
193
194    /// Set TCP keepalive on the underlying socket (RFC 1122 Section 4.2.3.6).
195    ///
196    /// Configures the operating system's TCP keepalive probes via the
197    /// driver task. This does not send any data on the IMAP wire — it
198    /// sets socket options on the underlying file descriptor via
199    /// `setsockopt(2)`.
200    ///
201    /// # Errors
202    ///
203    /// Returns [`Error::Io`] if the `setsockopt` call fails.
204    /// Returns [`Error::DriverGone`] if the driver task has exited.
205    pub async fn set_keepalive(&self, keepalive: TcpKeepalive) -> Result<(), Error> {
206        let (result_tx, result_rx) = tokio::sync::oneshot::channel();
207        let dcmd = super::driver::DriverCommand::SetKeepalive {
208            keepalive,
209            result_tx,
210        };
211        if self.cmd_tx.send(dcmd).await.is_err() {
212            return Err(self.observe_driver_panic().await);
213        }
214        match result_rx.await {
215            Ok(result) => result,
216            Err(_) => Err(self.observe_driver_panic().await),
217        }
218    }
219
220    /// If the driver task has terminated, return an error describing
221    /// why (panic message if panicked, or `DriverGone` if exited
222    /// cleanly). Called from `submit` when `cmd_tx.send` fails so the
223    /// caller sees the real reason instead of a generic `Disconnected`.
224    pub(super) async fn observe_driver_panic(&self) -> Error {
225        let mut guard = self.driver_handle.lock().await;
226        let Some(handle) = guard.take() else {
227            return Error::DriverGone;
228        };
229        // `handle.is_finished()` avoids blocking if still running.
230        if !handle.is_finished() {
231            // Put it back — still running. The cmd_tx failure may be
232            // a TOCTOU race with the driver exiting.
233            *guard = Some(handle);
234            drop(guard);
235            return Error::DriverGone;
236        }
237        match handle.await {
238            Err(join_err) if join_err.is_panic() => {
239                let panic_msg = join_err
240                    .into_panic()
241                    .downcast::<String>()
242                    .map(|s| *s)
243                    .or_else(|p| p.downcast::<&'static str>().map(|s| s.to_string()))
244                    .unwrap_or_else(|_| "driver panicked (payload not a String)".to_string());
245                Error::DriverPanicked(panic_msg)
246            }
247            Ok(()) | Err(_) => Error::DriverGone,
248        }
249    }
250
251    /// Upgrade to TLS via STARTTLS (RFC 3501 Section 6.2.1).
252    ///
253    /// Only valid if `TlsMode::StartTls` was used and `connect` did not already
254    /// perform the upgrade. Errors if the server doesn't advertise STARTTLS.
255    ///
256    /// The upgrade is atomic via the `Poisoned` sentinel pattern (I9, I10)
257    /// — handled entirely by the driver task.
258    pub async fn starttls(&self, timeout: Duration) -> Result<(), Error> {
259        self.starttls_with_config(build_default_tls_config(), timeout)
260            .await
261    }
262
263    /// STARTTLS with a custom TLS configuration (RFC 3501 Section 6.2.1 /
264    /// RFC 9051 Section 6.2.1).
265    ///
266    /// The upgrade is atomic via the `Poisoned` sentinel pattern (I9, I10):
267    ///
268    /// 1. Send STARTTLS, await tagged OK.
269    /// 2. Verify the wire buffer is empty (no injected bytes — B10 fix).
270    /// 3. `mem::replace` the reader with a `Poisoned`-stream reader.
271    /// 4. TLS handshake (may suspend). If it fails or the future is
272    ///    cancelled, the reader stays wrapping `Poisoned` forever.
273    /// 5. Install a fresh `WireReader` on the new TLS stream.
274    /// 6. Re-fetch capabilities (RFC 3501 §6.2.1).
275    ///
276    /// All steps are executed by the driver task. The caller submits
277    /// the upgrade command and awaits the result.
278    pub async fn starttls_with_config(
279        &self,
280        tls_config: Arc<rustls::ClientConfig>,
281        timeout: Duration,
282    ) -> Result<(), Error> {
283        self.require_state(&[SessionState::NotAuthenticated])?;
284
285        // Check STARTTLS capability from the snapshot.
286        {
287            let snap = self.state_rx.borrow();
288            if !snap.capabilities.is_empty()
289                && !snap
290                    .capabilities
291                    .iter()
292                    .any(|c| matches!(c, Capability::StartTls))
293            {
294                return Err(Error::StartTlsUnavailable);
295            }
296        }
297
298        // Validate server name before submitting — fail early without
299        // involving the driver (RFC 3501 §6.2.1).
300        let server_name = rustls_pki_types::ServerName::try_from(self.host.clone())
301            .map_err(|e| Error::Protocol(format!("invalid TLS server name: {e}")))?;
302
303        debug!("upgrading to TLS via STARTTLS");
304        tokio::time::timeout(
305            timeout,
306            self.submit_upgrade(driver::UpgradePayload::StartTls {
307                tls_config,
308                server_name,
309            }),
310        )
311        .await
312        .map_err(|_| Error::Timeout)?
313    }
314}