Skip to main content

nexus_async_net/ws/tokio/
stream.rs

1//! Async WebSocket — tokio backend.
2//!
3//! Builder + handshake logic. `WsReader`/`WsWriter` (the primary API)
4//! live in the shared `ws::parts` module.
5
6use std::pin::Pin;
7
8use nexus_net::WireStream;
9use nexus_net::buf::WriteBuf;
10use nexus_net::http::HTTP_HANDSHAKE_BUFFER;
11#[cfg(feature = "tls")]
12use nexus_net::tls::TlsConfig;
13use nexus_net::ws::{
14    CloseCode, Error as WsError, FrameReader, FrameReaderBuilder, FrameWriter, HandshakeError,
15    Role, parse_ws_url,
16};
17use tokio::net::TcpStream;
18
19use crate::maybe_tls::MaybeTls;
20use crate::ws::parts::{WsReader, WsWriter, fill_async, write_all_async};
21
22// =============================================================================
23// Handshake — standalone async functions
24// =============================================================================
25
26async fn connect_handshake<S: WireStream + Unpin>(
27    mut stream: S,
28    url: &str,
29    reader_builder: FrameReaderBuilder,
30    write_cap: usize,
31    max_read_size: usize,
32) -> Result<(WsReader, WsWriter, S), WsError> {
33    let parsed = parse_ws_url(url)?;
34    let host_header = parsed.host_header();
35
36    let key = nexus_net::ws::handshake::generate_key();
37    let key_str =
38        std::str::from_utf8(&key).expect("base64-encoded key is always valid ASCII/UTF-8");
39
40    let headers: [(&str, &str); 5] = [
41        ("Host", &host_header),
42        ("Upgrade", "websocket"),
43        ("Connection", "Upgrade"),
44        ("Sec-WebSocket-Key", key_str),
45        ("Sec-WebSocket-Version", "13"),
46    ];
47    let req_size = nexus_net::http::request_size("GET", parsed.path, &headers);
48    let mut req_buf = vec![0u8; req_size];
49    let n = nexus_net::http::write_request("GET", parsed.path, &headers, &mut req_buf)
50        .map_err(|_| HandshakeError::MalformedHttp)?;
51
52    write_all_async(&mut stream, &req_buf[..n]).await?;
53
54    let mut resp_reader = nexus_net::http::ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
55    loop {
56        if resp_reader.spare().is_empty() {
57            return Err(HandshakeError::MalformedHttp.into());
58        }
59        let n = fill_async(&mut stream, &mut resp_reader, HTTP_HANDSHAKE_BUFFER).await?;
60        if n == 0 {
61            return Err(HandshakeError::MalformedHttp.into());
62        }
63        match resp_reader.next() {
64            Ok(Some(resp)) => {
65                if resp.status != 101 {
66                    return Err(HandshakeError::UnexpectedStatus(resp.status).into());
67                }
68                let upgrade = resp
69                    .header("Upgrade")
70                    .ok_or(HandshakeError::MissingUpgrade)?;
71                if !upgrade.eq_ignore_ascii_case("websocket") {
72                    return Err(HandshakeError::MissingUpgrade.into());
73                }
74                let conn = resp
75                    .header("Connection")
76                    .ok_or(HandshakeError::MissingConnection)?;
77                if !conn
78                    .as_bytes()
79                    .windows(7)
80                    .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
81                {
82                    return Err(HandshakeError::MissingConnection.into());
83                }
84                let accept = resp
85                    .header("Sec-WebSocket-Accept")
86                    .ok_or(HandshakeError::InvalidAcceptKey)?;
87                if !nexus_net::ws::handshake::validate_accept(key_str, accept) {
88                    return Err(HandshakeError::InvalidAcceptKey.into());
89                }
90
91                let mut reader = reader_builder.role(Role::Client).build();
92                let remainder = resp_reader.remainder();
93                if !remainder.is_empty() {
94                    reader
95                        .read(remainder)
96                        .map_err(|_| HandshakeError::MalformedHttp)?;
97                }
98
99                return Ok((
100                    WsReader {
101                        reader,
102                        max_read_size,
103                    },
104                    WsWriter {
105                        writer: FrameWriter::new(Role::Client),
106                        write_buf: WriteBuf::new(write_cap, 14),
107                    },
108                    stream,
109                ));
110            }
111            Ok(None) => {}
112            Err(_) => return Err(HandshakeError::MalformedHttp.into()),
113        }
114    }
115}
116
117async fn accept_handshake<S: WireStream + Unpin>(
118    mut stream: S,
119    reader_builder: FrameReaderBuilder,
120    write_cap: usize,
121    max_read_size: usize,
122) -> Result<(WsReader, WsWriter, S), WsError> {
123    let mut req_reader = nexus_net::http::RequestReader::new(HTTP_HANDSHAKE_BUFFER);
124
125    let ws_key;
126    loop {
127        if req_reader.spare().is_empty() {
128            return Err(HandshakeError::MalformedHttp.into());
129        }
130        let n = fill_async(&mut stream, &mut req_reader, HTTP_HANDSHAKE_BUFFER).await?;
131        if n == 0 {
132            return Err(HandshakeError::MalformedHttp.into());
133        }
134        match req_reader.next() {
135            Ok(Some(req)) => {
136                if req.method != "GET" {
137                    return Err(HandshakeError::MalformedHttp.into());
138                }
139                let upgrade = req
140                    .header("Upgrade")
141                    .ok_or(HandshakeError::MissingUpgrade)?;
142                if !upgrade.eq_ignore_ascii_case("websocket") {
143                    return Err(HandshakeError::MissingUpgrade.into());
144                }
145                let conn = req
146                    .header("Connection")
147                    .ok_or(HandshakeError::MissingConnection)?;
148                if !conn
149                    .as_bytes()
150                    .windows(7)
151                    .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
152                {
153                    return Err(HandshakeError::MissingConnection.into());
154                }
155                let version = req
156                    .header("Sec-WebSocket-Version")
157                    .ok_or(HandshakeError::UnsupportedVersion)?;
158                if version != "13" {
159                    return Err(HandshakeError::UnsupportedVersion.into());
160                }
161                let key = req
162                    .header("Sec-WebSocket-Key")
163                    .ok_or(HandshakeError::MissingKey)?;
164                ws_key = key.to_owned();
165                break;
166            }
167            Ok(None) => {}
168            Err(_) => return Err(HandshakeError::MalformedHttp.into()),
169        }
170    }
171
172    let accept = nexus_net::ws::handshake::compute_accept_key(&ws_key);
173    let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
174
175    let resp_headers = [
176        ("Upgrade", "websocket"),
177        ("Connection", "Upgrade"),
178        ("Sec-WebSocket-Accept", accept_str),
179    ];
180    let resp_size = nexus_net::http::response_size("Switching Protocols", &resp_headers);
181    let mut resp_buf = vec![0u8; resp_size];
182    let n =
183        nexus_net::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
184            .map_err(|_| HandshakeError::MalformedHttp)?;
185    write_all_async(&mut stream, &resp_buf[..n]).await?;
186
187    let mut reader = reader_builder.role(Role::Server).build();
188    let remainder = req_reader.remainder();
189    if !remainder.is_empty() {
190        reader
191            .read(remainder)
192            .map_err(|_| HandshakeError::MalformedHttp)?;
193    }
194
195    Ok((
196        WsReader {
197            reader,
198            max_read_size,
199        },
200        WsWriter {
201            writer: FrameWriter::new(Role::Server),
202            write_buf: WriteBuf::new(write_cap, 14),
203        },
204        stream,
205    ))
206}
207
208// =============================================================================
209// WsStream — Stream/Sink ecosystem adapter
210// =============================================================================
211
212/// Bundled WebSocket stream for `Stream`/`Sink` ecosystem compatibility.
213///
214/// This type exists solely to implement `futures_core::Stream` and
215/// `futures_sink::Sink<OwnedMessage>`. It uses owned messages
216/// (allocates per message) and cannot overlap read/write borrows.
217///
218/// For performance-sensitive code, use [`WsReader`] and [`WsWriter`]
219/// directly — they provide zero-copy messages and independent borrows.
220///
221/// # Construction
222///
223/// ```ignore
224/// // Connect, then bundle for Stream/Sink usage
225/// let (reader, writer, conn) = WsStreamBuilder::new()
226///     .connect("ws://localhost:8080/ws")
227///     .await?;
228/// let ws = WsStream::from_parts(reader, writer, conn);
229/// ```
230pub struct WsStream<S> {
231    stream: S,
232    reader: FrameReader,
233    writer: FrameWriter,
234    write_buf: WriteBuf,
235    max_read_size: usize,
236}
237
238impl<S> WsStream<S> {
239    /// Construct from decomposed parts.
240    ///
241    /// Inverse of [`into_parts`](Self::into_parts).
242    pub fn from_parts(reader: WsReader, writer: WsWriter, stream: S) -> Self {
243        Self {
244            stream,
245            max_read_size: reader.max_read_size,
246            reader: reader.reader,
247            writer: writer.writer,
248            write_buf: writer.write_buf,
249        }
250    }
251
252    /// Decompose into reader, writer, and transport stream.
253    pub fn into_parts(self) -> (WsReader, WsWriter, S) {
254        (
255            WsReader {
256                reader: self.reader,
257                max_read_size: self.max_read_size,
258            },
259            WsWriter {
260                writer: self.writer,
261                write_buf: self.write_buf,
262            },
263            self.stream,
264        )
265    }
266
267    /// Low-level construction from raw nexus-net types.
268    ///
269    /// For testing or custom handshakes. Prefer [`from_parts`](Self::from_parts)
270    /// with builder-produced components.
271    pub fn from_raw_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
272        Self {
273            stream,
274            reader,
275            writer,
276            write_buf: WriteBuf::new(65_536, 14),
277            max_read_size: usize::MAX,
278        }
279    }
280}
281
282// =============================================================================
283// Builder
284// =============================================================================
285
286/// Builder for WebSocket connections.
287///
288/// Returns `(WsReader, WsWriter, S)` — the decomposed sans-IO types.
289///
290/// # Example
291///
292/// ```ignore
293/// let (mut reader, mut writer, mut conn) = WsStreamBuilder::new()
294///     .disable_nagle()
295///     .connect("ws://localhost:8080/ws")
296///     .await?;
297/// ```
298pub struct WsStreamBuilder {
299    reader_builder: FrameReaderBuilder,
300    write_buf_capacity: usize,
301    buffer_capacity: usize,
302    max_read_size: Option<usize>,
303    #[cfg(feature = "tls")]
304    tls_config: Option<TlsConfig>,
305    nodelay: bool,
306    connect_timeout: Option<std::time::Duration>,
307    #[cfg(feature = "socket-opts")]
308    tcp_keepalive: Option<std::time::Duration>,
309    #[cfg(feature = "socket-opts")]
310    recv_buf_size: Option<usize>,
311    #[cfg(feature = "socket-opts")]
312    send_buf_size: Option<usize>,
313}
314
315const DEFAULT_BUFFER_CAPACITY: usize = 1024 * 1024;
316
317impl WsStreamBuilder {
318    /// Create a new builder with defaults.
319    #[must_use]
320    pub fn new() -> Self {
321        Self {
322            reader_builder: FrameReader::builder(),
323            write_buf_capacity: 65_536,
324            buffer_capacity: DEFAULT_BUFFER_CAPACITY,
325            max_read_size: None,
326            #[cfg(feature = "tls")]
327            tls_config: None,
328            nodelay: false,
329            connect_timeout: None,
330            #[cfg(feature = "socket-opts")]
331            tcp_keepalive: None,
332            #[cfg(feature = "socket-opts")]
333            recv_buf_size: None,
334            #[cfg(feature = "socket-opts")]
335            send_buf_size: None,
336        }
337    }
338
339    fn resolved_max_read_size(&self) -> usize {
340        self.max_read_size.map_or_else(
341            || (self.buffer_capacity / 8).max(1),
342            |n| n.min(self.buffer_capacity).max(1),
343        )
344    }
345
346    /// ReadBuf capacity. Default: 1MB.
347    #[must_use]
348    pub fn buffer_capacity(mut self, n: usize) -> Self {
349        self.buffer_capacity = n;
350        self.reader_builder = self.reader_builder.buffer_capacity(n);
351        self
352    }
353
354    /// Maximum bytes to read from the transport per recv call.
355    ///
356    /// Caps the slice passed to the underlying read, bounding the worst-case
357    /// memcpy per message. Lower values reduce tail latency at the cost of
358    /// more frequent reads.
359    ///
360    /// Default: 1/8 of buffer capacity. Clamped to `[1, buffer_capacity]`.
361    ///
362    /// **Note:** This only affects `WsReader::recv()`. The `Stream`
363    /// implementation on [`WsStream`] uses the full spare slice for
364    /// compatibility with `StreamExt` combinators. For latency-sensitive
365    /// code, use `WsReader::recv()` directly.
366    #[must_use]
367    pub fn max_read_size(mut self, n: usize) -> Self {
368        self.max_read_size = Some(n);
369        self
370    }
371
372    /// Fraction of buffer capacity consumed before proactive compaction.
373    ///
374    /// See [`FrameReaderBuilder::compact_at`](nexus_net::ws::FrameReaderBuilder::compact_at)
375    /// for details. Default: 0.5.
376    ///
377    /// **Note:** This only affects `WsReader::recv()`. The `Stream`
378    /// implementation does not use proactive compaction.
379    #[must_use]
380    pub fn compact_at(mut self, fraction: f64) -> Self {
381        self.reader_builder = self.reader_builder.compact_at(fraction);
382        self
383    }
384
385    /// Maximum single frame payload. Default: 16MB.
386    #[must_use]
387    pub fn max_frame_size(mut self, n: u64) -> Self {
388        self.reader_builder = self.reader_builder.max_frame_size(n);
389        self
390    }
391
392    /// Maximum assembled message size. Default: 16MB.
393    #[must_use]
394    pub fn max_message_size(mut self, n: usize) -> Self {
395        self.reader_builder = self.reader_builder.max_message_size(n);
396        self
397    }
398
399    /// Write buffer capacity. Default: 64KB.
400    #[must_use]
401    pub fn write_buffer_capacity(mut self, n: usize) -> Self {
402        self.write_buf_capacity = n;
403        self
404    }
405
406    /// Custom TLS configuration.
407    #[cfg(feature = "tls")]
408    #[must_use]
409    pub fn tls(mut self, config: &TlsConfig) -> Self {
410        self.tls_config = Some(config.clone());
411        self
412    }
413
414    /// Set TCP_NODELAY.
415    #[must_use]
416    pub fn disable_nagle(mut self) -> Self {
417        self.nodelay = true;
418        self
419    }
420
421    /// TCP connect timeout.
422    #[must_use]
423    pub fn connect_timeout(mut self, d: std::time::Duration) -> Self {
424        self.connect_timeout = Some(d);
425        self
426    }
427
428    /// Set TCP keepalive idle time.
429    ///
430    /// Enables OS-level dead connection detection. The kernel sends
431    /// probes after `idle` of inactivity.
432    #[cfg(feature = "socket-opts")]
433    #[must_use]
434    pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
435        self.tcp_keepalive = Some(idle);
436        self
437    }
438
439    /// Set `SO_RCVBUF` (socket receive buffer size).
440    #[cfg(feature = "socket-opts")]
441    #[must_use]
442    pub fn recv_buffer_size(mut self, n: usize) -> Self {
443        self.recv_buf_size = Some(n);
444        self
445    }
446
447    /// Set `SO_SNDBUF` (socket send buffer size).
448    #[cfg(feature = "socket-opts")]
449    #[must_use]
450    pub fn send_buffer_size(mut self, n: usize) -> Self {
451        self.send_buf_size = Some(n);
452        self
453    }
454
455    /// Connect to a WebSocket server. Creates TCP socket, handles TLS.
456    pub async fn connect(self, url: &str) -> Result<(WsReader, WsWriter, MaybeTls), WsError> {
457        let parsed = parse_ws_url(url)?;
458        let addr = format!("{}:{}", parsed.host, parsed.port);
459
460        let tcp = match self.connect_timeout {
461            Some(timeout) => tokio::time::timeout(timeout, TcpStream::connect(&addr))
462                .await
463                .map_err(|_| {
464                    WsError::Io(std::io::Error::new(
465                        std::io::ErrorKind::TimedOut,
466                        "connect timeout",
467                    ))
468                })??,
469            None => TcpStream::connect(&addr).await?,
470        };
471        if self.nodelay {
472            tcp.set_nodelay(true)?;
473        }
474        #[cfg(feature = "socket-opts")]
475        self.apply_socket_opts(&tcp)?;
476
477        let stream = if parsed.tls {
478            #[cfg(feature = "tls")]
479            {
480                let tls_config = match &self.tls_config {
481                    Some(c) => c.clone(),
482                    None => TlsConfig::new().map_err(WsError::Tls)?,
483                };
484
485                let connector =
486                    tokio_rustls::TlsConnector::from(tls_config.client_config().clone());
487                let server_name =
488                    tokio_rustls::rustls::pki_types::ServerName::try_from(parsed.host.to_owned())
489                        .map_err(|_| {
490                        WsError::Tls(nexus_net::tls::TlsError::InvalidHostname(
491                            parsed.host.to_string(),
492                        ))
493                    })?;
494                let tls_stream = connector
495                    .connect(server_name, tcp)
496                    .await
497                    .map_err(|e| WsError::Tls(nexus_net::tls::TlsError::Io(e)))?;
498                MaybeTls::Tls(Box::new(tls_stream))
499            }
500            #[cfg(not(feature = "tls"))]
501            {
502                return Err(WsError::TlsNotEnabled);
503            }
504        } else {
505            MaybeTls::Plain(tcp)
506        };
507
508        let max_read_size = self.resolved_max_read_size();
509        connect_handshake(
510            stream,
511            url,
512            self.reader_builder,
513            self.write_buf_capacity,
514            max_read_size,
515        )
516        .await
517    }
518
519    /// Connect with a pre-connected async stream.
520    pub async fn connect_with<S: WireStream + Unpin>(
521        self,
522        stream: S,
523        url: &str,
524    ) -> Result<(WsReader, WsWriter, S), WsError> {
525        let max_read_size = self.resolved_max_read_size();
526        connect_handshake(
527            stream,
528            url,
529            self.reader_builder,
530            self.write_buf_capacity,
531            max_read_size,
532        )
533        .await
534    }
535
536    /// Accept an incoming WebSocket connection (server-side).
537    pub async fn accept<S: WireStream + Unpin>(
538        self,
539        stream: S,
540    ) -> Result<(WsReader, WsWriter, S), WsError> {
541        let max_read_size = self.resolved_max_read_size();
542        accept_handshake(
543            stream,
544            self.reader_builder,
545            self.write_buf_capacity,
546            max_read_size,
547        )
548        .await
549    }
550}
551
552#[cfg(feature = "socket-opts")]
553impl WsStreamBuilder {
554    fn apply_socket_opts(&self, tcp: &TcpStream) -> Result<(), WsError> {
555        let sock = socket2::SockRef::from(tcp);
556        if let Some(idle) = self.tcp_keepalive {
557            let keepalive = socket2::TcpKeepalive::new().with_time(idle);
558            sock.set_tcp_keepalive(&keepalive)?;
559        }
560        if let Some(size) = self.recv_buf_size {
561            sock.set_recv_buffer_size(size)?;
562        }
563        if let Some(size) = self.send_buf_size {
564            sock.set_send_buffer_size(size)?;
565        }
566        Ok(())
567    }
568}
569
570impl Default for WsStreamBuilder {
571    fn default() -> Self {
572        Self::new()
573    }
574}
575
576// =============================================================================
577// Stream + Sink (ecosystem compat — allocates per message)
578// =============================================================================
579
580use std::task::{Context, Poll};
581
582use futures_core::Stream;
583use futures_sink::Sink;
584use nexus_net::ws::OwnedMessage;
585
586impl<S: WireStream + Unpin> Stream for WsStream<S> {
587    type Item = Result<OwnedMessage, WsError>;
588
589    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
590        let this = self.get_mut();
591
592        loop {
593            match this.reader.poll() {
594                Ok(true) => {
595                    return match this.reader.next() {
596                        Ok(Some(msg)) => Poll::Ready(Some(Ok(msg.into_owned()))),
597                        Ok(None) => Poll::Ready(None),
598                        Err(e) => Poll::Ready(Some(Err(e.into()))),
599                    };
600                }
601                Ok(false) => {}
602                Err(e) => return Poll::Ready(Some(Err(e.into()))),
603            }
604
605            if this.reader.should_compact() {
606                this.reader.compact();
607            }
608            if this.reader.spare().is_empty() {
609                this.reader.compact();
610                if this.reader.spare().is_empty() {
611                    return Poll::Ready(Some(Err(std::io::Error::other(
612                        "websocket read buffer full",
613                    )
614                    .into())));
615                }
616            }
617
618            match Pin::new(&mut this.stream).poll_fill_into(
619                cx,
620                &mut this.reader,
621                this.max_read_size,
622            ) {
623                Poll::Ready(Ok(0)) => return Poll::Ready(None),
624                Poll::Ready(Ok(_)) => {}
625                Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
626                Poll::Pending => return Poll::Pending,
627            }
628        }
629    }
630}
631
632impl<S: WireStream + Unpin> Sink<OwnedMessage> for WsStream<S> {
633    type Error = WsError;
634
635    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
636        Poll::Ready(Ok(()))
637    }
638
639    fn start_send(self: Pin<&mut Self>, item: OwnedMessage) -> Result<(), Self::Error> {
640        let this = self.get_mut();
641        match &item {
642            OwnedMessage::Text(s) => {
643                this.writer
644                    .encode_text_into(s.as_bytes(), &mut this.write_buf);
645            }
646            OwnedMessage::Binary(b) => {
647                this.writer.encode_binary_into(b, &mut this.write_buf);
648            }
649            OwnedMessage::Ping(b) => {
650                this.writer
651                    .encode_ping_into(b, &mut this.write_buf)
652                    .map_err(WsError::Encode)?;
653            }
654            OwnedMessage::Pong(b) => {
655                this.writer
656                    .encode_pong_into(b, &mut this.write_buf)
657                    .map_err(WsError::Encode)?;
658            }
659            OwnedMessage::Close(cf) => {
660                if cf.code == CloseCode::NoStatus {
661                    let mut dst = [0u8; 14];
662                    let n = this.writer.encode_empty_close(&mut dst);
663                    this.write_buf.clear();
664                    this.write_buf.append(&dst[..n]);
665                } else {
666                    this.writer
667                        .encode_close_into(
668                            cf.code.as_u16(),
669                            cf.reason.as_bytes(),
670                            &mut this.write_buf,
671                        )
672                        .map_err(WsError::Encode)?;
673                }
674            }
675        }
676        Ok(())
677    }
678
679    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
680        let this = self.get_mut();
681        while !this.write_buf.is_empty() {
682            let data = this.write_buf.data();
683            match Pin::new(&mut this.stream).poll_write(cx, data) {
684                Poll::Ready(Ok(0)) => {
685                    return Poll::Ready(Err(WsError::Io(std::io::Error::new(
686                        std::io::ErrorKind::WriteZero,
687                        "write returned 0",
688                    ))));
689                }
690                Poll::Ready(Ok(n)) => {
691                    this.write_buf.advance(n);
692                }
693                Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
694                Poll::Pending => return Poll::Pending,
695            }
696        }
697        Pin::new(&mut this.stream)
698            .poll_flush(cx)
699            .map_err(WsError::Io)
700    }
701
702    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
703        match <Self as Sink<OwnedMessage>>::poll_flush(self.as_mut(), cx) {
704            Poll::Pending => return Poll::Pending,
705            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
706            Poll::Ready(Ok(())) => {}
707        }
708        let this = self.get_mut();
709        Pin::new(&mut this.stream)
710            .poll_shutdown(cx)
711            .map_err(WsError::Io)
712    }
713}
714
715// =============================================================================
716// Tests
717// =============================================================================
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use crate::AsyncReadAdapter;
723    use nexus_net::ws::Message;
724    use std::io::Cursor;
725    use std::pin::Pin;
726    use std::task::{Context, Poll};
727    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
728
729    struct MockStream(Cursor<Vec<u8>>);
730
731    impl AsyncRead for MockStream {
732        fn poll_read(
733            mut self: Pin<&mut Self>,
734            _cx: &mut Context<'_>,
735            buf: &mut ReadBuf<'_>,
736        ) -> Poll<std::io::Result<()>> {
737            let n = std::io::Read::read(&mut self.0, buf.initialize_unfilled())?;
738            buf.advance(n);
739            Poll::Ready(Ok(()))
740        }
741    }
742
743    impl AsyncWrite for MockStream {
744        fn poll_write(
745            self: Pin<&mut Self>,
746            _cx: &mut Context<'_>,
747            buf: &[u8],
748        ) -> Poll<std::io::Result<usize>> {
749            Poll::Ready(Ok(buf.len()))
750        }
751        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
752            Poll::Ready(Ok(()))
753        }
754        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
755            Poll::Ready(Ok(()))
756        }
757    }
758
759    fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
760        let mut frame = Vec::new();
761        let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
762        frame.push(byte0);
763        if payload.len() <= 125 {
764            frame.push(payload.len() as u8);
765        } else if payload.len() <= 65535 {
766            frame.push(126);
767            frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
768        } else {
769            frame.push(127);
770            frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
771        }
772        frame.extend_from_slice(payload);
773        frame
774    }
775
776    fn parts_from_bytes(data: Vec<u8>) -> (WsReader, WsWriter, AsyncReadAdapter<MockStream>) {
777        let mock = AsyncReadAdapter::new(MockStream(Cursor::new(data)));
778        let reader = FrameReader::builder().role(Role::Client).build();
779        let writer = FrameWriter::new(Role::Client);
780        let ws = WsStream::from_raw_parts(mock, reader, writer);
781        ws.into_parts()
782    }
783
784    // -- Primary API tests (WsReader / WsWriter) -----------------------------
785
786    #[tokio::test]
787    async fn recv_text() {
788        let frame = make_frame(true, 0x1, b"Hello");
789        let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
790        match reader.recv(&mut conn).await.unwrap().unwrap() {
791            Message::Text(s) => assert_eq!(s, "Hello"),
792            other => panic!("expected Text, got {other:?}"),
793        }
794    }
795
796    #[tokio::test]
797    async fn recv_binary() {
798        let frame = make_frame(true, 0x2, &[0x42; 100]);
799        let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
800        match reader.recv(&mut conn).await.unwrap().unwrap() {
801            Message::Binary(b) => assert_eq!(b.len(), 100),
802            other => panic!("expected Binary, got {other:?}"),
803        }
804    }
805
806    #[tokio::test]
807    async fn recv_ping() {
808        let frame = make_frame(true, 0x9, b"ping");
809        let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
810        match reader.recv(&mut conn).await.unwrap().unwrap() {
811            Message::Ping(p) => assert_eq!(p, b"ping"),
812            other => panic!("expected Ping, got {other:?}"),
813        }
814    }
815
816    #[tokio::test]
817    async fn recv_fragmented_text() {
818        let mut data = make_frame(false, 0x1, b"Hel");
819        data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
820        let (mut reader, _writer, mut conn) = parts_from_bytes(data);
821        match reader.recv(&mut conn).await.unwrap().unwrap() {
822            Message::Text(s) => assert_eq!(s, "Hello"),
823            other => panic!("expected Text, got {other:?}"),
824        }
825    }
826
827    #[tokio::test]
828    async fn recv_fragment_with_control() {
829        let mut data = make_frame(false, 0x1, b"Hel");
830        data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
831        data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
832        let (mut reader, _writer, mut conn) = parts_from_bytes(data);
833        match reader.recv(&mut conn).await.unwrap().unwrap() {
834            Message::Ping(p) => assert_eq!(p, b"ping"),
835            other => panic!("expected Ping, got {other:?}"),
836        }
837        match reader.recv(&mut conn).await.unwrap().unwrap() {
838            Message::Text(s) => assert_eq!(s, "Hello"),
839            other => panic!("expected Text, got {other:?}"),
840        }
841    }
842
843    #[tokio::test]
844    async fn recv_close() {
845        let mut payload = vec![];
846        payload.extend_from_slice(&1000u16.to_be_bytes());
847        payload.extend_from_slice(b"bye");
848        let frame = make_frame(true, 0x8, &payload);
849        let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
850        match reader.recv(&mut conn).await.unwrap().unwrap() {
851            Message::Close(cf) => {
852                assert_eq!(cf.code, CloseCode::Normal);
853                assert_eq!(cf.reason, "bye");
854            }
855            other => panic!("expected Close, got {other:?}"),
856        }
857    }
858
859    #[tokio::test]
860    async fn eof_returns_none() {
861        let (mut reader, _writer, mut conn) = parts_from_bytes(Vec::new());
862        assert!(reader.recv(&mut conn).await.unwrap().is_none());
863    }
864
865    #[tokio::test]
866    async fn fifo_three_messages() {
867        let mut data = make_frame(true, 0x1, b"first");
868        data.extend_from_slice(&make_frame(true, 0x1, b"second"));
869        data.extend_from_slice(&make_frame(true, 0x1, b"third"));
870        let (mut reader, _writer, mut conn) = parts_from_bytes(data);
871
872        match reader.recv(&mut conn).await.unwrap().unwrap() {
873            Message::Text(s) => assert_eq!(s, "first"),
874            other => panic!("expected first, got {other:?}"),
875        }
876        match reader.recv(&mut conn).await.unwrap().unwrap() {
877            Message::Text(s) => assert_eq!(s, "second"),
878            other => panic!("expected second, got {other:?}"),
879        }
880        match reader.recv(&mut conn).await.unwrap().unwrap() {
881            Message::Text(s) => assert_eq!(s, "third"),
882            other => panic!("expected third, got {other:?}"),
883        }
884    }
885
886    #[tokio::test]
887    async fn ping_echo_split_borrow() {
888        let mut data = make_frame(true, 0x9, b"ping-data");
889        data.extend_from_slice(&make_frame(true, 0x1, b"hello"));
890        let (mut reader, mut writer, mut conn) = parts_from_bytes(data);
891
892        match reader.recv(&mut conn).await.unwrap().unwrap() {
893            Message::Ping(payload) => {
894                writer.send_pong(&mut conn, payload).await.unwrap();
895            }
896            other => panic!("expected Ping, got {other:?}"),
897        }
898
899        match reader.recv(&mut conn).await.unwrap().unwrap() {
900            Message::Text(s) => assert_eq!(s, "hello"),
901            other => panic!("expected Text, got {other:?}"),
902        }
903    }
904
905    #[tokio::test]
906    async fn text_response_while_holding_message() {
907        let data = make_frame(true, 0x1, b"request");
908        let (mut reader, mut writer, mut conn) = parts_from_bytes(data);
909
910        match reader.recv(&mut conn).await.unwrap().unwrap() {
911            Message::Text(req) => {
912                assert_eq!(req, "request");
913                let response = format!("echo: {req}");
914                writer.send_text(&mut conn, &response).await.unwrap();
915            }
916            other => panic!("expected Text, got {other:?}"),
917        }
918    }
919
920    // -- Stream/Sink tests (WsStream) ----------------------------------------
921
922    #[tokio::test]
923    async fn stream_yields_owned_messages() {
924        use std::pin::pin;
925
926        let mut data = make_frame(true, 0x1, b"hello");
927        data.extend_from_slice(&make_frame(true, 0x2, &[0x42]));
928        let (reader, writer, conn) = parts_from_bytes(data);
929        let ws = WsStream::from_parts(reader, writer, conn);
930        let mut ws = pin!(ws);
931
932        let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
933        match poll_result {
934            Poll::Ready(Some(Ok(OwnedMessage::Text(s)))) => assert_eq!(s, "hello"),
935            other => panic!("expected Text, got {other:?}"),
936        }
937        let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
938        match poll_result {
939            Poll::Ready(Some(Ok(OwnedMessage::Binary(b)))) => assert_eq!(b, vec![0x42]),
940            other => panic!("expected Binary, got {other:?}"),
941        }
942        let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
943        assert!(matches!(poll_result, Poll::Ready(None)));
944    }
945
946    fn noop_cx() -> Context<'static> {
947        use std::task::{RawWaker, RawWakerVTable, Waker};
948        const VTABLE: RawWakerVTable =
949            RawWakerVTable::new(|p| RawWaker::new(p, &VTABLE), |_| {}, |_| {}, |_| {});
950        // SAFETY: The vtable functions (clone/wake/wake_by_ref/drop) are all no-ops
951        // that never dereference the data pointer, so the null data pointer is sound.
952        // The vtable is 'static (const) and correctly returns a valid RawWaker on clone.
953        let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
954        let waker = Box::leak(Box::new(waker));
955        Context::from_waker(waker)
956    }
957
958    #[tokio::test]
959    async fn accept_server_side() {
960        use tokio::net::TcpListener;
961
962        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
963        let addr = listener.local_addr().unwrap();
964
965        let server = tokio::spawn(async move {
966            let (tcp, _) = listener.accept().await.unwrap();
967            let (mut reader, mut writer, mut conn) = WsStreamBuilder::new()
968                .accept(AsyncReadAdapter::new(tcp))
969                .await
970                .unwrap();
971            match reader.recv(&mut conn).await.unwrap().unwrap() {
972                Message::Text(s) => assert_eq!(s, "hello from client"),
973                other => panic!("expected Text, got {other:?}"),
974            }
975            writer
976                .send_text(&mut conn, "hello from server")
977                .await
978                .unwrap();
979        });
980
981        let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
982        let url = format!("ws://127.0.0.1:{}/ws", addr.port());
983        let (mut reader, mut writer, mut conn) = WsStreamBuilder::new()
984            .connect_with(AsyncReadAdapter::new(tcp), &url)
985            .await
986            .unwrap();
987
988        writer
989            .send_text(&mut conn, "hello from client")
990            .await
991            .unwrap();
992
993        match reader.recv(&mut conn).await.unwrap().unwrap() {
994            Message::Text(s) => assert_eq!(s, "hello from server"),
995            other => panic!("expected Text, got {other:?}"),
996        }
997
998        server.await.unwrap();
999    }
1000
1001    struct BrokenWriteStream(Cursor<Vec<u8>>);
1002
1003    impl AsyncRead for BrokenWriteStream {
1004        fn poll_read(
1005            mut self: Pin<&mut Self>,
1006            _cx: &mut Context<'_>,
1007            buf: &mut ReadBuf<'_>,
1008        ) -> Poll<std::io::Result<()>> {
1009            let n = std::io::Read::read(&mut self.0, buf.initialize_unfilled())?;
1010            buf.advance(n);
1011            Poll::Ready(Ok(()))
1012        }
1013    }
1014
1015    impl AsyncWrite for BrokenWriteStream {
1016        fn poll_write(
1017            self: Pin<&mut Self>,
1018            _cx: &mut Context<'_>,
1019            _buf: &[u8],
1020        ) -> Poll<std::io::Result<usize>> {
1021            Poll::Ready(Err(std::io::Error::new(
1022                std::io::ErrorKind::BrokenPipe,
1023                "connection lost",
1024            )))
1025        }
1026        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1027            Poll::Ready(Ok(()))
1028        }
1029        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1030            Poll::Ready(Ok(()))
1031        }
1032    }
1033
1034    #[tokio::test]
1035    async fn send_on_broken_stream_returns_error() {
1036        let mock = AsyncReadAdapter::new(BrokenWriteStream(Cursor::new(Vec::new())));
1037        let reader = FrameReader::builder().role(Role::Client).build();
1038        let writer = FrameWriter::new(Role::Client);
1039        let (_, mut ws_writer, mut conn) =
1040            WsStream::from_raw_parts(mock, reader, writer).into_parts();
1041
1042        let result = ws_writer.send_text(&mut conn, "hello").await;
1043        assert!(result.is_err(), "send on broken stream should return error");
1044
1045        let result = ws_writer.send_binary(&mut conn, &[1, 2, 3]).await;
1046        assert!(result.is_err(), "subsequent send should also fail");
1047    }
1048
1049    #[tokio::test]
1050    async fn from_parts_roundtrip() {
1051        let data = make_frame(true, 0x1, b"test");
1052        let (reader, writer, conn) = parts_from_bytes(data);
1053        let ws = WsStream::from_parts(reader, writer, conn);
1054        let (mut reader, _writer, mut conn) = ws.into_parts();
1055
1056        match reader.recv(&mut conn).await.unwrap().unwrap() {
1057            Message::Text(s) => assert_eq!(s, "test"),
1058            other => panic!("expected Text, got {other:?}"),
1059        }
1060    }
1061}