1use std::time::Duration;
4
5use super::error::ProtocolError;
6use super::frame::Role;
7use super::frame_reader::{FrameReader, FrameReaderBuilder};
8use super::frame_writer::FrameWriter;
9use super::message::{CloseCode, Message};
10use nexus_net::buf::WriteBuf;
11
12use super::handshake;
13use super::handshake::HandshakeError;
14use std::io::{self, Read, Write};
15
16#[cfg(feature = "tls")]
17use nexus_net::tls::{TlsConfig, TlsError};
18
19#[non_exhaustive]
25pub struct ParsedUrl<'a> {
26 pub tls: bool,
28 pub host: &'a str,
30 pub port: u16,
33 pub path: &'a str,
36}
37
38impl ParsedUrl<'_> {
39 pub fn host_header(&self) -> String {
41 let default = if self.tls { 443 } else { 80 };
42 if self.port == default {
43 self.host.to_string()
44 } else {
45 format!("{}:{}", self.host, self.port)
46 }
47 }
48}
49
50pub fn parse_ws_url(url: &str) -> Result<ParsedUrl<'_>, Error> {
54 let (tls, rest) = if let Some(r) = url.strip_prefix("wss://") {
55 (true, r)
56 } else if let Some(r) = url.strip_prefix("ws://") {
57 (false, r)
58 } else {
59 return Err(Error::InvalidUrl(url.to_string()));
60 };
61
62 let (host_port, path) = rest
63 .find('/')
64 .map_or((rest, "/"), |i| (&rest[..i], &rest[i..]));
65
66 if host_port.is_empty() {
67 return Err(Error::InvalidUrl(format!("empty host: {url}")));
68 }
69
70 let default_port = if tls { 443 } else { 80 };
71
72 let (host, port) = if host_port.starts_with('[') {
74 match host_port.find(']') {
75 Some(end) => {
76 let h = &host_port[1..end];
77 let rest = &host_port[end + 1..];
78 if let Some(port_str) = rest.strip_prefix(':') {
79 let p = port_str
80 .parse::<u16>()
81 .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
82 (h, p)
83 } else {
84 (h, default_port)
85 }
86 }
87 None => return Err(Error::InvalidUrl(format!("unclosed bracket: {url}"))),
88 }
89 } else {
90 match host_port.rfind(':') {
91 None => (host_port, default_port),
92 Some(i) => {
93 let port_str = &host_port[i + 1..];
94 if port_str.is_empty() {
95 (&host_port[..i], default_port)
96 } else {
97 let p = port_str
98 .parse::<u16>()
99 .map_err(|_| Error::InvalidUrl(format!("invalid port: {url}")))?;
100 (&host_port[..i], p)
101 }
102 }
103 }
104 };
105
106 Ok(ParsedUrl {
107 tls,
108 host,
109 port,
110 path,
111 })
112}
113
114#[derive(Debug)]
120pub enum Error {
121 Io(std::io::Error),
123 Protocol(ProtocolError),
125 Encode(super::frame_writer::EncodeError),
127 Handshake(HandshakeError),
129 #[cfg(feature = "tls")]
145 Tls(TlsError),
146 InvalidUrl(String),
148 TlsNotEnabled,
150}
151
152impl std::fmt::Display for Error {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 match self {
155 Self::Io(e) => write!(f, "I/O error: {e}"),
156 Self::Protocol(e) => write!(f, "protocol error: {e}"),
157 Self::Encode(e) => write!(f, "encode error: {e}"),
158 Self::Handshake(e) => write!(f, "handshake error: {e}"),
159 #[cfg(feature = "tls")]
160 Self::Tls(e) => write!(f, "TLS error: {e}"),
161 Self::InvalidUrl(u) => write!(f, "invalid WebSocket URL: {u}"),
162 Self::TlsNotEnabled => write!(f, "wss:// requires the 'tls' feature"),
163 }
164 }
165}
166
167impl std::error::Error for Error {}
168
169impl From<std::io::Error> for Error {
170 fn from(e: std::io::Error) -> Self {
171 Self::Io(e)
172 }
173}
174impl From<ProtocolError> for Error {
175 fn from(e: ProtocolError) -> Self {
176 Self::Protocol(e)
177 }
178}
179impl From<super::frame_writer::EncodeError> for Error {
180 fn from(e: super::frame_writer::EncodeError) -> Self {
181 Self::Encode(e)
182 }
183}
184impl From<HandshakeError> for Error {
185 fn from(e: HandshakeError) -> Self {
186 Self::Handshake(e)
187 }
188}
189#[cfg(feature = "tls")]
190impl From<TlsError> for Error {
191 fn from(e: TlsError) -> Self {
192 match e {
193 TlsError::Io(io) => Self::Io(io),
194 other => Self::Tls(other),
195 }
196 }
197}
198
199pub struct ClientBuilder {
216 pub(crate) reader_builder: FrameReaderBuilder,
217 pub(crate) write_buf_capacity: usize,
218 pub(crate) write_buf_headroom: usize,
219 #[cfg(feature = "tls")]
220 pub(crate) tls_config: Option<TlsConfig>,
221 pub(crate) tcp_nodelay: bool,
222 #[cfg(feature = "socket-opts")]
223 pub(crate) recv_buf_size: Option<usize>,
224 #[cfg(feature = "socket-opts")]
225 pub(crate) send_buf_size: Option<usize>,
226 pub(crate) connect_timeout: Option<Duration>,
227 pub(crate) read_timeout: Option<Duration>,
228}
229
230impl ClientBuilder {
231 #[must_use]
233 pub fn new() -> Self {
234 Self {
235 reader_builder: FrameReader::builder(),
236 write_buf_capacity: 65_536,
237 write_buf_headroom: 14,
238 #[cfg(feature = "tls")]
239 tls_config: None,
240 tcp_nodelay: false,
241 #[cfg(feature = "socket-opts")]
242 recv_buf_size: None,
243 #[cfg(feature = "socket-opts")]
244 send_buf_size: None,
245 connect_timeout: None,
246 read_timeout: None,
247 }
248 }
249
250 #[must_use]
252 pub fn buffer_capacity(mut self, n: usize) -> Self {
253 self.reader_builder = self.reader_builder.buffer_capacity(n);
254 self
255 }
256
257 #[must_use]
259 pub fn max_frame_size(mut self, n: u64) -> Self {
260 self.reader_builder = self.reader_builder.max_frame_size(n);
261 self
262 }
263
264 #[must_use]
266 pub fn max_message_size(mut self, n: usize) -> Self {
267 self.reader_builder = self.reader_builder.max_message_size(n);
268 self
269 }
270
271 #[must_use]
273 pub fn write_buffer_capacity(mut self, n: usize) -> Self {
274 self.write_buf_capacity = n;
275 self
276 }
277
278 #[must_use]
280 pub fn disable_nagle(mut self) -> Self {
281 self.tcp_nodelay = true;
282 self
283 }
284
285 #[cfg(feature = "socket-opts")]
287 #[must_use]
288 pub fn recv_buffer_size(mut self, n: usize) -> Self {
289 self.recv_buf_size = Some(n);
290 self
291 }
292
293 #[cfg(feature = "socket-opts")]
295 #[must_use]
296 pub fn send_buffer_size(mut self, n: usize) -> Self {
297 self.send_buf_size = Some(n);
298 self
299 }
300
301 #[must_use]
303 pub fn connect_timeout(mut self, d: Duration) -> Self {
304 self.connect_timeout = Some(d);
305 self
306 }
307
308 #[must_use]
310 pub fn read_timeout(mut self, d: Duration) -> Self {
311 self.read_timeout = Some(d);
312 self
313 }
314
315 #[cfg(feature = "tls")]
319 #[must_use]
320 pub fn tls(mut self, config: &TlsConfig) -> Self {
321 self.tls_config = Some(config.clone());
322 self
323 }
324
325 #[cfg(feature = "tls")]
335 pub fn connect(
336 self,
337 url: &str,
338 ) -> Result<Client<nexus_net::MaybeTls<std::net::TcpStream>>, Error> {
339 let parsed = parse_ws_url(url)?;
340 let addr = format!("{}:{}", parsed.host, parsed.port);
341
342 let tcp = match self.connect_timeout {
343 Some(timeout) => {
344 let addrs: Vec<std::net::SocketAddr> =
345 std::net::ToSocketAddrs::to_socket_addrs(&addr)
346 .map_err(Error::Io)?
347 .collect();
348 let first = addrs
349 .first()
350 .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
351 std::net::TcpStream::connect_timeout(first, timeout)?
352 }
353 None => std::net::TcpStream::connect(&addr)?,
354 };
355
356 self.apply_socket_opts(&tcp)?;
357
358 let stream = if parsed.tls {
359 let config = match self.tls_config {
360 Some(c) => c,
361 None => TlsConfig::new().map_err(Error::Tls)?,
362 };
363 let codec = nexus_net::tls::TlsCodec::new(&config, parsed.host)?;
364 let tls = nexus_net::tls::TlsStream::connect(tcp, codec).map_err(Error::Tls)?;
365 nexus_net::MaybeTls::Tls(Box::new(tls))
366 } else {
367 nexus_net::MaybeTls::Plain(tcp)
368 };
369
370 let host_header = parsed.host_header();
371 Client::connect_impl(
372 stream,
373 &host_header,
374 parsed.path,
375 self.reader_builder,
376 self.write_buf_capacity,
377 self.write_buf_headroom,
378 )
379 }
380
381 #[cfg(not(feature = "tls"))]
383 pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, Error> {
384 let parsed = parse_ws_url(url)?;
385 if parsed.tls {
386 return Err(Error::TlsNotEnabled);
387 }
388 let addr = format!("{}:{}", parsed.host, parsed.port);
389
390 let tcp = match self.connect_timeout {
391 Some(timeout) => {
392 let addrs: Vec<std::net::SocketAddr> =
393 std::net::ToSocketAddrs::to_socket_addrs(&addr)
394 .map_err(Error::Io)?
395 .collect();
396 let first = addrs
397 .first()
398 .ok_or_else(|| Error::Io(io::Error::other("DNS resolution failed")))?;
399 std::net::TcpStream::connect_timeout(first, timeout)?
400 }
401 None => std::net::TcpStream::connect(&addr)?,
402 };
403
404 self.apply_socket_opts(&tcp)?;
405
406 let host_header = parsed.host_header();
407 Client::connect_impl(
408 tcp,
409 &host_header,
410 parsed.path,
411 self.reader_builder,
412 self.write_buf_capacity,
413 self.write_buf_headroom,
414 )
415 }
416
417 pub fn connect_with<S: Read + Write>(self, stream: S, url: &str) -> Result<Client<S>, Error> {
423 let parsed = parse_ws_url(url)?;
424 let host_header = parsed.host_header();
425 Client::connect_impl(
426 stream,
427 &host_header,
428 parsed.path,
429 self.reader_builder,
430 self.write_buf_capacity,
431 self.write_buf_headroom,
432 )
433 }
434
435 pub fn accept<S: Read + Write>(self, stream: S) -> Result<Client<S>, Error> {
437 Client::accept_impl(
438 stream,
439 self.reader_builder,
440 self.write_buf_capacity,
441 self.write_buf_headroom,
442 )
443 }
444
445 fn apply_socket_opts(&self, tcp: &std::net::TcpStream) -> Result<(), Error> {
446 if self.tcp_nodelay {
447 tcp.set_nodelay(true)?;
448 }
449 if let Some(timeout) = self.read_timeout {
450 tcp.set_read_timeout(Some(timeout))?;
451 }
452 #[cfg(feature = "socket-opts")]
453 {
454 let sock = socket2::SockRef::from(tcp);
455 if let Some(size) = self.recv_buf_size {
456 sock.set_recv_buffer_size(size)?;
457 }
458 if let Some(size) = self.send_buf_size {
459 sock.set_send_buffer_size(size)?;
460 }
461 }
462 Ok(())
463 }
464}
465
466impl Default for ClientBuilder {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472pub struct Client<S> {
501 pub(crate) stream: S,
502 pub(crate) reader: FrameReader,
503 pub(crate) writer: FrameWriter,
504 pub(crate) write_buf: WriteBuf,
505 pub(crate) poisoned: bool,
506}
507
508impl Client<std::net::TcpStream> {
509 #[must_use]
511 pub fn builder() -> ClientBuilder {
512 ClientBuilder::new()
513 }
514}
515
516impl<S> Client<S> {
519 pub fn from_parts(stream: S, reader: FrameReader, writer: FrameWriter) -> Self {
521 Self {
522 stream,
523 reader,
524 writer,
525 write_buf: WriteBuf::new(65_536, 14),
526 poisoned: false,
527 }
528 }
529
530 pub(crate) fn from_parts_internal(
532 stream: S,
533 reader: FrameReader,
534 writer: FrameWriter,
535 write_buf: WriteBuf,
536 ) -> Self {
537 Self {
538 stream,
539 reader,
540 writer,
541 write_buf,
542 poisoned: false,
543 }
544 }
545
546 pub fn is_poisoned(&self) -> bool {
551 self.poisoned
552 }
553
554 pub fn stream(&self) -> &S {
556 &self.stream
557 }
558
559 pub fn stream_mut(&mut self) -> &mut S {
561 &mut self.stream
562 }
563
564 pub fn reader(&self) -> &FrameReader {
566 &self.reader
567 }
568
569 pub fn frame_writer(&self) -> &FrameWriter {
571 &self.writer
572 }
573}
574
575impl<S: Read + Write> Client<S> {
578 pub fn connect_with(stream: S, url: &str) -> Result<Self, Error> {
582 ClientBuilder::new().connect_with(stream, url)
583 }
584
585 pub fn accept(stream: S) -> Result<Self, Error> {
587 ClientBuilder::new().accept(stream)
588 }
589
590 pub fn recv(&mut self) -> Result<Option<Message<'_>>, Error> {
594 loop {
595 if self.reader.poll()? {
596 return Ok(self.reader.next()?);
597 }
598 match self.read_into_reader() {
599 Ok(0) => return Ok(None),
600 Ok(_) => {}
601 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
602 Err(e) => return Err(Error::Io(e)),
603 }
604 }
605 }
606
607 pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
609 self.writer
610 .encode_text_into(text.as_bytes(), &mut self.write_buf);
611 self.flush_write_buf_or_poison()
612 }
613
614 pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
616 self.writer.encode_binary_into(data, &mut self.write_buf);
617 self.flush_write_buf_or_poison()
618 }
619
620 pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
622 self.writer
623 .encode_ping_into(data, &mut self.write_buf)
624 .map_err(Error::Encode)?;
625 self.flush_write_buf_or_poison()
626 }
627
628 pub fn send_pong(&mut self, data: &[u8]) -> Result<(), Error> {
630 self.writer
631 .encode_pong_into(data, &mut self.write_buf)
632 .map_err(Error::Encode)?;
633 self.flush_write_buf_or_poison()
634 }
635
636 pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
638 if code == CloseCode::NoStatus {
639 let mut dst = [0u8; 14];
640 let n = self.writer.encode_empty_close(&mut dst);
641 self.write_raw(&dst[..n]).inspect_err(|_| {
642 self.poisoned = true;
643 })
644 } else {
645 self.writer
646 .encode_close_into(code.as_u16(), reason.as_bytes(), &mut self.write_buf)
647 .map_err(Error::Encode)?;
648 self.flush_write_buf_or_poison()
649 }
650 }
651
652 fn read_into_reader(&mut self) -> io::Result<usize> {
661 self.reader.read_from(&mut self.stream)
662 }
663
664 fn flush_write_buf_or_poison(&mut self) -> Result<(), Error> {
666 self.flush_write_buf().inspect_err(|_| {
667 self.poisoned = true;
668 })
669 }
670
671 fn flush_write_buf(&mut self) -> Result<(), Error> {
673 self.stream.write_all(self.write_buf.data())?;
674 Ok(())
675 }
676
677 fn write_raw(&mut self, data: &[u8]) -> Result<(), Error> {
679 self.stream.write_all(data)?;
680 Ok(())
681 }
682
683 pub(crate) fn connect_impl(
690 mut stream: S,
691 host: &str,
692 path: &str,
693 reader_builder: FrameReaderBuilder,
694 write_cap: usize,
695 write_headroom: usize,
696 ) -> Result<Self, Error> {
697 let key = handshake::generate_key();
698 let key_str = std::str::from_utf8(&key).expect("base64 output is valid ASCII");
699
700 let headers = [
701 ("Host", host),
702 ("Upgrade", "websocket"),
703 ("Connection", "Upgrade"),
704 ("Sec-WebSocket-Key", key_str),
705 ("Sec-WebSocket-Version", "13"),
706 ];
707 let req_size = crate::http::request_size("GET", path, &headers);
708 let mut req_buf = vec![0u8; req_size];
709 let n = crate::http::write_request("GET", path, &headers, &mut req_buf)
710 .map_err(|_| HandshakeError::MalformedHttp)?;
711
712 stream.write_all(&req_buf[..n])?;
713
714 let mut resp_reader = crate::http::ResponseReader::new(4096);
715 let mut tmp = [0u8; 4096];
716 loop {
717 let bytes_read = stream.read(&mut tmp)?;
718 if bytes_read == 0 {
719 return Err(HandshakeError::MalformedHttp.into());
720 }
721
722 resp_reader
723 .read(&tmp[..bytes_read])
724 .map_err(|_| HandshakeError::MalformedHttp)?;
725 match resp_reader.next() {
726 Ok(Some(resp)) => {
727 if resp.status != 101 {
728 return Err(HandshakeError::UnexpectedStatus(resp.status).into());
729 }
730 let upgrade = resp
731 .header("Upgrade")
732 .ok_or(HandshakeError::MissingUpgrade)?;
733 if !upgrade.eq_ignore_ascii_case("websocket") {
734 return Err(HandshakeError::MissingUpgrade.into());
735 }
736 let conn = resp
737 .header("Connection")
738 .ok_or(HandshakeError::MissingConnection)?;
739 if !contains_ignore_case(conn, "upgrade") {
740 return Err(HandshakeError::MissingConnection.into());
741 }
742 let accept = resp
743 .header("Sec-WebSocket-Accept")
744 .ok_or(HandshakeError::InvalidAcceptKey)?;
745 if !handshake::validate_accept(key_str, accept) {
746 return Err(HandshakeError::InvalidAcceptKey.into());
747 }
748
749 let mut reader = reader_builder.role(Role::Client).build();
750 let remainder = resp_reader.remainder();
751 if !remainder.is_empty() {
752 reader
753 .read(remainder)
754 .map_err(|_| HandshakeError::MalformedHttp)?;
755 }
756
757 return Ok(Self {
758 stream,
759 reader,
760 writer: FrameWriter::new(Role::Client),
761 write_buf: WriteBuf::new(write_cap, write_headroom),
762 poisoned: false,
763 });
764 }
765 Ok(None) => {} Err(_) => return Err(HandshakeError::MalformedHttp.into()),
767 }
768 }
769 }
770
771 fn accept_impl(
772 mut stream: S,
773 reader_builder: FrameReaderBuilder,
774 write_cap: usize,
775 write_headroom: usize,
776 ) -> Result<Self, Error> {
777 let mut req_reader = crate::http::RequestReader::new(4096);
778 let mut tmp = [0u8; 4096];
779
780 let ws_key;
781 loop {
782 let n = stream.read(&mut tmp)?;
783 if n == 0 {
784 return Err(HandshakeError::MalformedHttp.into());
785 }
786 req_reader
787 .read(&tmp[..n])
788 .map_err(|_| HandshakeError::MalformedHttp)?;
789 match req_reader.next() {
790 Ok(Some(req)) => {
791 if req.method != "GET" {
792 return Err(HandshakeError::MalformedHttp.into());
793 }
794 let upgrade = req
795 .header("Upgrade")
796 .ok_or(HandshakeError::MissingUpgrade)?;
797 if !upgrade.eq_ignore_ascii_case("websocket") {
798 return Err(HandshakeError::MissingUpgrade.into());
799 }
800 let conn = req
801 .header("Connection")
802 .ok_or(HandshakeError::MissingConnection)?;
803 if !contains_ignore_case(conn, "upgrade") {
804 return Err(HandshakeError::MissingConnection.into());
805 }
806 let version = req
807 .header("Sec-WebSocket-Version")
808 .ok_or(HandshakeError::UnsupportedVersion)?;
809 if version != "13" {
810 return Err(HandshakeError::UnsupportedVersion.into());
811 }
812 let key = req
813 .header("Sec-WebSocket-Key")
814 .ok_or(HandshakeError::MissingKey)?;
815 ws_key = key.to_owned();
816 break;
817 }
818 Ok(None) => {}
819 Err(_) => return Err(HandshakeError::MalformedHttp.into()),
820 }
821 }
822
823 let accept = handshake::compute_accept_key(&ws_key);
824 let accept_str = std::str::from_utf8(&accept).expect("base64 output is valid ASCII");
825
826 let resp_headers = [
827 ("Upgrade", "websocket"),
828 ("Connection", "Upgrade"),
829 ("Sec-WebSocket-Accept", accept_str),
830 ];
831 let resp_size = crate::http::response_size("Switching Protocols", &resp_headers);
832 let mut resp_buf = vec![0u8; resp_size];
833 let n =
834 crate::http::write_response(101, "Switching Protocols", &resp_headers, &mut resp_buf)
835 .map_err(|_| HandshakeError::MalformedHttp)?;
836 stream.write_all(&resp_buf[..n])?;
837
838 let mut reader = reader_builder.role(Role::Server).build();
839 let remainder = req_reader.remainder();
840 if !remainder.is_empty() {
841 reader
842 .read(remainder)
843 .map_err(|_| HandshakeError::MalformedHttp)?;
844 }
845
846 Ok(Self {
847 stream,
848 reader,
849 writer: FrameWriter::new(Role::Server),
850 write_buf: WriteBuf::new(write_cap, write_headroom),
851 poisoned: false,
852 })
853 }
854}
855
856pub fn pair(role: Role) -> (FrameReader, FrameWriter) {
860 (
861 FrameReader::builder().role(role).build(),
862 FrameWriter::new(role),
863 )
864}
865
866pub fn pair_with(role: Role, reader_builder: FrameReaderBuilder) -> (FrameReader, FrameWriter) {
868 (reader_builder.role(role).build(), FrameWriter::new(role))
869}
870
871fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
872 haystack
873 .as_bytes()
874 .windows(needle.len())
875 .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
876}
877
878#[cfg(test)]
879mod tests {
880 use super::*;
881
882 #[test]
887 fn parse_ws_url_plain() {
888 let p = parse_ws_url("ws://localhost:8080/ws").unwrap();
889 assert!(!p.tls);
890 assert_eq!(p.host, "localhost");
891 assert_eq!(p.port, 8080);
892 assert_eq!(p.path, "/ws");
893 }
894
895 #[test]
896 fn parse_ws_url_tls() {
897 let p = parse_ws_url("wss://exchange.com/ws/v1").unwrap();
898 assert!(p.tls);
899 assert_eq!(p.host, "exchange.com");
900 assert_eq!(p.port, 443);
901 assert_eq!(p.path, "/ws/v1");
902 }
903
904 #[test]
905 fn parse_ws_url_default_port() {
906 let p = parse_ws_url("ws://host/path").unwrap();
907 assert_eq!(p.port, 80);
908
909 let p = parse_ws_url("wss://host/path").unwrap();
910 assert_eq!(p.port, 443);
911 }
912
913 #[test]
914 fn parse_ws_url_no_path() {
915 let p = parse_ws_url("ws://host").unwrap();
916 assert_eq!(p.path, "/");
917 }
918
919 #[test]
920 fn parse_ws_url_invalid_scheme() {
921 assert!(parse_ws_url("http://host").is_err());
922 assert!(parse_ws_url("host/path").is_err());
923 }
924
925 mod sync_tests {
930 use super::*;
931 use std::io::{self, Read, Write};
932
933 #[test]
934 fn pair_creates_matching_roles() {
935 let (mut reader, _writer) = pair(Role::Client);
936 let frame = make_frame(true, 0x1, b"test");
937 reader.read(&frame).unwrap();
938 let msg = reader.next().unwrap().unwrap();
939 assert!(matches!(msg, Message::Text(s) if s == "test"));
940 }
941
942 struct ByteAtATimeStream {
943 data: Vec<u8>,
944 pos: usize,
945 }
946
947 impl ByteAtATimeStream {
948 fn new(data: Vec<u8>) -> Self {
949 Self { data, pos: 0 }
950 }
951 }
952
953 impl Read for ByteAtATimeStream {
954 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
955 if self.pos >= self.data.len() {
956 return Ok(0);
957 }
958 buf[0] = self.data[self.pos];
959 self.pos += 1;
960 Ok(1)
961 }
962 }
963
964 impl Write for ByteAtATimeStream {
965 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
966 Ok(buf.len())
967 }
968 fn flush(&mut self) -> io::Result<()> {
969 Ok(())
970 }
971 }
972
973 fn make_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec<u8> {
974 let mut frame = Vec::new();
975 let byte0 = if fin { 0x80 } else { 0x00 } | opcode;
976 frame.push(byte0);
977 if payload.len() <= 125 {
978 frame.push(payload.len() as u8);
979 } else if payload.len() <= 65535 {
980 frame.push(126);
981 frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
982 } else {
983 frame.push(127);
984 frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
985 }
986 frame.extend_from_slice(payload);
987 frame
988 }
989
990 fn ws_from_bytes(data: Vec<u8>) -> Client<ByteAtATimeStream> {
991 let mock = ByteAtATimeStream::new(data);
992 let reader = FrameReader::builder().role(Role::Client).build();
993 let writer = FrameWriter::new(Role::Client);
994 Client::from_parts(mock, reader, writer)
995 }
996
997 #[test]
998 fn recv_text() {
999 let frame = make_frame(true, 0x1, b"Hello");
1000 let mut ws = ws_from_bytes(frame);
1001 match ws.recv().unwrap().unwrap() {
1002 Message::Text(s) => assert_eq!(s, "Hello"),
1003 other => panic!("expected Text, got {other:?}"),
1004 }
1005 }
1006
1007 #[test]
1008 fn recv_ping() {
1009 let frame = make_frame(true, 0x9, &[0x42; 125]);
1010 let mut ws = ws_from_bytes(frame);
1011 match ws.recv().unwrap().unwrap() {
1012 Message::Ping(p) => assert_eq!(p.len(), 125),
1013 other => panic!("expected Ping, got {other:?}"),
1014 }
1015 }
1016
1017 #[test]
1018 fn recv_fragmented_text() {
1019 let mut data = make_frame(false, 0x1, b"Hel");
1020 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1021 let mut ws = ws_from_bytes(data);
1022 match ws.recv().unwrap().unwrap() {
1023 Message::Text(s) => assert_eq!(s, "Hello"),
1024 other => panic!("expected Text, got {other:?}"),
1025 }
1026 }
1027
1028 #[test]
1029 fn recv_fragment_with_ping() {
1030 let mut data = make_frame(false, 0x1, b"Hel");
1031 data.extend_from_slice(&make_frame(true, 0x9, b"ping"));
1032 data.extend_from_slice(&make_frame(true, 0x0, b"lo"));
1033 let mut ws = ws_from_bytes(data);
1034 match ws.recv().unwrap().unwrap() {
1035 Message::Ping(p) => assert_eq!(p, b"ping"),
1036 other => panic!("expected Ping, got {other:?}"),
1037 }
1038 match ws.recv().unwrap().unwrap() {
1039 Message::Text(s) => assert_eq!(s, "Hello"),
1040 other => panic!("expected Text, got {other:?}"),
1041 }
1042 }
1043
1044 #[test]
1045 fn recv_close() {
1046 let mut payload = vec![];
1047 payload.extend_from_slice(&1000u16.to_be_bytes());
1048 payload.extend_from_slice(b"bye");
1049 let frame = make_frame(true, 0x8, &payload);
1050 let mut ws = ws_from_bytes(frame);
1051 match ws.recv().unwrap().unwrap() {
1052 Message::Close(cf) => {
1053 assert_eq!(cf.code, CloseCode::Normal);
1054 assert_eq!(cf.reason, "bye");
1055 }
1056 other => panic!("expected Close, got {other:?}"),
1057 }
1058 }
1059
1060 #[test]
1061 fn eof_returns_none() {
1062 let mut ws = ws_from_bytes(Vec::new());
1063 assert!(ws.recv().unwrap().is_none());
1064 }
1065
1066 #[test]
1067 fn would_block_returns_none() {
1068 struct WouldBlockStream;
1069 impl Read for WouldBlockStream {
1070 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
1071 Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
1072 }
1073 }
1074 impl Write for WouldBlockStream {
1075 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1076 Ok(buf.len())
1077 }
1078 fn flush(&mut self) -> io::Result<()> {
1079 Ok(())
1080 }
1081 }
1082
1083 let reader = FrameReader::builder().role(Role::Client).build();
1084 let writer = FrameWriter::new(Role::Client);
1085 let mut ws = Client::from_parts(WouldBlockStream, reader, writer);
1086 assert!(ws.recv().unwrap().is_none());
1087 }
1088 }
1089
1090 #[test]
1095 fn ws_error_io() {
1096 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
1097 let err = Error::from(io_err);
1098 assert!(matches!(err, Error::Io(_)));
1099 assert!(err.to_string().contains("broken"));
1100 }
1101
1102 #[test]
1103 fn ws_error_protocol() {
1104 let proto = ProtocolError::InvalidUtf8;
1105 let err = Error::from(proto);
1106 assert!(matches!(err, Error::Protocol(ProtocolError::InvalidUtf8)));
1107 assert!(err.to_string().contains("protocol error"));
1108 }
1109
1110 #[test]
1111 fn ws_error_encode() {
1112 let enc = crate::ws::EncodeError::ControlPayloadTooLarge(200);
1113 let err = Error::from(enc);
1114 assert!(matches!(err, Error::Encode(_)));
1115 assert!(err.to_string().contains("encode error"));
1116 }
1117
1118 #[test]
1119 fn ws_error_handshake() {
1120 let hs = HandshakeError::MissingUpgrade;
1121 let err = Error::from(hs);
1122 assert!(matches!(
1123 err,
1124 Error::Handshake(HandshakeError::MissingUpgrade)
1125 ));
1126 assert!(err.to_string().contains("handshake error"));
1127 }
1128
1129 #[test]
1130 fn ws_error_invalid_url() {
1131 let err = Error::InvalidUrl("bad://url".into());
1132 assert!(matches!(err, Error::InvalidUrl(_)));
1133 assert!(err.to_string().contains("bad://url"));
1134 }
1135
1136 #[test]
1137 fn ws_error_tls_not_enabled() {
1138 let err = Error::TlsNotEnabled;
1139 assert!(matches!(err, Error::TlsNotEnabled));
1140 assert!(err.to_string().contains("tls"));
1141 }
1142}