nexus_web/ws/connecting.rs
1//! Non-blocking WebSocket connection handshake.
2
3use std::io::{self, Read, Write};
4
5use super::frame::Role;
6use super::frame_reader::{FrameReader, FrameReaderBuilder};
7use super::frame_writer::FrameWriter;
8use super::handshake::{self, HandshakeError};
9use super::stream::{Client, ClientBuilder, Error, parse_ws_url};
10use nexus_net::buf::WriteBuf;
11
12#[cfg(feature = "tls")]
13use nexus_net::tls::{TlsCodec, TlsError};
14
15/// A WebSocket connection in the handshake phase.
16///
17/// Drive the handshake by calling [`poll()`](Self::poll) when the socket
18/// is ready. Returns [`Client<S>`] when complete.
19///
20/// Check [`wants_read()`](Self::wants_read) / [`wants_write()`](Self::wants_write)
21/// to determine which readiness event to wait for in your event loop.
22///
23/// # Usage
24///
25/// ```ignore
26/// use nexus_web::ws::{Connecting, ClientBuilder};
27///
28/// let tcp = TcpStream::connect("exchange.com:443")?;
29/// tcp.set_nonblocking(true)?;
30/// let mut connecting = ClientBuilder::new()
31/// .begin_connect(tcp, "wss://exchange.com/ws")?;
32///
33/// // In your event loop:
34/// loop {
35/// // ... poll for socket readiness ...
36/// if let Some(ws) = connecting.poll()? {
37/// // Handshake complete — ws.recv() is now available
38/// break;
39/// }
40/// }
41/// ```
42pub struct Connecting<S> {
43 // ManuallyDrop: ownership transferred to Client in finish().
44 // Drop impl handles cleanup if finish() is never called (error path).
45 stream: std::mem::ManuallyDrop<S>,
46 state: ConnectState,
47 #[cfg(feature = "tls")]
48 tls: Option<TlsCodec>,
49 reader_builder: FrameReaderBuilder,
50 write_buf_capacity: usize,
51 write_buf_headroom: usize,
52 // Handshake data
53 ws_key: [u8; 24],
54 req_buf: Vec<u8>,
55 req_offset: usize,
56 resp_reader: crate::http::ResponseReader,
57 host: String,
58 path: String,
59 finished: bool, // true after finish() called — suppress Drop
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63enum ConnectState {
64 /// TLS handshake: need to write.
65 #[cfg(feature = "tls")]
66 TlsWrite,
67 /// TLS handshake: need to read.
68 #[cfg(feature = "tls")]
69 TlsRead,
70 /// Sending HTTP upgrade request.
71 HttpSend,
72 /// Reading HTTP upgrade response.
73 HttpRecv,
74 /// Handshake complete, ready to transition.
75 Done,
76}
77
78impl ClientBuilder {
79 /// Start a non-blocking connection handshake.
80 ///
81 /// Returns a [`Connecting`] that must be driven to completion
82 /// via [`poll()`](Connecting::poll) before messages can be sent/received.
83 ///
84 /// The caller is responsible for setting the socket to non-blocking
85 /// mode before calling this.
86 pub fn begin_connect<S: Read + Write>(
87 self,
88 stream: S,
89 url: &str,
90 ) -> Result<Connecting<S>, Error> {
91 let parsed = parse_ws_url(url)?;
92
93 #[cfg(feature = "tls")]
94 let tls = if parsed.tls {
95 let config = match self.tls_config {
96 Some(c) => c,
97 None => nexus_net::tls::TlsConfig::new().map_err(Error::Tls)?,
98 };
99 Some(TlsCodec::new(&config, parsed.host)?)
100 } else {
101 None
102 };
103
104 #[cfg(not(feature = "tls"))]
105 if parsed.tls {
106 return Err(Error::TlsNotEnabled);
107 }
108
109 let ws_key = handshake::generate_key();
110
111 #[cfg(feature = "tls")]
112 let initial_state = if tls.is_some() {
113 ConnectState::TlsWrite
114 } else {
115 ConnectState::HttpSend
116 };
117
118 #[cfg(not(feature = "tls"))]
119 let initial_state = ConnectState::HttpSend;
120
121 let mut connecting = Connecting {
122 stream: std::mem::ManuallyDrop::new(stream),
123 state: initial_state,
124 #[cfg(feature = "tls")]
125 tls,
126 reader_builder: self.reader_builder,
127 write_buf_capacity: self.write_buf_capacity,
128 write_buf_headroom: self.write_buf_headroom,
129 ws_key,
130 req_buf: Vec::new(),
131 req_offset: 0,
132 resp_reader: crate::http::ResponseReader::new(4096),
133 host: parsed.host.to_owned(),
134 path: parsed.path.to_owned(),
135 finished: false,
136 };
137
138 // Build the HTTP upgrade request for ws:// (no TLS step)
139 if matches!(initial_state, ConnectState::HttpSend) {
140 let path = connecting.path.clone();
141 connecting.prepare_http_request(&path);
142 }
143
144 Ok(connecting)
145 }
146}
147
148impl<S: Read + Write> Connecting<S> {
149 /// Drive the handshake forward. Non-blocking.
150 ///
151 /// Returns `Ok(None)` while in progress, `Ok(Some(ws))` when the
152 /// connection is ready and [`recv()`](Client::recv) can be called.
153 ///
154 /// Call when the socket is readable or writable (check
155 /// [`wants_read()`](Self::wants_read) / [`wants_write()`](Self::wants_write)).
156 ///
157 /// On `WouldBlock`, returns `Ok(None)` — call again when the socket
158 /// is ready.
159 pub fn poll(&mut self) -> Result<Option<Client<S>>, Error> {
160 loop {
161 match self.state {
162 #[cfg(feature = "tls")]
163 ConnectState::TlsWrite => {
164 let tls = self
165 .tls
166 .as_mut()
167 .expect("TLS codec must exist in TLS handshake state");
168 match tls.write_tls_to(&mut *self.stream) {
169 Ok(_) => {}
170 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
171 Err(e) => return Err(e.into()),
172 }
173 if tls.is_handshaking() {
174 self.state = ConnectState::TlsRead;
175 } else {
176 self.state = ConnectState::HttpSend;
177 let path = self.path.clone();
178 self.prepare_http_request(&path);
179 }
180 }
181 #[cfg(feature = "tls")]
182 ConnectState::TlsRead => {
183 let tls = self
184 .tls
185 .as_mut()
186 .expect("TLS codec must exist in TLS handshake state");
187 match tls.read_tls_from(&mut *self.stream) {
188 Ok(0) => {
189 // Peer closed mid-TLS-handshake — not a
190 // malformed-HTTP condition (we haven't sent
191 // the HTTP upgrade yet).
192 return Err(Error::Io(io::Error::new(
193 io::ErrorKind::UnexpectedEof,
194 "connection closed during TLS handshake",
195 )));
196 }
197 Ok(_) => {}
198 Err(TlsError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => {
199 return Ok(None);
200 }
201 Err(e) => return Err(e.into()),
202 }
203 if tls.wants_write() {
204 self.state = ConnectState::TlsWrite;
205 } else if !tls.is_handshaking() {
206 self.state = ConnectState::HttpSend;
207 let path = self.path.clone();
208 self.prepare_http_request(&path);
209 }
210 }
211 ConnectState::HttpSend => {
212 if self.req_offset >= self.req_buf.len() {
213 self.state = ConnectState::HttpRecv;
214 return Ok(None);
215 }
216
217 #[cfg(feature = "tls")]
218 if let Some(tls) = &mut self.tls {
219 // TLS path: feed plaintext chunks until the
220 // request is consumed. The HTTP upgrade is
221 // small (always under rustls's 64 KiB plaintext
222 // queue cap) so a single `encrypt` typically
223 // accepts everything; the loop guards against
224 // partial acceptance defensively.
225 while self.req_offset < self.req_buf.len() {
226 let data = &self.req_buf[self.req_offset..];
227 let n = tls.encrypt(data)?;
228 if n == 0 {
229 break; // queue full; drain ciphertext below
230 }
231 self.req_offset += n;
232 }
233 // Flush whatever ciphertext we can
234 match tls.write_tls_to(&mut *self.stream) {
235 Ok(_) => {}
236 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
237 Err(e) => return Err(e.into()),
238 }
239 // If TLS still has buffered ciphertext, come back later
240 if tls.wants_write() {
241 return Ok(None);
242 }
243 self.state = ConnectState::HttpRecv;
244 return Ok(None);
245 }
246
247 // Plain WS path: write plaintext directly
248 {
249 let data = &self.req_buf[self.req_offset..];
250 let n = match (*self.stream).write(data) {
251 Ok(n) => n,
252 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
253 Err(e) => return Err(e.into()),
254 };
255 if n == 0 {
256 return Err(Error::Io(io::Error::new(
257 io::ErrorKind::WriteZero,
258 "write returned 0 during handshake",
259 )));
260 }
261 self.req_offset += n;
262 if self.req_offset >= self.req_buf.len() {
263 self.state = ConnectState::HttpRecv;
264 }
265 }
266 return Ok(None);
267 }
268 ConnectState::HttpRecv => {
269 let mut tmp = [0u8; 4096];
270 let n = self.read_bytes(&mut tmp)?;
271 if n == 0 {
272 return Ok(None);
273 }
274
275 self.resp_reader
276 .read(&tmp[..n])
277 .map_err(|_| HandshakeError::MalformedHttp)?;
278
279 // Check if we have a complete response.
280 // validate_upgrade borrows self immutably, so we
281 // can't call it while resp_reader is mutably borrowed.
282 // next() consumes the response, so we validate inline.
283 match self.resp_reader.next() {
284 Ok(Some(resp)) => {
285 if resp.status != 101 {
286 return Err(HandshakeError::UnexpectedStatus(resp.status).into());
287 }
288 let upgrade = resp
289 .header("Upgrade")
290 .ok_or(HandshakeError::MissingUpgrade)?;
291 if !upgrade.eq_ignore_ascii_case("websocket") {
292 return Err(HandshakeError::MissingUpgrade.into());
293 }
294 let conn = resp
295 .header("Connection")
296 .ok_or(HandshakeError::MissingConnection)?;
297 if !conn
298 .as_bytes()
299 .windows(7)
300 .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
301 {
302 return Err(HandshakeError::MissingConnection.into());
303 }
304 let key_str = std::str::from_utf8(&self.ws_key)
305 .expect("base64 output is valid ASCII");
306 let accept = resp
307 .header("Sec-WebSocket-Accept")
308 .ok_or(HandshakeError::InvalidAcceptKey)?;
309 if !handshake::validate_accept(key_str, accept) {
310 return Err(HandshakeError::InvalidAcceptKey.into());
311 }
312 self.state = ConnectState::Done;
313 // Fall through to Done
314 }
315 Ok(None) => return Ok(None),
316 Err(_) => return Err(HandshakeError::MalformedHttp.into()),
317 }
318 }
319 ConnectState::Done => {
320 return Ok(Some(self.finish()?));
321 }
322 }
323 }
324 }
325
326 /// Whether the handshake needs to write to the socket.
327 pub fn wants_write(&self) -> bool {
328 matches!(
329 self.state,
330 ConnectState::HttpSend | if_tls!(ConnectState::TlsWrite)
331 )
332 }
333
334 /// Whether the handshake needs to read from the socket.
335 pub fn wants_read(&self) -> bool {
336 matches!(
337 self.state,
338 ConnectState::HttpRecv | if_tls!(ConnectState::TlsRead)
339 )
340 }
341
342 /// Access the underlying stream (for mio registration).
343 pub fn stream(&self) -> &S {
344 &self.stream
345 }
346
347 /// Mutable access to the underlying stream.
348 pub fn stream_mut(&mut self) -> &mut S {
349 &mut self.stream
350 }
351
352 // =========================================================================
353 // Internal
354 // =========================================================================
355
356 fn prepare_http_request(&mut self, path: &str) {
357 let key_str = std::str::from_utf8(&self.ws_key).expect("base64 output is valid ASCII");
358 let headers = [
359 ("Host", self.host.as_str()),
360 ("Upgrade", "websocket"),
361 ("Connection", "Upgrade"),
362 ("Sec-WebSocket-Key", key_str),
363 ("Sec-WebSocket-Version", "13"),
364 ];
365 let size = crate::http::request_size("GET", path, &headers);
366 let mut buf = vec![0u8; size];
367 // unwrap is safe: buffer is exactly the right size
368 let n = crate::http::write_request("GET", path, &headers, &mut buf)
369 .expect("request fits in handshake buffer");
370 self.req_buf = buf[..n].to_vec();
371 self.req_offset = 0;
372 }
373
374 fn finish(&mut self) -> Result<Client<S>, Error> {
375 self.finished = true;
376
377 let reader_builder = std::mem::replace(&mut self.reader_builder, FrameReader::builder());
378 let mut reader = reader_builder.role(Role::Client).build();
379 let remainder = self.resp_reader.remainder();
380 if !remainder.is_empty() {
381 reader
382 .read(remainder)
383 .map_err(|_| Error::Handshake(HandshakeError::MalformedHttp))?;
384 }
385
386 // SAFETY: stream is ManuallyDrop. We take ownership here.
387 // The `finished` flag prevents Drop from dropping it again.
388 // finish() is only called once (state == Done).
389 let stream = unsafe { std::mem::ManuallyDrop::take(&mut self.stream) };
390
391 Ok(Client::from_parts_internal(
392 stream,
393 reader,
394 FrameWriter::new(Role::Client),
395 WriteBuf::new(self.write_buf_capacity, self.write_buf_headroom),
396 ))
397 }
398
399 /// Read bytes through TLS or direct.
400 /// Returns Ok(n) for data, Err(WouldBlock) for non-blocking no-data,
401 /// Err(UnexpectedEof) for connection closed during handshake.
402 fn read_bytes(&mut self, dst: &mut [u8]) -> Result<usize, Error> {
403 #[cfg(feature = "tls")]
404 if let Some(tls) = &mut self.tls {
405 // Drain any plaintext rustls already has decrypted from a
406 // prior read. Skipping this and always reading more
407 // ciphertext first risks overflowing rustls's plaintext
408 // queue on bursty servers.
409 let n = tls.read_plaintext(dst).map_err(Error::Tls)?;
410 if n > 0 {
411 return Ok(n);
412 }
413 // No buffered plaintext — pull more ciphertext.
414 return match tls.read_tls_from(&mut *self.stream) {
415 Ok(0) => Err(Error::Io(io::Error::new(
416 io::ErrorKind::UnexpectedEof,
417 "connection closed during TLS handshake",
418 ))),
419 Ok(_) => tls.read_plaintext(dst).map_err(Error::Tls),
420 Err(TlsError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
421 Err(e) => Err(e.into()),
422 };
423 }
424 match (*self.stream).read(dst) {
425 Ok(n) => Ok(n),
426 Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
427 Err(e) => Err(e.into()),
428 }
429 }
430}
431
432impl<S> Drop for Connecting<S> {
433 fn drop(&mut self) {
434 if !self.finished {
435 // finish() was never called — drop the stream manually.
436 // SAFETY: stream hasn't been taken via ManuallyDrop::take.
437 unsafe {
438 std::mem::ManuallyDrop::drop(&mut self.stream);
439 }
440 }
441 // tls is Option — dropped normally by the compiler.
442 }
443}
444
445// Macro to conditionally include TLS variants in matches!()
446#[cfg(feature = "tls")]
447macro_rules! if_tls {
448 ($pat:pat) => {
449 $pat
450 };
451}
452#[cfg(not(feature = "tls"))]
453macro_rules! if_tls {
454 ($pat:pat) => {
455 ConnectState::Done
456 }; // never matches Done twice, but unused
457}
458use if_tls;