1use std::time::Duration;
4
5use super::error::ProtocolError;
6use super::frame::Role;
7use super::frame_reader::{FrameReader, FrameReaderBuilder};
8use super::frame_writer::FrameWriter;
9use super::message::{CloseCode, Message};
10use crate::buf::WriteBuf;
11
12use super::handshake;
13use super::handshake::HandshakeError;
14use std::io::{self, Read, Write};
15
16#[cfg(feature = "tls")]
17use crate::tls::{TlsConfig, TlsError};
18
19#[non_exhaustive]
25pub struct ParsedUrl<'a> {
26 pub tls: bool,
28 pub host: &'a str,
30 pub port: u16,
33 pub path: &'a str,
36}
37
38impl ParsedUrl<'_> {
39 pub fn host_header(&self) -> String {
41 let default = if self.tls { 443 } else { 80 };
42 if self.port == default {
43 self.host.to_string()
44 } else {
45 format!("{}:{}", self.host, self.port)
46 }
47 }
48}
49
50pub fn parse_ws_url(url: &str) -> Result<ParsedUrl<'_>, Error> {
54 let (tls, rest) = if let Some(r) = url.strip_prefix("wss://") {
55 (true, r)
56 } else if let Some(r) = url.strip_prefix("ws://") {
57 (false, r)
58 } else {
59 return Err(Error::InvalidUrl(url.to_string()));
60 };
61
62 let (host_port, path) = rest
63 .find('/')
64 .map_or((rest, "/"), |i| (&rest[..i], &rest[i..]));
65
66 if host_port.is_empty() {
67 return Err(Error::InvalidUrl(format!("empty host: {url}")));
68 }
69
70 let default_port = if tls { 443 } else { 80 };
71
72 let (host, port) = if host_port.starts_with('[') {
74 match host_port.find(']') {
75 Some(end) => {
76 let h = &host_port[1..end];
77 let rest = &host_port[end + 1..];
78 if let Some(port_str) = rest.strip_prefix(':') {
79 let p = port_str
80 .parse::<u16>()
81 .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
82 (h, p)
83 } else {
84 (h, default_port)
85 }
86 }
87 None => return Err(Error::InvalidUrl(format!("unclosed bracket: {url}"))),
88 }
89 } else {
90 match host_port.rfind(':') {
91 None => (host_port, default_port),
92 Some(i) => {
93 let port_str = &host_port[i + 1..];
94 if port_str.is_empty() {
95 (&host_port[..i], default_port)
96 } else {
97 let p = port_str
98 .parse::<u16>()
99 .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
100 (&host_port[..i], p)
101 }
102 }
103 }
104 };
105
106 Ok(ParsedUrl {
107 tls,
108 host,
109 port,
110 path,
111 })
112}
113
114#[derive(Debug)]
120pub enum Error {
121 Io(std::io::Error),
123 Protocol(ProtocolError),
125 Encode(super::frame_writer::EncodeError),
127 Handshake(HandshakeError),
129 #[cfg(feature = "tls")]
145 Tls(TlsError),
146 InvalidUrl(String),
148 TlsNotEnabled,
150}
151
152impl std::fmt::Display for Error {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 match self {
155 Self::Io(e) => write!(f, "I/O error: {e}"),
156 Self::Protocol(e) => write!(f, "protocol error: {e}"),
157 Self::Encode(e) => write!(f, "encode error: {e}"),
158 Self::Handshake(e) => write!(f, "handshake error: {e}"),
159 #[cfg(feature = "tls")]
160 Self::Tls(e) => write!(f, "TLS error: {e}"),
161 Self::InvalidUrl(u) => write!(f, "invalid WebSocket URL: {u}"),
162 Self::TlsNotEnabled => write!(f, "wss:// requires the 'tls' feature"),
163 }
164 }
165}
166
167impl std::error::Error for Error {}
168
169impl From<std::io::Error> for Error {
170 fn from(e: std::io::Error) -> Self {
171 Self::Io(e)
172 }
173}
174impl From<ProtocolError> for Error {
175 fn from(e: ProtocolError) -> Self {
176 Self::Protocol(e)
177 }
178}
179impl From<super::frame_writer::EncodeError> for Error {
180 fn from(e: super::frame_writer::EncodeError) -> Self {
181 Self::Encode(e)
182 }
183}
184impl From<HandshakeError> for Error {
185 fn from(e: HandshakeError) -> Self {
186 Self::Handshake(e)
187 }
188}
189#[cfg(feature = "tls")]
190impl From<TlsError> for Error {
191 fn from(e: TlsError) -> Self {
192 match e {
193 TlsError::Io(io) => Self::Io(io),
194 other => Self::Tls(other),
195 }
196 }
197}
198
199pub struct ClientBuilder {
216 pub(crate) reader_builder: FrameReaderBuilder,
217 pub(crate) write_buf_capacity: usize,
218 pub(crate) write_buf_headroom: usize,
219 #[cfg(feature = "tls")]
220 pub(crate) tls_config: Option<TlsConfig>,
221 pub(crate) tcp_nodelay: bool,
222 #[cfg(feature = "socket-opts")]
223 pub(crate) recv_buf_size: Option<usize>,
224 #[cfg(feature = "socket-opts")]
225 pub(crate) send_buf_size: Option<usize>,
226 pub(crate) connect_timeout: Option<Duration>,
227 pub(crate) read_timeout: Option<Duration>,
228}
229
230impl ClientBuilder {
231 #[must_use]
233 pub fn new() -> Self {
234 Self {
235 reader_builder: FrameReader::builder(),
236 write_buf_capacity: 65_536,
237 write_buf_headroom: 14,
238 #[cfg(feature = "tls")]
239 tls_config: None,
240 tcp_nodelay: false,
241 #[cfg(feature = "socket-opts")]
242 recv_buf_size: None,
243 #[cfg(feature = "socket-opts")]
244 send_buf_size: None,
245 connect_timeout: None,
246 read_timeout: None,
247 }
248 }
249
250 #[must_use]
252 pub fn buffer_capacity(mut self, n: usize) -> Self {
253 self.reader_builder = self.reader_builder.buffer_capacity(n);
254 self
255 }
256
257 #[must_use]
259 pub fn max_frame_size(mut self, n: u64) -> Self {
260 self.reader_builder = self.reader_builder.max_frame_size(n);
261 self
262 }
263
264 #[must_use]
266 pub fn max_message_size(mut self, n: usize) -> Self {
267 self.reader_builder = self.reader_builder.max_message_size(n);
268 self
269 }
270
271 #[must_use]
273 pub fn write_buffer_capacity(mut self, n: usize) -> Self {
274 self.write_buf_capacity = n;
275 self
276 }
277
278 #[must_use]
280 pub fn disable_nagle(mut self) -> Self {
281 self.tcp_nodelay = true;
282 self
283 }
284
285 #[cfg(feature = "socket-opts")]
287 #[must_use]
288 pub fn recv_buffer_size(mut self, n: usize) -> Self {
289 self.recv_buf_size = Some(n);
290 self
291 }
292
293 #[cfg(feature = "socket-opts")]
295 #[must_use]
296 pub fn send_buffer_size(mut self, n: usize) -> Self {
297 self.send_buf_size = Some(n);
298 self
299 }
300
301 #[must_use]
303 pub fn connect_timeout(mut self, d: Duration) -> Self {
304 self.connect_timeout = Some(d);
305 self
306 }
307
308 #[must_use]
310 pub fn read_timeout(mut self, d: Duration) -> Self {
311 self.read_timeout = Some(d);
312 self
313 }
314
315 #[cfg(feature = "tls")]
319 #[must_use]
320 pub fn tls(mut self, config: &TlsConfig) -> Self {
321 self.tls_config = Some(config.clone());
322 self
323 }
324
325 #[cfg(feature = "tls")]
335 pub fn connect(self, url: &str) -> Result<Client<crate::MaybeTls<std::net::TcpStream>>, Error> {
336 let parsed = parse_ws_url(url)?;
337 let addr = format!("{}:{}", parsed.host, parsed.port);
338
339 let tcp = match self.connect_timeout {
340 Some(timeout) => {
341 let addrs: Vec<std::net::SocketAddr> =
342 std::net::ToSocketAddrs::to_socket_addrs(&addr)
343 .map_err(Error::Io)?
344 .collect();
345 let first = addrs
346 .first()
347 .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
348 std::net::TcpStream::connect_timeout(first, timeout)?
349 }
350 None => std::net::TcpStream::connect(&addr)?,
351 };
352
353 self.apply_socket_opts(&tcp)?;
354
355 let stream = if parsed.tls {
356 let config = match self.tls_config {
357 Some(c) => c,
358 None => TlsConfig::new().map_err(Error::Tls)?,
359 };
360 let codec = crate::tls::TlsCodec::new(&config, parsed.host)?;
361 let tls = crate::tls::TlsStream::connect(tcp, codec).map_err(Error::Tls)?;
362 crate::MaybeTls::Tls(Box::new(tls))
363 } else {
364 crate::MaybeTls::Plain(tcp)
365 };
366
367 let host_header = parsed.host_header();
368 Client::connect_impl(
369 stream,
370 &host_header,
371 parsed.path,
372 self.reader_builder,
373 self.write_buf_capacity,
374 self.write_buf_headroom,
375 )
376 }
377
378 #[cfg(not(feature = "tls"))]
380 pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, Error> {
381 let parsed = parse_ws_url(url)?;
382 if parsed.tls {
383 return Err(Error::TlsNotEnabled);
384 }
385 let addr = format!("{}:{}", parsed.host, parsed.port);
386
387 let tcp = match self.connect_timeout {
388 Some(timeout) => {
389 let addrs: Vec<std::net::SocketAddr> =
390 std::net::ToSocketAddrs::to_socket_addrs(&addr)
391 .map_err(Error::Io)?
392 .collect();
393 let first = addrs
394 .first()
395 .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
396 std::net::TcpStream::connect_timeout(first, timeout)?
397 }
398 None => std::net::TcpStream::connect(&addr)?,
399 };
400
401 self.apply_socket_opts(&tcp)?;
402
403 let host_header = parsed.host_header();
404 Client::connect_impl(
405 tcp,
406 &host_header,
407 parsed.path,
408 self.reader_builder,
409 self.write_buf_capacity,
410 self.write_buf_headroom,
411 )
412 }
413
414 pub fn connect_with<S: Read + Write>(self, stream: S, url: &str) -> Result<Client<S>, Error> {
420 let parsed = parse_ws_url(url)?;
421 let host_header = parsed.host_header();
422 Client::connect_impl(
423 stream,
424 &host_header,
425 parsed.path,
426 self.reader_builder,
427 self.write_buf_capacity,
428 self.write_buf_headroom,
429 )
430 }
431
432 pub fn accept<S: Read + Write>(self, stream: S) -> Result<Client<S>, Error> {
434 Client::accept_impl(
435 stream,
436 self.reader_builder,
437 self.write_buf_capacity,
438 self.write_buf_headroom,
439 )
440 }
441
442 fn apply_socket_opts(&self, tcp: &std::net::TcpStream) -> Result<(), Error> {
443 if self.tcp_nodelay {
444 tcp.set_nodelay(true)?;
445 }
446 if let Some(timeout) = self.read_timeout {
447 tcp.set_read_timeout(Some(timeout))?;
448 }
449 #[cfg(feature = "socket-opts")]
450 {
451 let sock = socket2::SockRef::from(tcp);
452 if let Some(size) = self.recv_buf_size {
453 sock.set_recv_buffer_size(size)?;
454 }
455 if let Some(size) = self.send_buf_size {
456 sock.set_send_buffer_size(size)?;
457 }
458 }
459 Ok(())
460 }
461}
462
463impl Default for ClientBuilder {
464 fn default() -> Self {
465 Self::new()
466 }
467}
468
469pub struct Client<S> {
498 pub(crate) stream: S,
499 pub(crate) reader: FrameReader,
500 pub(crate) writer: FrameWriter,
501 pub(crate) write_buf: WriteBuf,
502 pub(crate) poisoned: bool,
503}
504
505impl Client<std::net::TcpStream> {
506 #[must_use]
508 pub fn builder() -> ClientBuilder {
509 ClientBuilder::new()
510 }
511}
512
513impl<S> Client<S> {
516 pub fn from_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
518 Self {
519 stream,
520 reader,
521 writer,
522 write_buf: WriteBuf::new(65_536, 14),
523 poisoned: false,
524 }
525 }
526
527 pub(crate) fn from_parts_internal(
529 stream: S,
530 reader: FrameReader,
531 writer: FrameWriter,
532 write_buf: WriteBuf,
533 ) -> Self {
534 Self {
535 stream,
536 reader,
537 writer,
538 write_buf,
539 poisoned: false,
540 }
541 }
542
543 pub fn is_poisoned(&self) -> bool {
548 self.poisoned
549 }
550
551 pub fn stream(&self) -> &S {
553 &self.stream
554 }
555
556 pub fn stream_mut(&mut self) -> &mut S {
558 &mut self.stream
559 }
560
561 pub fn reader(&self) -> &FrameReader {
563 &self.reader
564 }
565
566 pub fn frame_writer(&self) -> &FrameWriter {
568 &self.writer
569 }
570}
571
572impl<S: Read + Write> Client<S> {
575 pub fn connect_with(stream: S, url: &str) -> Result<Self, Error> {
579 ClientBuilder::new().connect_with(stream, url)
580 }
581
582 pub fn accept(stream: S) -> Result<Self, Error> {
584 ClientBuilder::new().accept(stream)
585 }
586
587 pub fn recv(&mut self) -> Result<Option<Message<'_>>, Error> {
591 loop {
592 if self.reader.poll()? {
593 return Ok(self.reader.next()?);
594 }
595 match self.read_into_reader() {
596 Ok(0) => return Ok(None),
597 Ok(_) => {}
598 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
599 Err(e) => return Err(Error::Io(e)),
600 }
601 }
602 }
603
604 pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
606 self.writer
607 .encode_text_into(text.as_bytes(), &mut self.write_buf);
608 self.flush_write_buf_or_poison()
609 }
610
611 pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
613 self.writer.encode_binary_into(data, &mut self.write_buf);
614 self.flush_write_buf_or_poison()
615 }
616
617 pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
619 self.writer
620 .encode_ping_into(data, &mut self.write_buf)
621 .map_err(Error::Encode)?;
622 self.flush_write_buf_or_poison()
623 }
624
625 pub fn send_pong(&mut self, data: &[u8]) -> Result<(), Error> {
627 self.writer
628 .encode_pong_into(data, &mut self.write_buf)
629 .map_err(Error::Encode)?;
630 self.flush_write_buf_or_poison()
631 }
632
633 pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
635 if code == CloseCode::NoStatus {
636 let mut dst = [0u8; 14];
637 let n = self.writer.encode_empty_close(&mut dst);
638 self.write_raw(&dst[..n]).inspect_err(|_| {
639 self.poisoned = true;
640 })
641 } else {
642 self.writer
643 .encode_close_into(code.as_u16(), reason.as_bytes(), &mut self.write_buf)
644 .map_err(Error::Encode)?;
645 self.flush_write_buf_or_poison()
646 }
647 }
648
649 fn read_into_reader(&mut self) -> io::Result<usize> {
658 self.reader.read_from(&mut self.stream)
659 }
660
661 fn flush_write_buf_or_poison(&mut self) -> Result<(), Error> {
663 self.flush_write_buf().inspect_err(|_| {
664 self.poisoned = true;
665 })
666 }
667
668 fn flush_write_buf(&mut self) -> Result<(), Error> {
670 self.stream.write_all(self.write_buf.data())?;
671 Ok(())
672 }
673
674 fn write_raw(&mut self, data: &[u8]) -> Result<(), Error> {
676 self.stream.write_all(data)?;
677 Ok(())
678 }
679
680 pub(crate) fn connect_impl(
687 mut stream: S,
688 host: &str,
689 path: &str,
690 reader_builder: FrameReaderBuilder,
691 write_cap: usize,
692 write_headroom: usize,
693 ) -> Result<Self, Error> {
694 let key = handshake::generate_key();
695 let key_str = std::str::from_utf8(&key).expect("base64 output is valid ASCII");
696
697 let headers = [
698 ("Host", host),
699 ("Upgrade", "websocket"),
700 ("Connection", "Upgrade"),
701 ("Sec-WebSocket-Key", key_str),
702 ("Sec-WebSocket-Version", "13"),
703 ];
704 let req_size = crate::http::request_size("GET", path, &headers);
705 let mut req_buf = vec![0u8; req_size];
706 let n = crate::http::write_request("GET", path, &headers, &mut req_buf)
707 .map_err(|_| HandshakeError::MalformedHttp)?;
708
709 stream.write_all(&req_buf[..n])?;
710
711 let mut resp_reader = crate::http::ResponseReader::new(4096);
712 let mut tmp = [0u8; 4096];
713 loop {
714 let bytes_read = stream.read(&mut tmp)?;
715 if bytes_read == 0 {
716 return Err(HandshakeError::MalformedHttp.into());
717 }
718
719 resp_reader
720 .read(&tmp[..bytes_read])
721 .map_err(|_| HandshakeError::MalformedHttp)?;
722 match resp_reader.next() {
723 Ok(Some(resp)) => {
724 if resp.status != 101 {
725 return Err(HandshakeError::UnexpectedStatus(resp.status).into());
726 }
727 let upgrade = resp
728 .header("Upgrade")
729 .ok_or(HandshakeError::MissingUpgrade)?;
730 if !upgrade.eq_ignore_ascii_case("websocket") {
731 return Err(HandshakeError::MissingUpgrade.into());
732 }
733 let conn = resp
734 .header("Connection")
735 .ok_or(HandshakeError::MissingConnection)?;
736 if !contains_ignore_case(conn, "upgrade") {
737 return Err(HandshakeError::MissingConnection.into());
738 }
739 let accept = resp
740 .header("Sec-WebSocket-Accept")
741 .ok_or(HandshakeError::InvalidAcceptKey)?;
742 if !handshake::validate_accept(key_str, accept) {
743 return Err(HandshakeError::InvalidAcceptKey.into());
744 }
745
746 let mut reader = reader_builder.role(Role::Client).build();
747 let remainder = resp_reader.remainder();
748 if !remainder.is_empty() {
749 reader
750 .read(remainder)
751 .map_err(|_| HandshakeError::MalformedHttp)?;
752 }
753
754 return Ok(Self {
755 stream,
756 reader,
757 writer: FrameWriter::new(Role::Client),
758 write_buf: WriteBuf::new(write_cap, write_headroom),
759 poisoned: false,
760 });
761 }
762 Ok(None) => {} Err(_) => return Err(HandshakeError::MalformedHttp.into()),
764 }
765 }
766 }
767
768 fn accept_impl(
769 mut stream: S,
770 reader_builder: FrameReaderBuilder,
771 write_cap: usize,
772 write_headroom: usize,
773 ) -> Result<Self, Error> {
774 let mut req_reader = crate::http::RequestReader::new(4096);
775 let mut tmp = [0u8; 4096];
776
777 let ws_key;
778 loop {
779 let n = stream.read(&mut tmp)?;
780 if n == 0 {
781 return Err(HandshakeError::MalformedHttp.into());
782 }
783 req_reader
784 .read(&tmp[..n])
785 .map_err(|_| HandshakeError::MalformedHttp)?;
786 match req_reader.next() {
787 Ok(Some(req)) => {
788 if req.method != "GET" {
789 return Err(HandshakeError::MalformedHttp.into());
790 }
791 let upgrade = req
792 .header("Upgrade")
793 .ok_or(HandshakeError::MissingUpgrade)?;
794 if !upgrade.eq_ignore_ascii_case("websocket") {
795 return Err(HandshakeError::MissingUpgrade.into());
796 }
797 let conn = req
798 .header("Connection")
799 .ok_or(HandshakeError::MissingConnection)?;
800 if !contains_ignore_case(conn, "upgrade") {
801 return Err(HandshakeError::MissingConnection.into());
802 }
803 let version = req
804 .header("Sec-WebSocket-Version")
805 .ok_or(HandshakeError::UnsupportedVersion)?;
806 if version != "13" {
807 return Err(HandshakeError::UnsupportedVersion.into());
808 }
809 let key = req
810 .header("Sec-WebSocket-Key")
811 .ok_or(HandshakeError::MissingKey)?;
812 ws_key = key.to_owned();
813 break;
814 }
815 Ok(None) => {}
816 Err(_) => return Err(HandshakeError::MalformedHttp.into()),
817 }
818 }
819
820 let accept = handshake::compute_accept_key(&ws_key);
821 let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
822
823 let resp_headers = [
824 ("Upgrade", "websocket"),
825 ("Connection", "Upgrade"),
826 ("Sec-WebSocket-Accept", accept_str),
827 ];
828 let resp_size = crate::http::response_size("Switching Protocols", &resp_headers);
829 let mut resp_buf = vec![0u8; resp_size];
830 let n =
831 crate::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
832 .map_err(|_| HandshakeError::MalformedHttp)?;
833 stream.write_all(&resp_buf[..n])?;
834
835 let mut reader = reader_builder.role(Role::Server).build();
836 let remainder = req_reader.remainder();
837 if !remainder.is_empty() {
838 reader
839 .read(remainder)
840 .map_err(|_| HandshakeError::MalformedHttp)?;
841 }
842
843 Ok(Self {
844 stream,
845 reader,
846 writer: FrameWriter::new(Role::Server),
847 write_buf: WriteBuf::new(write_cap, write_headroom),
848 poisoned: false,
849 })
850 }
851}
852
853pub fn pair(role: Role) -> (FrameReader, FrameWriter) {
857 (
858 FrameReader::builder().role(role).build(),
859 FrameWriter::new(role),
860 )
861}
862
863pub fn pair_with(role: Role, reader_builder: FrameReaderBuilder) -> (FrameReader, FrameWriter) {
865 (reader_builder.role(role).build(), FrameWriter::new(role))
866}
867
868fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
869 haystack
870 .as_bytes()
871 .windows(needle.len())
872 .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878
879 #[test]
884 fn parse_ws_url_plain() {
885 let p = parse_ws_url("ws://localhost:8080/ws").unwrap();
886 assert!(!p.tls);
887 assert_eq!(p.host, "localhost");
888 assert_eq!(p.port, 8080);
889 assert_eq!(p.path, "/ws");
890 }
891
892 #[test]
893 fn parse_ws_url_tls() {
894 let p = parse_ws_url("wss://exchange.com/ws/v1").unwrap();
895 assert!(p.tls);
896 assert_eq!(p.host, "exchange.com");
897 assert_eq!(p.port, 443);
898 assert_eq!(p.path, "/ws/v1");
899 }
900
901 #[test]
902 fn parse_ws_url_default_port() {
903 let p = parse_ws_url("ws://host/path").unwrap();
904 assert_eq!(p.port, 80);
905
906 let p = parse_ws_url("wss://host/path").unwrap();
907 assert_eq!(p.port, 443);
908 }
909
910 #[test]
911 fn parse_ws_url_no_path() {
912 let p = parse_ws_url("ws://host").unwrap();
913 assert_eq!(p.path, "/");
914 }
915
916 #[test]
917 fn parse_ws_url_invalid_scheme() {
918 assert!(parse_ws_url("http://host").is_err());
919 assert!(parse_ws_url("host/path").is_err());
920 }
921
922 mod sync_tests {
927 use super::*;
928 use std::io::{self, Read, Write};
929
930 #[test]
931 fn pair_creates_matching_roles() {
932 let (mut reader, _writer) = pair(Role::Client);
933 let frame = make_frame(true, 0x1, b"test");
934 reader.read(&frame).unwrap();
935 let msg = reader.next().unwrap().unwrap();
936 assert!(matches!(msg, Message::Text(s) if s == "test"));
937 }
938
939 struct ByteAtATimeStream {
940 data: Vec<u8>,
941 pos: usize,
942 }
943
944 impl ByteAtATimeStream {
945 fn new(data: Vec<u8>) -> Self {
946 Self { data, pos: 0 }
947 }
948 }
949
950 impl Read for ByteAtATimeStream {
951 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
952 if self.pos >= self.data.len() {
953 return Ok(0);
954 }
955 buf[0] = self.data[self.pos];
956 self.pos += 1;
957 Ok(1)
958 }
959 }
960
961 impl Write for ByteAtATimeStream {
962 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
963 Ok(buf.len())
964 }
965 fn flush(&mut self) -> io::Result<()> {
966 Ok(())
967 }
968 }
969
970 fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
971 let mut frame = Vec::new();
972 let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
973 frame.push(byte0);
974 if payload.len() <= 125 {
975 frame.push(payload.len() as u8);
976 } else if payload.len() <= 65535 {
977 frame.push(126);
978 frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
979 } else {
980 frame.push(127);
981 frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
982 }
983 frame.extend_from_slice(payload);
984 frame
985 }
986
987 fn ws_from_bytes(data: Vec<u8>) -> Client<ByteAtATimeStream> {
988 let mock = ByteAtATimeStream::new(data);
989 let reader = FrameReader::builder().role(Role::Client).build();
990 let writer = FrameWriter::new(Role::Client);
991 Client::from_parts(mock, reader, writer)
992 }
993
994 #[test]
995 fn recv_text() {
996 let frame = make_frame(true, 0x1, b"Hello");
997 let mut ws = ws_from_bytes(frame);
998 match ws.recv().unwrap().unwrap() {
999 Message::Text(s) => assert_eq!(s, "Hello"),
1000 other => panic!("expected Text, got {other:?}"),
1001 }
1002 }
1003
1004 #[test]
1005 fn recv_ping() {
1006 let frame = make_frame(true, 0x9, &[0x42; 125]);
1007 let mut ws = ws_from_bytes(frame);
1008 match ws.recv().unwrap().unwrap() {
1009 Message::Ping(p) => assert_eq!(p.len(), 125),
1010 other => panic!("expected Ping, got {other:?}"),
1011 }
1012 }
1013
1014 #[test]
1015 fn recv_fragmented_text() {
1016 let mut data = make_frame(false, 0x1, b"Hel");
1017 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1018 let mut ws = ws_from_bytes(data);
1019 match ws.recv().unwrap().unwrap() {
1020 Message::Text(s) => assert_eq!(s, "Hello"),
1021 other => panic!("expected Text, got {other:?}"),
1022 }
1023 }
1024
1025 #[test]
1026 fn recv_fragment_with_ping() {
1027 let mut data = make_frame(false, 0x1, b"Hel");
1028 data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
1029 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1030 let mut ws = ws_from_bytes(data);
1031 match ws.recv().unwrap().unwrap() {
1032 Message::Ping(p) => assert_eq!(p, b"ping"),
1033 other => panic!("expected Ping, got {other:?}"),
1034 }
1035 match ws.recv().unwrap().unwrap() {
1036 Message::Text(s) => assert_eq!(s, "Hello"),
1037 other => panic!("expected Text, got {other:?}"),
1038 }
1039 }
1040
1041 #[test]
1042 fn recv_close() {
1043 let mut payload = vec![];
1044 payload.extend_from_slice(&1000u16.to_be_bytes());
1045 payload.extend_from_slice(b"bye");
1046 let frame = make_frame(true, 0x8, &payload);
1047 let mut ws = ws_from_bytes(frame);
1048 match ws.recv().unwrap().unwrap() {
1049 Message::Close(cf) => {
1050 assert_eq!(cf.code, CloseCode::Normal);
1051 assert_eq!(cf.reason, "bye");
1052 }
1053 other => panic!("expected Close, got {other:?}"),
1054 }
1055 }
1056
1057 #[test]
1058 fn eof_returns_none() {
1059 let mut ws = ws_from_bytes(Vec::new());
1060 assert!(ws.recv().unwrap().is_none());
1061 }
1062
1063 #[test]
1064 fn would_block_returns_none() {
1065 struct WouldBlockStream;
1066 impl Read for WouldBlockStream {
1067 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
1068 Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
1069 }
1070 }
1071 impl Write for WouldBlockStream {
1072 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1073 Ok(buf.len())
1074 }
1075 fn flush(&mut self) -> io::Result<()> {
1076 Ok(())
1077 }
1078 }
1079
1080 let reader = FrameReader::builder().role(Role::Client).build();
1081 let writer = FrameWriter::new(Role::Client);
1082 let mut ws = Client::from_parts(WouldBlockStream, reader, writer);
1083 assert!(ws.recv().unwrap().is_none());
1084 }
1085 }
1086}