Skip to main content

nexus_web/ws/
connecting.rs

1//! Non-blocking WebSocket connection handshake.
2
3use std::io::{self, Read, Write};
4
5use super::frame::Role;
6use super::frame_reader::{FrameReader, FrameReaderBuilder};
7use super::frame_writer::FrameWriter;
8use super::handshake::{self, HandshakeError};
9use super::stream::{Client, ClientBuilder, Error, parse_ws_url};
10use nexus_net::buf::WriteBuf;
11
12#[cfg(feature = "tls")]
13use nexus_net::tls::{TlsCodec, TlsError};
14
15/// A WebSocket connection in the handshake phase.
16///
17/// Drive the handshake by calling [`poll()`](Self::poll) when the socket
18/// is ready. Returns [`Client<S>`] when complete.
19///
20/// Check [`wants_read()`](Self::wants_read) / [`wants_write()`](Self::wants_write)
21/// to determine which readiness event to wait for in your event loop.
22///
23/// # Usage
24///
25/// ```ignore
26/// use nexus_web::ws::{Connecting, ClientBuilder};
27///
28/// let tcp = TcpStream::connect("exchange.com:443")?;
29/// tcp.set_nonblocking(true)?;
30/// let mut connecting = ClientBuilder::new()
31///     .begin_connect(tcp, "wss://exchange.com/ws")?;
32///
33/// // In your event loop:
34/// loop {
35///     // ... poll for socket readiness ...
36///     if let Some(ws) = connecting.poll()? {
37///         // Handshake complete — ws.recv() is now available
38///         break;
39///     }
40/// }
41/// ```
42pub struct Connecting<S> {
43    // ManuallyDrop: ownership transferred to Client in finish().
44    // Drop impl handles cleanup if finish() is never called (error path).
45    stream: std::mem::ManuallyDrop<S>,
46    state: ConnectState,
47    #[cfg(feature = "tls")]
48    tls: Option<TlsCodec>,
49    reader_builder: FrameReaderBuilder,
50    write_buf_capacity: usize,
51    write_buf_headroom: usize,
52    // Handshake data
53    ws_key: [u8; 24],
54    req_buf: Vec<u8>,
55    req_offset: usize,
56    resp_reader: crate::http::ResponseReader,
57    host: String,
58    path: String,
59    finished: bool, // true after finish() called — suppress Drop
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63enum ConnectState {
64    /// TLS handshake: need to write.
65    #[cfg(feature = "tls")]
66    TlsWrite,
67    /// TLS handshake: need to read.
68    #[cfg(feature = "tls")]
69    TlsRead,
70    /// Sending HTTP upgrade request.
71    HttpSend,
72    /// Reading HTTP upgrade response.
73    HttpRecv,
74    /// Handshake complete, ready to transition.
75    Done,
76}
77
78impl ClientBuilder {
79    /// Start a non-blocking connection handshake.
80    ///
81    /// Returns a [`Connecting`] that must be driven to completion
82    /// via [`poll()`](Connecting::poll) before messages can be sent/received.
83    ///
84    /// The caller is responsible for setting the socket to non-blocking
85    /// mode before calling this.
86    pub fn begin_connect<S: Read + Write>(
87        self,
88        stream: S,
89        url: &str,
90    ) -> Result<Connecting<S>, Error> {
91        let parsed = parse_ws_url(url)?;
92
93        #[cfg(feature = "tls")]
94        let tls = if parsed.tls {
95            let config = match self.tls_config {
96                Some(c) => c,
97                None => nexus_net::tls::TlsConfig::new().map_err(Error::Tls)?,
98            };
99            Some(TlsCodec::new(&config, parsed.host)?)
100        } else {
101            None
102        };
103
104        #[cfg(not(feature = "tls"))]
105        if parsed.tls {
106            return Err(Error::TlsNotEnabled);
107        }
108
109        let ws_key = handshake::generate_key();
110
111        #[cfg(feature = "tls")]
112        let initial_state = if tls.is_some() {
113            ConnectState::TlsWrite
114        } else {
115            ConnectState::HttpSend
116        };
117
118        #[cfg(not(feature = "tls"))]
119        let initial_state = ConnectState::HttpSend;
120
121        let mut connecting = Connecting {
122            stream: std::mem::ManuallyDrop::new(stream),
123            state: initial_state,
124            #[cfg(feature = "tls")]
125            tls,
126            reader_builder: self.reader_builder,
127            write_buf_capacity: self.write_buf_capacity,
128            write_buf_headroom: self.write_buf_headroom,
129            ws_key,
130            req_buf: Vec::new(),
131            req_offset: 0,
132            resp_reader: crate::http::ResponseReader::new(4096),
133            host: parsed.host.to_owned(),
134            path: parsed.path.to_owned(),
135            finished: false,
136        };
137
138        // Build the HTTP upgrade request for ws:// (no TLS step)
139        if matches!(initial_state, ConnectState::HttpSend) {
140            let path = connecting.path.clone();
141            connecting.prepare_http_request(&path);
142        }
143
144        Ok(connecting)
145    }
146}
147
148impl<S: Read + Write> Connecting<S> {
149    /// Drive the handshake forward. Non-blocking.
150    ///
151    /// Returns `Ok(None)` while in progress, `Ok(Some(ws))` when the
152    /// connection is ready and [`recv()`](Client::recv) can be called.
153    ///
154    /// Call when the socket is readable or writable (check
155    /// [`wants_read()`](Self::wants_read) / [`wants_write()`](Self::wants_write)).
156    ///
157    /// On `WouldBlock`, returns `Ok(None)` — call again when the socket
158    /// is ready.
159    pub fn poll(&mut self) -> Result<Option<Client<S>>, Error> {
160        loop {
161            match self.state {
162                #[cfg(feature = "tls")]
163                ConnectState::TlsWrite => {
164                    let tls = self
165                        .tls
166                        .as_mut()
167                        .expect("TLS codec must exist in TLS handshake state");
168                    match tls.write_tls_to(&mut *self.stream) {
169                        Ok(_) => {}
170                        Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
171                        Err(e) => return Err(e.into()),
172                    }
173                    if tls.is_handshaking() {
174                        self.state = ConnectState::TlsRead;
175                    } else {
176                        self.state = ConnectState::HttpSend;
177                        let path = self.path.clone();
178                        self.prepare_http_request(&path);
179                    }
180                }
181                #[cfg(feature = "tls")]
182                ConnectState::TlsRead => {
183                    let tls = self
184                        .tls
185                        .as_mut()
186                        .expect("TLS codec must exist in TLS handshake state");
187                    match tls.read_tls_from(&mut *self.stream) {
188                        Ok(0) => {
189                            // Peer closed mid-TLS-handshake — not a
190                            // malformed-HTTP condition (we haven't sent
191                            // the HTTP upgrade yet).
192                            return Err(Error::Io(io::Error::new(
193                                io::ErrorKind::UnexpectedEof,
194                                "connection closed during TLS handshake",
195                            )));
196                        }
197                        Ok(_) => {}
198                        Err(TlsError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => {
199                            return Ok(None);
200                        }
201                        Err(e) => return Err(e.into()),
202                    }
203                    if tls.wants_write() {
204                        self.state = ConnectState::TlsWrite;
205                    } else if !tls.is_handshaking() {
206                        self.state = ConnectState::HttpSend;
207                        let path = self.path.clone();
208                        self.prepare_http_request(&path);
209                    }
210                }
211                ConnectState::HttpSend => {
212                    if self.req_offset >= self.req_buf.len() {
213                        self.state = ConnectState::HttpRecv;
214                        return Ok(None);
215                    }
216
217                    #[cfg(feature = "tls")]
218                    if let Some(tls) = &mut self.tls {
219                        // TLS path: feed plaintext chunks until the
220                        // request is consumed. The HTTP upgrade is
221                        // small (always under rustls's 64 KiB plaintext
222                        // queue cap) so a single `encrypt` typically
223                        // accepts everything; the loop guards against
224                        // partial acceptance defensively.
225                        while self.req_offset < self.req_buf.len() {
226                            let data = &self.req_buf[self.req_offset..];
227                            let n = tls.encrypt(data)?;
228                            if n == 0 {
229                                break; // queue full; drain ciphertext below
230                            }
231                            self.req_offset += n;
232                        }
233                        // Flush whatever ciphertext we can
234                        match tls.write_tls_to(&mut *self.stream) {
235                            Ok(_) => {}
236                            Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
237                            Err(e) => return Err(e.into()),
238                        }
239                        // If TLS still has buffered ciphertext, come back later
240                        if tls.wants_write() {
241                            return Ok(None);
242                        }
243                        self.state = ConnectState::HttpRecv;
244                        return Ok(None);
245                    }
246
247                    // Plain WS path: write plaintext directly
248                    {
249                        let data = &self.req_buf[self.req_offset..];
250                        let n = match (*self.stream).write(data) {
251                            Ok(n) => n,
252                            Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
253                            Err(e) => return Err(e.into()),
254                        };
255                        if n == 0 {
256                            return Err(Error::Io(io::Error::new(
257                                io::ErrorKind::WriteZero,
258                                "write returned 0 during handshake",
259                            )));
260                        }
261                        self.req_offset += n;
262                        if self.req_offset >= self.req_buf.len() {
263                            self.state = ConnectState::HttpRecv;
264                        }
265                    }
266                    return Ok(None);
267                }
268                ConnectState::HttpRecv => {
269                    let mut tmp = [0u8; 4096];
270                    let n = self.read_bytes(&mut tmp)?;
271                    if n == 0 {
272                        return Ok(None);
273                    }
274
275                    self.resp_reader
276                        .read(&tmp[..n])
277                        .map_err(|_| HandshakeError::MalformedHttp)?;
278
279                    // Check if we have a complete response.
280                    // validate_upgrade borrows self immutably, so we
281                    // can't call it while resp_reader is mutably borrowed.
282                    // next() consumes the response, so we validate inline.
283                    match self.resp_reader.next() {
284                        Ok(Some(resp)) => {
285                            if resp.status != 101 {
286                                return Err(HandshakeError::UnexpectedStatus(resp.status).into());
287                            }
288                            let upgrade = resp
289                                .header("Upgrade")
290                                .ok_or(HandshakeError::MissingUpgrade)?;
291                            if !upgrade.eq_ignore_ascii_case("websocket") {
292                                return Err(HandshakeError::MissingUpgrade.into());
293                            }
294                            let conn = resp
295                                .header("Connection")
296                                .ok_or(HandshakeError::MissingConnection)?;
297                            if !conn
298                                .as_bytes()
299                                .windows(7)
300                                .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
301                            {
302                                return Err(HandshakeError::MissingConnection.into());
303                            }
304                            let key_str = std::str::from_utf8(&self.ws_key)
305                                .expect("base64 output is valid ASCII");
306                            let accept = resp
307                                .header("Sec-WebSocket-Accept")
308                                .ok_or(HandshakeError::InvalidAcceptKey)?;
309                            if !handshake::validate_accept(key_str, accept) {
310                                return Err(HandshakeError::InvalidAcceptKey.into());
311                            }
312                            self.state = ConnectState::Done;
313                            // Fall through to Done
314                        }
315                        Ok(None) => return Ok(None),
316                        Err(_) => return Err(HandshakeError::MalformedHttp.into()),
317                    }
318                }
319                ConnectState::Done => {
320                    return Ok(Some(self.finish()?));
321                }
322            }
323        }
324    }
325
326    /// Whether the handshake needs to write to the socket.
327    pub fn wants_write(&self) -> bool {
328        matches!(
329            self.state,
330            ConnectState::HttpSend | if_tls!(ConnectState::TlsWrite)
331        )
332    }
333
334    /// Whether the handshake needs to read from the socket.
335    pub fn wants_read(&self) -> bool {
336        matches!(
337            self.state,
338            ConnectState::HttpRecv | if_tls!(ConnectState::TlsRead)
339        )
340    }
341
342    /// Access the underlying stream (for mio registration).
343    pub fn stream(&self) -> &S {
344        &self.stream
345    }
346
347    /// Mutable access to the underlying stream.
348    pub fn stream_mut(&mut self) -> &mut S {
349        &mut self.stream
350    }
351
352    // =========================================================================
353    // Internal
354    // =========================================================================
355
356    fn prepare_http_request(&mut self, path: &str) {
357        let key_str = std::str::from_utf8(&self.ws_key).expect("base64 output is valid ASCII");
358        let headers = [
359            ("Host", self.host.as_str()),
360            ("Upgrade", "websocket"),
361            ("Connection", "Upgrade"),
362            ("Sec-WebSocket-Key", key_str),
363            ("Sec-WebSocket-Version", "13"),
364        ];
365        let size = crate::http::request_size("GET", path, &headers);
366        let mut buf = vec![0u8; size];
367        // unwrap is safe: buffer is exactly the right size
368        let n = crate::http::write_request("GET", path, &headers, &mut buf)
369            .expect("request fits in handshake buffer");
370        self.req_buf = buf[..n].to_vec();
371        self.req_offset = 0;
372    }
373
374    fn finish(&mut self) -> Result<Client<S>, Error> {
375        self.finished = true;
376
377        let reader_builder = std::mem::replace(&mut self.reader_builder, FrameReader::builder());
378        let mut reader = reader_builder.role(Role::Client).build();
379        let remainder = self.resp_reader.remainder();
380        if !remainder.is_empty() {
381            reader
382                .read(remainder)
383                .map_err(|_| Error::Handshake(HandshakeError::MalformedHttp))?;
384        }
385
386        // SAFETY: stream is ManuallyDrop. We take ownership here.
387        // The `finished` flag prevents Drop from dropping it again.
388        // finish() is only called once (state == Done).
389        let stream = unsafe { std::mem::ManuallyDrop::take(&mut self.stream) };
390
391        Ok(Client::from_parts_internal(
392            stream,
393            reader,
394            FrameWriter::new(Role::Client),
395            WriteBuf::new(self.write_buf_capacity, self.write_buf_headroom),
396        ))
397    }
398
399    /// Read bytes through TLS or direct.
400    /// Returns Ok(n) for data, Err(WouldBlock) for non-blocking no-data,
401    /// Err(UnexpectedEof) for connection closed during handshake.
402    fn read_bytes(&mut self, dst: &mut [u8]) -> Result<usize, Error> {
403        #[cfg(feature = "tls")]
404        if let Some(tls) = &mut self.tls {
405            // Drain any plaintext rustls already has decrypted from a
406            // prior read. Skipping this and always reading more
407            // ciphertext first risks overflowing rustls's plaintext
408            // queue on bursty servers.
409            let n = tls.read_plaintext(dst).map_err(Error::Tls)?;
410            if n > 0 {
411                return Ok(n);
412            }
413            // No buffered plaintext — pull more ciphertext.
414            return match tls.read_tls_from(&mut *self.stream) {
415                Ok(0) => Err(Error::Io(io::Error::new(
416                    io::ErrorKind::UnexpectedEof,
417                    "connection closed during TLS handshake",
418                ))),
419                Ok(_) => tls.read_plaintext(dst).map_err(Error::Tls),
420                Err(TlsError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
421                Err(e) => Err(e.into()),
422            };
423        }
424        match (*self.stream).read(dst) {
425            Ok(n) => Ok(n),
426            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
427            Err(e) => Err(e.into()),
428        }
429    }
430}
431
432impl<S> Drop for Connecting<S> {
433    fn drop(&mut self) {
434        if !self.finished {
435            // finish() was never called — drop the stream manually.
436            // SAFETY: stream hasn't been taken via ManuallyDrop::take.
437            unsafe {
438                std::mem::ManuallyDrop::drop(&mut self.stream);
439            }
440        }
441        // tls is Option — dropped normally by the compiler.
442    }
443}
444
445// Macro to conditionally include TLS variants in matches!()
446#[cfg(feature = "tls")]
447macro_rules! if_tls {
448    ($pat:pat) => {
449        $pat
450    };
451}
452#[cfg(not(feature = "tls"))]
453macro_rules! if_tls {
454    ($pat:pat) => {
455        ConnectState::Done
456    }; // never matches Done twice, but unused
457}
458use if_tls;