1use std::future::Future;
2use std::time::Duration;
3
4use crate::url::Url;
5use bytes::{Bytes, BytesMut};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
7use tokio::time::timeout as tokio_timeout;
8
9use crate::transport::connector::MaybeHttpsStream;
10use crate::websocket::error::{WebSocketError, WebSocketResult};
11use crate::websocket::frame::{
12 decode_frame, encode_frame_append, encode_frame_into, Frame, FrameConfig, FrameDecoder,
13 MaskRng, OpCode,
14};
15use crate::websocket::message::{CloseFrame, Message, PreparedMessage};
16use crate::websocket::WebSocketConfig;
17
18const READ_CHUNK_SIZE: usize = 16 * 1024;
19const INITIAL_READ_CAPACITY: usize = 16 * 1024;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WebSocketFrameOpcode {
23 Continuation,
24 Text,
25 Binary,
26 Close,
27 Ping,
28 Pong,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct WebSocketFrame {
33 pub fin: bool,
34 pub opcode: WebSocketFrameOpcode,
35 pub payload: Bytes,
36}
37
38impl From<OpCode> for WebSocketFrameOpcode {
39 fn from(value: OpCode) -> Self {
40 match value {
41 OpCode::Continuation => Self::Continuation,
42 OpCode::Text => Self::Text,
43 OpCode::Binary => Self::Binary,
44 OpCode::Close => Self::Close,
45 OpCode::Ping => Self::Ping,
46 OpCode::Pong => Self::Pong,
47 }
48 }
49}
50
51impl From<Frame> for WebSocketFrame {
52 fn from(frame: Frame) -> Self {
53 Self {
54 fin: frame.fin,
55 opcode: WebSocketFrameOpcode::from(frame.opcode),
56 payload: frame.payload,
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct WebSocket {
63 stream: MaybeHttpsStream,
64 url: Url,
65 protocol: Option<String>,
66 read_buffer: BytesMut,
67 write_buffer: BytesMut,
68 frame_config: FrameConfig,
69 read_timeout: Option<Duration>,
70 write_timeout: Option<Duration>,
71 decoder: FrameDecoder,
72 mask_rng: MaskRng,
73 close_sent: bool,
74 close_received: bool,
75}
76
77#[derive(Debug)]
78pub struct WebSocketReader {
79 stream: ReadHalf<MaybeHttpsStream>,
80 url: Url,
81 read_buffer: BytesMut,
82 frame_config: FrameConfig,
83 read_timeout: Option<Duration>,
84 decoder: FrameDecoder,
85 close_received: bool,
86}
87
88#[derive(Debug)]
89pub struct WebSocketWriter {
90 stream: WriteHalf<MaybeHttpsStream>,
91 url: Url,
92 write_buffer: BytesMut,
93 frame_config: FrameConfig,
94 write_timeout: Option<Duration>,
95 mask_rng: MaskRng,
96 close_sent: bool,
97}
98
99impl WebSocket {
100 pub(crate) fn new(
101 stream: MaybeHttpsStream,
102 url: Url,
103 protocol: Option<String>,
104 config: WebSocketConfig,
105 initial_read_buffer: Bytes,
106 ) -> Self {
107 let mut read_buffer =
111 BytesMut::with_capacity(INITIAL_READ_CAPACITY.max(initial_read_buffer.len()));
112 read_buffer.extend_from_slice(&initial_read_buffer);
113 Self {
114 stream,
115 url,
116 protocol,
117 read_buffer,
118 write_buffer: BytesMut::with_capacity(READ_CHUNK_SIZE),
119 frame_config: FrameConfig::new(config.max_frame_size, config.max_message_size),
120 read_timeout: config.read_timeout,
121 write_timeout: config.write_timeout,
122 decoder: FrameDecoder::new(),
123 mask_rng: MaskRng::new(),
124 close_sent: false,
125 close_received: false,
126 }
127 }
128
129 pub fn url(&self) -> &Url {
130 &self.url
131 }
132
133 pub fn protocol(&self) -> Option<&str> {
134 self.protocol.as_deref()
135 }
136
137 pub fn split(self) -> (WebSocketReader, WebSocketWriter) {
138 let (read_stream, write_stream) = tokio::io::split(self.stream);
139 let reader = WebSocketReader {
140 stream: read_stream,
141 url: self.url.clone(),
142 read_buffer: self.read_buffer,
143 frame_config: self.frame_config,
144 read_timeout: self.read_timeout,
145 decoder: self.decoder,
146 close_received: self.close_received,
147 };
148 let writer = WebSocketWriter {
149 stream: write_stream,
150 url: self.url,
151 write_buffer: self.write_buffer,
152 frame_config: self.frame_config,
153 write_timeout: self.write_timeout,
154 mask_rng: self.mask_rng,
155 close_sent: self.close_sent,
156 };
157 (reader, writer)
158 }
159
160 pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
161 if self.close_sent && !matches!(msg, Message::Close(_)) {
162 return Err(WebSocketError::protocol(
163 &self.url,
164 "cannot send data after close frame",
165 ));
166 }
167
168 match msg {
169 Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
170 Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
171 Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
172 Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
173 Message::Close(frame) => self.close(frame).await,
174 }
175 }
176
177 pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
178 self.send(Message::Text(text.into())).await
179 }
180
181 pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
182 self.send(Message::Binary(bytes.into())).await
183 }
184
185 pub async fn send_prepared(&mut self, message: &PreparedMessage) -> WebSocketResult<()> {
186 match message {
187 PreparedMessage::Text(bytes) => self.write_frame(OpCode::Text, bytes).await,
188 PreparedMessage::Binary(bytes) => self.write_frame(OpCode::Binary, bytes).await,
189 }
190 }
191
192 pub async fn send_prepared_batch<'a>(
193 &mut self,
194 messages: impl IntoIterator<Item = &'a PreparedMessage>,
195 ) -> WebSocketResult<()> {
196 self.write_prepared_batch(messages).await
197 }
198
199 pub async fn next_frame(&mut self) -> WebSocketResult<Option<WebSocketFrame>> {
200 Self::read_next_frame(
201 &self.url,
202 self.read_timeout,
203 &mut self.stream,
204 &mut self.read_buffer,
205 self.frame_config,
206 )
207 .await
208 }
209
210 pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
211 loop {
212 let frame = match decode_frame(&self.url, &mut self.read_buffer, self.frame_config) {
213 Ok(frame) => frame,
214 Err(error) => return Err(self.best_effort_close_for_error(error).await),
215 };
216
217 if let Some(frame) = frame {
218 let message = match self
219 .decoder
220 .decode_message(&self.url, frame, self.frame_config)
221 {
222 Ok(message) => message,
223 Err(error) => return Err(self.best_effort_close_for_error(error).await),
224 };
225
226 match message {
227 Some(Message::Ping(payload)) => {
228 if !self.close_received {
229 self.write_control(OpCode::Pong, &payload).await?;
230 }
231 return Ok(Some(Message::Ping(payload)));
232 }
233 Some(Message::Close(frame)) => {
234 self.close_received = true;
235 if !self.close_sent {
236 self.send_close_raw(frame.clone()).await?;
237 }
238 return Ok(None);
239 }
240 Some(other) => return Ok(Some(other)),
241 None => {}
242 }
243 } else {
244 self.read_buffer.reserve(READ_CHUNK_SIZE);
245 let n = Self::io_with_timeout(
246 &self.url,
247 self.read_timeout,
248 "read",
249 self.stream.read_buf(&mut self.read_buffer),
250 )
251 .await?;
252 if n == 0 {
253 return if self.close_sent || self.close_received {
254 Ok(None)
255 } else {
256 Err(WebSocketError::connection_closed(&self.url))
257 };
258 }
259 }
260 }
261 }
262
263 pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
264 if !self.close_sent {
265 self.send_close_raw(frame).await?;
266 }
267 Ok(())
268 }
269
270 async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
271 validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
272 encode_frame_into(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
273 Self::io_with_timeout(
274 &self.url,
275 self.write_timeout,
276 "write",
277 self.stream.write_all(&self.write_buffer),
278 )
279 .await
280 }
281
282 async fn write_prepared_batch<'a>(
283 &mut self,
284 messages: impl IntoIterator<Item = &'a PreparedMessage>,
285 ) -> WebSocketResult<()> {
286 self.write_buffer.clear();
292 for message in messages {
293 let (opcode, payload) = match message {
294 PreparedMessage::Text(bytes) => (OpCode::Text, bytes.as_ref()),
295 PreparedMessage::Binary(bytes) => (OpCode::Binary, bytes.as_ref()),
296 };
297 validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
298 encode_frame_append(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
299 }
300 if self.write_buffer.is_empty() {
301 return Ok(());
302 }
303 Self::io_with_timeout(
304 &self.url,
305 self.write_timeout,
306 "write",
307 self.stream.write_all(&self.write_buffer),
308 )
309 .await
310 }
311
312 async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
313 if payload.len() > 125 {
314 return Err(WebSocketError::protocol(
315 &self.url,
316 "control frame payload exceeds 125 bytes",
317 ));
318 }
319 self.write_frame(opcode, payload).await?;
320 Self::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush()).await
321 }
322
323 async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
324 let payload = match frame {
325 Some(frame) => frame.encode(&self.url)?,
326 None => Vec::new(),
327 };
328 self.write_control(OpCode::Close, &payload).await?;
329 self.close_sent = true;
330 Ok(())
331 }
332
333 async fn best_effort_close_for_error(&mut self, error: WebSocketError) -> WebSocketError {
334 if let Some(code) = error.close_code() {
335 if !self.close_sent {
336 let frame = CloseFrame {
337 code,
338 reason: String::new(),
339 };
340 let _ = self.send_close_raw(Some(frame)).await;
341 }
342 }
343 error
344 }
345
346 async fn io_with_timeout<T, F>(
347 url: &Url,
348 timeout: Option<Duration>,
349 operation: &'static str,
350 future: F,
351 ) -> WebSocketResult<T>
352 where
353 F: Future<Output = std::io::Result<T>>,
354 {
355 let result = match timeout {
356 Some(duration) => {
357 tokio_timeout(duration, future)
358 .await
359 .map_err(|_| WebSocketError::Timeout {
360 url: url.to_string(),
361 operation: format!("{operation} after {:?}", duration),
362 })?
363 }
364 None => future.await,
365 };
366
367 result.map_err(|error| WebSocketError::io(url, error))
368 }
369
370 async fn read_next_frame<S>(
371 url: &Url,
372 read_timeout: Option<Duration>,
373 stream: &mut S,
374 read_buffer: &mut BytesMut,
375 frame_config: FrameConfig,
376 ) -> WebSocketResult<Option<WebSocketFrame>>
377 where
378 S: tokio::io::AsyncRead + Unpin,
379 {
380 loop {
381 if let Some(frame) = decode_frame(url, read_buffer, frame_config)? {
382 return Ok(Some(WebSocketFrame {
383 fin: frame.fin,
384 opcode: frame.opcode.into(),
385 payload: frame.payload,
386 }));
387 }
388 read_buffer.reserve(READ_CHUNK_SIZE);
389 let n = Self::io_with_timeout(url, read_timeout, "read", stream.read_buf(read_buffer))
390 .await?;
391 if n == 0 {
392 return Ok(None);
393 }
394 }
395 }
396}
397
398impl WebSocketReader {
399 pub async fn next_frame(&mut self) -> WebSocketResult<Option<WebSocketFrame>> {
400 WebSocket::read_next_frame(
401 &self.url,
402 self.read_timeout,
403 &mut self.stream,
404 &mut self.read_buffer,
405 self.frame_config,
406 )
407 .await
408 }
409
410 pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
411 loop {
412 let frame = decode_frame(&self.url, &mut self.read_buffer, self.frame_config)?;
413 if let Some(frame) = frame {
414 let message = self
415 .decoder
416 .decode_message(&self.url, frame, self.frame_config)?;
417 match message {
418 Some(Message::Close(_)) => {
419 self.close_received = true;
420 return Ok(None);
421 }
422 Some(other) => return Ok(Some(other)),
423 None => {}
424 }
425 } else {
426 self.read_buffer.reserve(READ_CHUNK_SIZE);
427 let n = WebSocket::io_with_timeout(
428 &self.url,
429 self.read_timeout,
430 "read",
431 self.stream.read_buf(&mut self.read_buffer),
432 )
433 .await?;
434 if n == 0 {
435 return if self.close_received {
436 Ok(None)
437 } else {
438 Err(WebSocketError::connection_closed(&self.url))
439 };
440 }
441 }
442 }
443 }
444}
445
446impl WebSocketWriter {
447 pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
448 if self.close_sent && !matches!(msg, Message::Close(_)) {
449 return Err(WebSocketError::protocol(
450 &self.url,
451 "cannot send data after close frame",
452 ));
453 }
454
455 match msg {
456 Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
457 Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
458 Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
459 Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
460 Message::Close(frame) => self.close(frame).await,
461 }
462 }
463
464 pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
465 self.send(Message::Text(text.into())).await
466 }
467
468 pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
469 self.send(Message::Binary(bytes.into())).await
470 }
471
472 pub async fn send_prepared(&mut self, message: &PreparedMessage) -> WebSocketResult<()> {
473 match message {
474 PreparedMessage::Text(bytes) => self.write_frame(OpCode::Text, bytes).await,
475 PreparedMessage::Binary(bytes) => self.write_frame(OpCode::Binary, bytes).await,
476 }
477 }
478
479 pub async fn send_prepared_batch<'a>(
480 &mut self,
481 messages: impl IntoIterator<Item = &'a PreparedMessage>,
482 ) -> WebSocketResult<()> {
483 self.write_prepared_batch(messages).await
484 }
485
486 pub async fn send_ping(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
487 self.send(Message::Ping(bytes.into())).await
488 }
489
490 pub async fn send_pong(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
491 self.send(Message::Pong(bytes.into())).await
492 }
493
494 pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
495 if !self.close_sent {
496 self.send_close_raw(frame).await?;
497 }
498 Ok(())
499 }
500
501 async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
502 validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
503 encode_frame_into(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
504 WebSocket::io_with_timeout(
505 &self.url,
506 self.write_timeout,
507 "write",
508 self.stream.write_all(&self.write_buffer),
509 )
510 .await
511 }
512
513 async fn write_prepared_batch<'a>(
514 &mut self,
515 messages: impl IntoIterator<Item = &'a PreparedMessage>,
516 ) -> WebSocketResult<()> {
517 self.write_buffer.clear();
522 for message in messages {
523 let (opcode, payload) = match message {
524 PreparedMessage::Text(bytes) => (OpCode::Text, bytes.as_ref()),
525 PreparedMessage::Binary(bytes) => (OpCode::Binary, bytes.as_ref()),
526 };
527 validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
528 encode_frame_append(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
529 }
530 if self.write_buffer.is_empty() {
531 return Ok(());
532 }
533 WebSocket::io_with_timeout(
534 &self.url,
535 self.write_timeout,
536 "write",
537 self.stream.write_all(&self.write_buffer),
538 )
539 .await
540 }
541
542 async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
543 if payload.len() > 125 {
544 return Err(WebSocketError::protocol(
545 &self.url,
546 "control frame payload exceeds 125 bytes",
547 ));
548 }
549 self.write_frame(opcode, payload).await?;
550 WebSocket::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush())
551 .await
552 }
553
554 async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
555 let payload = match frame {
556 Some(frame) => frame.encode(&self.url)?,
557 None => Vec::new(),
558 };
559 self.write_control(OpCode::Close, &payload).await?;
560 self.close_sent = true;
561 Ok(())
562 }
563}
564
565fn validate_outbound_payload(
566 url: &Url,
567 frame_config: FrameConfig,
568 opcode: OpCode,
569 payload: &[u8],
570) -> WebSocketResult<()> {
571 if payload.len() > frame_config.max_frame_size {
572 return Err(WebSocketError::limit_exceeded(
573 url,
574 format!("frame exceeds {} bytes", frame_config.max_frame_size),
575 ));
576 }
577 if matches!(opcode, OpCode::Text | OpCode::Binary)
578 && payload.len() > frame_config.max_message_size
579 {
580 return Err(WebSocketError::limit_exceeded(
581 url,
582 format!("message exceeds {} bytes", frame_config.max_message_size),
583 ));
584 }
585 Ok(())
586}