1#[cfg(feature = "json")]
4mod json;
5mod message;
6
7use std::{
8 borrow::Cow,
9 fmt,
10 net::{IpAddr, Ipv4Addr, Ipv6Addr},
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 task::{Context, Poll, ready},
14};
15
16use crate::{Error, RequestBuilder, Response, error, proxy::IntoProxy};
17use futures_util::{Sink, SinkExt, Stream, StreamExt};
18use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Version, header, uri::Scheme};
19use hyper2::ext::Protocol;
20use serde::Serialize;
21use tokio_tungstenite::tungstenite::{self, protocol};
22use tungstenite::protocol::WebSocketConfig;
23
24pub use message::{CloseCode, CloseFrame, Message, Utf8Bytes};
25
26pub type WebSocketStream = tokio_tungstenite::WebSocketStream<crate::Upgraded>;
28
29#[derive(Debug)]
32pub struct WebSocketRequestBuilder {
33 inner: RequestBuilder,
34 accept_key: Option<Cow<'static, str>>,
35 protocols: Option<Vec<Cow<'static, str>>>,
36 config: WebSocketConfig,
37}
38
39impl WebSocketRequestBuilder {
40 pub fn new(inner: RequestBuilder) -> Self {
42 Self {
43 inner,
44 accept_key: None,
45 protocols: None,
46 config: WebSocketConfig::default(),
47 }
48 }
49
50 pub fn accept_key<K>(mut self, key: K) -> Self
62 where
63 K: Into<Cow<'static, str>>,
64 {
65 self.accept_key = Some(key.into());
66 self
67 }
68
69 pub fn protocols<P>(mut self, protocols: P) -> Self
91 where
92 P: IntoIterator,
93 P::Item: Into<Cow<'static, str>>,
94 {
95 let protocols = protocols.into_iter().map(Into::into).collect();
96 self.protocols = Some(protocols);
97 self
98 }
99
100 pub fn max_frame_size(mut self, max_frame_size: usize) -> Self {
102 self.config.max_frame_size = Some(max_frame_size);
103 self
104 }
105
106 pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
108 self.config.read_buffer_size = read_buffer_size;
109 self
110 }
111
112 pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
114 self.config.write_buffer_size = write_buffer_size;
115 self
116 }
117
118 pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
120 self.config.max_write_buffer_size = max_write_buffer_size;
121 self
122 }
123
124 pub fn max_message_size(mut self, max_message_size: usize) -> Self {
126 self.config.max_message_size = Some(max_message_size);
127 self
128 }
129
130 pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
132 self.config.accept_unmasked_frames = accept_unmasked_frames;
133 self
134 }
135
136 pub fn use_http2(mut self) -> Self {
145 self.inner = self.inner.version(Version::HTTP_2);
146 self
147 }
148
149 pub fn header<K, V>(mut self, key: K, value: V) -> Self
151 where
152 HeaderName: TryFrom<K>,
153 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
154 HeaderValue: TryFrom<V>,
155 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
156 {
157 self.inner = self.inner.header(key, value);
158 self
159 }
160
161 pub fn header_append<K, V>(mut self, key: K, value: V) -> Self
163 where
164 HeaderName: TryFrom<K>,
165 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
166 HeaderValue: TryFrom<V>,
167 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
168 {
169 self.inner = self.inner.header_append(key, value);
170 self
171 }
172
173 pub fn headers(mut self, headers: HeaderMap) -> Self {
177 self.inner = self.inner.headers(headers);
178 self
179 }
180
181 pub fn auth<V>(mut self, value: V) -> Self
183 where
184 HeaderValue: TryFrom<V>,
185 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
186 {
187 self.inner = self.inner.auth(value);
188 self
189 }
190
191 pub fn basic_auth<U, P>(mut self, username: U, password: Option<P>) -> Self
193 where
194 U: fmt::Display,
195 P: fmt::Display,
196 {
197 self.inner = self.inner.basic_auth(username, password);
198 self
199 }
200
201 pub fn bearer_auth<T>(mut self, token: T) -> Self
203 where
204 T: fmt::Display,
205 {
206 self.inner = self.inner.bearer_auth(token);
207 self
208 }
209
210 pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
212 self.inner = self.inner.query(query);
213 self
214 }
215
216 pub fn proxy<U: IntoProxy>(mut self, proxy: U) -> Self {
218 self.inner = self.inner.proxy(proxy);
219 self
220 }
221
222 pub fn local_address<V>(mut self, local_address: V) -> Self
224 where
225 V: Into<Option<IpAddr>>,
226 {
227 self.inner = self.inner.local_address(local_address);
228 self
229 }
230
231 pub fn local_addresses<V4, V6>(mut self, ipv4: V4, ipv6: V6) -> Self
233 where
234 V4: Into<Option<Ipv4Addr>>,
235 V6: Into<Option<Ipv6Addr>>,
236 {
237 self.inner = self.inner.local_addresses(ipv4, ipv6);
238 self
239 }
240
241 #[cfg(any(
243 target_os = "android",
244 target_os = "fuchsia",
245 target_os = "linux",
246 all(
247 feature = "apple-network-device-binding",
248 any(
249 target_os = "ios",
250 target_os = "visionos",
251 target_os = "macos",
252 target_os = "tvos",
253 target_os = "watchos",
254 )
255 )
256 ))]
257 #[cfg_attr(docsrs, doc(cfg(feature = "apple-network-device-binding")))]
258 pub fn interface<I>(mut self, interface: I) -> Self
259 where
260 I: Into<std::borrow::Cow<'static, str>>,
261 {
262 self.inner = self.inner.interface(interface);
263 self
264 }
265
266 pub async fn send(self) -> Result<WebSocketResponse, Error> {
268 let (client, request) = self.inner.build_split();
269 let mut request = request?;
270
271 let url = request.url_mut();
273 let new_scheme = match url.scheme() {
274 "ws" => Scheme::HTTP,
275 "wss" => Scheme::HTTPS,
276 _ => {
277 return Err(error::url_bad_scheme(url.clone()));
278 }
279 };
280
281 url.set_scheme(new_scheme.as_str())
283 .map_err(|_| error::url_bad_scheme(url.clone()))?;
284
285 let version = request.version().unwrap_or(Version::HTTP_11);
288
289 let headers = request.headers_mut();
291 headers.insert(
292 header::SEC_WEBSOCKET_VERSION,
293 HeaderValue::from_static("13"),
294 );
295
296 let accept_key = match version {
298 Version::HTTP_10 | Version::HTTP_11 => {
299 let nonce = self
301 .accept_key
302 .unwrap_or_else(|| Cow::Owned(tungstenite::handshake::client::generate_key()));
303
304 headers.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
305 headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
306 headers.insert(header::SEC_WEBSOCKET_KEY, HeaderValue::from_str(&nonce)?);
307
308 *request.method_mut() = Method::GET;
309 *request.version_mut() = Some(Version::HTTP_11);
310 Some(nonce)
311 }
312 Version::HTTP_2 => {
313 *request.method_mut() = Method::CONNECT;
314 *request.version_mut() = Some(Version::HTTP_2);
315 *request.protocol_mut() = Some(Protocol::from_static("websocket"));
316 None
317 }
318 _ => {
319 return Err(error::upgrade(format!(
320 "unsupported version: {:?}",
321 version
322 )));
323 }
324 };
325
326 if let Some(ref protocols) = self.protocols {
328 if !protocols.is_empty() {
330 let subprotocols = protocols
331 .iter()
332 .map(|s| s.as_ref())
333 .collect::<Vec<&str>>()
334 .join(", ");
335
336 request
337 .headers_mut()
338 .insert(header::SEC_WEBSOCKET_PROTOCOL, subprotocols.parse()?);
339 }
340 }
341
342 client
343 .execute(request)
344 .await
345 .map(|inner| WebSocketResponse {
346 inner,
347 accept_key,
348 protocols: self.protocols,
349 config: self.config,
350 version,
351 })
352 }
353}
354
355#[derive(Debug)]
360pub struct WebSocketResponse {
361 inner: Response,
362 accept_key: Option<Cow<'static, str>>,
363 protocols: Option<Vec<Cow<'static, str>>>,
364 config: WebSocketConfig,
365 version: Version,
366}
367
368impl Deref for WebSocketResponse {
369 type Target = Response;
370
371 fn deref(&self) -> &Self::Target {
372 &self.inner
373 }
374}
375
376impl DerefMut for WebSocketResponse {
377 fn deref_mut(&mut self) -> &mut Self::Target {
378 &mut self.inner
379 }
380}
381
382impl WebSocketResponse {
383 pub async fn into_websocket(self) -> Result<WebSocket, Error> {
386 let (inner, protocol) = {
387 let status = self.inner.status();
388 let headers = self.inner.headers();
389
390 if !matches!(
391 self.inner.version(),
392 Version::HTTP_10 | Version::HTTP_11 | Version::HTTP_2
393 ) {
394 return Err(error::upgrade(format!(
395 "unexpected version: {:?}",
396 self.inner.version()
397 )));
398 }
399
400 match self.version {
401 Version::HTTP_10 | Version::HTTP_11 => {
402 if status != StatusCode::SWITCHING_PROTOCOLS {
403 let body = self.inner.text().await?;
404 return Err(error::upgrade(format!("unexpected status code: {}", body)));
405 }
406
407 if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") {
408 return Err(error::upgrade("missing connection header"));
409 }
410
411 if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") {
412 return Err(error::upgrade("invalid upgrade header"));
413 }
414
415 match self
416 .accept_key
417 .zip(headers.get(header::SEC_WEBSOCKET_ACCEPT))
418 {
419 Some((nonce, header)) => {
420 if !header.to_str().is_ok_and(|s| {
421 s == tungstenite::handshake::derive_accept_key(nonce.as_bytes())
422 }) {
423 return Err(error::upgrade(format!(
424 "invalid accept key: {:?}",
425 header
426 )));
427 }
428 }
429 None => {
430 return Err(error::upgrade("missing accept key"));
431 }
432 }
433 }
434 Version::HTTP_2 => {
435 if status != StatusCode::OK {
436 return Err(error::upgrade(format!(
437 "unexpected status code: {}",
438 status
439 )));
440 }
441 }
442 _ => {
443 return Err(error::upgrade(format!(
444 "unsupported version: {:?}",
445 self.version
446 )));
447 }
448 }
449
450 let protocol = headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
451
452 match (
453 self.protocols.as_ref().is_none_or(|p| p.is_empty()),
454 &protocol,
455 ) {
456 (true, None) => {
457 }
460 (false, None) => {
461 return Err(error::upgrade("missing protocol"));
463 }
464 (false, Some(protocol)) => {
465 if let Some((protocols, protocol)) = self.protocols.zip(protocol.to_str().ok())
466 {
467 if !protocols.contains(&Cow::Borrowed(protocol)) {
468 return Err(error::upgrade(format!("invalid protocol: {}", protocol)));
470 }
471 } else {
472 return Err(error::upgrade("invalid protocol"));
474 }
475 }
476 (true, Some(_)) => {
477 return Err(error::upgrade("invalid protocol"));
479 }
480 }
481
482 let upgraded = self.inner.upgrade().await?;
483 let inner = WebSocketStream::from_raw_socket(
484 upgraded,
485 protocol::Role::Client,
486 Some(self.config),
487 )
488 .await;
489
490 (inner, protocol)
491 };
492
493 Ok(WebSocket { inner, protocol })
494 }
495}
496
497fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
499 if let Some(header) = headers.get(&key) {
500 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
501 } else {
502 false
503 }
504}
505
506fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
508 let header = if let Some(header) = headers.get(&key) {
509 header
510 } else {
511 return false;
512 };
513
514 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
515 header.to_ascii_lowercase().contains(value)
516 } else {
517 false
518 }
519}
520
521#[derive(Debug)]
523pub struct WebSocket {
524 inner: WebSocketStream,
525 protocol: Option<HeaderValue>,
526}
527
528impl WebSocket {
529 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
533 self.next().await
534 }
535
536 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
538 self.inner
539 .send(msg.into_tungstenite())
540 .await
541 .map_err(Into::into)
542 }
543
544 pub fn protocol(&self) -> Option<&HeaderValue> {
546 self.protocol.as_ref()
547 }
548
549 pub async fn close(self, code: CloseCode, reason: Option<Utf8Bytes>) -> Result<(), Error> {
551 let mut inner = self.inner;
552 inner
553 .close(Some(tungstenite::protocol::CloseFrame {
554 code: code.0.into(),
555 reason: reason
556 .unwrap_or(Utf8Bytes::from_static("Goodbye"))
557 .into_tungstenite(),
558 }))
559 .await
560 .map_err(Into::into)
561 }
562}
563
564impl Stream for WebSocket {
565 type Item = Result<Message, Error>;
566
567 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
568 loop {
569 match ready!(self.inner.poll_next_unpin(cx)) {
570 Some(Ok(msg)) => {
571 if let Some(msg) = Message::from_tungstenite(msg) {
572 return Poll::Ready(Some(Ok(msg)));
573 }
574 }
575 Some(Err(err)) => return Poll::Ready(Some(Err(error::body(err)))),
576 None => return Poll::Ready(None),
577 }
578 }
579 }
580}
581
582impl Sink<Message> for WebSocket {
583 type Error = Error;
584
585 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
586 Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
587 }
588
589 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
590 Pin::new(&mut self.inner)
591 .start_send(item.into_tungstenite())
592 .map_err(Into::into)
593 }
594
595 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
596 Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
597 }
598
599 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
600 Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
601 }
602}