Skip to main content

nexus_net/ws/
stream.rs

1//! WebSocket stream — I/O wrapper with HTTP upgrade handshake.
2
3use std::time::Duration;
4
5use super::error::ProtocolError;
6use super::frame::Role;
7use super::frame_reader::{FrameReader, FrameReaderBuilder};
8use super::frame_writer::FrameWriter;
9use super::message::{CloseCode, Message};
10use crate::buf::WriteBuf;
11
12use super::handshake;
13use super::handshake::HandshakeError;
14use std::io::{self, Read, Write};
15
16#[cfg(feature = "tls")]
17use crate::tls::{TlsConfig, TlsError};
18
19// =============================================================================
20// URL parsing
21// =============================================================================
22
23/// Parsed WebSocket URL.
24#[non_exhaustive]
25pub struct ParsedUrl<'a> {
26    /// Whether the URL is `wss://` (true) or `ws://` (false).
27    pub tls: bool,
28    /// Host portion (no port).
29    pub host: &'a str,
30    /// Port — explicit if present, otherwise the scheme default
31    /// (80 for ws, 443 for wss).
32    pub port: u16,
33    /// Path portion (everything after the host:port, including the
34    /// leading `/`). Defaults to `/` when absent in the input URL.
35    pub path: &'a str,
36}
37
38impl ParsedUrl<'_> {
39    /// Host header value: includes port if non-default.
40    pub fn host_header(&self) -> String {
41        let default = if self.tls { 443 } else { 80 };
42        if self.port == default {
43            self.host.to_string()
44        } else {
45            format!("{}:{}", self.host, self.port)
46        }
47    }
48}
49
50/// Parse a `ws://` or `wss://` URL into its scheme, host, port, and
51/// path. Supports IPv6 bracket notation (`[::1]:8080`). Returns
52/// [`Error::InvalidUrl`] on a malformed input or missing scheme.
53pub fn parse_ws_url(url: &str) -> Result<ParsedUrl<'_>, Error> {
54    let (tls, rest) = if let Some(r) = url.strip_prefix("wss://") {
55        (true, r)
56    } else if let Some(r) = url.strip_prefix("ws://") {
57        (false, r)
58    } else {
59        return Err(Error::InvalidUrl(url.to_string()));
60    };
61
62    let (host_port, path) = rest
63        .find('/')
64        .map_or((rest, "/"), |i| (&rest[..i], &rest[i..]));
65
66    if host_port.is_empty() {
67        return Err(Error::InvalidUrl(format!("empty host: {url}")));
68    }
69
70    let default_port = if tls { 443 } else { 80 };
71
72    // IPv6 bracket notation: [::1]:8080
73    let (host, port) = if host_port.starts_with('[') {
74        match host_port.find(']') {
75            Some(end) => {
76                let h = &host_port[1..end];
77                let rest = &host_port[end + 1..];
78                if let Some(port_str) = rest.strip_prefix(':') {
79                    let p = port_str
80                        .parse::<u16>()
81                        .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
82                    (h, p)
83                } else {
84                    (h, default_port)
85                }
86            }
87            None => return Err(Error::InvalidUrl(format!("unclosed bracket: {url}"))),
88        }
89    } else {
90        match host_port.rfind(':') {
91            None => (host_port, default_port),
92            Some(i) => {
93                let port_str = &host_port[i + 1..];
94                if port_str.is_empty() {
95                    (&host_port[..i], default_port)
96                } else {
97                    let p = port_str
98                        .parse::<u16>()
99                        .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
100                    (&host_port[..i], p)
101                }
102            }
103        }
104    };
105
106    Ok(ParsedUrl {
107        tls,
108        host,
109        port,
110        path,
111    })
112}
113
114// =============================================================================
115// Error
116// =============================================================================
117
118/// Unified error type for WebSocket stream operations.
119#[derive(Debug)]
120pub enum Error {
121    /// I/O error from the underlying stream.
122    Io(std::io::Error),
123    /// WebSocket protocol error.
124    Protocol(ProtocolError),
125    /// Encoding error (e.g., control frame payload too large).
126    Encode(super::frame_writer::EncodeError),
127    /// HTTP handshake failed.
128    Handshake(HandshakeError),
129    /// TLS error during connection setup (handshake, certificate
130    /// validation, SNI hostname verification).
131    ///
132    /// **Steady-state TLS protocol errors** (decrypt failure, peer
133    /// alert, malformed record received during a frame) on the async
134    /// `nexus-async-net` paths surface as [`Error::Io`](Self::Io)
135    /// instead — the underlying [`TlsError`](crate::tls::TlsError) is
136    /// wrapped via `io::Error::other` and reachable via
137    /// `io_err.source()` or `io_err.get_ref()`. This asymmetry stems
138    /// from the [`WireStream`](crate::WireStream) trait returning
139    /// `io::Result` for poll methods. Sync WS surfaces `Tls` directly
140    /// because its `TlsStream` exposes `TlsError` natively. Pattern-
141    /// match on both `Io` and `Tls` if you need to distinguish TLS-
142    /// protocol failures from generic transport failures across both
143    /// surfaces.
144    #[cfg(feature = "tls")]
145    Tls(TlsError),
146    /// Invalid WebSocket URL.
147    InvalidUrl(String),
148    /// `wss://` URL used without the `tls` feature enabled.
149    TlsNotEnabled,
150}
151
152impl std::fmt::Display for Error {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            Self::Io(e) => write!(f, "I/O error: {e}"),
156            Self::Protocol(e) => write!(f, "protocol error: {e}"),
157            Self::Encode(e) => write!(f, "encode error: {e}"),
158            Self::Handshake(e) => write!(f, "handshake error: {e}"),
159            #[cfg(feature = "tls")]
160            Self::Tls(e) => write!(f, "TLS error: {e}"),
161            Self::InvalidUrl(u) => write!(f, "invalid WebSocket URL: {u}"),
162            Self::TlsNotEnabled => write!(f, "wss:// requires the 'tls' feature"),
163        }
164    }
165}
166
167impl std::error::Error for Error {}
168
169impl From<std::io::Error> for Error {
170    fn from(e: std::io::Error) -> Self {
171        Self::Io(e)
172    }
173}
174impl From<ProtocolError> for Error {
175    fn from(e: ProtocolError) -> Self {
176        Self::Protocol(e)
177    }
178}
179impl From<super::frame_writer::EncodeError> for Error {
180    fn from(e: super::frame_writer::EncodeError) -> Self {
181        Self::Encode(e)
182    }
183}
184impl From<HandshakeError> for Error {
185    fn from(e: HandshakeError) -> Self {
186        Self::Handshake(e)
187    }
188}
189#[cfg(feature = "tls")]
190impl From<TlsError> for Error {
191    fn from(e: TlsError) -> Self {
192        match e {
193            TlsError::Io(io) => Self::Io(io),
194            other => Self::Tls(other),
195        }
196    }
197}
198
199// =============================================================================
200// ClientBuilder
201// =============================================================================
202
203/// Builder for [`Client`].
204///
205/// Configures buffer sizes, socket options, and optional TLS.
206///
207/// # Examples
208///
209/// ```ignore
210/// let mut ws = Client::builder()
211///     .disable_nagle()
212///     .buffer_capacity(2 * 1024 * 1024)
213///     .connect("wss://exchange.com/ws")?;
214/// ```
215pub struct ClientBuilder {
216    pub(crate) reader_builder: FrameReaderBuilder,
217    pub(crate) write_buf_capacity: usize,
218    pub(crate) write_buf_headroom: usize,
219    #[cfg(feature = "tls")]
220    pub(crate) tls_config: Option<TlsConfig>,
221    pub(crate) tcp_nodelay: bool,
222    #[cfg(feature = "socket-opts")]
223    pub(crate) recv_buf_size: Option<usize>,
224    #[cfg(feature = "socket-opts")]
225    pub(crate) send_buf_size: Option<usize>,
226    pub(crate) connect_timeout: Option<Duration>,
227    pub(crate) read_timeout: Option<Duration>,
228}
229
230impl ClientBuilder {
231    /// Create a new builder with defaults.
232    #[must_use]
233    pub fn new() -> Self {
234        Self {
235            reader_builder: FrameReader::builder(),
236            write_buf_capacity: 65_536,
237            write_buf_headroom: 14,
238            #[cfg(feature = "tls")]
239            tls_config: None,
240            tcp_nodelay: false,
241            #[cfg(feature = "socket-opts")]
242            recv_buf_size: None,
243            #[cfg(feature = "socket-opts")]
244            send_buf_size: None,
245            connect_timeout: None,
246            read_timeout: None,
247        }
248    }
249
250    /// ReadBuf capacity. Default: 1MB.
251    #[must_use]
252    pub fn buffer_capacity(mut self, n: usize) -> Self {
253        self.reader_builder = self.reader_builder.buffer_capacity(n);
254        self
255    }
256
257    /// Maximum single frame payload. Default: 16MB.
258    #[must_use]
259    pub fn max_frame_size(mut self, n: u64) -> Self {
260        self.reader_builder = self.reader_builder.max_frame_size(n);
261        self
262    }
263
264    /// Maximum assembled message size. Default: 16MB.
265    #[must_use]
266    pub fn max_message_size(mut self, n: usize) -> Self {
267        self.reader_builder = self.reader_builder.max_message_size(n);
268        self
269    }
270
271    /// Write buffer capacity. Default: 64KB.
272    #[must_use]
273    pub fn write_buffer_capacity(mut self, n: usize) -> Self {
274        self.write_buf_capacity = n;
275        self
276    }
277
278    /// Set `TCP_NODELAY` (disable Nagle's algorithm).
279    #[must_use]
280    pub fn disable_nagle(mut self) -> Self {
281        self.tcp_nodelay = true;
282        self
283    }
284
285    /// Set `SO_RCVBUF` (socket receive buffer size).
286    #[cfg(feature = "socket-opts")]
287    #[must_use]
288    pub fn recv_buffer_size(mut self, n: usize) -> Self {
289        self.recv_buf_size = Some(n);
290        self
291    }
292
293    /// Set `SO_SNDBUF` (socket send buffer size).
294    #[cfg(feature = "socket-opts")]
295    #[must_use]
296    pub fn send_buffer_size(mut self, n: usize) -> Self {
297        self.send_buf_size = Some(n);
298        self
299    }
300
301    /// TCP connect timeout.
302    #[must_use]
303    pub fn connect_timeout(mut self, d: Duration) -> Self {
304        self.connect_timeout = Some(d);
305        self
306    }
307
308    /// Socket read timeout.
309    #[must_use]
310    pub fn read_timeout(mut self, d: Duration) -> Self {
311        self.read_timeout = Some(d);
312        self
313    }
314
315    /// Set a custom TLS configuration.
316    ///
317    /// If not set, `wss://` URLs use [`TlsConfig::new()`] (system defaults).
318    #[cfg(feature = "tls")]
319    #[must_use]
320    pub fn tls(mut self, config: &TlsConfig) -> Self {
321        self.tls_config = Some(config.clone());
322        self
323    }
324
325    /// Connect to a WebSocket server (blocking).
326    ///
327    /// Creates a TCP socket, applies socket options, and performs the
328    /// full handshake (TLS if `wss://`, then HTTP upgrade).
329    ///
330    /// When the `tls` feature is enabled, returns `Client<MaybeTls<TcpStream>>`
331    /// regardless of scheme — `ws://` uses `MaybeTls::Plain`, `wss://` uses
332    /// `MaybeTls::Tls`. Without the `tls` feature, returns `Client<TcpStream>`
333    /// and errors on `wss://`.
334    #[cfg(feature = "tls")]
335    pub fn connect(self, url: &str) -> Result<Client<crate::MaybeTls<std::net::TcpStream>>, Error> {
336        let parsed = parse_ws_url(url)?;
337        let addr = format!("{}:{}", parsed.host, parsed.port);
338
339        let tcp = match self.connect_timeout {
340            Some(timeout) => {
341                let addrs: Vec<std::net::SocketAddr> =
342                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
343                        .map_err(Error::Io)?
344                        .collect();
345                let first = addrs
346                    .first()
347                    .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
348                std::net::TcpStream::connect_timeout(first, timeout)?
349            }
350            None => std::net::TcpStream::connect(&addr)?,
351        };
352
353        self.apply_socket_opts(&tcp)?;
354
355        let stream = if parsed.tls {
356            let config = match self.tls_config {
357                Some(c) => c,
358                None => TlsConfig::new().map_err(Error::Tls)?,
359            };
360            let codec = crate::tls::TlsCodec::new(&config, parsed.host)?;
361            let tls = crate::tls::TlsStream::connect(tcp, codec).map_err(Error::Tls)?;
362            crate::MaybeTls::Tls(Box::new(tls))
363        } else {
364            crate::MaybeTls::Plain(tcp)
365        };
366
367        let host_header = parsed.host_header();
368        Client::connect_impl(
369            stream,
370            &host_header,
371            parsed.path,
372            self.reader_builder,
373            self.write_buf_capacity,
374            self.write_buf_headroom,
375        )
376    }
377
378    /// Connect to a WebSocket server (blocking, no TLS feature).
379    #[cfg(not(feature = "tls"))]
380    pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, Error> {
381        let parsed = parse_ws_url(url)?;
382        if parsed.tls {
383            return Err(Error::TlsNotEnabled);
384        }
385        let addr = format!("{}:{}", parsed.host, parsed.port);
386
387        let tcp = match self.connect_timeout {
388            Some(timeout) => {
389                let addrs: Vec<std::net::SocketAddr> =
390                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
391                        .map_err(Error::Io)?
392                        .collect();
393                let first = addrs
394                    .first()
395                    .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
396                std::net::TcpStream::connect_timeout(first, timeout)?
397            }
398            None => std::net::TcpStream::connect(&addr)?,
399        };
400
401        self.apply_socket_opts(&tcp)?;
402
403        let host_header = parsed.host_header();
404        Client::connect_impl(
405            tcp,
406            &host_header,
407            parsed.path,
408            self.reader_builder,
409            self.write_buf_capacity,
410            self.write_buf_headroom,
411        )
412    }
413
414    /// Connect using a pre-connected stream.
415    ///
416    /// The stream must already handle TLS if connecting to `wss://`.
417    /// For example, pass a `TlsStream<TcpStream>` or `MaybeTls<TcpStream>`.
418    /// This method only performs the HTTP upgrade handshake.
419    pub fn connect_with<S: Read + Write>(self, stream: S, url: &str) -> Result<Client<S>, Error> {
420        let parsed = parse_ws_url(url)?;
421        let host_header = parsed.host_header();
422        Client::connect_impl(
423            stream,
424            &host_header,
425            parsed.path,
426            self.reader_builder,
427            self.write_buf_capacity,
428            self.write_buf_headroom,
429        )
430    }
431
432    /// Accept an incoming WebSocket connection (server-side).
433    pub fn accept<S: Read + Write>(self, stream: S) -> Result<Client<S>, Error> {
434        Client::accept_impl(
435            stream,
436            self.reader_builder,
437            self.write_buf_capacity,
438            self.write_buf_headroom,
439        )
440    }
441
442    fn apply_socket_opts(&self, tcp: &std::net::TcpStream) -> Result<(), Error> {
443        if self.tcp_nodelay {
444            tcp.set_nodelay(true)?;
445        }
446        if let Some(timeout) = self.read_timeout {
447            tcp.set_read_timeout(Some(timeout))?;
448        }
449        #[cfg(feature = "socket-opts")]
450        {
451            let sock = socket2::SockRef::from(tcp);
452            if let Some(size) = self.recv_buf_size {
453                sock.set_recv_buffer_size(size)?;
454            }
455            if let Some(size) = self.send_buf_size {
456                sock.set_send_buffer_size(size)?;
457            }
458        }
459        Ok(())
460    }
461}
462
463impl Default for ClientBuilder {
464    fn default() -> Self {
465        Self::new()
466    }
467}
468
469// =============================================================================
470// Client
471// =============================================================================
472
473/// WebSocket stream — owns a socket, reader, writer, and buffers.
474///
475/// Handles both plain `ws://` and encrypted `wss://` connections.
476/// The URL scheme determines whether TLS is used — no separate type needed.
477///
478/// # Usage
479///
480/// ```ignore
481/// use nexus_net::ws::Client;
482/// use nexus_net::tls::TlsConfig;
483///
484/// // Plain WebSocket
485/// let mut ws = Client::builder().connect("ws://localhost:8080/ws")?;
486///
487/// // TLS WebSocket (requires 'tls' feature)
488/// let tls = TlsConfig::new()?;
489/// let mut ws = Client::builder().tls(&tls).connect("wss://exchange.com/ws")?;
490///
491/// // Same API for both:
492/// ws.send_text("Hello!")?;
493/// while let Some(msg) = ws.recv()? {
494///     // ...
495/// }
496/// ```
497pub struct Client<S> {
498    pub(crate) stream: S,
499    pub(crate) reader: FrameReader,
500    pub(crate) writer: FrameWriter,
501    pub(crate) write_buf: WriteBuf,
502    pub(crate) poisoned: bool,
503}
504
505impl Client<std::net::TcpStream> {
506    /// Create a builder for configuring buffer sizes, socket options, and TLS.
507    #[must_use]
508    pub fn builder() -> ClientBuilder {
509        ClientBuilder::new()
510    }
511}
512
513// -- Unbounded impl: accessors and constructors that need no I/O traits -------
514
515impl<S> Client<S> {
516    /// Create from pre-existing parts. For testing or custom handshakes.
517    pub fn from_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
518        Self {
519            stream,
520            reader,
521            writer,
522            write_buf: WriteBuf::new(65_536, 14),
523            poisoned: false,
524        }
525    }
526
527    /// Internal constructor with all fields. Used by Connecting::finish().
528    pub(crate) fn from_parts_internal(
529        stream: S,
530        reader: FrameReader,
531        writer: FrameWriter,
532        write_buf: WriteBuf,
533    ) -> Self {
534        Self {
535            stream,
536            reader,
537            writer,
538            write_buf,
539            poisoned: false,
540        }
541    }
542
543    /// Whether the stream is poisoned (I/O error occurred during send).
544    ///
545    /// A poisoned stream should not be reused — the connection may be
546    /// in an indeterminate state (partial frame written).
547    pub fn is_poisoned(&self) -> bool {
548        self.poisoned
549    }
550
551    /// Access the underlying stream.
552    pub fn stream(&self) -> &S {
553        &self.stream
554    }
555
556    /// Mutable access to the underlying stream.
557    pub fn stream_mut(&mut self) -> &mut S {
558        &mut self.stream
559    }
560
561    /// Access the FrameReader.
562    pub fn reader(&self) -> &FrameReader {
563        &self.reader
564    }
565
566    /// Access the FrameWriter.
567    pub fn frame_writer(&self) -> &FrameWriter {
568        &self.writer
569    }
570}
571
572// -- Blocking I/O impl --------------------------------------------------------
573
574impl<S: Read + Write> Client<S> {
575    /// Connect using a pre-connected socket with default configuration.
576    ///
577    /// IPv6 addresses must use bracket notation: `ws://[::1]:8080/path`.
578    pub fn connect_with(stream: S, url: &str) -> Result<Self, Error> {
579        ClientBuilder::new().connect_with(stream, url)
580    }
581
582    /// Accept an incoming WebSocket connection (server-side).
583    pub fn accept(stream: S) -> Result<Self, Error> {
584        ClientBuilder::new().accept(stream)
585    }
586
587    /// Receive the next message. Reads from the socket as needed.
588    ///
589    /// Returns `Ok(None)` on EOF, buffer full, or `WouldBlock` (non-blocking sockets).
590    pub fn recv(&mut self) -> Result<Option<Message<'_>>, Error> {
591        loop {
592            if self.reader.poll()? {
593                return Ok(self.reader.next()?);
594            }
595            match self.read_into_reader() {
596                Ok(0) => return Ok(None),
597                Ok(_) => {}
598                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
599                Err(e) => return Err(Error::Io(e)),
600            }
601        }
602    }
603
604    /// Send a text message.
605    pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
606        self.writer
607            .encode_text_into(text.as_bytes(), &mut self.write_buf);
608        self.flush_write_buf_or_poison()
609    }
610
611    /// Send a binary message.
612    pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
613        self.writer.encode_binary_into(data, &mut self.write_buf);
614        self.flush_write_buf_or_poison()
615    }
616
617    /// Send a ping.
618    pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
619        self.writer
620            .encode_ping_into(data, &mut self.write_buf)
621            .map_err(Error::Encode)?;
622        self.flush_write_buf_or_poison()
623    }
624
625    /// Send a pong.
626    pub fn send_pong(&mut self, data: &[u8]) -> Result<(), Error> {
627        self.writer
628            .encode_pong_into(data, &mut self.write_buf)
629            .map_err(Error::Encode)?;
630        self.flush_write_buf_or_poison()
631    }
632
633    /// Initiate close handshake.
634    pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
635        if code == CloseCode::NoStatus {
636            let mut dst = [0u8; 14];
637            let n = self.writer.encode_empty_close(&mut dst);
638            self.write_raw(&dst[..n]).inspect_err(|_| {
639                self.poisoned = true;
640            })
641        } else {
642            self.writer
643                .encode_close_into(code.as_u16(), reason.as_bytes(), &mut self.write_buf)
644                .map_err(Error::Encode)?;
645            self.flush_write_buf_or_poison()
646        }
647    }
648
649    // =========================================================================
650    // Internal — read/write with optional TLS
651    // =========================================================================
652
653    /// Read bytes into the FrameReader.
654    ///
655    /// TLS is now handled at the stream level (`TlsStream<S>` or
656    /// `MaybeTls<S>`), so this always reads plaintext from `S`.
657    fn read_into_reader(&mut self) -> io::Result<usize> {
658        self.reader.read_from(&mut self.stream)
659    }
660
661    /// Flush write_buf, poisoning on I/O error.
662    fn flush_write_buf_or_poison(&mut self) -> Result<(), Error> {
663        self.flush_write_buf().inspect_err(|_| {
664            self.poisoned = true;
665        })
666    }
667
668    /// Flush the write_buf to the socket.
669    fn flush_write_buf(&mut self) -> Result<(), Error> {
670        self.stream.write_all(self.write_buf.data())?;
671        Ok(())
672    }
673
674    /// Write raw bytes to the socket.
675    fn write_raw(&mut self, data: &[u8]) -> Result<(), Error> {
676        self.stream.write_all(data)?;
677        Ok(())
678    }
679
680    // =========================================================================
681    // Internal — handshake
682    // =========================================================================
683
684    /// Perform the HTTP upgrade handshake on a stream that is already
685    /// plaintext-ready (TLS handled at the stream level).
686    pub(crate) fn connect_impl(
687        mut stream: S,
688        host: &str,
689        path: &str,
690        reader_builder: FrameReaderBuilder,
691        write_cap: usize,
692        write_headroom: usize,
693    ) -> Result<Self, Error> {
694        let key = handshake::generate_key();
695        let key_str = std::str::from_utf8(&key).expect("base64 output is valid ASCII");
696
697        let headers = [
698            ("Host", host),
699            ("Upgrade", "websocket"),
700            ("Connection", "Upgrade"),
701            ("Sec-WebSocket-Key", key_str),
702            ("Sec-WebSocket-Version", "13"),
703        ];
704        let req_size = crate::http::request_size("GET", path, &headers);
705        let mut req_buf = vec![0u8; req_size];
706        let n = crate::http::write_request("GET", path, &headers, &mut req_buf)
707            .map_err(|_| HandshakeError::MalformedHttp)?;
708
709        stream.write_all(&req_buf[..n])?;
710
711        let mut resp_reader = crate::http::ResponseReader::new(4096);
712        let mut tmp = [0u8; 4096];
713        loop {
714            let bytes_read = stream.read(&mut tmp)?;
715            if bytes_read == 0 {
716                return Err(HandshakeError::MalformedHttp.into());
717            }
718
719            resp_reader
720                .read(&tmp[..bytes_read])
721                .map_err(|_| HandshakeError::MalformedHttp)?;
722            match resp_reader.next() {
723                Ok(Some(resp)) => {
724                    if resp.status != 101 {
725                        return Err(HandshakeError::UnexpectedStatus(resp.status).into());
726                    }
727                    let upgrade = resp
728                        .header("Upgrade")
729                        .ok_or(HandshakeError::MissingUpgrade)?;
730                    if !upgrade.eq_ignore_ascii_case("websocket") {
731                        return Err(HandshakeError::MissingUpgrade.into());
732                    }
733                    let conn = resp
734                        .header("Connection")
735                        .ok_or(HandshakeError::MissingConnection)?;
736                    if !contains_ignore_case(conn, "upgrade") {
737                        return Err(HandshakeError::MissingConnection.into());
738                    }
739                    let accept = resp
740                        .header("Sec-WebSocket-Accept")
741                        .ok_or(HandshakeError::InvalidAcceptKey)?;
742                    if !handshake::validate_accept(key_str, accept) {
743                        return Err(HandshakeError::InvalidAcceptKey.into());
744                    }
745
746                    let mut reader = reader_builder.role(Role::Client).build();
747                    let remainder = resp_reader.remainder();
748                    if !remainder.is_empty() {
749                        reader
750                            .read(remainder)
751                            .map_err(|_| HandshakeError::MalformedHttp)?;
752                    }
753
754                    return Ok(Self {
755                        stream,
756                        reader,
757                        writer: FrameWriter::new(Role::Client),
758                        write_buf: WriteBuf::new(write_cap, write_headroom),
759                        poisoned: false,
760                    });
761                }
762                Ok(None) => {} // need more bytes
763                Err(_) => return Err(HandshakeError::MalformedHttp.into()),
764            }
765        }
766    }
767
768    fn accept_impl(
769        mut stream: S,
770        reader_builder: FrameReaderBuilder,
771        write_cap: usize,
772        write_headroom: usize,
773    ) -> Result<Self, Error> {
774        let mut req_reader = crate::http::RequestReader::new(4096);
775        let mut tmp = [0u8; 4096];
776
777        let ws_key;
778        loop {
779            let n = stream.read(&mut tmp)?;
780            if n == 0 {
781                return Err(HandshakeError::MalformedHttp.into());
782            }
783            req_reader
784                .read(&tmp[..n])
785                .map_err(|_| HandshakeError::MalformedHttp)?;
786            match req_reader.next() {
787                Ok(Some(req)) => {
788                    if req.method != "GET" {
789                        return Err(HandshakeError::MalformedHttp.into());
790                    }
791                    let upgrade = req
792                        .header("Upgrade")
793                        .ok_or(HandshakeError::MissingUpgrade)?;
794                    if !upgrade.eq_ignore_ascii_case("websocket") {
795                        return Err(HandshakeError::MissingUpgrade.into());
796                    }
797                    let conn = req
798                        .header("Connection")
799                        .ok_or(HandshakeError::MissingConnection)?;
800                    if !contains_ignore_case(conn, "upgrade") {
801                        return Err(HandshakeError::MissingConnection.into());
802                    }
803                    let version = req
804                        .header("Sec-WebSocket-Version")
805                        .ok_or(HandshakeError::UnsupportedVersion)?;
806                    if version != "13" {
807                        return Err(HandshakeError::UnsupportedVersion.into());
808                    }
809                    let key = req
810                        .header("Sec-WebSocket-Key")
811                        .ok_or(HandshakeError::MissingKey)?;
812                    ws_key = key.to_owned();
813                    break;
814                }
815                Ok(None) => {}
816                Err(_) => return Err(HandshakeError::MalformedHttp.into()),
817            }
818        }
819
820        let accept = handshake::compute_accept_key(&ws_key);
821        let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
822
823        let resp_headers = [
824            ("Upgrade", "websocket"),
825            ("Connection", "Upgrade"),
826            ("Sec-WebSocket-Accept", accept_str),
827        ];
828        let resp_size = crate::http::response_size("Switching Protocols", &resp_headers);
829        let mut resp_buf = vec![0u8; resp_size];
830        let n =
831            crate::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
832                .map_err(|_| HandshakeError::MalformedHttp)?;
833        stream.write_all(&resp_buf[..n])?;
834
835        let mut reader = reader_builder.role(Role::Server).build();
836        let remainder = req_reader.remainder();
837        if !remainder.is_empty() {
838            reader
839                .read(remainder)
840                .map_err(|_| HandshakeError::MalformedHttp)?;
841        }
842
843        Ok(Self {
844            stream,
845            reader,
846            writer: FrameWriter::new(Role::Server),
847            write_buf: WriteBuf::new(write_cap, write_headroom),
848            poisoned: false,
849        })
850    }
851}
852
853/// Create a matched FrameReader + FrameWriter pair.
854///
855/// Prevents mismatched roles between reader and writer.
856pub fn pair(role: Role) -> (FrameReader, FrameWriter) {
857    (
858        FrameReader::builder().role(role).build(),
859        FrameWriter::new(role),
860    )
861}
862
863/// Create a pair with a configured FrameReader.
864pub fn pair_with(role: Role, reader_builder: FrameReaderBuilder) -> (FrameReader, FrameWriter) {
865    (reader_builder.role(role).build(), FrameWriter::new(role))
866}
867
868fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
869    haystack
870        .as_bytes()
871        .windows(needle.len())
872        .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
873}
874
875#[cfg(test)]
876mod tests {
877    use super::*;
878
879    // =========================================================================
880    // URL parsing
881    // =========================================================================
882
883    #[test]
884    fn parse_ws_url_plain() {
885        let p = parse_ws_url("ws://localhost:8080/ws").unwrap();
886        assert!(!p.tls);
887        assert_eq!(p.host, "localhost");
888        assert_eq!(p.port, 8080);
889        assert_eq!(p.path, "/ws");
890    }
891
892    #[test]
893    fn parse_ws_url_tls() {
894        let p = parse_ws_url("wss://exchange.com/ws/v1").unwrap();
895        assert!(p.tls);
896        assert_eq!(p.host, "exchange.com");
897        assert_eq!(p.port, 443);
898        assert_eq!(p.path, "/ws/v1");
899    }
900
901    #[test]
902    fn parse_ws_url_default_port() {
903        let p = parse_ws_url("ws://host/path").unwrap();
904        assert_eq!(p.port, 80);
905
906        let p = parse_ws_url("wss://host/path").unwrap();
907        assert_eq!(p.port, 443);
908    }
909
910    #[test]
911    fn parse_ws_url_no_path() {
912        let p = parse_ws_url("ws://host").unwrap();
913        assert_eq!(p.path, "/");
914    }
915
916    #[test]
917    fn parse_ws_url_invalid_scheme() {
918        assert!(parse_ws_url("http://host").is_err());
919        assert!(parse_ws_url("host/path").is_err());
920    }
921
922    // =========================================================================
923    // Blocking Client tests
924    // =========================================================================
925
926    mod sync_tests {
927        use super::*;
928        use std::io::{self, Read, Write};
929
930        #[test]
931        fn pair_creates_matching_roles() {
932            let (mut reader, _writer) = pair(Role::Client);
933            let frame = make_frame(true, 0x1, b"test");
934            reader.read(&frame).unwrap();
935            let msg = reader.next().unwrap().unwrap();
936            assert!(matches!(msg, Message::Text(s) if s == "test"));
937        }
938
939        struct ByteAtATimeStream {
940            data: Vec<u8>,
941            pos: usize,
942        }
943
944        impl ByteAtATimeStream {
945            fn new(data: Vec<u8>) -> Self {
946                Self { data, pos: 0 }
947            }
948        }
949
950        impl Read for ByteAtATimeStream {
951            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
952                if self.pos >= self.data.len() {
953                    return Ok(0);
954                }
955                buf[0] = self.data[self.pos];
956                self.pos += 1;
957                Ok(1)
958            }
959        }
960
961        impl Write for ByteAtATimeStream {
962            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
963                Ok(buf.len())
964            }
965            fn flush(&mut self) -> io::Result<()> {
966                Ok(())
967            }
968        }
969
970        fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
971            let mut frame = Vec::new();
972            let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
973            frame.push(byte0);
974            if payload.len() <= 125 {
975                frame.push(payload.len() as u8);
976            } else if payload.len() <= 65535 {
977                frame.push(126);
978                frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
979            } else {
980                frame.push(127);
981                frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
982            }
983            frame.extend_from_slice(payload);
984            frame
985        }
986
987        fn ws_from_bytes(data: Vec<u8>) -> Client<ByteAtATimeStream> {
988            let mock = ByteAtATimeStream::new(data);
989            let reader = FrameReader::builder().role(Role::Client).build();
990            let writer = FrameWriter::new(Role::Client);
991            Client::from_parts(mock, reader, writer)
992        }
993
994        #[test]
995        fn recv_text() {
996            let frame = make_frame(true, 0x1, b"Hello");
997            let mut ws = ws_from_bytes(frame);
998            match ws.recv().unwrap().unwrap() {
999                Message::Text(s) => assert_eq!(s, "Hello"),
1000                other => panic!("expected Text, got {other:?}"),
1001            }
1002        }
1003
1004        #[test]
1005        fn recv_ping() {
1006            let frame = make_frame(true, 0x9, &[0x42; 125]);
1007            let mut ws = ws_from_bytes(frame);
1008            match ws.recv().unwrap().unwrap() {
1009                Message::Ping(p) => assert_eq!(p.len(), 125),
1010                other => panic!("expected Ping, got {other:?}"),
1011            }
1012        }
1013
1014        #[test]
1015        fn recv_fragmented_text() {
1016            let mut data = make_frame(false, 0x1, b"Hel");
1017            data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1018            let mut ws = ws_from_bytes(data);
1019            match ws.recv().unwrap().unwrap() {
1020                Message::Text(s) => assert_eq!(s, "Hello"),
1021                other => panic!("expected Text, got {other:?}"),
1022            }
1023        }
1024
1025        #[test]
1026        fn recv_fragment_with_ping() {
1027            let mut data = make_frame(false, 0x1, b"Hel");
1028            data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
1029            data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1030            let mut ws = ws_from_bytes(data);
1031            match ws.recv().unwrap().unwrap() {
1032                Message::Ping(p) => assert_eq!(p, b"ping"),
1033                other => panic!("expected Ping, got {other:?}"),
1034            }
1035            match ws.recv().unwrap().unwrap() {
1036                Message::Text(s) => assert_eq!(s, "Hello"),
1037                other => panic!("expected Text, got {other:?}"),
1038            }
1039        }
1040
1041        #[test]
1042        fn recv_close() {
1043            let mut payload = vec![];
1044            payload.extend_from_slice(&1000u16.to_be_bytes());
1045            payload.extend_from_slice(b"bye");
1046            let frame = make_frame(true, 0x8, &payload);
1047            let mut ws = ws_from_bytes(frame);
1048            match ws.recv().unwrap().unwrap() {
1049                Message::Close(cf) => {
1050                    assert_eq!(cf.code, CloseCode::Normal);
1051                    assert_eq!(cf.reason, "bye");
1052                }
1053                other => panic!("expected Close, got {other:?}"),
1054            }
1055        }
1056
1057        #[test]
1058        fn eof_returns_none() {
1059            let mut ws = ws_from_bytes(Vec::new());
1060            assert!(ws.recv().unwrap().is_none());
1061        }
1062
1063        #[test]
1064        fn would_block_returns_none() {
1065            struct WouldBlockStream;
1066            impl Read for WouldBlockStream {
1067                fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
1068                    Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
1069                }
1070            }
1071            impl Write for WouldBlockStream {
1072                fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1073                    Ok(buf.len())
1074                }
1075                fn flush(&mut self) -> io::Result<()> {
1076                    Ok(())
1077                }
1078            }
1079
1080            let reader = FrameReader::builder().role(Role::Client).build();
1081            let writer = FrameWriter::new(Role::Client);
1082            let mut ws = Client::from_parts(WouldBlockStream, reader, writer);
1083            assert!(ws.recv().unwrap().is_none());
1084        }
1085    }
1086}