1use std::{convert::TryFrom, fmt, net::SocketAddr, str};
30
31use actix_codec::Framed;
32use actix_http::{ws, Payload, RequestHead};
33use actix_rt::time::timeout;
34use actix_service::Service as _;
35
36pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
37
38use crate::{
39 client::ClientConfig,
40 connect::{BoxedSocket, ConnectRequest},
41 error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
42 http::{
43 header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION},
44 ConnectionType, Method, StatusCode, Uri, Version,
45 },
46 ClientResponse,
47};
48
49#[cfg(feature = "cookies")]
50use crate::cookie::{Cookie, CookieJar};
51
52pub struct WebsocketsRequest {
54 pub(crate) head: RequestHead,
55 err: Option<HttpError>,
56 origin: Option<HeaderValue>,
57 protocols: Option<String>,
58 addr: Option<SocketAddr>,
59 max_size: usize,
60 server_mode: bool,
61 config: ClientConfig,
62
63 #[cfg(feature = "cookies")]
64 cookies: Option<CookieJar>,
65}
66
67impl WebsocketsRequest {
68 pub(crate) fn new<U>(uri: U, config: ClientConfig) -> Self
70 where
71 Uri: TryFrom<U>,
72 <Uri as TryFrom<U>>::Error: Into<HttpError>,
73 {
74 let mut err = None;
75
76 #[allow(clippy::field_reassign_with_default)]
77 let mut head = {
78 let mut head = RequestHead::default();
79 head.method = Method::GET;
80 head.version = Version::HTTP_11;
81 head
82 };
83
84 match Uri::try_from(uri) {
85 Ok(uri) => head.uri = uri,
86 Err(e) => err = Some(e.into()),
87 }
88
89 WebsocketsRequest {
90 head,
91 err,
92 config,
93 addr: None,
94 origin: None,
95 protocols: None,
96 max_size: 65_536,
97 server_mode: false,
98 #[cfg(feature = "cookies")]
99 cookies: None,
100 }
101 }
102
103 pub fn address(mut self, addr: SocketAddr) -> Self {
108 self.addr = Some(addr);
109 self
110 }
111
112 pub fn protocols<U, V>(mut self, protos: U) -> Self
114 where
115 U: IntoIterator<Item = V>,
116 V: AsRef<str>,
117 {
118 let mut protos = protos
119 .into_iter()
120 .fold(String::new(), |acc, s| acc + s.as_ref() + ",");
121 protos.pop();
122 self.protocols = Some(protos);
123 self
124 }
125
126 #[cfg(feature = "cookies")]
128 pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
129 if self.cookies.is_none() {
130 let mut jar = CookieJar::new();
131 jar.add(cookie.into_owned());
132 self.cookies = Some(jar)
133 } else {
134 self.cookies.as_mut().unwrap().add(cookie.into_owned());
135 }
136 self
137 }
138
139 pub fn origin<V, E>(mut self, origin: V) -> Self
141 where
142 HeaderValue: TryFrom<V, Error = E>,
143 HttpError: From<E>,
144 {
145 match HeaderValue::try_from(origin) {
146 Ok(value) => self.origin = Some(value),
147 Err(e) => self.err = Some(e.into()),
148 }
149 self
150 }
151
152 pub fn max_frame_size(mut self, size: usize) -> Self {
156 self.max_size = size;
157 self
158 }
159
160 pub fn server_mode(mut self) -> Self {
162 self.server_mode = true;
163 self
164 }
165
166 pub fn header<K, V>(mut self, key: K, value: V) -> Self
171 where
172 HeaderName: TryFrom<K>,
173 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
174 V: TryIntoHeaderValue,
175 {
176 match HeaderName::try_from(key) {
177 Ok(key) => match value.try_into_value() {
178 Ok(value) => {
179 self.head.headers.append(key, value);
180 }
181 Err(e) => self.err = Some(e.into()),
182 },
183 Err(e) => self.err = Some(e.into()),
184 }
185 self
186 }
187
188 pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
190 where
191 HeaderName: TryFrom<K>,
192 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
193 V: TryIntoHeaderValue,
194 {
195 match HeaderName::try_from(key) {
196 Ok(key) => match value.try_into_value() {
197 Ok(value) => {
198 self.head.headers.insert(key, value);
199 }
200 Err(e) => self.err = Some(e.into()),
201 },
202 Err(e) => self.err = Some(e.into()),
203 }
204 self
205 }
206
207 pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
209 where
210 HeaderName: TryFrom<K>,
211 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
212 V: TryIntoHeaderValue,
213 {
214 match HeaderName::try_from(key) {
215 Ok(key) => {
216 if !self.head.headers.contains_key(&key) {
217 match value.try_into_value() {
218 Ok(value) => {
219 self.head.headers.insert(key, value);
220 }
221 Err(e) => self.err = Some(e.into()),
222 }
223 }
224 }
225 Err(e) => self.err = Some(e.into()),
226 }
227 self
228 }
229
230 pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
232 where
233 U: fmt::Display,
234 {
235 let auth = match password {
236 Some(password) => format!("{}:{}", username, password),
237 None => format!("{}:", username),
238 };
239 self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
240 }
241
242 pub fn bearer_auth<T>(self, token: T) -> Self
244 where
245 T: fmt::Display,
246 {
247 self.header(AUTHORIZATION, format!("Bearer {}", token))
248 }
249
250 pub async fn connect(
252 mut self,
253 ) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
254 if let Some(e) = self.err.take() {
255 return Err(e.into());
256 }
257
258 let uri = &self.head.uri;
260 if uri.host().is_none() {
261 return Err(InvalidUrl::MissingHost.into());
262 } else if uri.scheme().is_none() {
263 return Err(InvalidUrl::MissingScheme.into());
264 } else if let Some(scheme) = uri.scheme() {
265 match scheme.as_str() {
266 "http" | "ws" | "https" | "wss" => {}
267 _ => return Err(InvalidUrl::UnknownScheme.into()),
268 }
269 } else {
270 return Err(InvalidUrl::UnknownScheme.into());
271 }
272
273 if !self.head.headers.contains_key(header::HOST) {
274 self.head.headers.insert(
275 header::HOST,
276 HeaderValue::from_str(uri.host().unwrap()).unwrap(),
277 );
278 }
279
280 #[cfg(feature = "cookies")]
282 if let Some(ref mut jar) = self.cookies {
283 let cookie: String = jar
284 .delta()
285 .map(|c| c.stripped().encoded().to_string())
287 .collect::<Vec<_>>()
288 .join("; ");
289
290 if !cookie.is_empty() {
291 self.head
292 .headers
293 .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
294 }
295 }
296
297 if let Some(origin) = self.origin.take() {
299 self.head.headers.insert(header::ORIGIN, origin);
300 }
301
302 self.head.set_connection_type(ConnectionType::Upgrade);
303
304 #[allow(clippy::declare_interior_mutable_const)]
305 const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
306 self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET);
307
308 #[allow(clippy::declare_interior_mutable_const)]
309 const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13");
310 self.head
311 .headers
312 .insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN);
313
314 if let Some(protocols) = self.protocols.take() {
315 self.head.headers.insert(
316 header::SEC_WEBSOCKET_PROTOCOL,
317 HeaderValue::try_from(protocols.as_str()).unwrap(),
318 );
319 }
320
321 let sec_key: [u8; 16] = rand::random();
324 let key = base64::encode(&sec_key);
325
326 self.head.headers.insert(
327 header::SEC_WEBSOCKET_KEY,
328 HeaderValue::try_from(key.as_str()).unwrap(),
329 );
330
331 let head = self.head;
332 let max_size = self.max_size;
333 let server_mode = self.server_mode;
334
335 let req = ConnectRequest::Tunnel(head, self.addr);
336
337 let fut = self.config.connector.call(req);
338
339 let res = if let Some(to) = self.config.timeout {
341 timeout(to, fut)
342 .await
343 .map_err(|_| SendRequestError::Timeout)??
344 } else {
345 fut.await?
346 };
347
348 let (head, framed) = res.into_tunnel_response();
349
350 if head.status != StatusCode::SWITCHING_PROTOCOLS {
352 return Err(WsClientError::InvalidResponseStatus(head.status));
353 }
354
355 let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
357 if let Ok(s) = hdr.to_str() {
358 s.to_ascii_lowercase().contains("websocket")
359 } else {
360 false
361 }
362 } else {
363 false
364 };
365 if !has_hdr {
366 log::trace!("Invalid upgrade header");
367 return Err(WsClientError::InvalidUpgradeHeader);
368 }
369
370 if let Some(conn) = head.headers.get(&header::CONNECTION) {
372 if let Ok(s) = conn.to_str() {
373 if !s.to_ascii_lowercase().contains("upgrade") {
374 log::trace!("Invalid connection header: {}", s);
375 return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
376 }
377 } else {
378 log::trace!("Invalid connection header: {:?}", conn);
379 return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
380 }
381 } else {
382 log::trace!("Missing connection header");
383 return Err(WsClientError::MissingConnectionHeader);
384 }
385
386 if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
387 let encoded = ws::hash_key(key.as_ref());
388
389 if hdr_key.as_bytes() != encoded {
390 log::trace!(
391 "Invalid challenge response: expected: {:?} received: {:?}",
392 &encoded,
393 key
394 );
395
396 return Err(WsClientError::InvalidChallengeResponse(
397 encoded,
398 hdr_key.clone(),
399 ));
400 }
401 } else {
402 log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
403 return Err(WsClientError::MissingWebSocketAcceptHeader);
404 };
405
406 Ok((
408 ClientResponse::new(head, Payload::None),
409 framed.into_map_codec(|_| {
410 if server_mode {
411 ws::Codec::new().max_size(max_size)
412 } else {
413 ws::Codec::new().max_size(max_size).client_mode()
414 }
415 }),
416 ))
417 }
418}
419
420impl fmt::Debug for WebsocketsRequest {
421 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422 writeln!(
423 f,
424 "\nWebsocketsRequest {}:{}",
425 self.head.method, self.head.uri
426 )?;
427 writeln!(f, " headers:")?;
428 for (key, val) in self.head.headers.iter() {
429 writeln!(f, " {:?}: {:?}", key, val)?;
430 }
431 Ok(())
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::Client;
439
440 #[actix_rt::test]
441 async fn test_debug() {
442 let request = Client::new().ws("/").header("x-test", "111");
443 let repr = format!("{:?}", request);
444 assert!(repr.contains("WebsocketsRequest"));
445 assert!(repr.contains("x-test"));
446 }
447
448 #[actix_rt::test]
449 async fn test_header_override() {
450 let req = Client::builder()
451 .add_default_header((header::CONTENT_TYPE, "111"))
452 .finish()
453 .ws("/")
454 .set_header(header::CONTENT_TYPE, "222");
455
456 assert_eq!(
457 req.head
458 .headers
459 .get(header::CONTENT_TYPE)
460 .unwrap()
461 .to_str()
462 .unwrap(),
463 "222"
464 );
465 }
466
467 #[actix_rt::test]
468 async fn basic_auth() {
469 let req = Client::new()
470 .ws("/")
471 .basic_auth("username", Some("password"));
472 assert_eq!(
473 req.head
474 .headers
475 .get(header::AUTHORIZATION)
476 .unwrap()
477 .to_str()
478 .unwrap(),
479 "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
480 );
481
482 let req = Client::new().ws("/").basic_auth("username", None);
483 assert_eq!(
484 req.head
485 .headers
486 .get(header::AUTHORIZATION)
487 .unwrap()
488 .to_str()
489 .unwrap(),
490 "Basic dXNlcm5hbWU6"
491 );
492 }
493
494 #[actix_rt::test]
495 async fn bearer_auth() {
496 let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
497 assert_eq!(
498 req.head
499 .headers
500 .get(header::AUTHORIZATION)
501 .unwrap()
502 .to_str()
503 .unwrap(),
504 "Bearer someS3cr3tAutht0k3n"
505 );
506 let _ = req.connect();
507 }
508
509 #[actix_rt::test]
510 async fn basics() {
511 let req = Client::new()
512 .ws("http://localhost/")
513 .origin("test-origin")
514 .max_frame_size(100)
515 .server_mode()
516 .protocols(&["v1", "v2"])
517 .set_header_if_none(header::CONTENT_TYPE, "json")
518 .set_header_if_none(header::CONTENT_TYPE, "text")
519 .cookie(Cookie::build("cookie1", "value1").finish());
520 assert_eq!(
521 req.origin.as_ref().unwrap().to_str().unwrap(),
522 "test-origin"
523 );
524 assert_eq!(req.max_size, 100);
525 assert!(req.server_mode);
526 assert_eq!(req.protocols, Some("v1,v2".to_string()));
527 assert_eq!(
528 req.head.headers.get(header::CONTENT_TYPE).unwrap(),
529 header::HeaderValue::from_static("json")
530 );
531
532 let _ = req.connect().await;
533
534 assert!(Client::new().ws("/").connect().await.is_err());
535 assert!(Client::new().ws("http:///test").connect().await.is_err());
536 assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
537 }
538}