1use std::pin::Pin;
7
8use nexus_net::WireStream;
9use nexus_net::buf::WriteBuf;
10use nexus_net::http::HTTP_HANDSHAKE_BUFFER;
11#[cfg(feature = "tls")]
12use nexus_net::tls::TlsConfig;
13use nexus_net::ws::{
14 CloseCode, Error as WsError, FrameReader, FrameReaderBuilder, FrameWriter, HandshakeError,
15 Role, parse_ws_url,
16};
17use tokio::net::TcpStream;
18
19use crate::maybe_tls::MaybeTls;
20use crate::ws::parts::{WsReader, WsWriter, fill_async, write_all_async};
21
22async fn connect_handshake<S: WireStream + Unpin>(
27 mut stream: S,
28 url: &str,
29 reader_builder: FrameReaderBuilder,
30 write_cap: usize,
31 max_read_size: usize,
32) -> Result<(WsReader, WsWriter, S), WsError> {
33 let parsed = parse_ws_url(url)?;
34 let host_header = parsed.host_header();
35
36 let key = nexus_net::ws::handshake::generate_key();
37 let key_str =
38 std::str::from_utf8(&key).expect("base64-encoded key is always valid ASCII/UTF-8");
39
40 let headers: [(&str, &str); 5] = [
41 ("Host", &host_header),
42 ("Upgrade", "websocket"),
43 ("Connection", "Upgrade"),
44 ("Sec-WebSocket-Key", key_str),
45 ("Sec-WebSocket-Version", "13"),
46 ];
47 let req_size = nexus_net::http::request_size("GET", parsed.path, &headers);
48 let mut req_buf = vec![0u8; req_size];
49 let n = nexus_net::http::write_request("GET", parsed.path, &headers, &mut req_buf)
50 .map_err(|_| HandshakeError::MalformedHttp)?;
51
52 write_all_async(&mut stream, &req_buf[..n]).await?;
53
54 let mut resp_reader = nexus_net::http::ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
55 loop {
56 if resp_reader.spare().is_empty() {
57 return Err(HandshakeError::MalformedHttp.into());
58 }
59 let n = fill_async(&mut stream, &mut resp_reader, HTTP_HANDSHAKE_BUFFER).await?;
60 if n == 0 {
61 return Err(HandshakeError::MalformedHttp.into());
62 }
63 match resp_reader.next() {
64 Ok(Some(resp)) => {
65 if resp.status != 101 {
66 return Err(HandshakeError::UnexpectedStatus(resp.status).into());
67 }
68 let upgrade = resp
69 .header("Upgrade")
70 .ok_or(HandshakeError::MissingUpgrade)?;
71 if !upgrade.eq_ignore_ascii_case("websocket") {
72 return Err(HandshakeError::MissingUpgrade.into());
73 }
74 let conn = resp
75 .header("Connection")
76 .ok_or(HandshakeError::MissingConnection)?;
77 if !conn
78 .as_bytes()
79 .windows(7)
80 .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
81 {
82 return Err(HandshakeError::MissingConnection.into());
83 }
84 let accept = resp
85 .header("Sec-WebSocket-Accept")
86 .ok_or(HandshakeError::InvalidAcceptKey)?;
87 if !nexus_net::ws::handshake::validate_accept(key_str, accept) {
88 return Err(HandshakeError::InvalidAcceptKey.into());
89 }
90
91 let mut reader = reader_builder.role(Role::Client).build();
92 let remainder = resp_reader.remainder();
93 if !remainder.is_empty() {
94 reader
95 .read(remainder)
96 .map_err(|_| HandshakeError::MalformedHttp)?;
97 }
98
99 return Ok((
100 WsReader {
101 reader,
102 max_read_size,
103 },
104 WsWriter {
105 writer: FrameWriter::new(Role::Client),
106 write_buf: WriteBuf::new(write_cap, 14),
107 },
108 stream,
109 ));
110 }
111 Ok(None) => {}
112 Err(_) => return Err(HandshakeError::MalformedHttp.into()),
113 }
114 }
115}
116
117async fn accept_handshake<S: WireStream + Unpin>(
118 mut stream: S,
119 reader_builder: FrameReaderBuilder,
120 write_cap: usize,
121 max_read_size: usize,
122) -> Result<(WsReader, WsWriter, S), WsError> {
123 let mut req_reader = nexus_net::http::RequestReader::new(HTTP_HANDSHAKE_BUFFER);
124
125 let ws_key;
126 loop {
127 if req_reader.spare().is_empty() {
128 return Err(HandshakeError::MalformedHttp.into());
129 }
130 let n = fill_async(&mut stream, &mut req_reader, HTTP_HANDSHAKE_BUFFER).await?;
131 if n == 0 {
132 return Err(HandshakeError::MalformedHttp.into());
133 }
134 match req_reader.next() {
135 Ok(Some(req)) => {
136 if req.method != "GET" {
137 return Err(HandshakeError::MalformedHttp.into());
138 }
139 let upgrade = req
140 .header("Upgrade")
141 .ok_or(HandshakeError::MissingUpgrade)?;
142 if !upgrade.eq_ignore_ascii_case("websocket") {
143 return Err(HandshakeError::MissingUpgrade.into());
144 }
145 let conn = req
146 .header("Connection")
147 .ok_or(HandshakeError::MissingConnection)?;
148 if !conn
149 .as_bytes()
150 .windows(7)
151 .any(|w| w.eq_ignore_ascii_case(b"upgrade"))
152 {
153 return Err(HandshakeError::MissingConnection.into());
154 }
155 let version = req
156 .header("Sec-WebSocket-Version")
157 .ok_or(HandshakeError::UnsupportedVersion)?;
158 if version != "13" {
159 return Err(HandshakeError::UnsupportedVersion.into());
160 }
161 let key = req
162 .header("Sec-WebSocket-Key")
163 .ok_or(HandshakeError::MissingKey)?;
164 ws_key = key.to_owned();
165 break;
166 }
167 Ok(None) => {}
168 Err(_) => return Err(HandshakeError::MalformedHttp.into()),
169 }
170 }
171
172 let accept = nexus_net::ws::handshake::compute_accept_key(&ws_key);
173 let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
174
175 let resp_headers = [
176 ("Upgrade", "websocket"),
177 ("Connection", "Upgrade"),
178 ("Sec-WebSocket-Accept", accept_str),
179 ];
180 let resp_size = nexus_net::http::response_size("Switching Protocols", &resp_headers);
181 let mut resp_buf = vec![0u8; resp_size];
182 let n =
183 nexus_net::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
184 .map_err(|_| HandshakeError::MalformedHttp)?;
185 write_all_async(&mut stream, &resp_buf[..n]).await?;
186
187 let mut reader = reader_builder.role(Role::Server).build();
188 let remainder = req_reader.remainder();
189 if !remainder.is_empty() {
190 reader
191 .read(remainder)
192 .map_err(|_| HandshakeError::MalformedHttp)?;
193 }
194
195 Ok((
196 WsReader {
197 reader,
198 max_read_size,
199 },
200 WsWriter {
201 writer: FrameWriter::new(Role::Server),
202 write_buf: WriteBuf::new(write_cap, 14),
203 },
204 stream,
205 ))
206}
207
208pub struct WsStream<S> {
231 stream: S,
232 reader: FrameReader,
233 writer: FrameWriter,
234 write_buf: WriteBuf,
235 max_read_size: usize,
236}
237
238impl<S> WsStream<S> {
239 pub fn from_parts(reader: WsReader, writer: WsWriter, stream: S) -> Self {
243 Self {
244 stream,
245 max_read_size: reader.max_read_size,
246 reader: reader.reader,
247 writer: writer.writer,
248 write_buf: writer.write_buf,
249 }
250 }
251
252 pub fn into_parts(self) -> (WsReader, WsWriter, S) {
254 (
255 WsReader {
256 reader: self.reader,
257 max_read_size: self.max_read_size,
258 },
259 WsWriter {
260 writer: self.writer,
261 write_buf: self.write_buf,
262 },
263 self.stream,
264 )
265 }
266
267 pub fn from_raw_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
272 Self {
273 stream,
274 reader,
275 writer,
276 write_buf: WriteBuf::new(65_536, 14),
277 max_read_size: usize::MAX,
278 }
279 }
280}
281
282pub struct WsStreamBuilder {
299 reader_builder: FrameReaderBuilder,
300 write_buf_capacity: usize,
301 buffer_capacity: usize,
302 max_read_size: Option<usize>,
303 #[cfg(feature = "tls")]
304 tls_config: Option<TlsConfig>,
305 nodelay: bool,
306 connect_timeout: Option<std::time::Duration>,
307 #[cfg(feature = "socket-opts")]
308 tcp_keepalive: Option<std::time::Duration>,
309 #[cfg(feature = "socket-opts")]
310 recv_buf_size: Option<usize>,
311 #[cfg(feature = "socket-opts")]
312 send_buf_size: Option<usize>,
313}
314
315const DEFAULT_BUFFER_CAPACITY: usize = 1024 * 1024;
316
317impl WsStreamBuilder {
318 #[must_use]
320 pub fn new() -> Self {
321 Self {
322 reader_builder: FrameReader::builder(),
323 write_buf_capacity: 65_536,
324 buffer_capacity: DEFAULT_BUFFER_CAPACITY,
325 max_read_size: None,
326 #[cfg(feature = "tls")]
327 tls_config: None,
328 nodelay: false,
329 connect_timeout: None,
330 #[cfg(feature = "socket-opts")]
331 tcp_keepalive: None,
332 #[cfg(feature = "socket-opts")]
333 recv_buf_size: None,
334 #[cfg(feature = "socket-opts")]
335 send_buf_size: None,
336 }
337 }
338
339 fn resolved_max_read_size(&self) -> usize {
340 self.max_read_size.map_or_else(
341 || (self.buffer_capacity / 8).max(1),
342 |n| n.min(self.buffer_capacity).max(1),
343 )
344 }
345
346 #[must_use]
348 pub fn buffer_capacity(mut self, n: usize) -> Self {
349 self.buffer_capacity = n;
350 self.reader_builder = self.reader_builder.buffer_capacity(n);
351 self
352 }
353
354 #[must_use]
367 pub fn max_read_size(mut self, n: usize) -> Self {
368 self.max_read_size = Some(n);
369 self
370 }
371
372 #[must_use]
380 pub fn compact_at(mut self, fraction: f64) -> Self {
381 self.reader_builder = self.reader_builder.compact_at(fraction);
382 self
383 }
384
385 #[must_use]
387 pub fn max_frame_size(mut self, n: u64) -> Self {
388 self.reader_builder = self.reader_builder.max_frame_size(n);
389 self
390 }
391
392 #[must_use]
394 pub fn max_message_size(mut self, n: usize) -> Self {
395 self.reader_builder = self.reader_builder.max_message_size(n);
396 self
397 }
398
399 #[must_use]
401 pub fn write_buffer_capacity(mut self, n: usize) -> Self {
402 self.write_buf_capacity = n;
403 self
404 }
405
406 #[cfg(feature = "tls")]
408 #[must_use]
409 pub fn tls(mut self, config: &TlsConfig) -> Self {
410 self.tls_config = Some(config.clone());
411 self
412 }
413
414 #[must_use]
416 pub fn disable_nagle(mut self) -> Self {
417 self.nodelay = true;
418 self
419 }
420
421 #[must_use]
423 pub fn connect_timeout(mut self, d: std::time::Duration) -> Self {
424 self.connect_timeout = Some(d);
425 self
426 }
427
428 #[cfg(feature = "socket-opts")]
433 #[must_use]
434 pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
435 self.tcp_keepalive = Some(idle);
436 self
437 }
438
439 #[cfg(feature = "socket-opts")]
441 #[must_use]
442 pub fn recv_buffer_size(mut self, n: usize) -> Self {
443 self.recv_buf_size = Some(n);
444 self
445 }
446
447 #[cfg(feature = "socket-opts")]
449 #[must_use]
450 pub fn send_buffer_size(mut self, n: usize) -> Self {
451 self.send_buf_size = Some(n);
452 self
453 }
454
455 pub async fn connect(self, url: &str) -> Result<(WsReader, WsWriter, MaybeTls), WsError> {
457 let parsed = parse_ws_url(url)?;
458 let addr = format!("{}:{}", parsed.host, parsed.port);
459
460 let tcp = match self.connect_timeout {
461 Some(timeout) => tokio::time::timeout(timeout, TcpStream::connect(&addr))
462 .await
463 .map_err(|_| {
464 WsError::Io(std::io::Error::new(
465 std::io::ErrorKind::TimedOut,
466 "connect timeout",
467 ))
468 })??,
469 None => TcpStream::connect(&addr).await?,
470 };
471 if self.nodelay {
472 tcp.set_nodelay(true)?;
473 }
474 #[cfg(feature = "socket-opts")]
475 self.apply_socket_opts(&tcp)?;
476
477 let stream = if parsed.tls {
478 #[cfg(feature = "tls")]
479 {
480 let tls_config = match &self.tls_config {
481 Some(c) => c.clone(),
482 None => TlsConfig::new().map_err(WsError::Tls)?,
483 };
484
485 let connector =
486 tokio_rustls::TlsConnector::from(tls_config.client_config().clone());
487 let server_name =
488 tokio_rustls::rustls::pki_types::ServerName::try_from(parsed.host.to_owned())
489 .map_err(|_| {
490 WsError::Tls(nexus_net::tls::TlsError::InvalidHostname(
491 parsed.host.to_string(),
492 ))
493 })?;
494 let tls_stream = connector
495 .connect(server_name, tcp)
496 .await
497 .map_err(|e| WsError::Tls(nexus_net::tls::TlsError::Io(e)))?;
498 MaybeTls::Tls(Box::new(tls_stream))
499 }
500 #[cfg(not(feature = "tls"))]
501 {
502 return Err(WsError::TlsNotEnabled);
503 }
504 } else {
505 MaybeTls::Plain(tcp)
506 };
507
508 let max_read_size = self.resolved_max_read_size();
509 connect_handshake(
510 stream,
511 url,
512 self.reader_builder,
513 self.write_buf_capacity,
514 max_read_size,
515 )
516 .await
517 }
518
519 pub async fn connect_with<S: WireStream + Unpin>(
521 self,
522 stream: S,
523 url: &str,
524 ) -> Result<(WsReader, WsWriter, S), WsError> {
525 let max_read_size = self.resolved_max_read_size();
526 connect_handshake(
527 stream,
528 url,
529 self.reader_builder,
530 self.write_buf_capacity,
531 max_read_size,
532 )
533 .await
534 }
535
536 pub async fn accept<S: WireStream + Unpin>(
538 self,
539 stream: S,
540 ) -> Result<(WsReader, WsWriter, S), WsError> {
541 let max_read_size = self.resolved_max_read_size();
542 accept_handshake(
543 stream,
544 self.reader_builder,
545 self.write_buf_capacity,
546 max_read_size,
547 )
548 .await
549 }
550}
551
552#[cfg(feature = "socket-opts")]
553impl WsStreamBuilder {
554 fn apply_socket_opts(&self, tcp: &TcpStream) -> Result<(), WsError> {
555 let sock = socket2::SockRef::from(tcp);
556 if let Some(idle) = self.tcp_keepalive {
557 let keepalive = socket2::TcpKeepalive::new().with_time(idle);
558 sock.set_tcp_keepalive(&keepalive)?;
559 }
560 if let Some(size) = self.recv_buf_size {
561 sock.set_recv_buffer_size(size)?;
562 }
563 if let Some(size) = self.send_buf_size {
564 sock.set_send_buffer_size(size)?;
565 }
566 Ok(())
567 }
568}
569
570impl Default for WsStreamBuilder {
571 fn default() -> Self {
572 Self::new()
573 }
574}
575
576use std::task::{Context, Poll};
581
582use futures_core::Stream;
583use futures_sink::Sink;
584use nexus_net::ws::OwnedMessage;
585
586impl<S: WireStream + Unpin> Stream for WsStream<S> {
587 type Item = Result<OwnedMessage, WsError>;
588
589 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
590 let this = self.get_mut();
591
592 loop {
593 match this.reader.poll() {
594 Ok(true) => {
595 return match this.reader.next() {
596 Ok(Some(msg)) => Poll::Ready(Some(Ok(msg.into_owned()))),
597 Ok(None) => Poll::Ready(None),
598 Err(e) => Poll::Ready(Some(Err(e.into()))),
599 };
600 }
601 Ok(false) => {}
602 Err(e) => return Poll::Ready(Some(Err(e.into()))),
603 }
604
605 if this.reader.should_compact() {
606 this.reader.compact();
607 }
608 if this.reader.spare().is_empty() {
609 this.reader.compact();
610 if this.reader.spare().is_empty() {
611 return Poll::Ready(Some(Err(std::io::Error::other(
612 "websocket read buffer full",
613 )
614 .into())));
615 }
616 }
617
618 match Pin::new(&mut this.stream).poll_fill_into(
619 cx,
620 &mut this.reader,
621 this.max_read_size,
622 ) {
623 Poll::Ready(Ok(0)) => return Poll::Ready(None),
624 Poll::Ready(Ok(_)) => {}
625 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
626 Poll::Pending => return Poll::Pending,
627 }
628 }
629 }
630}
631
632impl<S: WireStream + Unpin> Sink<OwnedMessage> for WsStream<S> {
633 type Error = WsError;
634
635 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
636 Poll::Ready(Ok(()))
637 }
638
639 fn start_send(self: Pin<&mut Self>, item: OwnedMessage) -> Result<(), Self::Error> {
640 let this = self.get_mut();
641 match &item {
642 OwnedMessage::Text(s) => {
643 this.writer
644 .encode_text_into(s.as_bytes(), &mut this.write_buf);
645 }
646 OwnedMessage::Binary(b) => {
647 this.writer.encode_binary_into(b, &mut this.write_buf);
648 }
649 OwnedMessage::Ping(b) => {
650 this.writer
651 .encode_ping_into(b, &mut this.write_buf)
652 .map_err(WsError::Encode)?;
653 }
654 OwnedMessage::Pong(b) => {
655 this.writer
656 .encode_pong_into(b, &mut this.write_buf)
657 .map_err(WsError::Encode)?;
658 }
659 OwnedMessage::Close(cf) => {
660 if cf.code == CloseCode::NoStatus {
661 let mut dst = [0u8; 14];
662 let n = this.writer.encode_empty_close(&mut dst);
663 this.write_buf.clear();
664 this.write_buf.append(&dst[..n]);
665 } else {
666 this.writer
667 .encode_close_into(
668 cf.code.as_u16(),
669 cf.reason.as_bytes(),
670 &mut this.write_buf,
671 )
672 .map_err(WsError::Encode)?;
673 }
674 }
675 }
676 Ok(())
677 }
678
679 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
680 let this = self.get_mut();
681 while !this.write_buf.is_empty() {
682 let data = this.write_buf.data();
683 match Pin::new(&mut this.stream).poll_write(cx, data) {
684 Poll::Ready(Ok(0)) => {
685 return Poll::Ready(Err(WsError::Io(std::io::Error::new(
686 std::io::ErrorKind::WriteZero,
687 "write returned 0",
688 ))));
689 }
690 Poll::Ready(Ok(n)) => {
691 this.write_buf.advance(n);
692 }
693 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
694 Poll::Pending => return Poll::Pending,
695 }
696 }
697 Pin::new(&mut this.stream)
698 .poll_flush(cx)
699 .map_err(WsError::Io)
700 }
701
702 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
703 match <Self as Sink<OwnedMessage>>::poll_flush(self.as_mut(), cx) {
704 Poll::Pending => return Poll::Pending,
705 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
706 Poll::Ready(Ok(())) => {}
707 }
708 let this = self.get_mut();
709 Pin::new(&mut this.stream)
710 .poll_shutdown(cx)
711 .map_err(WsError::Io)
712 }
713}
714
715#[cfg(test)]
720mod tests {
721 use super::*;
722 use crate::AsyncReadAdapter;
723 use nexus_net::ws::Message;
724 use std::io::Cursor;
725 use std::pin::Pin;
726 use std::task::{Context, Poll};
727 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
728
729 struct MockStream(Cursor<Vec<u8>>);
730
731 impl AsyncRead for MockStream {
732 fn poll_read(
733 mut self: Pin<&mut Self>,
734 _cx: &mut Context<'_>,
735 buf: &mut ReadBuf<'_>,
736 ) -> Poll<std::io::Result<()>> {
737 let n = std::io::Read::read(&mut self.0, buf.initialize_unfilled())?;
738 buf.advance(n);
739 Poll::Ready(Ok(()))
740 }
741 }
742
743 impl AsyncWrite for MockStream {
744 fn poll_write(
745 self: Pin<&mut Self>,
746 _cx: &mut Context<'_>,
747 buf: &[u8],
748 ) -> Poll<std::io::Result<usize>> {
749 Poll::Ready(Ok(buf.len()))
750 }
751 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
752 Poll::Ready(Ok(()))
753 }
754 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
755 Poll::Ready(Ok(()))
756 }
757 }
758
759 fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
760 let mut frame = Vec::new();
761 let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
762 frame.push(byte0);
763 if payload.len() <= 125 {
764 frame.push(payload.len() as u8);
765 } else if payload.len() <= 65535 {
766 frame.push(126);
767 frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
768 } else {
769 frame.push(127);
770 frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
771 }
772 frame.extend_from_slice(payload);
773 frame
774 }
775
776 fn parts_from_bytes(data: Vec<u8>) -> (WsReader, WsWriter, AsyncReadAdapter<MockStream>) {
777 let mock = AsyncReadAdapter::new(MockStream(Cursor::new(data)));
778 let reader = FrameReader::builder().role(Role::Client).build();
779 let writer = FrameWriter::new(Role::Client);
780 let ws = WsStream::from_raw_parts(mock, reader, writer);
781 ws.into_parts()
782 }
783
784 #[tokio::test]
787 async fn recv_text() {
788 let frame = make_frame(true, 0x1, b"Hello");
789 let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
790 match reader.recv(&mut conn).await.unwrap().unwrap() {
791 Message::Text(s) => assert_eq!(s, "Hello"),
792 other => panic!("expected Text, got {other:?}"),
793 }
794 }
795
796 #[tokio::test]
797 async fn recv_binary() {
798 let frame = make_frame(true, 0x2, &[0x42; 100]);
799 let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
800 match reader.recv(&mut conn).await.unwrap().unwrap() {
801 Message::Binary(b) => assert_eq!(b.len(), 100),
802 other => panic!("expected Binary, got {other:?}"),
803 }
804 }
805
806 #[tokio::test]
807 async fn recv_ping() {
808 let frame = make_frame(true, 0x9, b"ping");
809 let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
810 match reader.recv(&mut conn).await.unwrap().unwrap() {
811 Message::Ping(p) => assert_eq!(p, b"ping"),
812 other => panic!("expected Ping, got {other:?}"),
813 }
814 }
815
816 #[tokio::test]
817 async fn recv_fragmented_text() {
818 let mut data = make_frame(false, 0x1, b"Hel");
819 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
820 let (mut reader, _writer, mut conn) = parts_from_bytes(data);
821 match reader.recv(&mut conn).await.unwrap().unwrap() {
822 Message::Text(s) => assert_eq!(s, "Hello"),
823 other => panic!("expected Text, got {other:?}"),
824 }
825 }
826
827 #[tokio::test]
828 async fn recv_fragment_with_control() {
829 let mut data = make_frame(false, 0x1, b"Hel");
830 data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
831 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
832 let (mut reader, _writer, mut conn) = parts_from_bytes(data);
833 match reader.recv(&mut conn).await.unwrap().unwrap() {
834 Message::Ping(p) => assert_eq!(p, b"ping"),
835 other => panic!("expected Ping, got {other:?}"),
836 }
837 match reader.recv(&mut conn).await.unwrap().unwrap() {
838 Message::Text(s) => assert_eq!(s, "Hello"),
839 other => panic!("expected Text, got {other:?}"),
840 }
841 }
842
843 #[tokio::test]
844 async fn recv_close() {
845 let mut payload = vec![];
846 payload.extend_from_slice(&1000u16.to_be_bytes());
847 payload.extend_from_slice(b"bye");
848 let frame = make_frame(true, 0x8, &payload);
849 let (mut reader, _writer, mut conn) = parts_from_bytes(frame);
850 match reader.recv(&mut conn).await.unwrap().unwrap() {
851 Message::Close(cf) => {
852 assert_eq!(cf.code, CloseCode::Normal);
853 assert_eq!(cf.reason, "bye");
854 }
855 other => panic!("expected Close, got {other:?}"),
856 }
857 }
858
859 #[tokio::test]
860 async fn eof_returns_none() {
861 let (mut reader, _writer, mut conn) = parts_from_bytes(Vec::new());
862 assert!(reader.recv(&mut conn).await.unwrap().is_none());
863 }
864
865 #[tokio::test]
866 async fn fifo_three_messages() {
867 let mut data = make_frame(true, 0x1, b"first");
868 data.extend_from_slice(&make_frame(true, 0x1, b"second"));
869 data.extend_from_slice(&make_frame(true, 0x1, b"third"));
870 let (mut reader, _writer, mut conn) = parts_from_bytes(data);
871
872 match reader.recv(&mut conn).await.unwrap().unwrap() {
873 Message::Text(s) => assert_eq!(s, "first"),
874 other => panic!("expected first, got {other:?}"),
875 }
876 match reader.recv(&mut conn).await.unwrap().unwrap() {
877 Message::Text(s) => assert_eq!(s, "second"),
878 other => panic!("expected second, got {other:?}"),
879 }
880 match reader.recv(&mut conn).await.unwrap().unwrap() {
881 Message::Text(s) => assert_eq!(s, "third"),
882 other => panic!("expected third, got {other:?}"),
883 }
884 }
885
886 #[tokio::test]
887 async fn ping_echo_split_borrow() {
888 let mut data = make_frame(true, 0x9, b"ping-data");
889 data.extend_from_slice(&make_frame(true, 0x1, b"hello"));
890 let (mut reader, mut writer, mut conn) = parts_from_bytes(data);
891
892 match reader.recv(&mut conn).await.unwrap().unwrap() {
893 Message::Ping(payload) => {
894 writer.send_pong(&mut conn, payload).await.unwrap();
895 }
896 other => panic!("expected Ping, got {other:?}"),
897 }
898
899 match reader.recv(&mut conn).await.unwrap().unwrap() {
900 Message::Text(s) => assert_eq!(s, "hello"),
901 other => panic!("expected Text, got {other:?}"),
902 }
903 }
904
905 #[tokio::test]
906 async fn text_response_while_holding_message() {
907 let data = make_frame(true, 0x1, b"request");
908 let (mut reader, mut writer, mut conn) = parts_from_bytes(data);
909
910 match reader.recv(&mut conn).await.unwrap().unwrap() {
911 Message::Text(req) => {
912 assert_eq!(req, "request");
913 let response = format!("echo: {req}");
914 writer.send_text(&mut conn, &response).await.unwrap();
915 }
916 other => panic!("expected Text, got {other:?}"),
917 }
918 }
919
920 #[tokio::test]
923 async fn stream_yields_owned_messages() {
924 use std::pin::pin;
925
926 let mut data = make_frame(true, 0x1, b"hello");
927 data.extend_from_slice(&make_frame(true, 0x2, &[0x42]));
928 let (reader, writer, conn) = parts_from_bytes(data);
929 let ws = WsStream::from_parts(reader, writer, conn);
930 let mut ws = pin!(ws);
931
932 let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
933 match poll_result {
934 Poll::Ready(Some(Ok(OwnedMessage::Text(s)))) => assert_eq!(s, "hello"),
935 other => panic!("expected Text, got {other:?}"),
936 }
937 let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
938 match poll_result {
939 Poll::Ready(Some(Ok(OwnedMessage::Binary(b)))) => assert_eq!(b, vec![0x42]),
940 other => panic!("expected Binary, got {other:?}"),
941 }
942 let poll_result = futures_core::Stream::poll_next(ws.as_mut(), &mut noop_cx());
943 assert!(matches!(poll_result, Poll::Ready(None)));
944 }
945
946 fn noop_cx() -> Context<'static> {
947 use std::task::{RawWaker, RawWakerVTable, Waker};
948 const VTABLE: RawWakerVTable =
949 RawWakerVTable::new(|p| RawWaker::new(p, &VTABLE), |_| {}, |_| {}, |_| {});
950 let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
954 let waker = Box::leak(Box::new(waker));
955 Context::from_waker(waker)
956 }
957
958 #[tokio::test]
959 async fn accept_server_side() {
960 use tokio::net::TcpListener;
961
962 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
963 let addr = listener.local_addr().unwrap();
964
965 let server = tokio::spawn(async move {
966 let (tcp, _) = listener.accept().await.unwrap();
967 let (mut reader, mut writer, mut conn) = WsStreamBuilder::new()
968 .accept(AsyncReadAdapter::new(tcp))
969 .await
970 .unwrap();
971 match reader.recv(&mut conn).await.unwrap().unwrap() {
972 Message::Text(s) => assert_eq!(s, "hello from client"),
973 other => panic!("expected Text, got {other:?}"),
974 }
975 writer
976 .send_text(&mut conn, "hello from server")
977 .await
978 .unwrap();
979 });
980
981 let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
982 let url = format!("ws://127.0.0.1:{}/ws", addr.port());
983 let (mut reader, mut writer, mut conn) = WsStreamBuilder::new()
984 .connect_with(AsyncReadAdapter::new(tcp), &url)
985 .await
986 .unwrap();
987
988 writer
989 .send_text(&mut conn, "hello from client")
990 .await
991 .unwrap();
992
993 match reader.recv(&mut conn).await.unwrap().unwrap() {
994 Message::Text(s) => assert_eq!(s, "hello from server"),
995 other => panic!("expected Text, got {other:?}"),
996 }
997
998 server.await.unwrap();
999 }
1000
1001 struct BrokenWriteStream(Cursor<Vec<u8>>);
1002
1003 impl AsyncRead for BrokenWriteStream {
1004 fn poll_read(
1005 mut self: Pin<&mut Self>,
1006 _cx: &mut Context<'_>,
1007 buf: &mut ReadBuf<'_>,
1008 ) -> Poll<std::io::Result<()>> {
1009 let n = std::io::Read::read(&mut self.0, buf.initialize_unfilled())?;
1010 buf.advance(n);
1011 Poll::Ready(Ok(()))
1012 }
1013 }
1014
1015 impl AsyncWrite for BrokenWriteStream {
1016 fn poll_write(
1017 self: Pin<&mut Self>,
1018 _cx: &mut Context<'_>,
1019 _buf: &[u8],
1020 ) -> Poll<std::io::Result<usize>> {
1021 Poll::Ready(Err(std::io::Error::new(
1022 std::io::ErrorKind::BrokenPipe,
1023 "connection lost",
1024 )))
1025 }
1026 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1027 Poll::Ready(Ok(()))
1028 }
1029 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1030 Poll::Ready(Ok(()))
1031 }
1032 }
1033
1034 #[tokio::test]
1035 async fn send_on_broken_stream_returns_error() {
1036 let mock = AsyncReadAdapter::new(BrokenWriteStream(Cursor::new(Vec::new())));
1037 let reader = FrameReader::builder().role(Role::Client).build();
1038 let writer = FrameWriter::new(Role::Client);
1039 let (_, mut ws_writer, mut conn) =
1040 WsStream::from_raw_parts(mock, reader, writer).into_parts();
1041
1042 let result = ws_writer.send_text(&mut conn, "hello").await;
1043 assert!(result.is_err(), "send on broken stream should return error");
1044
1045 let result = ws_writer.send_binary(&mut conn, &[1, 2, 3]).await;
1046 assert!(result.is_err(), "subsequent send should also fail");
1047 }
1048
1049 #[tokio::test]
1050 async fn from_parts_roundtrip() {
1051 let data = make_frame(true, 0x1, b"test");
1052 let (reader, writer, conn) = parts_from_bytes(data);
1053 let ws = WsStream::from_parts(reader, writer, conn);
1054 let (mut reader, _writer, mut conn) = ws.into_parts();
1055
1056 match reader.recv(&mut conn).await.unwrap().unwrap() {
1057 Message::Text(s) => assert_eq!(s, "test"),
1058 other => panic!("expected Text, got {other:?}"),
1059 }
1060 }
1061}