awc/
ws.rs

1//! Websockets client
2//!
3//! Type definitions required to use [`awc::Client`](super::Client) as a WebSocket client.
4//!
5//! # Examples
6//!
7//! ```no_run
8//! use awc::{Client, ws};
9//! use futures_util::{sink::SinkExt as _, stream::StreamExt as _};
10//!
11//! #[actix_rt::main]
12//! async fn main() {
13//!     let (_resp, mut connection) = Client::new()
14//!         .ws("ws://echo.websocket.org")
15//!         .connect()
16//!         .await
17//!         .unwrap();
18//!
19//!     connection
20//!         .send(ws::Message::Text("Echo".into()))
21//!         .await
22//!         .unwrap();
23//!     let response = connection.next().await.unwrap().unwrap();
24//!
25//!     assert_eq!(response, ws::Frame::Text("Echo".as_bytes().into()));
26//! }
27//! ```
28
29use std::{convert::TryFrom, fmt, net::SocketAddr, str};
30
31use actix_codec::Framed;
32use actix_http::{ws, Payload, RequestHead};
33use actix_rt::time::timeout;
34use actix_service::Service as _;
35
36pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
37
38use crate::{
39    client::ClientConfig,
40    connect::{BoxedSocket, ConnectRequest},
41    error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
42    http::{
43        header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION},
44        ConnectionType, Method, StatusCode, Uri, Version,
45    },
46    ClientResponse,
47};
48
49#[cfg(feature = "cookies")]
50use crate::cookie::{Cookie, CookieJar};
51
52/// WebSocket connection.
53pub struct WebsocketsRequest {
54    pub(crate) head: RequestHead,
55    err: Option<HttpError>,
56    origin: Option<HeaderValue>,
57    protocols: Option<String>,
58    addr: Option<SocketAddr>,
59    max_size: usize,
60    server_mode: bool,
61    config: ClientConfig,
62
63    #[cfg(feature = "cookies")]
64    cookies: Option<CookieJar>,
65}
66
67impl WebsocketsRequest {
68    /// Create new WebSocket connection
69    pub(crate) fn new<U>(uri: U, config: ClientConfig) -> Self
70    where
71        Uri: TryFrom<U>,
72        <Uri as TryFrom<U>>::Error: Into<HttpError>,
73    {
74        let mut err = None;
75
76        #[allow(clippy::field_reassign_with_default)]
77        let mut head = {
78            let mut head = RequestHead::default();
79            head.method = Method::GET;
80            head.version = Version::HTTP_11;
81            head
82        };
83
84        match Uri::try_from(uri) {
85            Ok(uri) => head.uri = uri,
86            Err(e) => err = Some(e.into()),
87        }
88
89        WebsocketsRequest {
90            head,
91            err,
92            config,
93            addr: None,
94            origin: None,
95            protocols: None,
96            max_size: 65_536,
97            server_mode: false,
98            #[cfg(feature = "cookies")]
99            cookies: None,
100        }
101    }
102
103    /// Set socket address of the server.
104    ///
105    /// This address is used for connection. If address is not
106    /// provided url's host name get resolved.
107    pub fn address(mut self, addr: SocketAddr) -> Self {
108        self.addr = Some(addr);
109        self
110    }
111
112    /// Set supported WebSocket protocols
113    pub fn protocols<U, V>(mut self, protos: U) -> Self
114    where
115        U: IntoIterator<Item = V>,
116        V: AsRef<str>,
117    {
118        let mut protos = protos
119            .into_iter()
120            .fold(String::new(), |acc, s| acc + s.as_ref() + ",");
121        protos.pop();
122        self.protocols = Some(protos);
123        self
124    }
125
126    /// Set a cookie
127    #[cfg(feature = "cookies")]
128    pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
129        if self.cookies.is_none() {
130            let mut jar = CookieJar::new();
131            jar.add(cookie.into_owned());
132            self.cookies = Some(jar)
133        } else {
134            self.cookies.as_mut().unwrap().add(cookie.into_owned());
135        }
136        self
137    }
138
139    /// Set request Origin
140    pub fn origin<V, E>(mut self, origin: V) -> Self
141    where
142        HeaderValue: TryFrom<V, Error = E>,
143        HttpError: From<E>,
144    {
145        match HeaderValue::try_from(origin) {
146            Ok(value) => self.origin = Some(value),
147            Err(e) => self.err = Some(e.into()),
148        }
149        self
150    }
151
152    /// Set max frame size
153    ///
154    /// By default max size is set to 64kB
155    pub fn max_frame_size(mut self, size: usize) -> Self {
156        self.max_size = size;
157        self
158    }
159
160    /// Disable payload masking. By default ws client masks frame payload.
161    pub fn server_mode(mut self) -> Self {
162        self.server_mode = true;
163        self
164    }
165
166    /// Append a header.
167    ///
168    /// Header gets appended to existing header.
169    /// To override header use `set_header()` method.
170    pub fn header<K, V>(mut self, key: K, value: V) -> Self
171    where
172        HeaderName: TryFrom<K>,
173        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
174        V: TryIntoHeaderValue,
175    {
176        match HeaderName::try_from(key) {
177            Ok(key) => match value.try_into_value() {
178                Ok(value) => {
179                    self.head.headers.append(key, value);
180                }
181                Err(e) => self.err = Some(e.into()),
182            },
183            Err(e) => self.err = Some(e.into()),
184        }
185        self
186    }
187
188    /// Insert a header, replaces existing header.
189    pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
190    where
191        HeaderName: TryFrom<K>,
192        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
193        V: TryIntoHeaderValue,
194    {
195        match HeaderName::try_from(key) {
196            Ok(key) => match value.try_into_value() {
197                Ok(value) => {
198                    self.head.headers.insert(key, value);
199                }
200                Err(e) => self.err = Some(e.into()),
201            },
202            Err(e) => self.err = Some(e.into()),
203        }
204        self
205    }
206
207    /// Insert a header only if it is not yet set.
208    pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
209    where
210        HeaderName: TryFrom<K>,
211        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
212        V: TryIntoHeaderValue,
213    {
214        match HeaderName::try_from(key) {
215            Ok(key) => {
216                if !self.head.headers.contains_key(&key) {
217                    match value.try_into_value() {
218                        Ok(value) => {
219                            self.head.headers.insert(key, value);
220                        }
221                        Err(e) => self.err = Some(e.into()),
222                    }
223                }
224            }
225            Err(e) => self.err = Some(e.into()),
226        }
227        self
228    }
229
230    /// Set HTTP basic authorization header
231    pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
232    where
233        U: fmt::Display,
234    {
235        let auth = match password {
236            Some(password) => format!("{}:{}", username, password),
237            None => format!("{}:", username),
238        };
239        self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
240    }
241
242    /// Set HTTP bearer authentication header
243    pub fn bearer_auth<T>(self, token: T) -> Self
244    where
245        T: fmt::Display,
246    {
247        self.header(AUTHORIZATION, format!("Bearer {}", token))
248    }
249
250    /// Complete request construction and connect to a WebSocket server.
251    pub async fn connect(
252        mut self,
253    ) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
254        if let Some(e) = self.err.take() {
255            return Err(e.into());
256        }
257
258        // validate uri
259        let uri = &self.head.uri;
260        if uri.host().is_none() {
261            return Err(InvalidUrl::MissingHost.into());
262        } else if uri.scheme().is_none() {
263            return Err(InvalidUrl::MissingScheme.into());
264        } else if let Some(scheme) = uri.scheme() {
265            match scheme.as_str() {
266                "http" | "ws" | "https" | "wss" => {}
267                _ => return Err(InvalidUrl::UnknownScheme.into()),
268            }
269        } else {
270            return Err(InvalidUrl::UnknownScheme.into());
271        }
272
273        if !self.head.headers.contains_key(header::HOST) {
274            self.head.headers.insert(
275                header::HOST,
276                HeaderValue::from_str(uri.host().unwrap()).unwrap(),
277            );
278        }
279
280        // set cookies
281        #[cfg(feature = "cookies")]
282        if let Some(ref mut jar) = self.cookies {
283            let cookie: String = jar
284                .delta()
285                // ensure only name=value is written to cookie header
286                .map(|c| c.stripped().encoded().to_string())
287                .collect::<Vec<_>>()
288                .join("; ");
289
290            if !cookie.is_empty() {
291                self.head
292                    .headers
293                    .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
294            }
295        }
296
297        // origin
298        if let Some(origin) = self.origin.take() {
299            self.head.headers.insert(header::ORIGIN, origin);
300        }
301
302        self.head.set_connection_type(ConnectionType::Upgrade);
303
304        #[allow(clippy::declare_interior_mutable_const)]
305        const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
306        self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET);
307
308        #[allow(clippy::declare_interior_mutable_const)]
309        const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13");
310        self.head
311            .headers
312            .insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN);
313
314        if let Some(protocols) = self.protocols.take() {
315            self.head.headers.insert(
316                header::SEC_WEBSOCKET_PROTOCOL,
317                HeaderValue::try_from(protocols.as_str()).unwrap(),
318            );
319        }
320
321        // Generate a random key for the `Sec-WebSocket-Key` header which is a base64-encoded
322        // (see RFC 4648 §4) value that, when decoded, is 16 bytes in length (RFC 6455 §1.3).
323        let sec_key: [u8; 16] = rand::random();
324        let key = base64::encode(&sec_key);
325
326        self.head.headers.insert(
327            header::SEC_WEBSOCKET_KEY,
328            HeaderValue::try_from(key.as_str()).unwrap(),
329        );
330
331        let head = self.head;
332        let max_size = self.max_size;
333        let server_mode = self.server_mode;
334
335        let req = ConnectRequest::Tunnel(head, self.addr);
336
337        let fut = self.config.connector.call(req);
338
339        // set request timeout
340        let res = if let Some(to) = self.config.timeout {
341            timeout(to, fut)
342                .await
343                .map_err(|_| SendRequestError::Timeout)??
344        } else {
345            fut.await?
346        };
347
348        let (head, framed) = res.into_tunnel_response();
349
350        // verify response
351        if head.status != StatusCode::SWITCHING_PROTOCOLS {
352            return Err(WsClientError::InvalidResponseStatus(head.status));
353        }
354
355        // check for "UPGRADE" to WebSocket header
356        let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
357            if let Ok(s) = hdr.to_str() {
358                s.to_ascii_lowercase().contains("websocket")
359            } else {
360                false
361            }
362        } else {
363            false
364        };
365        if !has_hdr {
366            log::trace!("Invalid upgrade header");
367            return Err(WsClientError::InvalidUpgradeHeader);
368        }
369
370        // Check for "CONNECTION" header
371        if let Some(conn) = head.headers.get(&header::CONNECTION) {
372            if let Ok(s) = conn.to_str() {
373                if !s.to_ascii_lowercase().contains("upgrade") {
374                    log::trace!("Invalid connection header: {}", s);
375                    return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
376                }
377            } else {
378                log::trace!("Invalid connection header: {:?}", conn);
379                return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
380            }
381        } else {
382            log::trace!("Missing connection header");
383            return Err(WsClientError::MissingConnectionHeader);
384        }
385
386        if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
387            let encoded = ws::hash_key(key.as_ref());
388
389            if hdr_key.as_bytes() != encoded {
390                log::trace!(
391                    "Invalid challenge response: expected: {:?} received: {:?}",
392                    &encoded,
393                    key
394                );
395
396                return Err(WsClientError::InvalidChallengeResponse(
397                    encoded,
398                    hdr_key.clone(),
399                ));
400            }
401        } else {
402            log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
403            return Err(WsClientError::MissingWebSocketAcceptHeader);
404        };
405
406        // response and ws framed
407        Ok((
408            ClientResponse::new(head, Payload::None),
409            framed.into_map_codec(|_| {
410                if server_mode {
411                    ws::Codec::new().max_size(max_size)
412                } else {
413                    ws::Codec::new().max_size(max_size).client_mode()
414                }
415            }),
416        ))
417    }
418}
419
420impl fmt::Debug for WebsocketsRequest {
421    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422        writeln!(
423            f,
424            "\nWebsocketsRequest {}:{}",
425            self.head.method, self.head.uri
426        )?;
427        writeln!(f, "  headers:")?;
428        for (key, val) in self.head.headers.iter() {
429            writeln!(f, "    {:?}: {:?}", key, val)?;
430        }
431        Ok(())
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::Client;
439
440    #[actix_rt::test]
441    async fn test_debug() {
442        let request = Client::new().ws("/").header("x-test", "111");
443        let repr = format!("{:?}", request);
444        assert!(repr.contains("WebsocketsRequest"));
445        assert!(repr.contains("x-test"));
446    }
447
448    #[actix_rt::test]
449    async fn test_header_override() {
450        let req = Client::builder()
451            .add_default_header((header::CONTENT_TYPE, "111"))
452            .finish()
453            .ws("/")
454            .set_header(header::CONTENT_TYPE, "222");
455
456        assert_eq!(
457            req.head
458                .headers
459                .get(header::CONTENT_TYPE)
460                .unwrap()
461                .to_str()
462                .unwrap(),
463            "222"
464        );
465    }
466
467    #[actix_rt::test]
468    async fn basic_auth() {
469        let req = Client::new()
470            .ws("/")
471            .basic_auth("username", Some("password"));
472        assert_eq!(
473            req.head
474                .headers
475                .get(header::AUTHORIZATION)
476                .unwrap()
477                .to_str()
478                .unwrap(),
479            "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
480        );
481
482        let req = Client::new().ws("/").basic_auth("username", None);
483        assert_eq!(
484            req.head
485                .headers
486                .get(header::AUTHORIZATION)
487                .unwrap()
488                .to_str()
489                .unwrap(),
490            "Basic dXNlcm5hbWU6"
491        );
492    }
493
494    #[actix_rt::test]
495    async fn bearer_auth() {
496        let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
497        assert_eq!(
498            req.head
499                .headers
500                .get(header::AUTHORIZATION)
501                .unwrap()
502                .to_str()
503                .unwrap(),
504            "Bearer someS3cr3tAutht0k3n"
505        );
506        let _ = req.connect();
507    }
508
509    #[actix_rt::test]
510    async fn basics() {
511        let req = Client::new()
512            .ws("http://localhost/")
513            .origin("test-origin")
514            .max_frame_size(100)
515            .server_mode()
516            .protocols(&["v1", "v2"])
517            .set_header_if_none(header::CONTENT_TYPE, "json")
518            .set_header_if_none(header::CONTENT_TYPE, "text")
519            .cookie(Cookie::build("cookie1", "value1").finish());
520        assert_eq!(
521            req.origin.as_ref().unwrap().to_str().unwrap(),
522            "test-origin"
523        );
524        assert_eq!(req.max_size, 100);
525        assert!(req.server_mode);
526        assert_eq!(req.protocols, Some("v1,v2".to_string()));
527        assert_eq!(
528            req.head.headers.get(header::CONTENT_TYPE).unwrap(),
529            header::HeaderValue::from_static("json")
530        );
531
532        let _ = req.connect().await;
533
534        assert!(Client::new().ws("/").connect().await.is_err());
535        assert!(Client::new().ws("http:///test").connect().await.is_err());
536        assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
537    }
538}