1use std::future::Future;
2use std::time::Duration;
3
4use bytes::{Bytes, BytesMut};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::time::timeout as tokio_timeout;
7use url::Url;
8
9use crate::transport::connector::MaybeHttpsStream;
10use crate::websocket::error::{WebSocketError, WebSocketResult};
11use crate::websocket::frame::{decode_frame, encode_frame, FrameConfig, FrameDecoder, OpCode};
12use crate::websocket::message::{CloseFrame, Message};
13use crate::websocket::WebSocketConfig;
14
15#[derive(Debug)]
16pub struct WebSocket {
17 stream: MaybeHttpsStream,
18 url: Url,
19 protocol: Option<String>,
20 read_buffer: BytesMut,
21 frame_config: FrameConfig,
22 read_timeout: Option<Duration>,
23 write_timeout: Option<Duration>,
24 decoder: FrameDecoder,
25 close_sent: bool,
26 close_received: bool,
27}
28
29impl WebSocket {
30 pub(crate) fn new(
31 stream: MaybeHttpsStream,
32 url: Url,
33 protocol: Option<String>,
34 config: WebSocketConfig,
35 initial_read_buffer: Bytes,
36 ) -> Self {
37 Self {
38 stream,
39 url,
40 protocol,
41 read_buffer: BytesMut::from(&initial_read_buffer[..]),
42 frame_config: FrameConfig::new(config.max_frame_size, config.max_message_size),
43 read_timeout: config.read_timeout,
44 write_timeout: config.write_timeout,
45 decoder: FrameDecoder::new(),
46 close_sent: false,
47 close_received: false,
48 }
49 }
50
51 pub fn url(&self) -> &Url {
52 &self.url
53 }
54
55 pub fn protocol(&self) -> Option<&str> {
56 self.protocol.as_deref()
57 }
58
59 pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
60 if self.close_sent && !matches!(msg, Message::Close(_)) {
61 return Err(WebSocketError::protocol(
62 &self.url,
63 "cannot send data after close frame",
64 ));
65 }
66
67 match msg {
68 Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
69 Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
70 Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
71 Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
72 Message::Close(frame) => self.close(frame).await,
73 }
74 }
75
76 pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
77 self.send(Message::Text(text.into())).await
78 }
79
80 pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
81 self.send(Message::Binary(bytes.into())).await
82 }
83
84 pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
85 loop {
86 let frame = match decode_frame(&self.url, &mut self.read_buffer, self.frame_config) {
87 Ok(frame) => frame,
88 Err(error) => return Err(self.best_effort_close_for_error(error).await),
89 };
90
91 if let Some(frame) = frame {
92 let message = match self
93 .decoder
94 .decode_message(&self.url, frame, self.frame_config)
95 {
96 Ok(message) => message,
97 Err(error) => return Err(self.best_effort_close_for_error(error).await),
98 };
99
100 match message {
101 Some(Message::Ping(payload)) => {
102 if !self.close_received {
103 self.write_control(OpCode::Pong, &payload).await?;
104 }
105 return Ok(Some(Message::Ping(payload)));
106 }
107 Some(Message::Close(frame)) => {
108 self.close_received = true;
109 if !self.close_sent {
110 self.send_close_raw(frame.clone()).await?;
111 }
112 return Ok(None);
113 }
114 Some(other) => return Ok(Some(other)),
115 None => {}
116 }
117 } else {
118 let mut scratch = [0_u8; 8192];
119 let n = Self::io_with_timeout(
120 &self.url,
121 self.read_timeout,
122 "read",
123 self.stream.read(&mut scratch),
124 )
125 .await?;
126 if n == 0 {
127 return if self.close_sent || self.close_received {
128 Ok(None)
129 } else {
130 Err(WebSocketError::connection_closed(&self.url))
131 };
132 }
133 self.read_buffer.extend_from_slice(&scratch[..n]);
134 }
135 }
136 }
137
138 pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
139 if !self.close_sent {
140 self.send_close_raw(frame).await?;
141 }
142 Ok(())
143 }
144
145 async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
146 if payload.len() > self.frame_config.max_frame_size {
147 return Err(WebSocketError::limit_exceeded(
148 &self.url,
149 format!("frame exceeds {} bytes", self.frame_config.max_frame_size),
150 ));
151 }
152 if matches!(opcode, OpCode::Text | OpCode::Binary)
153 && payload.len() > self.frame_config.max_message_size
154 {
155 return Err(WebSocketError::limit_exceeded(
156 &self.url,
157 format!(
158 "message exceeds {} bytes",
159 self.frame_config.max_message_size
160 ),
161 ));
162 }
163 let bytes = encode_frame(opcode, payload, true)?;
164 Self::io_with_timeout(
165 &self.url,
166 self.write_timeout,
167 "write",
168 self.stream.write_all(&bytes),
169 )
170 .await?;
171 Self::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush()).await
172 }
173
174 async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
175 if payload.len() > 125 {
176 return Err(WebSocketError::protocol(
177 &self.url,
178 "control frame payload exceeds 125 bytes",
179 ));
180 }
181 self.write_frame(opcode, payload).await
182 }
183
184 async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
185 let payload = match frame {
186 Some(frame) => frame.encode(&self.url)?,
187 None => Vec::new(),
188 };
189 self.write_control(OpCode::Close, &payload).await?;
190 self.close_sent = true;
191 Ok(())
192 }
193
194 async fn best_effort_close_for_error(&mut self, error: WebSocketError) -> WebSocketError {
195 if let Some(code) = error.close_code() {
196 if !self.close_sent {
197 let frame = CloseFrame {
198 code,
199 reason: String::new(),
200 };
201 let _ = self.send_close_raw(Some(frame)).await;
202 }
203 }
204 error
205 }
206
207 async fn io_with_timeout<T, F>(
208 url: &Url,
209 timeout: Option<Duration>,
210 operation: &'static str,
211 future: F,
212 ) -> WebSocketResult<T>
213 where
214 F: Future<Output = std::io::Result<T>>,
215 {
216 let result = match timeout {
217 Some(duration) => {
218 tokio_timeout(duration, future)
219 .await
220 .map_err(|_| WebSocketError::Timeout {
221 url: url.to_string(),
222 operation: format!("{operation} after {:?}", duration),
223 })?
224 }
225 None => future.await,
226 };
227
228 result.map_err(|error| WebSocketError::io(url, error))
229 }
230}