use std::{
convert::Infallible,
future::{Future, IntoFuture},
marker::PhantomData,
pin::{Pin, pin},
str::FromStr,
time::Duration,
};
use futures_util::{
FutureExt, SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use tokio::{net::TcpStream, sync::watch, time};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Message, Utf8Bytes},
};
use tower::{Service, util::ServiceExt};
use tracing::{error, info, trace, warn};
use crate::websocket::{Request, Response, scanner};
type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type WsMessage = Option<Result<Message, tokio_tungstenite::tungstenite::Error>>;
pub fn client<M, S>(url: impl Into<String>, make_service: M) -> Client<M, S>
where
M: Service<(), Error = Infallible, Response = S>,
<M as Service<()>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
Client {
url: url.into(),
make_service,
config: Config::default(),
_marker: PhantomData,
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub max_reconnect_attempts: usize,
pub initial_reconnect_delay: Duration,
pub max_reconnect_delay: Duration,
pub reconnect_grace_period: Duration,
}
impl Default for Config {
fn default() -> Self {
Self {
max_reconnect_attempts: 5,
initial_reconnect_delay: Duration::from_secs(1),
max_reconnect_delay: Duration::from_secs(30),
reconnect_grace_period: Duration::from_secs(1),
}
}
}
pub struct Client<M, S> {
url: String,
make_service: M,
config: Config,
_marker: PhantomData<S>,
}
impl<M, S> Client<M, S>
where
M: Service<(), Error = Infallible, Response = S>,
<M as Service<()>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
pub fn with_config(mut self, config: Config) -> Self {
self.config = config;
self
}
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
url: self.url,
make_service: self.make_service,
config: self.config,
signal,
_marker: PhantomData,
}
}
async fn run(self) -> Result<(), Error> {
let Self {
url,
mut make_service,
config,
_marker,
} = self;
if !url.starts_with("ws://") && !url.starts_with("wss://") {
return Err(Error::InvalidUrl(url));
}
let mut current_url = url;
loop {
let (mut write, read) = try_accept(¤t_url, &config).await?;
trace!("websocket connection established to twitch eventsub");
let mut svc = make_service
.call(())
.await
.expect("make_service error is Infallible");
let recv_task = handle_connection(&mut write, read, &mut svc);
match recv_task.await? {
ConnectionResult::Url(url) => {
trace!("reconnect requested, switching to new url");
current_url = url;
time::sleep(config.reconnect_grace_period).await;
}
ConnectionResult::Closed => {
trace!("connection closed");
return Ok(());
}
}
}
}
}
impl<M, S> IntoFuture for Client<M, S>
where
M: Service<(), Error = Infallible, Response = S> + Clone + Send + 'static,
<M as Service<()>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
type Output = Result<(), Error>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.run().await })
}
}
pub struct WithGracefulShutdown<M, S, F> {
url: String,
make_service: M,
config: Config,
signal: F,
_marker: PhantomData<S>,
}
impl<M, S, F> WithGracefulShutdown<M, S, F>
where
M: Service<(), Error = Infallible, Response = S>,
<M as Service<()>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
async fn run(self) -> Result<(), Error> {
let Self {
url,
mut make_service,
config,
signal,
_marker,
} = self;
if !url.starts_with("ws://") && !url.starts_with("wss://") {
return Err(Error::InvalidUrl(url));
}
let (shutdown_tx, signal_rx) = watch::channel(());
tokio::spawn(async move {
signal.await;
trace!("received shutdown signal, initiating graceful shutdown");
drop(signal_rx);
});
let mut current_url = url;
let mut signal_closed = pin!(shutdown_tx.closed().fuse());
loop {
let (mut write, read) = tokio::select! {
conn = try_accept(¤t_url, &config) => conn?,
_ = &mut signal_closed => {
trace!("shutdown signal received, stopping connection");
return Ok(());
}
};
trace!("websocket connection established to twitch eventsub");
let mut svc = make_service
.call(())
.await
.expect("make_service error is Infallible");
let recv_task = handle_connection(&mut write, read, &mut svc);
tokio::select! {
result = recv_task => {
match result? {
ConnectionResult::Url(url) => {
trace!("reconnect requested, switching to new url");
current_url = url;
time::sleep(config.reconnect_grace_period).await;
},
ConnectionResult::Closed => {
trace!("connection closed");
return Ok(());
},
}
}
_ = &mut signal_closed => {
trace!("shutdown signal received");
let _ = write.close().await;
return Ok(())
}
}
}
}
}
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
where
M: Service<(), Error = Infallible, Response = S> + Clone + Send + 'static,
<M as Service<()>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = Result<(), Error>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.run().await })
}
}
async fn try_accept(url: &str, config: &Config) -> Result<(WsSink, WsStream), Error> {
let mut attempts = 0;
loop {
attempts += 1;
match connect_async(url).await {
Ok((ws_stream, _)) => {
info!("successfully conncted to websocket");
return Ok(ws_stream.split());
}
Err(e) => {
if attempts >= config.max_reconnect_attempts {
error!("failed to connect after {} attempts", attempts);
return Err(Error::FailedConnect { attempts });
}
let delay = calculate_backoff_delay(attempts, config);
warn!(
"connection attempt {}/{} failed: {}, retrying in {:?}",
attempts, config.max_reconnect_attempts, e, delay
);
time::sleep(delay).await;
}
}
}
}
enum ConnectionResult {
Url(String),
Closed,
}
async fn handle_connection<S>(
write: &mut WsSink,
mut read: WsStream,
svc: &mut S,
) -> Result<ConnectionResult, Error>
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
loop {
let msg = read.next().await;
if let Some(result) = handle_messages(write, svc, msg).await? {
return Ok(result);
}
}
}
async fn handle_messages<S>(
write: &mut WsSink,
svc: &mut S,
msg: WsMessage,
) -> Result<Option<ConnectionResult>, Error>
where
S: Service<Request, Response = Response, Error = Infallible>,
{
match msg {
Some(Ok(Message::Text(text))) => match handle_text_message(write, svc, text).await {
Ok(Some(url)) => {
trace!("received reconnect request, closing current connection");
let _ = write.close().await;
Ok(Some(ConnectionResult::Url(url)))
}
Ok(None) => Ok(None),
Err(e) => {
warn!("error handling text message: {}, closing connection", e);
let _ = write.close().await;
Err(e)
}
},
Some(Ok(Message::Ping(ping))) => {
trace!("received ping, sending pong");
write.send(Message::Pong(ping)).await?;
Ok(None)
}
Some(Ok(Message::Close(frame))) => {
trace!("received close frame: {:?}", frame);
let _ = write.close().await;
Ok(Some(ConnectionResult::Closed))
}
Some(Err(e)) => {
error!("websocket error: {}", e);
let _ = write.close().await;
Err(Error::WebSocket(e))
}
Some(Ok(Message::Pong(_) | Message::Binary(_) | Message::Frame(_))) => {
trace!("ignoring non-text message");
Ok(None)
}
None => {
trace!("websocket stream ended");
Ok(Some(ConnectionResult::Closed))
}
}
}
async fn handle_text_message<S>(
write: &mut WsSink,
svc: &mut S,
text: Utf8Bytes,
) -> Result<Option<String>, Error>
where
S: Service<Request, Response = Response, Error = Infallible>,
{
let req = Request::from_str(&text)?;
if req.is_keepalive() {
trace!("received keepalive, sending pong");
write.send(Message::Pong("".into())).await?;
return Ok(None);
}
if req.is_reconnect() {
info!("server requested reconnect to: {}", req.get_reconnect_url());
return Ok(Some(req.get_reconnect_url().to_string()));
}
svc.ready().await.expect("service error is Infallible");
let resp = svc.call(req).await.expect("service error is Infallible");
if resp.is_reconnect() {
trace!("handler requested reconnect");
return Ok(resp.url);
}
Ok(None)
}
fn calculate_backoff_delay(attempts: usize, config: &Config) -> Duration {
let backoff_multiplier = 2u32.saturating_pow(attempts.saturating_sub(1).min(5) as u32);
config
.initial_reconnect_delay
.saturating_mul(backoff_multiplier)
.min(config.max_reconnect_delay)
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("WebSocket error: {0}")]
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("Connection failed after {attempts} attempts")]
FailedConnect { attempts: usize },
#[error("Failed to parse message: {0}")]
ParseError(#[from] scanner::ScanError),
#[error("Invalid Websocket URL: {0}")]
InvalidUrl(String),
}