1use super::common::{Headers, ProtocolError, ProtocolResult, Timeout, Uri};
15use std::time::Duration;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum Message {
20 Text(String),
22 Binary(Vec<u8>),
24 Ping(Vec<u8>),
26 Pong(Vec<u8>),
28 Close(Option<CloseFrame>),
30}
31
32impl Message {
33 pub fn text(s: impl Into<String>) -> Self {
35 Message::Text(s.into())
36 }
37
38 pub fn binary(data: impl Into<Vec<u8>>) -> Self {
40 Message::Binary(data.into())
41 }
42
43 pub fn ping(data: impl Into<Vec<u8>>) -> Self {
45 Message::Ping(data.into())
46 }
47
48 pub fn pong(data: impl Into<Vec<u8>>) -> Self {
50 Message::Pong(data.into())
51 }
52
53 pub fn close(code: CloseCode, reason: impl Into<String>) -> Self {
55 Message::Close(Some(CloseFrame {
56 code,
57 reason: reason.into(),
58 }))
59 }
60
61 pub fn is_text(&self) -> bool {
63 matches!(self, Message::Text(_))
64 }
65
66 pub fn is_binary(&self) -> bool {
68 matches!(self, Message::Binary(_))
69 }
70
71 pub fn is_ping(&self) -> bool {
73 matches!(self, Message::Ping(_))
74 }
75
76 pub fn is_pong(&self) -> bool {
78 matches!(self, Message::Pong(_))
79 }
80
81 pub fn is_close(&self) -> bool {
83 matches!(self, Message::Close(_))
84 }
85
86 pub fn is_data(&self) -> bool {
88 matches!(self, Message::Text(_) | Message::Binary(_))
89 }
90
91 pub fn as_text(&self) -> Option<&str> {
93 match self {
94 Message::Text(s) => Some(s),
95 _ => None,
96 }
97 }
98
99 pub fn as_binary(&self) -> Option<&[u8]> {
101 match self {
102 Message::Binary(b) => Some(b),
103 _ => None,
104 }
105 }
106
107 pub fn into_text(self) -> Option<String> {
109 match self {
110 Message::Text(s) => Some(s),
111 _ => None,
112 }
113 }
114
115 pub fn into_bytes(self) -> Option<Vec<u8>> {
117 match self {
118 Message::Binary(b) => Some(b),
119 _ => None,
120 }
121 }
122
123 pub fn len(&self) -> usize {
125 match self {
126 Message::Text(s) => s.len(),
127 Message::Binary(b) | Message::Ping(b) | Message::Pong(b) => b.len(),
128 Message::Close(Some(frame)) => 2 + frame.reason.len(),
129 Message::Close(None) => 0,
130 }
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.len() == 0
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
141pub struct CloseFrame {
142 pub code: CloseCode,
144 pub reason: String,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
150#[repr(u16)]
151pub enum CloseCode {
152 Normal = 1000,
154 GoingAway = 1001,
156 Protocol = 1002,
158 Unsupported = 1003,
160 NoStatus = 1005,
162 Abnormal = 1006,
164 InvalidData = 1007,
166 Policy = 1008,
168 MessageTooBig = 1009,
170 MissingExtension = 1010,
172 InternalError = 1011,
174 TlsFailure = 1015,
176 Custom(u16),
178}
179
180impl CloseCode {
181 pub fn from_u16(code: u16) -> Self {
183 match code {
184 1000 => CloseCode::Normal,
185 1001 => CloseCode::GoingAway,
186 1002 => CloseCode::Protocol,
187 1003 => CloseCode::Unsupported,
188 1005 => CloseCode::NoStatus,
189 1006 => CloseCode::Abnormal,
190 1007 => CloseCode::InvalidData,
191 1008 => CloseCode::Policy,
192 1009 => CloseCode::MessageTooBig,
193 1010 => CloseCode::MissingExtension,
194 1011 => CloseCode::InternalError,
195 1015 => CloseCode::TlsFailure,
196 _ => CloseCode::Custom(code),
197 }
198 }
199
200 pub fn as_u16(&self) -> u16 {
202 match self {
203 CloseCode::Normal => 1000,
204 CloseCode::GoingAway => 1001,
205 CloseCode::Protocol => 1002,
206 CloseCode::Unsupported => 1003,
207 CloseCode::NoStatus => 1005,
208 CloseCode::Abnormal => 1006,
209 CloseCode::InvalidData => 1007,
210 CloseCode::Policy => 1008,
211 CloseCode::MessageTooBig => 1009,
212 CloseCode::MissingExtension => 1010,
213 CloseCode::InternalError => 1011,
214 CloseCode::TlsFailure => 1015,
215 CloseCode::Custom(code) => *code,
216 }
217 }
218
219 pub fn description(&self) -> &'static str {
221 match self {
222 CloseCode::Normal => "Normal closure",
223 CloseCode::GoingAway => "Endpoint going away",
224 CloseCode::Protocol => "Protocol error",
225 CloseCode::Unsupported => "Unsupported data type",
226 CloseCode::NoStatus => "No status received",
227 CloseCode::Abnormal => "Abnormal closure",
228 CloseCode::InvalidData => "Invalid payload data",
229 CloseCode::Policy => "Policy violation",
230 CloseCode::MessageTooBig => "Message too big",
231 CloseCode::MissingExtension => "Missing extension",
232 CloseCode::InternalError => "Internal server error",
233 CloseCode::TlsFailure => "TLS handshake failure",
234 CloseCode::Custom(_) => "Custom close code",
235 }
236 }
237}
238
239impl std::fmt::Display for CloseCode {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "{} ({})", self.description(), self.as_u16())
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct WebSocketConfig {
248 pub max_message_size: usize,
250 pub max_frame_size: usize,
252 pub accept_unmasked_frames: bool,
254 pub subprotocols: Vec<String>,
256 pub headers: Headers,
258 pub connect_timeout: Duration,
260 pub ping_interval: Option<Duration>,
262 pub pong_timeout: Option<Duration>,
264}
265
266impl Default for WebSocketConfig {
267 fn default() -> Self {
268 WebSocketConfig {
269 max_message_size: 64 * 1024 * 1024, max_frame_size: 16 * 1024 * 1024, accept_unmasked_frames: false,
272 subprotocols: Vec::new(),
273 headers: Headers::new(),
274 connect_timeout: Duration::from_secs(30),
275 ping_interval: Some(Duration::from_secs(30)),
276 pong_timeout: Some(Duration::from_secs(10)),
277 }
278 }
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
283pub enum ConnectionState {
284 Connecting,
286 Open,
288 Closing,
290 Closed,
292}
293
294#[derive(Debug)]
296pub struct WebSocket {
297 url: String,
299 config: WebSocketConfig,
301 state: ConnectionState,
303 subprotocol: Option<String>,
305 #[cfg(feature = "tokio-tungstenite")]
306 inner: Option<
307 tokio_tungstenite::WebSocketStream<
308 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
309 >,
310 >,
311}
312
313impl WebSocket {
314 pub fn builder(url: impl Into<String>) -> WebSocketBuilder {
316 WebSocketBuilder {
317 url: url.into(),
318 config: WebSocketConfig::default(),
319 }
320 }
321
322 #[cfg(feature = "tokio-tungstenite")]
324 pub async fn connect(url: impl Into<String>) -> ProtocolResult<Self> {
325 Self::builder(url).connect().await
326 }
327
328 #[cfg(not(feature = "tokio-tungstenite"))]
330 pub async fn connect(url: impl Into<String>) -> ProtocolResult<Self> {
331 let _ = url;
332 Err(ProtocolError::Protocol(
333 "WebSocket requires 'websocket' feature".to_string(),
334 ))
335 }
336
337 pub fn url(&self) -> &str {
339 &self.url
340 }
341
342 pub fn state(&self) -> ConnectionState {
344 self.state
345 }
346
347 pub fn is_open(&self) -> bool {
349 self.state == ConnectionState::Open
350 }
351
352 pub fn subprotocol(&self) -> Option<&str> {
354 self.subprotocol.as_deref()
355 }
356
357 #[cfg(feature = "tokio-tungstenite")]
359 pub async fn send(&mut self, message: Message) -> ProtocolResult<()> {
360 use futures_util::SinkExt;
361 use tokio_tungstenite::tungstenite::Message as TMessage;
362
363 if self.state != ConnectionState::Open {
364 return Err(ProtocolError::ChannelClosed);
365 }
366
367 let msg = match message {
368 Message::Text(s) => TMessage::Text(s),
369 Message::Binary(b) => TMessage::Binary(b),
370 Message::Ping(b) => TMessage::Ping(b),
371 Message::Pong(b) => TMessage::Pong(b),
372 Message::Close(frame) => {
373 let close_frame = frame.map(|f| {
374 tokio_tungstenite::tungstenite::protocol::CloseFrame {
375 code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(f.code.as_u16()),
376 reason: f.reason.into(),
377 }
378 });
379 TMessage::Close(close_frame)
380 }
381 };
382
383 if let Some(ref mut inner) = self.inner {
384 inner
385 .send(msg)
386 .await
387 .map_err(|e| ProtocolError::Protocol(e.to_string()))?;
388 }
389
390 Ok(())
391 }
392
393 #[cfg(not(feature = "tokio-tungstenite"))]
395 pub async fn send(&mut self, message: Message) -> ProtocolResult<()> {
396 let _ = message;
397 Err(ProtocolError::Protocol(
398 "WebSocket requires 'websocket' feature".to_string(),
399 ))
400 }
401
402 #[cfg(feature = "tokio-tungstenite")]
404 pub async fn recv(&mut self) -> ProtocolResult<Option<Message>> {
405 use futures_util::StreamExt;
406 use tokio_tungstenite::tungstenite::Message as TMessage;
407
408 if self.state != ConnectionState::Open {
409 return Ok(None);
410 }
411
412 if let Some(ref mut inner) = self.inner {
413 loop {
415 match inner.next().await {
416 Some(Ok(msg)) => {
417 let message = match msg {
418 TMessage::Text(s) => Message::Text(s),
419 TMessage::Binary(b) => Message::Binary(b),
420 TMessage::Ping(b) => Message::Ping(b),
421 TMessage::Pong(b) => Message::Pong(b),
422 TMessage::Close(frame) => {
423 self.state = ConnectionState::Closed;
424 Message::Close(frame.map(|f| CloseFrame {
425 code: CloseCode::from_u16(f.code.into()),
426 reason: f.reason.to_string(),
427 }))
428 }
429 TMessage::Frame(_) => continue, };
431 return Ok(Some(message));
432 }
433 Some(Err(e)) => {
434 self.state = ConnectionState::Closed;
435 return Err(ProtocolError::Protocol(e.to_string()));
436 }
437 None => {
438 self.state = ConnectionState::Closed;
439 return Ok(None);
440 }
441 }
442 }
443 } else {
444 Ok(None)
445 }
446 }
447
448 #[cfg(not(feature = "tokio-tungstenite"))]
450 pub async fn recv(&mut self) -> ProtocolResult<Option<Message>> {
451 Err(ProtocolError::Protocol(
452 "WebSocket requires 'websocket' feature".to_string(),
453 ))
454 }
455
456 pub async fn send_text(&mut self, text: impl Into<String>) -> ProtocolResult<()> {
458 self.send(Message::Text(text.into())).await
459 }
460
461 pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> ProtocolResult<()> {
463 self.send(Message::Binary(data.into())).await
464 }
465
466 pub async fn ping(&mut self, data: impl Into<Vec<u8>>) -> ProtocolResult<()> {
468 self.send(Message::Ping(data.into())).await
469 }
470
471 pub async fn close(
473 &mut self,
474 code: CloseCode,
475 reason: impl Into<String>,
476 ) -> ProtocolResult<()> {
477 if self.state != ConnectionState::Open {
478 return Ok(());
479 }
480
481 self.state = ConnectionState::Closing;
482 self.send(Message::close(code, reason)).await?;
483 self.state = ConnectionState::Closed;
484 Ok(())
485 }
486
487 pub async fn close_normal(&mut self) -> ProtocolResult<()> {
489 self.close(CloseCode::Normal, "").await
490 }
491}
492
493#[derive(Debug, Clone)]
495pub struct WebSocketBuilder {
496 url: String,
497 config: WebSocketConfig,
498}
499
500impl WebSocketBuilder {
501 pub fn max_message_size(mut self, size: usize) -> Self {
503 self.config.max_message_size = size;
504 self
505 }
506
507 pub fn max_frame_size(mut self, size: usize) -> Self {
509 self.config.max_frame_size = size;
510 self
511 }
512
513 pub fn subprotocol(mut self, protocol: impl Into<String>) -> Self {
515 self.config.subprotocols.push(protocol.into());
516 self
517 }
518
519 pub fn subprotocols(mut self, protocols: Vec<String>) -> Self {
521 self.config.subprotocols.extend(protocols);
522 self
523 }
524
525 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
527 self.config.headers.insert(key, value);
528 self
529 }
530
531 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
533 self.config.connect_timeout = timeout;
534 self
535 }
536
537 pub fn ping_interval(mut self, interval: Duration) -> Self {
539 self.config.ping_interval = Some(interval);
540 self
541 }
542
543 pub fn no_ping(mut self) -> Self {
545 self.config.ping_interval = None;
546 self
547 }
548
549 pub fn bearer_auth(self, token: impl Into<String>) -> Self {
551 self.header("Authorization", format!("Bearer {}", token.into()))
552 }
553
554 #[cfg(feature = "tokio-tungstenite")]
556 pub async fn connect(self) -> ProtocolResult<WebSocket> {
557 use tokio_tungstenite::connect_async;
558
559 let (ws_stream, _response) = connect_async(&self.url)
560 .await
561 .map_err(|e| ProtocolError::ConnectionFailed(e.to_string()))?;
562
563 Ok(WebSocket {
564 url: self.url,
565 config: self.config,
566 state: ConnectionState::Open,
567 subprotocol: None, inner: Some(ws_stream),
569 })
570 }
571
572 #[cfg(not(feature = "tokio-tungstenite"))]
574 pub async fn connect(self) -> ProtocolResult<WebSocket> {
575 Err(ProtocolError::Protocol(
576 "WebSocket requires 'websocket' feature".to_string(),
577 ))
578 }
579}
580
581#[derive(Debug)]
583pub struct ReconnectingWebSocket {
584 builder: WebSocketBuilder,
586 connection: Option<WebSocket>,
588 reconnect_config: ReconnectConfig,
590 attempt_count: u32,
592}
593
594#[derive(Debug, Clone)]
596pub struct ReconnectConfig {
597 pub initial_delay: Duration,
599 pub max_delay: Duration,
601 pub multiplier: f64,
603 pub max_attempts: Option<u32>,
605}
606
607impl Default for ReconnectConfig {
608 fn default() -> Self {
609 ReconnectConfig {
610 initial_delay: Duration::from_secs(1),
611 max_delay: Duration::from_secs(30),
612 multiplier: 2.0,
613 max_attempts: None,
614 }
615 }
616}
617
618impl ReconnectingWebSocket {
619 pub fn new(url: impl Into<String>) -> Self {
621 ReconnectingWebSocket {
622 builder: WebSocket::builder(url),
623 connection: None,
624 reconnect_config: ReconnectConfig::default(),
625 attempt_count: 0,
626 }
627 }
628
629 pub fn reconnect_config(mut self, config: ReconnectConfig) -> Self {
631 self.reconnect_config = config;
632 self
633 }
634
635 pub fn initial_delay(mut self, delay: Duration) -> Self {
637 self.reconnect_config.initial_delay = delay;
638 self
639 }
640
641 pub fn max_delay(mut self, delay: Duration) -> Self {
643 self.reconnect_config.max_delay = delay;
644 self
645 }
646
647 pub fn max_attempts(mut self, attempts: u32) -> Self {
649 self.reconnect_config.max_attempts = Some(attempts);
650 self
651 }
652
653 pub async fn connect(&mut self) -> ProtocolResult<()> {
655 self.connection = Some(self.builder.clone().connect().await?);
656 self.attempt_count = 0;
657 Ok(())
658 }
659
660 pub fn is_connected(&self) -> bool {
662 self.connection
663 .as_ref()
664 .map(|c| c.is_open())
665 .unwrap_or(false)
666 }
667
668 pub fn connection(&mut self) -> Option<&mut WebSocket> {
670 self.connection.as_mut()
671 }
672
673 fn next_delay(&self) -> Duration {
675 let delay = self.reconnect_config.initial_delay.mul_f64(
676 self.reconnect_config
677 .multiplier
678 .powi(self.attempt_count as i32),
679 );
680 delay.min(self.reconnect_config.max_delay)
681 }
682
683 pub async fn reconnect(&mut self) -> ProtocolResult<()> {
685 if let Some(max) = self.reconnect_config.max_attempts {
686 if self.attempt_count >= max {
687 return Err(ProtocolError::ConnectionFailed(format!(
688 "Max reconnection attempts ({}) exceeded",
689 max
690 )));
691 }
692 }
693
694 let delay = self.next_delay();
695 self.attempt_count += 1;
696
697 #[cfg(feature = "tokio")]
698 tokio::time::sleep(delay).await;
699
700 self.connection = Some(self.builder.clone().connect().await?);
701 self.attempt_count = 0;
702 Ok(())
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_message_types() {
712 let text = Message::text("hello");
713 assert!(text.is_text());
714 assert!(text.is_data());
715 assert_eq!(text.as_text(), Some("hello"));
716
717 let binary = Message::binary(vec![1, 2, 3]);
718 assert!(binary.is_binary());
719 assert!(binary.is_data());
720 assert_eq!(binary.as_binary(), Some(&[1u8, 2, 3][..]));
721
722 let ping = Message::ping(vec![1, 2]);
723 assert!(ping.is_ping());
724 assert!(!ping.is_data());
725 }
726
727 #[test]
728 fn test_close_codes() {
729 assert_eq!(CloseCode::Normal.as_u16(), 1000);
730 assert_eq!(CloseCode::from_u16(1000), CloseCode::Normal);
731 assert_eq!(CloseCode::from_u16(4000), CloseCode::Custom(4000));
732 }
733
734 #[test]
735 fn test_websocket_builder() {
736 let builder = WebSocket::builder("wss://example.com/socket")
737 .subprotocol("graphql-transport-ws")
738 .bearer_auth("token123")
739 .connect_timeout(Duration::from_secs(10));
740
741 assert_eq!(builder.url, "wss://example.com/socket");
742 assert!(builder
743 .config
744 .subprotocols
745 .contains(&"graphql-transport-ws".to_string()));
746 }
747
748 #[test]
749 fn test_reconnect_config() {
750 let config = ReconnectConfig {
751 initial_delay: Duration::from_secs(1),
752 max_delay: Duration::from_secs(30),
753 multiplier: 2.0,
754 max_attempts: Some(5),
755 };
756
757 let ws = ReconnectingWebSocket::new("wss://example.com").reconnect_config(config);
758
759 assert_eq!(ws.reconnect_config.max_attempts, Some(5));
760 }
761}