1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use fastwebsockets::{FragmentCollectorRead, WebSocketWrite};
5use http::{
6 Method,
7 header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE},
8};
9use http_body_util::Empty;
10use hyper::{Request, Uri, body::Bytes};
11use hyper_util::rt::TokioIo;
12use rustls_platform_verifier::ConfigVerifierExt;
13use tokio::net::TcpStream;
14use tokio::sync::{Mutex, watch};
15use tokio_rustls::{TlsConnector, client::TlsStream, rustls};
16
17use crate::error::{Error, Result};
18use crate::secrets::CustomerId;
19use crate::streamer::events::{ConnectionEvent, DisconnectReason};
20use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
21use crate::streamer::request::{RequestPayload, StreamerRequest};
22use crate::streamer::response::{RawStreamerResponse, StreamerResponse};
23use crate::streamer::subscription::SubscribeRequest;
24use crate::streamer::{account_activity, admin, book, chart, level_one, screener};
25use crate::token::TokenProvider;
26use crate::user_preferences::StreamerInfo;
27
28type Upgraded = TokioIo<hyper::upgrade::Upgraded>;
29type WsReadHalf = FragmentCollectorRead<tokio::io::ReadHalf<Upgraded>>;
30type WsWriteHalf = WebSocketWrite<tokio::io::WriteHalf<Upgraded>>;
31type WebSocket = fastwebsockets::WebSocket<Upgraded>;
32
33#[derive(Debug, thiserror::Error)]
36pub enum WebSocketError {
37 #[error("failed to connect to server: {0}")]
39 Connect(#[source] std::io::Error),
40 #[error("failed to perform websocket handshake: {0}")]
42 Handshake(#[source] fastwebsockets::WebSocketError),
43 #[error("invalid domain: {0}")]
45 InvalidDomain(#[source] rustls_pki_types::InvalidDnsNameError),
46 #[error("host is required")]
48 MissingHost,
49 #[error("failed to create TLS stream: {0}")]
51 TlsStream(#[source] std::io::Error),
52 #[error("failed to configure TLS: {0}")]
54 TlsConfig(#[source] rustls::Error),
55 #[error("failed to build upgrade request: {0}")]
57 BuildRequest(#[source] http::Error),
58 #[error("unsupported websocket scheme: {0}")]
64 UnsupportedScheme(String),
65 #[error("websocket runtime error: {0}")]
68 Runtime(#[from] fastwebsockets::WebSocketError),
69}
70
71impl WebSocketError {
72 pub fn is_retryable(&self) -> bool {
82 match self {
83 WebSocketError::Connect(_)
84 | WebSocketError::TlsStream(_)
85 | WebSocketError::Handshake(_)
86 | WebSocketError::Runtime(_) => true,
87 WebSocketError::InvalidDomain(_)
88 | WebSocketError::MissingHost
89 | WebSocketError::TlsConfig(_)
90 | WebSocketError::BuildRequest(_)
91 | WebSocketError::UnsupportedScheme(_) => false,
92 }
93 }
94}
95
96impl From<fastwebsockets::WebSocketError> for Error {
97 fn from(value: fastwebsockets::WebSocketError) -> Self {
98 Error::WebSocket(WebSocketError::Runtime(value))
99 }
100}
101
102struct SpawnExecutor;
103
104impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
105where
106 Fut: Future + Send + 'static,
107 Fut::Output: Send + 'static,
108{
109 fn execute(&self, fut: Fut) {
110 tokio::task::spawn(fut);
111 }
112}
113
114async fn connect_tls(uri: &Uri) -> std::result::Result<TlsStream<TcpStream>, WebSocketError> {
115 let host = uri.host().ok_or(WebSocketError::MissingHost)?;
116 let port = uri.port_u16().unwrap_or(443);
117 let addr = format!("{}:{}", host, port);
118
119 let socket = TcpStream::connect(addr)
120 .await
121 .map_err(WebSocketError::Connect)?;
122
123 let domain = rustls_pki_types::ServerName::try_from(host.to_string())
124 .map_err(WebSocketError::InvalidDomain)?;
125 let config =
126 rustls::ClientConfig::with_platform_verifier().map_err(WebSocketError::TlsConfig)?;
127 let connector = TlsConnector::from(Arc::new(config));
128 connector
129 .connect(domain, socket)
130 .await
131 .map_err(WebSocketError::TlsStream)
132}
133
134async fn connect_tcp(uri: &Uri) -> std::result::Result<TcpStream, WebSocketError> {
135 let host = uri.host().ok_or(WebSocketError::MissingHost)?;
136 let port = uri.port_u16().unwrap_or(80);
137 TcpStream::connect(format!("{}:{}", host, port))
138 .await
139 .map_err(WebSocketError::Connect)
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144enum WsTransport {
145 Tls,
147 Plain,
151}
152
153fn check_websocket_scheme(
159 scheme: Option<&str>,
160 allow_insecure: bool,
161) -> std::result::Result<WsTransport, WebSocketError> {
162 match scheme {
163 Some("wss") => Ok(WsTransport::Tls),
164 Some("ws") if allow_insecure => Ok(WsTransport::Plain),
165 Some("ws") => Err(WebSocketError::UnsupportedScheme("ws".to_string())),
166 Some(other) => Err(WebSocketError::UnsupportedScheme(other.to_string())),
167 None => Err(WebSocketError::UnsupportedScheme(String::new())),
168 }
169}
170
171async fn connect_websocket(uri: &Uri) -> std::result::Result<WebSocket, WebSocketError> {
172 let transport = check_websocket_scheme(uri.scheme_str(), cfg!(debug_assertions))?;
173
174 let req = Request::builder()
175 .method(Method::GET)
176 .uri(uri)
177 .header(HOST, uri.host().ok_or(WebSocketError::MissingHost)?)
178 .header(UPGRADE, "websocket")
179 .header(CONNECTION, "upgrade")
180 .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
181 .header(SEC_WEBSOCKET_VERSION, "13")
182 .body(Empty::<Bytes>::new())
183 .map_err(WebSocketError::BuildRequest)?;
184
185 match transport {
186 WsTransport::Tls => {
187 let stream = connect_tls(uri).await?;
188 let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
189 .await
190 .map_err(WebSocketError::Handshake)?;
191 Ok(ws)
192 }
193 WsTransport::Plain => {
194 let stream = connect_tcp(uri).await?;
195 let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
196 .await
197 .map_err(WebSocketError::Handshake)?;
198 Ok(ws)
199 }
200 }
201}
202
203pub async fn connect(
218 streamer_info: StreamerInfo,
219 token_provider: Arc<dyn TokenProvider + Send + Sync>,
220) -> Result<(ReadHalf, WriteHalf)> {
221 let validated = ValidatedStreamerInfo::try_from(streamer_info)?;
222 let websocket = connect_websocket(&validated.socket_url).await?;
223 Ok(split(websocket, validated, token_provider))
224}
225
226#[derive(Debug)]
229struct ValidatedStreamerInfo {
230 socket_url: Uri,
231 customer_id: CustomerId,
232 correlation_id: String,
233 channel: String,
234 function_id: String,
235}
236
237impl TryFrom<StreamerInfo> for ValidatedStreamerInfo {
238 type Error = Error;
239
240 fn try_from(info: StreamerInfo) -> Result<Self> {
241 fn required<T>(field: &'static str, value: Option<T>) -> Result<T> {
242 value.ok_or(Error::InvalidPreference {
243 field,
244 reason: "missing".to_string(),
245 })
246 }
247
248 let socket_url = required("streamerSocketUrl", info.streamer_socket_url)?
249 .parse::<Uri>()
250 .map_err(|e| Error::InvalidPreference {
251 field: "streamerSocketUrl",
252 reason: e.to_string(),
253 })?;
254
255 Ok(Self {
256 socket_url,
257 customer_id: required("schwabClientCustomerId", info.schwab_client_customer_id)?,
258 correlation_id: required("schwabClientCorrelId", info.schwab_client_correlation_id)?,
259 channel: required("schwabClientChannel", info.schwab_client_channel)?,
260 function_id: required("schwabClientFunctionId", info.schwab_client_function_id)?,
261 })
262 }
263}
264
265fn split(
274 websocket: WebSocket,
275 streamer_info: ValidatedStreamerInfo,
276 token_provider: Arc<dyn TokenProvider + Send + Sync>,
277) -> (ReadHalf, WriteHalf) {
278 let (read_half, write_half) = websocket.split(tokio::io::split);
279 let write_half = Arc::new(Mutex::new(write_half));
280 let (events_tx, _) = watch::channel(ConnectionEvent::Connected);
281
282 let reader = ReadHalf {
283 read_half: FragmentCollectorRead::new(read_half),
284 write_half: write_half.clone(),
285 events_tx,
286 };
287
288 let writer = WriteHalf {
289 write_half,
290 customer_id: streamer_info.customer_id,
291 correlation_id: streamer_info.correlation_id,
292 channel: streamer_info.channel,
293 function_id: streamer_info.function_id,
294 request_id: Arc::new(AtomicU64::new(0)),
295 token_provider,
296 };
297
298 (reader, writer)
299}
300
301async fn write_one(
307 write_half: Arc<Mutex<WsWriteHalf>>,
308 frame: fastwebsockets::Frame<'_>,
309) -> std::result::Result<(), fastwebsockets::WebSocketError> {
310 write_half.lock().await.write_frame(frame).await
311}
312
313pub struct ReadHalf {
318 read_half: WsReadHalf,
319 write_half: Arc<Mutex<WsWriteHalf>>,
320 events_tx: watch::Sender<ConnectionEvent>,
321}
322
323impl std::fmt::Debug for ReadHalf {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("ReadHalf").finish_non_exhaustive()
326 }
327}
328
329impl ReadHalf {
330 pub async fn recv(&mut self) -> Result<StreamerResponse> {
342 let write_half = self.write_half.clone();
343 let mut send_fn = move |frame| write_one(write_half.clone(), frame);
344 loop {
345 let frame = match self.read_half.read_frame(&mut send_fn).await {
346 Ok(f) => f,
347 Err(e) => {
348 self.events_tx.send_replace(ConnectionEvent::Disconnected(
349 DisconnectReason::Transport(e.to_string()),
350 ));
351 return Err(e.into());
352 }
353 };
354 if frame.opcode == fastwebsockets::OpCode::Text {
355 let raw_response: RawStreamerResponse = match serde_json::from_slice(&frame.payload)
356 {
357 Ok(r) => r,
358 Err(e) => {
359 self.events_tx.send_replace(ConnectionEvent::StreamError {
360 message: e.to_string(),
361 });
362 return Err(Error::Codec {
363 context: "streamer response frame".to_string(),
364 reason: e.to_string(),
365 });
366 }
367 };
368 let response = StreamerResponse::try_from(raw_response)?;
369 classify_and_emit(&self.events_tx, &response);
370 return Ok(response);
371 }
372 }
373 }
374
375 pub fn events(&self) -> watch::Receiver<ConnectionEvent> {
402 self.events_tx.subscribe()
403 }
404}
405
406fn classify_and_emit(events_tx: &watch::Sender<ConnectionEvent>, response: &StreamerResponse) {
409 let StreamerResponse::Response(responses) = response else {
410 return;
411 };
412 for r in responses {
413 let is_login = r.service == Service::Admin && r.command == StreamerCommand::Login;
414 match r.content.code {
415 ResponseCode::Ok if is_login => {
416 events_tx.send_replace(ConnectionEvent::LoggedIn);
417 }
418 ResponseCode::LoginDenied => {
419 events_tx.send_replace(ConnectionEvent::Disconnected(
420 DisconnectReason::LoginDenied(r.content.message.clone()),
421 ));
422 }
423 ResponseCode::CloseConnection => {
424 events_tx.send_replace(ConnectionEvent::Disconnected(
425 DisconnectReason::ServerClose(r.content.message.clone()),
426 ));
427 }
428 ResponseCode::StopStreaming => {
429 events_tx.send_replace(ConnectionEvent::Disconnected(
430 DisconnectReason::StopStreaming(r.content.message.clone()),
431 ));
432 }
433 _ => {}
434 }
435 }
436}
437
438#[derive(Clone)]
443pub struct WriteHalf {
444 write_half: Arc<Mutex<WsWriteHalf>>,
445 customer_id: CustomerId,
446 correlation_id: String,
447 channel: String,
448 function_id: String,
449 request_id: Arc<AtomicU64>,
450 token_provider: Arc<dyn TokenProvider + Send + Sync>,
451}
452
453impl std::fmt::Debug for WriteHalf {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 f.debug_struct("WriteHalf")
456 .field("channel", &self.channel)
457 .field("function_id", &self.function_id)
458 .finish_non_exhaustive()
459 }
460}
461
462impl WriteHalf {
463 pub async fn login(&self) -> Result<()> {
475 let auth_token = self.token_provider.access_token().await?;
476 let request = admin::Login {
477 authorization: auth_token,
478 schwab_client_channel: self.channel.clone(),
479 schwab_client_function_id: self.function_id.clone(),
480 };
481 self.send(request).await
482 }
483
484 pub async fn logout(&self) -> Result<()> {
486 self.send(admin::Logout).await
487 }
488
489 pub fn equities(&self) -> SubscribeRequest<'_, level_one::equities::Field> {
491 SubscribeRequest::new(self)
492 }
493
494 pub fn options(&self) -> SubscribeRequest<'_, level_one::options::Field> {
496 SubscribeRequest::new(self)
497 }
498
499 pub fn futures(&self) -> SubscribeRequest<'_, level_one::futures::Field> {
501 SubscribeRequest::new(self)
502 }
503
504 pub fn futures_options(&self) -> SubscribeRequest<'_, level_one::futures_options::Field> {
506 SubscribeRequest::new(self)
507 }
508
509 pub fn forex(&self) -> SubscribeRequest<'_, level_one::forex::Field> {
511 SubscribeRequest::new(self)
512 }
513
514 pub fn nyse_book(&self) -> SubscribeRequest<'_, book::nyse::Field> {
516 SubscribeRequest::new(self)
517 }
518
519 pub fn nasdaq_book(&self) -> SubscribeRequest<'_, book::nasdaq::Field> {
521 SubscribeRequest::new(self)
522 }
523
524 pub fn options_book(&self) -> SubscribeRequest<'_, book::options::Field> {
526 SubscribeRequest::new(self)
527 }
528
529 pub fn chart_equity(&self) -> SubscribeRequest<'_, chart::equity::Field> {
531 SubscribeRequest::new(self)
532 }
533
534 pub fn chart_futures(&self) -> SubscribeRequest<'_, chart::futures::Field> {
536 SubscribeRequest::new(self)
537 }
538
539 pub fn screener_equity(&self) -> SubscribeRequest<'_, screener::equity::Field> {
541 SubscribeRequest::new(self)
542 }
543
544 pub fn screener_option(&self) -> SubscribeRequest<'_, screener::option::Field> {
546 SubscribeRequest::new(self)
547 }
548
549 pub fn account_activity(&self) -> SubscribeRequest<'_, account_activity::Field> {
551 SubscribeRequest::new(self)
552 }
553
554 pub(crate) async fn send<T: Into<StreamerRequest>>(&self, request: T) -> Result<()> {
559 let request: StreamerRequest = request.into();
560 let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
561 let request = RequestPayload {
562 request_id,
563 service: request.service,
564 command: request.command,
565 parameters: request.parameters,
566 schwab_client_customer_id: self.customer_id.clone(),
567 schwab_client_correlation_id: self.correlation_id.clone(),
568 };
569
570 let serialized = serde_json::to_string(&request).map_err(|e| Error::Codec {
571 context: "streamer request envelope".to_string(),
572 reason: e.to_string(),
573 })?;
574 write_one(
575 self.write_half.clone(),
576 fastwebsockets::Frame::text(fastwebsockets::Payload::Borrowed(serialized.as_bytes())),
577 )
578 .await?;
579 Ok(())
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use crate::streamer::events::{ConnectionEvent, DisconnectReason};
587 use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
588 use crate::streamer::response::{ResponseContent, ResponsePayload};
589
590 fn response(code: ResponseCode, command: StreamerCommand, msg: &str) -> StreamerResponse {
591 StreamerResponse::Response(vec![ResponsePayload {
592 request_id: 1,
593 service: Service::Admin,
594 timestamp: 1,
595 command,
596 schwab_client_correlation_id: "x".into(),
597 content: ResponseContent {
598 code,
599 message: msg.into(),
600 },
601 }])
602 }
603
604 fn full_streamer_info() -> StreamerInfo {
605 StreamerInfo {
606 streamer_socket_url: Some("wss://streamer-api.schwab.com/ws".into()),
607 schwab_client_customer_id: Some(CustomerId::from("CUSTID")),
608 schwab_client_correlation_id: Some("abc-123".into()),
609 schwab_client_channel: Some("N9".into()),
610 schwab_client_function_id: Some("APIAPP".into()),
611 }
612 }
613
614 #[test]
615 fn validates_complete_streamer_info() {
616 let validated =
617 ValidatedStreamerInfo::try_from(full_streamer_info()).expect("complete info validates");
618 assert_eq!(validated.socket_url, "wss://streamer-api.schwab.com/ws");
619 assert_eq!(validated.correlation_id, "abc-123");
620 assert_eq!(validated.channel, "N9");
621 assert_eq!(validated.function_id, "APIAPP");
622 }
623
624 #[test]
625 fn missing_socket_url_reports_field() {
626 let mut info = full_streamer_info();
627 info.streamer_socket_url = None;
628 match ValidatedStreamerInfo::try_from(info) {
629 Err(Error::InvalidPreference { field, .. }) => {
630 assert_eq!(field, "streamerSocketUrl");
631 }
632 other => panic!("expected InvalidPreference, got {other:?}"),
633 }
634 }
635
636 #[test]
637 fn missing_customer_id_reports_field() {
638 let mut info = full_streamer_info();
639 info.schwab_client_customer_id = None;
640 match ValidatedStreamerInfo::try_from(info) {
641 Err(Error::InvalidPreference { field, .. }) => {
642 assert_eq!(field, "schwabClientCustomerId");
643 }
644 other => panic!("expected InvalidPreference, got {other:?}"),
645 }
646 }
647
648 #[test]
649 fn missing_correlation_id_reports_field() {
650 let mut info = full_streamer_info();
651 info.schwab_client_correlation_id = None;
652 match ValidatedStreamerInfo::try_from(info) {
653 Err(Error::InvalidPreference { field, .. }) => {
654 assert_eq!(field, "schwabClientCorrelId");
655 }
656 other => panic!("expected InvalidPreference, got {other:?}"),
657 }
658 }
659
660 #[test]
661 fn missing_channel_reports_field() {
662 let mut info = full_streamer_info();
663 info.schwab_client_channel = None;
664 match ValidatedStreamerInfo::try_from(info) {
665 Err(Error::InvalidPreference { field, .. }) => {
666 assert_eq!(field, "schwabClientChannel");
667 }
668 other => panic!("expected InvalidPreference, got {other:?}"),
669 }
670 }
671
672 #[test]
673 fn missing_function_id_reports_field() {
674 let mut info = full_streamer_info();
675 info.schwab_client_function_id = None;
676 match ValidatedStreamerInfo::try_from(info) {
677 Err(Error::InvalidPreference { field, .. }) => {
678 assert_eq!(field, "schwabClientFunctionId");
679 }
680 other => panic!("expected InvalidPreference, got {other:?}"),
681 }
682 }
683
684 #[test]
685 fn login_ok_emits_logged_in() {
686 let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
687 classify_and_emit(&tx, &response(ResponseCode::Ok, StreamerCommand::Login, ""));
688 assert!(rx.has_changed().unwrap());
689 assert_eq!(*rx.borrow_and_update(), ConnectionEvent::LoggedIn);
690 }
691
692 #[test]
693 fn login_denied_emits_disconnected() {
694 let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
695 classify_and_emit(
696 &tx,
697 &response(
698 ResponseCode::LoginDenied,
699 StreamerCommand::Login,
700 "token expired",
701 ),
702 );
703 match rx.borrow_and_update().clone() {
704 ConnectionEvent::Disconnected(DisconnectReason::LoginDenied(msg)) => {
705 assert!(msg.contains("token expired"), "msg = {msg}");
706 }
707 other => panic!("expected Disconnected(LoginDenied), got {other:?}"),
708 }
709 }
710
711 #[test]
712 fn close_connection_emits_disconnected_server_close() {
713 let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
714 classify_and_emit(
715 &tx,
716 &response(
717 ResponseCode::CloseConnection,
718 StreamerCommand::Subs,
719 "max connections",
720 ),
721 );
722 assert!(matches!(
723 *rx.borrow_and_update(),
724 ConnectionEvent::Disconnected(DisconnectReason::ServerClose(_))
725 ));
726 }
727
728 #[test]
729 fn stop_streaming_emits_disconnected_stop_streaming() {
730 let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
731 classify_and_emit(
732 &tx,
733 &response(
734 ResponseCode::StopStreaming,
735 StreamerCommand::Subs,
736 "inactivity",
737 ),
738 );
739 assert!(matches!(
740 *rx.borrow_and_update(),
741 ConnectionEvent::Disconnected(DisconnectReason::StopStreaming(_))
742 ));
743 }
744
745 #[test]
746 fn non_admin_ok_response_does_not_emit() {
747 let (tx, rx) = watch::channel(ConnectionEvent::Connected);
748 let r = StreamerResponse::Response(vec![ResponsePayload {
750 request_id: 1,
751 service: Service::LevelOneEquities,
752 timestamp: 1,
753 command: StreamerCommand::Subs,
754 schwab_client_correlation_id: "x".into(),
755 content: ResponseContent {
756 code: ResponseCode::Ok,
757 message: "".into(),
758 },
759 }]);
760 classify_and_emit(&tx, &r);
761 assert!(!rx.has_changed().unwrap());
763 }
764
765 #[test]
766 fn data_payload_does_not_emit() {
767 let (tx, rx) = watch::channel(ConnectionEvent::Connected);
768 let r = StreamerResponse::Notify(vec![]);
769 classify_and_emit(&tx, &r);
770 assert!(!rx.has_changed().unwrap());
771 }
772
773 #[test]
774 fn wss_is_accepted_in_both_modes() {
775 assert_eq!(
776 check_websocket_scheme(Some("wss"), false).unwrap(),
777 WsTransport::Tls
778 );
779 assert_eq!(
780 check_websocket_scheme(Some("wss"), true).unwrap(),
781 WsTransport::Tls
782 );
783 }
784
785 #[test]
786 fn ws_is_rejected_when_insecure_disallowed() {
787 match check_websocket_scheme(Some("ws"), false) {
788 Err(WebSocketError::UnsupportedScheme(scheme)) => assert_eq!(scheme, "ws"),
789 other => panic!("expected UnsupportedScheme(ws), got {other:?}"),
790 }
791 }
792
793 #[test]
794 fn ws_is_accepted_when_insecure_permitted() {
795 assert_eq!(
796 check_websocket_scheme(Some("ws"), true).unwrap(),
797 WsTransport::Plain
798 );
799 }
800
801 #[test]
802 fn other_schemes_are_always_rejected() {
803 for scheme in ["http", "https", "ftp", "file", ""] {
804 assert!(
805 matches!(
806 check_websocket_scheme(Some(scheme), true).unwrap_err(),
807 WebSocketError::UnsupportedScheme(_)
808 ),
809 "scheme {scheme:?} should be rejected with insecure mode on"
810 );
811 assert!(
812 matches!(
813 check_websocket_scheme(Some(scheme), false).unwrap_err(),
814 WebSocketError::UnsupportedScheme(_)
815 ),
816 "scheme {scheme:?} should be rejected with insecure mode off"
817 );
818 }
819 }
820
821 #[test]
822 fn no_scheme_is_rejected() {
823 assert!(matches!(
824 check_websocket_scheme(None, true).unwrap_err(),
825 WebSocketError::UnsupportedScheme(s) if s.is_empty()
826 ));
827 assert!(matches!(
828 check_websocket_scheme(None, false).unwrap_err(),
829 WebSocketError::UnsupportedScheme(s) if s.is_empty()
830 ));
831 }
832
833 #[test]
834 fn case_sensitive_scheme_match() {
835 assert!(check_websocket_scheme(Some("Wss"), false).is_err(),);
836 assert!(check_websocket_scheme(Some("WSS"), false).is_err(),);
837 }
838
839 #[test]
840 fn is_retryable_classifies_transport_failures_as_retryable() {
841 assert!(WebSocketError::Connect(std::io::Error::other("x")).is_retryable());
843 assert!(WebSocketError::TlsStream(std::io::Error::other("x")).is_retryable());
844 assert!(
845 WebSocketError::Handshake(fastwebsockets::WebSocketError::ConnectionClosed)
846 .is_retryable()
847 );
848 assert!(
849 WebSocketError::Runtime(fastwebsockets::WebSocketError::ConnectionClosed)
850 .is_retryable()
851 );
852 }
853
854 #[test]
855 fn is_retryable_classifies_config_failures_as_terminal() {
856 assert!(!WebSocketError::MissingHost.is_retryable());
858 assert!(!WebSocketError::UnsupportedScheme("ws".to_string()).is_retryable());
859 assert!(
860 !WebSocketError::InvalidDomain(
861 rustls_pki_types::ServerName::try_from("not a dns name").unwrap_err()
862 )
863 .is_retryable()
864 );
865 }
871
872 #[test]
873 fn error_is_retryable_delegates_to_websocket_error() {
874 let terminal = Error::WebSocket(WebSocketError::UnsupportedScheme("ws".to_string()));
878 assert!(!terminal.is_retryable());
879 let transient = Error::WebSocket(WebSocketError::Connect(std::io::Error::other(
880 "conn refused",
881 )));
882 assert!(transient.is_retryable());
883 }
884}