use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use fastwebsockets::{FragmentCollectorRead, WebSocketWrite};
use http::{
Method,
header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE},
};
use http_body_util::Empty;
use hyper::{Request, Uri, body::Bytes};
use hyper_util::rt::TokioIo;
use rustls_platform_verifier::ConfigVerifierExt;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, watch};
use tokio_rustls::{TlsConnector, client::TlsStream, rustls};
use crate::error::{Error, Result};
use crate::secrets::CustomerId;
use crate::streamer::events::{ConnectionEvent, DisconnectReason};
use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
use crate::streamer::request::{RequestPayload, StreamerRequest};
use crate::streamer::response::{RawStreamerResponse, StreamerResponse};
use crate::streamer::subscription::SubscribeRequest;
use crate::streamer::{account_activity, admin, book, chart, level_one, screener};
use crate::token::TokenProvider;
use crate::user_preferences::StreamerInfo;
type Upgraded = TokioIo<hyper::upgrade::Upgraded>;
type WsReadHalf = FragmentCollectorRead<tokio::io::ReadHalf<Upgraded>>;
type WsWriteHalf = WebSocketWrite<tokio::io::WriteHalf<Upgraded>>;
type WebSocket = fastwebsockets::WebSocket<Upgraded>;
#[derive(Debug, thiserror::Error)]
pub enum WebSocketError {
#[error("failed to connect to server: {0}")]
Connect(#[source] std::io::Error),
#[error("failed to perform websocket handshake: {0}")]
Handshake(#[source] fastwebsockets::WebSocketError),
#[error("invalid domain: {0}")]
InvalidDomain(#[source] rustls_pki_types::InvalidDnsNameError),
#[error("host is required")]
MissingHost,
#[error("failed to create TLS stream: {0}")]
TlsStream(#[source] std::io::Error),
#[error("failed to configure TLS: {0}")]
TlsConfig(#[source] rustls::Error),
#[error("failed to build upgrade request: {0}")]
BuildRequest(#[source] http::Error),
#[error("unsupported websocket scheme: {0}")]
UnsupportedScheme(String),
#[error("websocket runtime error: {0}")]
Runtime(#[from] fastwebsockets::WebSocketError),
}
impl WebSocketError {
pub fn is_retryable(&self) -> bool {
match self {
WebSocketError::Connect(_)
| WebSocketError::TlsStream(_)
| WebSocketError::Handshake(_)
| WebSocketError::Runtime(_) => true,
WebSocketError::InvalidDomain(_)
| WebSocketError::MissingHost
| WebSocketError::TlsConfig(_)
| WebSocketError::BuildRequest(_)
| WebSocketError::UnsupportedScheme(_) => false,
}
}
}
impl From<fastwebsockets::WebSocketError> for Error {
fn from(value: fastwebsockets::WebSocketError) -> Self {
Error::WebSocket(WebSocketError::Runtime(value))
}
}
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}
async fn connect_tls(uri: &Uri) -> std::result::Result<TlsStream<TcpStream>, WebSocketError> {
let host = uri.host().ok_or(WebSocketError::MissingHost)?;
let port = uri.port_u16().unwrap_or(443);
let addr = format!("{}:{}", host, port);
let socket = TcpStream::connect(addr)
.await
.map_err(WebSocketError::Connect)?;
let domain = rustls_pki_types::ServerName::try_from(host.to_string())
.map_err(WebSocketError::InvalidDomain)?;
let config =
rustls::ClientConfig::with_platform_verifier().map_err(WebSocketError::TlsConfig)?;
let connector = TlsConnector::from(Arc::new(config));
connector
.connect(domain, socket)
.await
.map_err(WebSocketError::TlsStream)
}
async fn connect_tcp(uri: &Uri) -> std::result::Result<TcpStream, WebSocketError> {
let host = uri.host().ok_or(WebSocketError::MissingHost)?;
let port = uri.port_u16().unwrap_or(80);
TcpStream::connect(format!("{}:{}", host, port))
.await
.map_err(WebSocketError::Connect)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum WsTransport {
Tls,
Plain,
}
fn check_websocket_scheme(
scheme: Option<&str>,
allow_insecure: bool,
) -> std::result::Result<WsTransport, WebSocketError> {
match scheme {
Some("wss") => Ok(WsTransport::Tls),
Some("ws") if allow_insecure => Ok(WsTransport::Plain),
Some("ws") => Err(WebSocketError::UnsupportedScheme("ws".to_string())),
Some(other) => Err(WebSocketError::UnsupportedScheme(other.to_string())),
None => Err(WebSocketError::UnsupportedScheme(String::new())),
}
}
async fn connect_websocket(uri: &Uri) -> std::result::Result<WebSocket, WebSocketError> {
let transport = check_websocket_scheme(uri.scheme_str(), cfg!(debug_assertions))?;
let req = Request::builder()
.method(Method::GET)
.uri(uri)
.header(HOST, uri.host().ok_or(WebSocketError::MissingHost)?)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
.header(SEC_WEBSOCKET_VERSION, "13")
.body(Empty::<Bytes>::new())
.map_err(WebSocketError::BuildRequest)?;
match transport {
WsTransport::Tls => {
let stream = connect_tls(uri).await?;
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
.await
.map_err(WebSocketError::Handshake)?;
Ok(ws)
}
WsTransport::Plain => {
let stream = connect_tcp(uri).await?;
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
.await
.map_err(WebSocketError::Handshake)?;
Ok(ws)
}
}
}
pub async fn connect(
streamer_info: StreamerInfo,
token_provider: Arc<dyn TokenProvider + Send + Sync>,
) -> Result<(ReadHalf, WriteHalf)> {
let validated = ValidatedStreamerInfo::try_from(streamer_info)?;
let websocket = connect_websocket(&validated.socket_url).await?;
Ok(split(websocket, validated, token_provider))
}
#[derive(Debug)]
struct ValidatedStreamerInfo {
socket_url: Uri,
customer_id: CustomerId,
correlation_id: String,
channel: String,
function_id: String,
}
impl TryFrom<StreamerInfo> for ValidatedStreamerInfo {
type Error = Error;
fn try_from(info: StreamerInfo) -> Result<Self> {
fn required<T>(field: &'static str, value: Option<T>) -> Result<T> {
value.ok_or(Error::InvalidPreference {
field,
reason: "missing".to_string(),
})
}
let socket_url = required("streamerSocketUrl", info.streamer_socket_url)?
.parse::<Uri>()
.map_err(|e| Error::InvalidPreference {
field: "streamerSocketUrl",
reason: e.to_string(),
})?;
Ok(Self {
socket_url,
customer_id: required("schwabClientCustomerId", info.schwab_client_customer_id)?,
correlation_id: required("schwabClientCorrelId", info.schwab_client_correlation_id)?,
channel: required("schwabClientChannel", info.schwab_client_channel)?,
function_id: required("schwabClientFunctionId", info.schwab_client_function_id)?,
})
}
}
fn split(
websocket: WebSocket,
streamer_info: ValidatedStreamerInfo,
token_provider: Arc<dyn TokenProvider + Send + Sync>,
) -> (ReadHalf, WriteHalf) {
let (read_half, write_half) = websocket.split(tokio::io::split);
let write_half = Arc::new(Mutex::new(write_half));
let (events_tx, _) = watch::channel(ConnectionEvent::Connected);
let reader = ReadHalf {
read_half: FragmentCollectorRead::new(read_half),
write_half: write_half.clone(),
events_tx,
};
let writer = WriteHalf {
write_half,
customer_id: streamer_info.customer_id,
correlation_id: streamer_info.correlation_id,
channel: streamer_info.channel,
function_id: streamer_info.function_id,
request_id: Arc::new(AtomicU64::new(0)),
token_provider,
};
(reader, writer)
}
async fn write_one(
write_half: Arc<Mutex<WsWriteHalf>>,
frame: fastwebsockets::Frame<'_>,
) -> std::result::Result<(), fastwebsockets::WebSocketError> {
write_half.lock().await.write_frame(frame).await
}
pub struct ReadHalf {
read_half: WsReadHalf,
write_half: Arc<Mutex<WsWriteHalf>>,
events_tx: watch::Sender<ConnectionEvent>,
}
impl std::fmt::Debug for ReadHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReadHalf").finish_non_exhaustive()
}
}
impl ReadHalf {
pub async fn recv(&mut self) -> Result<StreamerResponse> {
let write_half = self.write_half.clone();
let mut send_fn = move |frame| write_one(write_half.clone(), frame);
loop {
let frame = match self.read_half.read_frame(&mut send_fn).await {
Ok(f) => f,
Err(e) => {
self.events_tx.send_replace(ConnectionEvent::Disconnected(
DisconnectReason::Transport(e.to_string()),
));
return Err(e.into());
}
};
if frame.opcode == fastwebsockets::OpCode::Text {
let raw_response: RawStreamerResponse = match serde_json::from_slice(&frame.payload)
{
Ok(r) => r,
Err(e) => {
self.events_tx.send_replace(ConnectionEvent::StreamError {
message: e.to_string(),
});
return Err(Error::Codec {
context: "streamer response frame".to_string(),
reason: e.to_string(),
});
}
};
let response = StreamerResponse::try_from(raw_response)?;
classify_and_emit(&self.events_tx, &response);
return Ok(response);
}
}
}
pub fn events(&self) -> watch::Receiver<ConnectionEvent> {
self.events_tx.subscribe()
}
}
fn classify_and_emit(events_tx: &watch::Sender<ConnectionEvent>, response: &StreamerResponse) {
let StreamerResponse::Response(responses) = response else {
return;
};
for r in responses {
let is_login = r.service == Service::Admin && r.command == StreamerCommand::Login;
match r.content.code {
ResponseCode::Ok if is_login => {
events_tx.send_replace(ConnectionEvent::LoggedIn);
}
ResponseCode::LoginDenied => {
events_tx.send_replace(ConnectionEvent::Disconnected(
DisconnectReason::LoginDenied(r.content.message.clone()),
));
}
ResponseCode::CloseConnection => {
events_tx.send_replace(ConnectionEvent::Disconnected(
DisconnectReason::ServerClose(r.content.message.clone()),
));
}
ResponseCode::StopStreaming => {
events_tx.send_replace(ConnectionEvent::Disconnected(
DisconnectReason::StopStreaming(r.content.message.clone()),
));
}
_ => {}
}
}
}
#[derive(Clone)]
pub struct WriteHalf {
write_half: Arc<Mutex<WsWriteHalf>>,
customer_id: CustomerId,
correlation_id: String,
channel: String,
function_id: String,
request_id: Arc<AtomicU64>,
token_provider: Arc<dyn TokenProvider + Send + Sync>,
}
impl std::fmt::Debug for WriteHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteHalf")
.field("channel", &self.channel)
.field("function_id", &self.function_id)
.finish_non_exhaustive()
}
}
impl WriteHalf {
pub async fn login(&self) -> Result<()> {
let auth_token = self.token_provider.access_token().await?;
let request = admin::Login {
authorization: auth_token,
schwab_client_channel: self.channel.clone(),
schwab_client_function_id: self.function_id.clone(),
};
self.send(request).await
}
pub async fn logout(&self) -> Result<()> {
self.send(admin::Logout).await
}
pub fn equities(&self) -> SubscribeRequest<'_, level_one::equities::Field> {
SubscribeRequest::new(self)
}
pub fn options(&self) -> SubscribeRequest<'_, level_one::options::Field> {
SubscribeRequest::new(self)
}
pub fn futures(&self) -> SubscribeRequest<'_, level_one::futures::Field> {
SubscribeRequest::new(self)
}
pub fn futures_options(&self) -> SubscribeRequest<'_, level_one::futures_options::Field> {
SubscribeRequest::new(self)
}
pub fn forex(&self) -> SubscribeRequest<'_, level_one::forex::Field> {
SubscribeRequest::new(self)
}
pub fn nyse_book(&self) -> SubscribeRequest<'_, book::nyse::Field> {
SubscribeRequest::new(self)
}
pub fn nasdaq_book(&self) -> SubscribeRequest<'_, book::nasdaq::Field> {
SubscribeRequest::new(self)
}
pub fn options_book(&self) -> SubscribeRequest<'_, book::options::Field> {
SubscribeRequest::new(self)
}
pub fn chart_equity(&self) -> SubscribeRequest<'_, chart::equity::Field> {
SubscribeRequest::new(self)
}
pub fn chart_futures(&self) -> SubscribeRequest<'_, chart::futures::Field> {
SubscribeRequest::new(self)
}
pub fn screener_equity(&self) -> SubscribeRequest<'_, screener::equity::Field> {
SubscribeRequest::new(self)
}
pub fn screener_option(&self) -> SubscribeRequest<'_, screener::option::Field> {
SubscribeRequest::new(self)
}
pub fn account_activity(&self) -> SubscribeRequest<'_, account_activity::Field> {
SubscribeRequest::new(self)
}
pub(crate) async fn send<T: Into<StreamerRequest>>(&self, request: T) -> Result<()> {
let request: StreamerRequest = request.into();
let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
let request = RequestPayload {
request_id,
service: request.service,
command: request.command,
parameters: request.parameters,
schwab_client_customer_id: self.customer_id.clone(),
schwab_client_correlation_id: self.correlation_id.clone(),
};
let serialized = serde_json::to_string(&request).map_err(|e| Error::Codec {
context: "streamer request envelope".to_string(),
reason: e.to_string(),
})?;
write_one(
self.write_half.clone(),
fastwebsockets::Frame::text(fastwebsockets::Payload::Borrowed(serialized.as_bytes())),
)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::streamer::events::{ConnectionEvent, DisconnectReason};
use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
use crate::streamer::response::{ResponseContent, ResponsePayload};
fn response(code: ResponseCode, command: StreamerCommand, msg: &str) -> StreamerResponse {
StreamerResponse::Response(vec![ResponsePayload {
request_id: 1,
service: Service::Admin,
timestamp: 1,
command,
schwab_client_correlation_id: "x".into(),
content: ResponseContent {
code,
message: msg.into(),
},
}])
}
fn full_streamer_info() -> StreamerInfo {
StreamerInfo {
streamer_socket_url: Some("wss://streamer-api.schwab.com/ws".into()),
schwab_client_customer_id: Some(CustomerId::from("CUSTID")),
schwab_client_correlation_id: Some("abc-123".into()),
schwab_client_channel: Some("N9".into()),
schwab_client_function_id: Some("APIAPP".into()),
}
}
#[test]
fn validates_complete_streamer_info() {
let validated =
ValidatedStreamerInfo::try_from(full_streamer_info()).expect("complete info validates");
assert_eq!(validated.socket_url, "wss://streamer-api.schwab.com/ws");
assert_eq!(validated.correlation_id, "abc-123");
assert_eq!(validated.channel, "N9");
assert_eq!(validated.function_id, "APIAPP");
}
#[test]
fn missing_socket_url_reports_field() {
let mut info = full_streamer_info();
info.streamer_socket_url = None;
match ValidatedStreamerInfo::try_from(info) {
Err(Error::InvalidPreference { field, .. }) => {
assert_eq!(field, "streamerSocketUrl");
}
other => panic!("expected InvalidPreference, got {other:?}"),
}
}
#[test]
fn missing_customer_id_reports_field() {
let mut info = full_streamer_info();
info.schwab_client_customer_id = None;
match ValidatedStreamerInfo::try_from(info) {
Err(Error::InvalidPreference { field, .. }) => {
assert_eq!(field, "schwabClientCustomerId");
}
other => panic!("expected InvalidPreference, got {other:?}"),
}
}
#[test]
fn missing_correlation_id_reports_field() {
let mut info = full_streamer_info();
info.schwab_client_correlation_id = None;
match ValidatedStreamerInfo::try_from(info) {
Err(Error::InvalidPreference { field, .. }) => {
assert_eq!(field, "schwabClientCorrelId");
}
other => panic!("expected InvalidPreference, got {other:?}"),
}
}
#[test]
fn missing_channel_reports_field() {
let mut info = full_streamer_info();
info.schwab_client_channel = None;
match ValidatedStreamerInfo::try_from(info) {
Err(Error::InvalidPreference { field, .. }) => {
assert_eq!(field, "schwabClientChannel");
}
other => panic!("expected InvalidPreference, got {other:?}"),
}
}
#[test]
fn missing_function_id_reports_field() {
let mut info = full_streamer_info();
info.schwab_client_function_id = None;
match ValidatedStreamerInfo::try_from(info) {
Err(Error::InvalidPreference { field, .. }) => {
assert_eq!(field, "schwabClientFunctionId");
}
other => panic!("expected InvalidPreference, got {other:?}"),
}
}
#[test]
fn login_ok_emits_logged_in() {
let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
classify_and_emit(&tx, &response(ResponseCode::Ok, StreamerCommand::Login, ""));
assert!(rx.has_changed().unwrap());
assert_eq!(*rx.borrow_and_update(), ConnectionEvent::LoggedIn);
}
#[test]
fn login_denied_emits_disconnected() {
let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
classify_and_emit(
&tx,
&response(
ResponseCode::LoginDenied,
StreamerCommand::Login,
"token expired",
),
);
match rx.borrow_and_update().clone() {
ConnectionEvent::Disconnected(DisconnectReason::LoginDenied(msg)) => {
assert!(msg.contains("token expired"), "msg = {msg}");
}
other => panic!("expected Disconnected(LoginDenied), got {other:?}"),
}
}
#[test]
fn close_connection_emits_disconnected_server_close() {
let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
classify_and_emit(
&tx,
&response(
ResponseCode::CloseConnection,
StreamerCommand::Subs,
"max connections",
),
);
assert!(matches!(
*rx.borrow_and_update(),
ConnectionEvent::Disconnected(DisconnectReason::ServerClose(_))
));
}
#[test]
fn stop_streaming_emits_disconnected_stop_streaming() {
let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
classify_and_emit(
&tx,
&response(
ResponseCode::StopStreaming,
StreamerCommand::Subs,
"inactivity",
),
);
assert!(matches!(
*rx.borrow_and_update(),
ConnectionEvent::Disconnected(DisconnectReason::StopStreaming(_))
));
}
#[test]
fn non_admin_ok_response_does_not_emit() {
let (tx, rx) = watch::channel(ConnectionEvent::Connected);
let r = StreamerResponse::Response(vec![ResponsePayload {
request_id: 1,
service: Service::LevelOneEquities,
timestamp: 1,
command: StreamerCommand::Subs,
schwab_client_correlation_id: "x".into(),
content: ResponseContent {
code: ResponseCode::Ok,
message: "".into(),
},
}]);
classify_and_emit(&tx, &r);
assert!(!rx.has_changed().unwrap());
}
#[test]
fn data_payload_does_not_emit() {
let (tx, rx) = watch::channel(ConnectionEvent::Connected);
let r = StreamerResponse::Notify(vec![]);
classify_and_emit(&tx, &r);
assert!(!rx.has_changed().unwrap());
}
#[test]
fn wss_is_accepted_in_both_modes() {
assert_eq!(
check_websocket_scheme(Some("wss"), false).unwrap(),
WsTransport::Tls
);
assert_eq!(
check_websocket_scheme(Some("wss"), true).unwrap(),
WsTransport::Tls
);
}
#[test]
fn ws_is_rejected_when_insecure_disallowed() {
match check_websocket_scheme(Some("ws"), false) {
Err(WebSocketError::UnsupportedScheme(scheme)) => assert_eq!(scheme, "ws"),
other => panic!("expected UnsupportedScheme(ws), got {other:?}"),
}
}
#[test]
fn ws_is_accepted_when_insecure_permitted() {
assert_eq!(
check_websocket_scheme(Some("ws"), true).unwrap(),
WsTransport::Plain
);
}
#[test]
fn other_schemes_are_always_rejected() {
for scheme in ["http", "https", "ftp", "file", ""] {
assert!(
matches!(
check_websocket_scheme(Some(scheme), true).unwrap_err(),
WebSocketError::UnsupportedScheme(_)
),
"scheme {scheme:?} should be rejected with insecure mode on"
);
assert!(
matches!(
check_websocket_scheme(Some(scheme), false).unwrap_err(),
WebSocketError::UnsupportedScheme(_)
),
"scheme {scheme:?} should be rejected with insecure mode off"
);
}
}
#[test]
fn no_scheme_is_rejected() {
assert!(matches!(
check_websocket_scheme(None, true).unwrap_err(),
WebSocketError::UnsupportedScheme(s) if s.is_empty()
));
assert!(matches!(
check_websocket_scheme(None, false).unwrap_err(),
WebSocketError::UnsupportedScheme(s) if s.is_empty()
));
}
#[test]
fn case_sensitive_scheme_match() {
assert!(check_websocket_scheme(Some("Wss"), false).is_err(),);
assert!(check_websocket_scheme(Some("WSS"), false).is_err(),);
}
#[test]
fn is_retryable_classifies_transport_failures_as_retryable() {
assert!(WebSocketError::Connect(std::io::Error::other("x")).is_retryable());
assert!(WebSocketError::TlsStream(std::io::Error::other("x")).is_retryable());
assert!(
WebSocketError::Handshake(fastwebsockets::WebSocketError::ConnectionClosed)
.is_retryable()
);
assert!(
WebSocketError::Runtime(fastwebsockets::WebSocketError::ConnectionClosed)
.is_retryable()
);
}
#[test]
fn is_retryable_classifies_config_failures_as_terminal() {
assert!(!WebSocketError::MissingHost.is_retryable());
assert!(!WebSocketError::UnsupportedScheme("ws".to_string()).is_retryable());
assert!(
!WebSocketError::InvalidDomain(
rustls_pki_types::ServerName::try_from("not a dns name").unwrap_err()
)
.is_retryable()
);
}
#[test]
fn error_is_retryable_delegates_to_websocket_error() {
let terminal = Error::WebSocket(WebSocketError::UnsupportedScheme("ws".to_string()));
assert!(!terminal.is_retryable());
let transient = Error::WebSocket(WebSocketError::Connect(std::io::Error::other(
"conn refused",
)));
assert!(transient.is_retryable());
}
}