Skip to main content

nexus_web/ws/
stream.rs

1//! WebSocket stream — I/O wrapper with HTTP upgrade handshake.
2
3use std::time::Duration;
4
5use super::error::ProtocolError;
6use super::frame::Role;
7use super::frame_reader::{FrameReader, FrameReaderBuilder};
8use super::frame_writer::FrameWriter;
9use super::message::{CloseCode, Message};
10use nexus_net::buf::WriteBuf;
11
12use super::handshake;
13use super::handshake::HandshakeError;
14use std::io::{self, Read, Write};
15
16#[cfg(feature = "tls")]
17use nexus_net::tls::{TlsConfig, TlsError};
18
19// =============================================================================
20// URL parsing
21// =============================================================================
22
23/// Parsed WebSocket URL.
24#[non_exhaustive]
25pub struct ParsedUrl<'a> {
26    /// Whether the URL is `wss://` (true) or `ws://` (false).
27    pub tls: bool,
28    /// Host portion (no port).
29    pub host: &'a str,
30    /// Port — explicit if present, otherwise the scheme default
31    /// (80 for ws, 443 for wss).
32    pub port: u16,
33    /// Path portion (everything after the host:port, including the
34    /// leading `/`). Defaults to `/` when absent in the input URL.
35    pub path: &'a str,
36}
37
38impl ParsedUrl<'_> {
39    /// Host header value: includes port if non-default.
40    pub fn host_header(&self) -> String {
41        let default = if self.tls { 443 } else { 80 };
42        if self.port == default {
43            self.host.to_string()
44        } else {
45            format!("{}:{}", self.host, self.port)
46        }
47    }
48}
49
50/// Parse a `ws://` or `wss://` URL into its scheme, host, port, and
51/// path. Supports IPv6 bracket notation (`[::1]:8080`). Returns
52/// [`Error::InvalidUrl`] on a malformed input or missing scheme.
53pub fn parse_ws_url(url: &str) -> Result<ParsedUrl<'_>, Error> {
54    let (tls, rest) = if let Some(r) = url.strip_prefix("wss://") {
55        (true, r)
56    } else if let Some(r) = url.strip_prefix("ws://") {
57        (false, r)
58    } else {
59        return Err(Error::InvalidUrl(url.to_string()));
60    };
61
62    let (host_port, path) = rest
63        .find('/')
64        .map_or((rest, "/"), |i| (&rest[..i], &rest[i..]));
65
66    if host_port.is_empty() {
67        return Err(Error::InvalidUrl(format!("empty host: {url}")));
68    }
69
70    let default_port = if tls { 443 } else { 80 };
71
72    // IPv6 bracket notation: [::1]:8080
73    let (host, port) = if host_port.starts_with('[') {
74        match host_port.find(']') {
75            Some(end) => {
76                let h = &host_port[1..end];
77                let rest = &host_port[end + 1..];
78                if let Some(port_str) = rest.strip_prefix(':') {
79                    let p = port_str
80                        .parse::<u16>()
81                        .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
82                    (h, p)
83                } else {
84                    (h, default_port)
85                }
86            }
87            None => return Err(Error::InvalidUrl(format!("unclosed bracket: {url}"))),
88        }
89    } else {
90        match host_port.rfind(':') {
91            None => (host_port, default_port),
92            Some(i) => {
93                let port_str = &host_port[i + 1..];
94                if port_str.is_empty() {
95                    (&host_port[..i], default_port)
96                } else {
97                    let p = port_str
98                        .parse::<u16>()
99                        .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
100                    (&host_port[..i], p)
101                }
102            }
103        }
104    };
105
106    Ok(ParsedUrl {
107        tls,
108        host,
109        port,
110        path,
111    })
112}
113
114// =============================================================================
115// Error
116// =============================================================================
117
118/// Unified error type for WebSocket stream operations.
119#[derive(Debug)]
120pub enum Error {
121    /// I/O error from the underlying stream.
122    Io(std::io::Error),
123    /// WebSocket protocol error.
124    Protocol(ProtocolError),
125    /// Encoding error (e.g., control frame payload too large).
126    Encode(super::frame_writer::EncodeError),
127    /// HTTP handshake failed.
128    Handshake(HandshakeError),
129    /// TLS error during connection setup (handshake, certificate
130    /// validation, SNI hostname verification).
131    ///
132    /// **Steady-state TLS protocol errors** (decrypt failure, peer
133    /// alert, malformed record received during a frame) on the async
134    /// `nexus-async-web` paths surface as [`Error::Io`](Self::Io)
135    /// instead — the underlying [`TlsError`](nexus_net::tls::TlsError) is
136    /// wrapped via `io::Error::other` and reachable via
137    /// `io_err.source()` or `io_err.get_ref()`. This asymmetry stems
138    /// from the [`WireStream`](crate::WireStream) trait returning
139    /// `io::Result` for poll methods. Sync WS surfaces `Tls` directly
140    /// because its `TlsStream` exposes `TlsError` natively. Pattern-
141    /// match on both `Io` and `Tls` if you need to distinguish TLS-
142    /// protocol failures from generic transport failures across both
143    /// surfaces.
144    #[cfg(feature = "tls")]
145    Tls(TlsError),
146    /// Invalid WebSocket URL.
147    InvalidUrl(String),
148    /// `wss://` URL used without the `tls` feature enabled.
149    TlsNotEnabled,
150}
151
152impl std::fmt::Display for Error {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            Self::Io(e) => write!(f, "I/O error: {e}"),
156            Self::Protocol(e) => write!(f, "protocol error: {e}"),
157            Self::Encode(e) => write!(f, "encode error: {e}"),
158            Self::Handshake(e) => write!(f, "handshake error: {e}"),
159            #[cfg(feature = "tls")]
160            Self::Tls(e) => write!(f, "TLS error: {e}"),
161            Self::InvalidUrl(u) => write!(f, "invalid WebSocket URL: {u}"),
162            Self::TlsNotEnabled => write!(f, "wss:// requires the 'tls' feature"),
163        }
164    }
165}
166
167impl std::error::Error for Error {}
168
169impl From<std::io::Error> for Error {
170    fn from(e: std::io::Error) -> Self {
171        Self::Io(e)
172    }
173}
174impl From<ProtocolError> for Error {
175    fn from(e: ProtocolError) -> Self {
176        Self::Protocol(e)
177    }
178}
179impl From<super::frame_writer::EncodeError> for Error {
180    fn from(e: super::frame_writer::EncodeError) -> Self {
181        Self::Encode(e)
182    }
183}
184impl From<HandshakeError> for Error {
185    fn from(e: HandshakeError) -> Self {
186        Self::Handshake(e)
187    }
188}
189#[cfg(feature = "tls")]
190impl From<TlsError> for Error {
191    fn from(e: TlsError) -> Self {
192        match e {
193            TlsError::Io(io) => Self::Io(io),
194            other => Self::Tls(other),
195        }
196    }
197}
198
199// =============================================================================
200// ClientBuilder
201// =============================================================================
202
203/// Builder for [`Client`].
204///
205/// Configures buffer sizes, socket options, and optional TLS.
206///
207/// # Examples
208///
209/// ```ignore
210/// let mut ws = Client::builder()
211///     .disable_nagle()
212///     .buffer_capacity(2 * 1024 * 1024)
213///     .connect("wss://exchange.com/ws")?;
214/// ```
215pub struct ClientBuilder {
216    pub(crate) reader_builder: FrameReaderBuilder,
217    pub(crate) write_buf_capacity: usize,
218    pub(crate) write_buf_headroom: usize,
219    #[cfg(feature = "tls")]
220    pub(crate) tls_config: Option<TlsConfig>,
221    pub(crate) tcp_nodelay: bool,
222    #[cfg(feature = "socket-opts")]
223    pub(crate) recv_buf_size: Option<usize>,
224    #[cfg(feature = "socket-opts")]
225    pub(crate) send_buf_size: Option<usize>,
226    pub(crate) connect_timeout: Option<Duration>,
227    pub(crate) read_timeout: Option<Duration>,
228}
229
230impl ClientBuilder {
231    /// Create a new builder with defaults.
232    #[must_use]
233    pub fn new() -> Self {
234        Self {
235            reader_builder: FrameReader::builder(),
236            write_buf_capacity: 65_536,
237            write_buf_headroom: 14,
238            #[cfg(feature = "tls")]
239            tls_config: None,
240            tcp_nodelay: false,
241            #[cfg(feature = "socket-opts")]
242            recv_buf_size: None,
243            #[cfg(feature = "socket-opts")]
244            send_buf_size: None,
245            connect_timeout: None,
246            read_timeout: None,
247        }
248    }
249
250    /// ReadBuf capacity. Default: 1MB.
251    #[must_use]
252    pub fn buffer_capacity(mut self, n: usize) -> Self {
253        self.reader_builder = self.reader_builder.buffer_capacity(n);
254        self
255    }
256
257    /// Maximum single frame payload. Default: 16MB.
258    #[must_use]
259    pub fn max_frame_size(mut self, n: u64) -> Self {
260        self.reader_builder = self.reader_builder.max_frame_size(n);
261        self
262    }
263
264    /// Maximum assembled message size. Default: 16MB.
265    #[must_use]
266    pub fn max_message_size(mut self, n: usize) -> Self {
267        self.reader_builder = self.reader_builder.max_message_size(n);
268        self
269    }
270
271    /// Write buffer capacity. Default: 64KB.
272    #[must_use]
273    pub fn write_buffer_capacity(mut self, n: usize) -> Self {
274        self.write_buf_capacity = n;
275        self
276    }
277
278    /// Set `TCP_NODELAY` (disable Nagle's algorithm).
279    #[must_use]
280    pub fn disable_nagle(mut self) -> Self {
281        self.tcp_nodelay = true;
282        self
283    }
284
285    /// Set `SO_RCVBUF` (socket receive buffer size).
286    #[cfg(feature = "socket-opts")]
287    #[must_use]
288    pub fn recv_buffer_size(mut self, n: usize) -> Self {
289        self.recv_buf_size = Some(n);
290        self
291    }
292
293    /// Set `SO_SNDBUF` (socket send buffer size).
294    #[cfg(feature = "socket-opts")]
295    #[must_use]
296    pub fn send_buffer_size(mut self, n: usize) -> Self {
297        self.send_buf_size = Some(n);
298        self
299    }
300
301    /// TCP connect timeout.
302    #[must_use]
303    pub fn connect_timeout(mut self, d: Duration) -> Self {
304        self.connect_timeout = Some(d);
305        self
306    }
307
308    /// Socket read timeout.
309    #[must_use]
310    pub fn read_timeout(mut self, d: Duration) -> Self {
311        self.read_timeout = Some(d);
312        self
313    }
314
315    /// Set a custom TLS configuration.
316    ///
317    /// If not set, `wss://` URLs use [`TlsConfig::new()`] (system defaults).
318    #[cfg(feature = "tls")]
319    #[must_use]
320    pub fn tls(mut self, config: &TlsConfig) -> Self {
321        self.tls_config = Some(config.clone());
322        self
323    }
324
325    /// Connect to a WebSocket server (blocking).
326    ///
327    /// Creates a TCP socket, applies socket options, and performs the
328    /// full handshake (TLS if `wss://`, then HTTP upgrade).
329    ///
330    /// When the `tls` feature is enabled, returns `Client<MaybeTls<TcpStream>>`
331    /// regardless of scheme — `ws://` uses `MaybeTls::Plain`, `wss://` uses
332    /// `MaybeTls::Tls`. Without the `tls` feature, returns `Client<TcpStream>`
333    /// and errors on `wss://`.
334    #[cfg(feature = "tls")]
335    pub fn connect(
336        self,
337        url: &str,
338    ) -> Result<Client<nexus_net::MaybeTls<std::net::TcpStream>>, Error> {
339        let parsed = parse_ws_url(url)?;
340        let addr = format!("{}:{}", parsed.host, parsed.port);
341
342        let tcp = match self.connect_timeout {
343            Some(timeout) => {
344                let addrs: Vec<std::net::SocketAddr> =
345                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
346                        .map_err(Error::Io)?
347                        .collect();
348                let first = addrs
349                    .first()
350                    .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
351                std::net::TcpStream::connect_timeout(first, timeout)?
352            }
353            None => std::net::TcpStream::connect(&addr)?,
354        };
355
356        self.apply_socket_opts(&tcp)?;
357
358        let stream = if parsed.tls {
359            let config = match self.tls_config {
360                Some(c) => c,
361                None => TlsConfig::new().map_err(Error::Tls)?,
362            };
363            let codec = nexus_net::tls::TlsCodec::new(&config, parsed.host)?;
364            let tls = nexus_net::tls::TlsStream::connect(tcp, codec).map_err(Error::Tls)?;
365            nexus_net::MaybeTls::Tls(Box::new(tls))
366        } else {
367            nexus_net::MaybeTls::Plain(tcp)
368        };
369
370        let host_header = parsed.host_header();
371        Client::connect_impl(
372            stream,
373            &host_header,
374            parsed.path,
375            self.reader_builder,
376            self.write_buf_capacity,
377            self.write_buf_headroom,
378        )
379    }
380
381    /// Connect to a WebSocket server (blocking, no TLS feature).
382    #[cfg(not(feature = "tls"))]
383    pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, Error> {
384        let parsed = parse_ws_url(url)?;
385        if parsed.tls {
386            return Err(Error::TlsNotEnabled);
387        }
388        let addr = format!("{}:{}", parsed.host, parsed.port);
389
390        let tcp = match self.connect_timeout {
391            Some(timeout) => {
392                let addrs: Vec<std::net::SocketAddr> =
393                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
394                        .map_err(Error::Io)?
395                        .collect();
396                let first = addrs
397                    .first()
398                    .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
399                std::net::TcpStream::connect_timeout(first, timeout)?
400            }
401            None => std::net::TcpStream::connect(&addr)?,
402        };
403
404        self.apply_socket_opts(&tcp)?;
405
406        let host_header = parsed.host_header();
407        Client::connect_impl(
408            tcp,
409            &host_header,
410            parsed.path,
411            self.reader_builder,
412            self.write_buf_capacity,
413            self.write_buf_headroom,
414        )
415    }
416
417    /// Connect using a pre-connected stream.
418    ///
419    /// The stream must already handle TLS if connecting to `wss://`.
420    /// For example, pass a `TlsStream<TcpStream>` or `MaybeTls<TcpStream>`.
421    /// This method only performs the HTTP upgrade handshake.
422    pub fn connect_with<S: Read + Write>(self, stream: S, url: &str) -> Result<Client<S>, Error> {
423        let parsed = parse_ws_url(url)?;
424        let host_header = parsed.host_header();
425        Client::connect_impl(
426            stream,
427            &host_header,
428            parsed.path,
429            self.reader_builder,
430            self.write_buf_capacity,
431            self.write_buf_headroom,
432        )
433    }
434
435    /// Accept an incoming WebSocket connection (server-side).
436    pub fn accept<S: Read + Write>(self, stream: S) -> Result<Client<S>, Error> {
437        Client::accept_impl(
438            stream,
439            self.reader_builder,
440            self.write_buf_capacity,
441            self.write_buf_headroom,
442        )
443    }
444
445    fn apply_socket_opts(&self, tcp: &std::net::TcpStream) -> Result<(), Error> {
446        if self.tcp_nodelay {
447            tcp.set_nodelay(true)?;
448        }
449        if let Some(timeout) = self.read_timeout {
450            tcp.set_read_timeout(Some(timeout))?;
451        }
452        #[cfg(feature = "socket-opts")]
453        {
454            let sock = socket2::SockRef::from(tcp);
455            if let Some(size) = self.recv_buf_size {
456                sock.set_recv_buffer_size(size)?;
457            }
458            if let Some(size) = self.send_buf_size {
459                sock.set_send_buffer_size(size)?;
460            }
461        }
462        Ok(())
463    }
464}
465
466impl Default for ClientBuilder {
467    fn default() -> Self {
468        Self::new()
469    }
470}
471
472// =============================================================================
473// Client
474// =============================================================================
475
476/// WebSocket stream — owns a socket, reader, writer, and buffers.
477///
478/// Handles both plain `ws://` and encrypted `wss://` connections.
479/// The URL scheme determines whether TLS is used — no separate type needed.
480///
481/// # Usage
482///
483/// ```ignore
484/// use nexus_web::ws::Client;
485/// use nexus_web::tls::TlsConfig;
486///
487/// // Plain WebSocket
488/// let mut ws = Client::builder().connect("ws://localhost:8080/ws")?;
489///
490/// // TLS WebSocket (requires 'tls' feature)
491/// let tls = TlsConfig::new()?;
492/// let mut ws = Client::builder().tls(&tls).connect("wss://exchange.com/ws")?;
493///
494/// // Same API for both:
495/// ws.send_text("Hello!")?;
496/// while let Some(msg) = ws.recv()? {
497///     // ...
498/// }
499/// ```
500pub struct Client<S> {
501    pub(crate) stream: S,
502    pub(crate) reader: FrameReader,
503    pub(crate) writer: FrameWriter,
504    pub(crate) write_buf: WriteBuf,
505    pub(crate) poisoned: bool,
506}
507
508impl Client<std::net::TcpStream> {
509    /// Create a builder for configuring buffer sizes, socket options, and TLS.
510    #[must_use]
511    pub fn builder() -> ClientBuilder {
512        ClientBuilder::new()
513    }
514}
515
516// -- Unbounded impl: accessors and constructors that need no I/O traits -------
517
518impl<S> Client<S> {
519    /// Create from pre-existing parts. For testing or custom handshakes.
520    pub fn from_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
521        Self {
522            stream,
523            reader,
524            writer,
525            write_buf: WriteBuf::new(65_536, 14),
526            poisoned: false,
527        }
528    }
529
530    /// Internal constructor with all fields. Used by Connecting::finish().
531    pub(crate) fn from_parts_internal(
532        stream: S,
533        reader: FrameReader,
534        writer: FrameWriter,
535        write_buf: WriteBuf,
536    ) -> Self {
537        Self {
538            stream,
539            reader,
540            writer,
541            write_buf,
542            poisoned: false,
543        }
544    }
545
546    /// Whether the stream is poisoned (I/O error occurred during send).
547    ///
548    /// A poisoned stream should not be reused — the connection may be
549    /// in an indeterminate state (partial frame written).
550    pub fn is_poisoned(&self) -> bool {
551        self.poisoned
552    }
553
554    /// Access the underlying stream.
555    pub fn stream(&self) -> &S {
556        &self.stream
557    }
558
559    /// Mutable access to the underlying stream.
560    pub fn stream_mut(&mut self) -> &mut S {
561        &mut self.stream
562    }
563
564    /// Access the FrameReader.
565    pub fn reader(&self) -> &FrameReader {
566        &self.reader
567    }
568
569    /// Access the FrameWriter.
570    pub fn frame_writer(&self) -> &FrameWriter {
571        &self.writer
572    }
573}
574
575// -- Blocking I/O impl --------------------------------------------------------
576
577impl<S: Read + Write> Client<S> {
578    /// Connect using a pre-connected socket with default configuration.
579    ///
580    /// IPv6 addresses must use bracket notation: `ws://[::1]:8080/path`.
581    pub fn connect_with(stream: S, url: &str) -> Result<Self, Error> {
582        ClientBuilder::new().connect_with(stream, url)
583    }
584
585    /// Accept an incoming WebSocket connection (server-side).
586    pub fn accept(stream: S) -> Result<Self, Error> {
587        ClientBuilder::new().accept(stream)
588    }
589
590    /// Receive the next message. Reads from the socket as needed.
591    ///
592    /// Returns `Ok(None)` on EOF, buffer full, or `WouldBlock` (non-blocking sockets).
593    pub fn recv(&mut self) -> Result<Option<Message<'_>>, Error> {
594        loop {
595            if self.reader.poll()? {
596                return Ok(self.reader.next()?);
597            }
598            match self.read_into_reader() {
599                Ok(0) => return Ok(None),
600                Ok(_) => {}
601                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
602                Err(e) => return Err(Error::Io(e)),
603            }
604        }
605    }
606
607    /// Send a text message.
608    pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
609        self.writer
610            .encode_text_into(text.as_bytes(), &mut self.write_buf);
611        self.flush_write_buf_or_poison()
612    }
613
614    /// Send a binary message.
615    pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
616        self.writer.encode_binary_into(data, &mut self.write_buf);
617        self.flush_write_buf_or_poison()
618    }
619
620    /// Send a ping.
621    pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
622        self.writer
623            .encode_ping_into(data, &mut self.write_buf)
624            .map_err(Error::Encode)?;
625        self.flush_write_buf_or_poison()
626    }
627
628    /// Send a pong.
629    pub fn send_pong(&mut self, data: &[u8]) -> Result<(), Error> {
630        self.writer
631            .encode_pong_into(data, &mut self.write_buf)
632            .map_err(Error::Encode)?;
633        self.flush_write_buf_or_poison()
634    }
635
636    /// Initiate close handshake.
637    pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
638        if code == CloseCode::NoStatus {
639            let mut dst = [0u8; 14];
640            let n = self.writer.encode_empty_close(&mut dst);
641            self.write_raw(&dst[..n]).inspect_err(|_| {
642                self.poisoned = true;
643            })
644        } else {
645            self.writer
646                .encode_close_into(code.as_u16(), reason.as_bytes(), &mut self.write_buf)
647                .map_err(Error::Encode)?;
648            self.flush_write_buf_or_poison()
649        }
650    }
651
652    // =========================================================================
653    // Internal — read/write with optional TLS
654    // =========================================================================
655
656    /// Read bytes into the FrameReader.
657    ///
658    /// TLS is now handled at the stream level (`TlsStream<S>` or
659    /// `MaybeTls<S>`), so this always reads plaintext from `S`.
660    fn read_into_reader(&mut self) -> io::Result<usize> {
661        self.reader.read_from(&mut self.stream)
662    }
663
664    /// Flush write_buf, poisoning on I/O error.
665    fn flush_write_buf_or_poison(&mut self) -> Result<(), Error> {
666        self.flush_write_buf().inspect_err(|_| {
667            self.poisoned = true;
668        })
669    }
670
671    /// Flush the write_buf to the socket.
672    fn flush_write_buf(&mut self) -> Result<(), Error> {
673        self.stream.write_all(self.write_buf.data())?;
674        Ok(())
675    }
676
677    /// Write raw bytes to the socket.
678    fn write_raw(&mut self, data: &[u8]) -> Result<(), Error> {
679        self.stream.write_all(data)?;
680        Ok(())
681    }
682
683    // =========================================================================
684    // Internal — handshake
685    // =========================================================================
686
687    /// Perform the HTTP upgrade handshake on a stream that is already
688    /// plaintext-ready (TLS handled at the stream level).
689    pub(crate) fn connect_impl(
690        mut stream: S,
691        host: &str,
692        path: &str,
693        reader_builder: FrameReaderBuilder,
694        write_cap: usize,
695        write_headroom: usize,
696    ) -> Result<Self, Error> {
697        let key = handshake::generate_key();
698        let key_str = std::str::from_utf8(&key).expect("base64 output is valid ASCII");
699
700        let headers = [
701            ("Host", host),
702            ("Upgrade", "websocket"),
703            ("Connection", "Upgrade"),
704            ("Sec-WebSocket-Key", key_str),
705            ("Sec-WebSocket-Version", "13"),
706        ];
707        let req_size = crate::http::request_size("GET", path, &headers);
708        let mut req_buf = vec![0u8; req_size];
709        let n = crate::http::write_request("GET", path, &headers, &mut req_buf)
710            .map_err(|_| HandshakeError::MalformedHttp)?;
711
712        stream.write_all(&req_buf[..n])?;
713
714        let mut resp_reader = crate::http::ResponseReader::new(4096);
715        let mut tmp = [0u8; 4096];
716        loop {
717            let bytes_read = stream.read(&mut tmp)?;
718            if bytes_read == 0 {
719                return Err(HandshakeError::MalformedHttp.into());
720            }
721
722            resp_reader
723                .read(&tmp[..bytes_read])
724                .map_err(|_| HandshakeError::MalformedHttp)?;
725            match resp_reader.next() {
726                Ok(Some(resp)) => {
727                    if resp.status != 101 {
728                        return Err(HandshakeError::UnexpectedStatus(resp.status).into());
729                    }
730                    let upgrade = resp
731                        .header("Upgrade")
732                        .ok_or(HandshakeError::MissingUpgrade)?;
733                    if !upgrade.eq_ignore_ascii_case("websocket") {
734                        return Err(HandshakeError::MissingUpgrade.into());
735                    }
736                    let conn = resp
737                        .header("Connection")
738                        .ok_or(HandshakeError::MissingConnection)?;
739                    if !contains_ignore_case(conn, "upgrade") {
740                        return Err(HandshakeError::MissingConnection.into());
741                    }
742                    let accept = resp
743                        .header("Sec-WebSocket-Accept")
744                        .ok_or(HandshakeError::InvalidAcceptKey)?;
745                    if !handshake::validate_accept(key_str, accept) {
746                        return Err(HandshakeError::InvalidAcceptKey.into());
747                    }
748
749                    let mut reader = reader_builder.role(Role::Client).build();
750                    let remainder = resp_reader.remainder();
751                    if !remainder.is_empty() {
752                        reader
753                            .read(remainder)
754                            .map_err(|_| HandshakeError::MalformedHttp)?;
755                    }
756
757                    return Ok(Self {
758                        stream,
759                        reader,
760                        writer: FrameWriter::new(Role::Client),
761                        write_buf: WriteBuf::new(write_cap, write_headroom),
762                        poisoned: false,
763                    });
764                }
765                Ok(None) => {} // need more bytes
766                Err(_) => return Err(HandshakeError::MalformedHttp.into()),
767            }
768        }
769    }
770
771    fn accept_impl(
772        mut stream: S,
773        reader_builder: FrameReaderBuilder,
774        write_cap: usize,
775        write_headroom: usize,
776    ) -> Result<Self, Error> {
777        let mut req_reader = crate::http::RequestReader::new(4096);
778        let mut tmp = [0u8; 4096];
779
780        let ws_key;
781        loop {
782            let n = stream.read(&mut tmp)?;
783            if n == 0 {
784                return Err(HandshakeError::MalformedHttp.into());
785            }
786            req_reader
787                .read(&tmp[..n])
788                .map_err(|_| HandshakeError::MalformedHttp)?;
789            match req_reader.next() {
790                Ok(Some(req)) => {
791                    if req.method != "GET" {
792                        return Err(HandshakeError::MalformedHttp.into());
793                    }
794                    let upgrade = req
795                        .header("Upgrade")
796                        .ok_or(HandshakeError::MissingUpgrade)?;
797                    if !upgrade.eq_ignore_ascii_case("websocket") {
798                        return Err(HandshakeError::MissingUpgrade.into());
799                    }
800                    let conn = req
801                        .header("Connection")
802                        .ok_or(HandshakeError::MissingConnection)?;
803                    if !contains_ignore_case(conn, "upgrade") {
804                        return Err(HandshakeError::MissingConnection.into());
805                    }
806                    let version = req
807                        .header("Sec-WebSocket-Version")
808                        .ok_or(HandshakeError::UnsupportedVersion)?;
809                    if version != "13" {
810                        return Err(HandshakeError::UnsupportedVersion.into());
811                    }
812                    let key = req
813                        .header("Sec-WebSocket-Key")
814                        .ok_or(HandshakeError::MissingKey)?;
815                    ws_key = key.to_owned();
816                    break;
817                }
818                Ok(None) => {}
819                Err(_) => return Err(HandshakeError::MalformedHttp.into()),
820            }
821        }
822
823        let accept = handshake::compute_accept_key(&ws_key);
824        let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
825
826        let resp_headers = [
827            ("Upgrade", "websocket"),
828            ("Connection", "Upgrade"),
829            ("Sec-WebSocket-Accept", accept_str),
830        ];
831        let resp_size = crate::http::response_size("Switching Protocols", &resp_headers);
832        let mut resp_buf = vec![0u8; resp_size];
833        let n =
834            crate::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
835                .map_err(|_| HandshakeError::MalformedHttp)?;
836        stream.write_all(&resp_buf[..n])?;
837
838        let mut reader = reader_builder.role(Role::Server).build();
839        let remainder = req_reader.remainder();
840        if !remainder.is_empty() {
841            reader
842                .read(remainder)
843                .map_err(|_| HandshakeError::MalformedHttp)?;
844        }
845
846        Ok(Self {
847            stream,
848            reader,
849            writer: FrameWriter::new(Role::Server),
850            write_buf: WriteBuf::new(write_cap, write_headroom),
851            poisoned: false,
852        })
853    }
854}
855
856/// Create a matched FrameReader + FrameWriter pair.
857///
858/// Prevents mismatched roles between reader and writer.
859pub fn pair(role: Role) -> (FrameReader, FrameWriter) {
860    (
861        FrameReader::builder().role(role).build(),
862        FrameWriter::new(role),
863    )
864}
865
866/// Create a pair with a configured FrameReader.
867pub fn pair_with(role: Role, reader_builder: FrameReaderBuilder) -> (FrameReader, FrameWriter) {
868    (reader_builder.role(role).build(), FrameWriter::new(role))
869}
870
871fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
872    haystack
873        .as_bytes()
874        .windows(needle.len())
875        .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
876}
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881
882    // =========================================================================
883    // URL parsing
884    // =========================================================================
885
886    #[test]
887    fn parse_ws_url_plain() {
888        let p = parse_ws_url("ws://localhost:8080/ws").unwrap();
889        assert!(!p.tls);
890        assert_eq!(p.host, "localhost");
891        assert_eq!(p.port, 8080);
892        assert_eq!(p.path, "/ws");
893    }
894
895    #[test]
896    fn parse_ws_url_tls() {
897        let p = parse_ws_url("wss://exchange.com/ws/v1").unwrap();
898        assert!(p.tls);
899        assert_eq!(p.host, "exchange.com");
900        assert_eq!(p.port, 443);
901        assert_eq!(p.path, "/ws/v1");
902    }
903
904    #[test]
905    fn parse_ws_url_default_port() {
906        let p = parse_ws_url("ws://host/path").unwrap();
907        assert_eq!(p.port, 80);
908
909        let p = parse_ws_url("wss://host/path").unwrap();
910        assert_eq!(p.port, 443);
911    }
912
913    #[test]
914    fn parse_ws_url_no_path() {
915        let p = parse_ws_url("ws://host").unwrap();
916        assert_eq!(p.path, "/");
917    }
918
919    #[test]
920    fn parse_ws_url_invalid_scheme() {
921        assert!(parse_ws_url("http://host").is_err());
922        assert!(parse_ws_url("host/path").is_err());
923    }
924
925    // =========================================================================
926    // Blocking Client tests
927    // =========================================================================
928
929    mod sync_tests {
930        use super::*;
931        use std::io::{self, Read, Write};
932
933        #[test]
934        fn pair_creates_matching_roles() {
935            let (mut reader, _writer) = pair(Role::Client);
936            let frame = make_frame(true, 0x1, b"test");
937            reader.read(&frame).unwrap();
938            let msg = reader.next().unwrap().unwrap();
939            assert!(matches!(msg, Message::Text(s) if s == "test"));
940        }
941
942        struct ByteAtATimeStream {
943            data: Vec<u8>,
944            pos: usize,
945        }
946
947        impl ByteAtATimeStream {
948            fn new(data: Vec<u8>) -> Self {
949                Self { data, pos: 0 }
950            }
951        }
952
953        impl Read for ByteAtATimeStream {
954            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
955                if self.pos >= self.data.len() {
956                    return Ok(0);
957                }
958                buf[0] = self.data[self.pos];
959                self.pos += 1;
960                Ok(1)
961            }
962        }
963
964        impl Write for ByteAtATimeStream {
965            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
966                Ok(buf.len())
967            }
968            fn flush(&mut self) -> io::Result<()> {
969                Ok(())
970            }
971        }
972
973        fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
974            let mut frame = Vec::new();
975            let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
976            frame.push(byte0);
977            if payload.len() <= 125 {
978                frame.push(payload.len() as u8);
979            } else if payload.len() <= 65535 {
980                frame.push(126);
981                frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
982            } else {
983                frame.push(127);
984                frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
985            }
986            frame.extend_from_slice(payload);
987            frame
988        }
989
990        fn ws_from_bytes(data: Vec<u8>) -> Client<ByteAtATimeStream> {
991            let mock = ByteAtATimeStream::new(data);
992            let reader = FrameReader::builder().role(Role::Client).build();
993            let writer = FrameWriter::new(Role::Client);
994            Client::from_parts(mock, reader, writer)
995        }
996
997        #[test]
998        fn recv_text() {
999            let frame = make_frame(true, 0x1, b"Hello");
1000            let mut ws = ws_from_bytes(frame);
1001            match ws.recv().unwrap().unwrap() {
1002                Message::Text(s) => assert_eq!(s, "Hello"),
1003                other => panic!("expected Text, got {other:?}"),
1004            }
1005        }
1006
1007        #[test]
1008        fn recv_ping() {
1009            let frame = make_frame(true, 0x9, &[0x42; 125]);
1010            let mut ws = ws_from_bytes(frame);
1011            match ws.recv().unwrap().unwrap() {
1012                Message::Ping(p) => assert_eq!(p.len(), 125),
1013                other => panic!("expected Ping, got {other:?}"),
1014            }
1015        }
1016
1017        #[test]
1018        fn recv_fragmented_text() {
1019            let mut data = make_frame(false, 0x1, b"Hel");
1020            data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1021            let mut ws = ws_from_bytes(data);
1022            match ws.recv().unwrap().unwrap() {
1023                Message::Text(s) => assert_eq!(s, "Hello"),
1024                other => panic!("expected Text, got {other:?}"),
1025            }
1026        }
1027
1028        #[test]
1029        fn recv_fragment_with_ping() {
1030            let mut data = make_frame(false, 0x1, b"Hel");
1031            data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
1032            data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1033            let mut ws = ws_from_bytes(data);
1034            match ws.recv().unwrap().unwrap() {
1035                Message::Ping(p) => assert_eq!(p, b"ping"),
1036                other => panic!("expected Ping, got {other:?}"),
1037            }
1038            match ws.recv().unwrap().unwrap() {
1039                Message::Text(s) => assert_eq!(s, "Hello"),
1040                other => panic!("expected Text, got {other:?}"),
1041            }
1042        }
1043
1044        #[test]
1045        fn recv_close() {
1046            let mut payload = vec![];
1047            payload.extend_from_slice(&1000u16.to_be_bytes());
1048            payload.extend_from_slice(b"bye");
1049            let frame = make_frame(true, 0x8, &payload);
1050            let mut ws = ws_from_bytes(frame);
1051            match ws.recv().unwrap().unwrap() {
1052                Message::Close(cf) => {
1053                    assert_eq!(cf.code, CloseCode::Normal);
1054                    assert_eq!(cf.reason, "bye");
1055                }
1056                other => panic!("expected Close, got {other:?}"),
1057            }
1058        }
1059
1060        #[test]
1061        fn eof_returns_none() {
1062            let mut ws = ws_from_bytes(Vec::new());
1063            assert!(ws.recv().unwrap().is_none());
1064        }
1065
1066        #[test]
1067        fn would_block_returns_none() {
1068            struct WouldBlockStream;
1069            impl Read for WouldBlockStream {
1070                fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
1071                    Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
1072                }
1073            }
1074            impl Write for WouldBlockStream {
1075                fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1076                    Ok(buf.len())
1077                }
1078                fn flush(&mut self) -> io::Result<()> {
1079                    Ok(())
1080                }
1081            }
1082
1083            let reader = FrameReader::builder().role(Role::Client).build();
1084            let writer = FrameWriter::new(Role::Client);
1085            let mut ws = Client::from_parts(WouldBlockStream, reader, writer);
1086            assert!(ws.recv().unwrap().is_none());
1087        }
1088    }
1089
1090    // =========================================================================
1091    // ws::Error variant coverage
1092    // =========================================================================
1093
1094    #[test]
1095    fn ws_error_io() {
1096        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
1097        let err = Error::from(io_err);
1098        assert!(matches!(err, Error::Io(_)));
1099        assert!(err.to_string().contains("broken"));
1100    }
1101
1102    #[test]
1103    fn ws_error_protocol() {
1104        let proto = ProtocolError::InvalidUtf8;
1105        let err = Error::from(proto);
1106        assert!(matches!(err, Error::Protocol(ProtocolError::InvalidUtf8)));
1107        assert!(err.to_string().contains("protocol error"));
1108    }
1109
1110    #[test]
1111    fn ws_error_encode() {
1112        let enc = crate::ws::EncodeError::ControlPayloadTooLarge(200);
1113        let err = Error::from(enc);
1114        assert!(matches!(err, Error::Encode(_)));
1115        assert!(err.to_string().contains("encode error"));
1116    }
1117
1118    #[test]
1119    fn ws_error_handshake() {
1120        let hs = HandshakeError::MissingUpgrade;
1121        let err = Error::from(hs);
1122        assert!(matches!(
1123            err,
1124            Error::Handshake(HandshakeError::MissingUpgrade)
1125        ));
1126        assert!(err.to_string().contains("handshake error"));
1127    }
1128
1129    #[test]
1130    fn ws_error_invalid_url() {
1131        let err = Error::InvalidUrl("bad://url".into());
1132        assert!(matches!(err, Error::InvalidUrl(_)));
1133        assert!(err.to_string().contains("bad://url"));
1134    }
1135
1136    #[test]
1137    fn ws_error_tls_not_enabled() {
1138        let err = Error::TlsNotEnabled;
1139        assert!(matches!(err, Error::TlsNotEnabled));
1140        assert!(err.to_string().contains("tls"));
1141    }
1142}