use {super::*, crate::*};
use std::future::Future;
use std::sync::Arc;
use futures::prelude::*;
use tokio::prelude::*;
use tokio::prelude::{AsyncRead, AsyncWrite};
use tokio::sync::Mutex;
use tokio::time::Duration;
const PING_INACTIVITY: Duration = Duration::from_secs(45);
const PING_WINDOW: Duration = Duration::from_secs(10);
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + 'a + Send>>;
type ConnectFuture<IO> = BoxFuture<'static, Result<IO, std::io::Error>>;
pub struct Connector<IO> {
connect: Arc<dyn Fn() -> ConnectFuture<IO> + Send + Sync + 'static>,
}
impl<IO> Connector<IO>
where
IO: AsyncRead + AsyncWrite,
IO: Send + Sync + 'static,
{
pub fn new<F, R>(connect_func: F) -> Self
where
F: Fn() -> R + Send + Sync + 'static,
R: Future<Output = Result<IO, std::io::Error>> + Send + Sync + 'static,
{
Self {
connect: Arc::new(move || Box::pin(connect_func())),
}
}
}
impl<IO> Clone for Connector<IO> {
fn clone(&self) -> Self {
Self {
connect: Arc::clone(&self.connect),
}
}
}
impl<IO> std::fmt::Debug for Connector<IO> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connector").finish()
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct RetryStrategy;
impl RetryStrategy {
pub async fn immediately(result: Result<Status, Error>) -> Result<bool, Error> {
if let Ok(Status::Canceled) = result {
return Ok(false);
}
Ok(true)
}
pub async fn on_timeout(result: Result<Status, Error>) -> Result<bool, Error> {
let status = if let Status::Timeout = result? {
true
} else {
false
};
Ok(status)
}
pub async fn on_error(result: Result<Status, Error>) -> Result<bool, Error> {
Ok(result.is_err())
}
}
pub struct Runner {
dispatcher: Dispatcher,
receiver: Rx,
writer: Writer,
abort: abort::Abort,
ready: Arc<tokio::sync::Notify>,
}
impl Runner {
pub fn new(dispatcher: Dispatcher) -> (Runner, Control) {
Self::new_with_rate_limit(dispatcher, RateLimit::default())
}
pub fn new_without_rate_limit(dispatcher: Dispatcher) -> (Runner, Control) {
let (sender, receiver) = mpsc::channel(64);
let stop = abort::Abort::default();
let writer = Writer::new(crate::encode::AsyncMpscWriter::new(sender));
let ready = Arc::new(tokio::sync::Notify::default());
let this = Self {
dispatcher,
receiver,
abort: stop.clone(),
writer: writer.clone(),
ready: ready.clone(),
};
let control = Control {
writer,
stop,
ready,
};
(this, control)
}
pub fn new_with_rate_limit(dispatcher: Dispatcher, rate_limit: RateLimit) -> (Runner, Control) {
let (sender, receiver) = mpsc::channel(64);
let stop = abort::Abort::default();
let writer = Writer::new(crate::encode::AsyncMpscWriter::new(sender))
.with_rate_limiter(Arc::new(Mutex::new(rate_limit)));
let ready = Arc::new(tokio::sync::Notify::default());
let this = Self {
dispatcher,
receiver,
abort: stop.clone(),
writer: writer.clone(),
ready: ready.clone(),
};
let control = Control {
writer,
stop,
ready,
};
(this, control)
}
pub async fn run_to_completion<IO>(&mut self, connector: Connector<IO>) -> Result<Status, Error>
where
IO: AsyncRead + AsyncWrite,
IO: Unpin + Send + Sync + 'static,
{
let io = (connector.connect)().await.map_err(Error::Io)?;
let mut stream = tokio::io::BufStream::new(io);
let mut buffer = String::with_capacity(1024);
let mut ping = self
.dispatcher
.subscribe_internal::<crate::events::Ping>(true);
struct Token(Arc<tokio::sync::Notify>, Arc<tokio::sync::Notify>);
impl Drop for Token {
fn drop(&mut self) {
self.0.notify();
self.1.notify();
}
}
let restart = Arc::new(tokio::sync::Notify::default());
let _token = Token(restart.clone(), self.ready.clone());
let mut out = self.writer.clone();
let (mut check_timeout, timeout_delay, timeout_task) =
check_connection(restart, &self.dispatcher, out.clone());
loop {
tokio::select! {
_ = self.abort.wait_for() => {
log::debug!("received signal from user to stop");
let _ = self.dispatcher.clear_subscriptions_all();
break Ok(Status::Canceled)
}
Some(msg) = ping.next() => {
if out.pong(&msg.token).await.is_err() {
log::debug!("cannot send pong");
break Ok(Status::Eof);
}
}
Ok(n) = &mut stream.read_line(&mut buffer) => {
if n == 0 {
log::info!("read 0 bytes. this is an EOF");
break Ok(Status::Eof)
}
let mut visited = false;
for msg in decode(&buffer) {
let msg = msg?;
log::trace!(target: "twitchchat::runner::read", "< {}", msg.raw.escape_debug());
self.dispatcher.dispatch(&msg);
visited = true;
}
if !visited {
log::warn!("twitch sent an incomplete message");
break Ok(Status::Eof)
}
buffer.clear();
let _ = check_timeout.send(()).await;
},
Some(data) = &mut self.receiver.next() => {
log::trace!(target: "twitchchat::runner::write", "> {}", std::str::from_utf8(&data).unwrap().escape_debug());
stream.write_all(&data).await?;
stream.flush().await?
},
_ = timeout_delay.notified() => {
log::warn!(target: "twitchchat::runner::timeout", "timeout detected, quitting loop");
drop(check_timeout);
timeout_task.await;
break Ok(Status::Timeout);
},
else => {
log::info!("all futures are dead. ending loop");
break Ok(Status::Eof)
}
}
}
}
pub async fn run_with_retry<IO, F, R>(
&mut self,
connector: Connector<IO>,
retry_check: F,
) -> Result<(), Error>
where
IO: AsyncRead + AsyncWrite,
IO: Unpin + Send + Sync + 'static,
F: Fn(Result<Status, Error>) -> R,
F: Send + Sync,
R: Future<Output = Result<bool, Error>> + Send + Sync + 'static,
R::Output: Send,
{
loop {
let res = self.run_to_completion(connector.clone()).await;
match retry_check(res).await {
Err(err) => break Err(err),
Ok(false) => break Ok(()),
Ok(true) => {}
}
self.dispatcher.reset_internal_subscriptions();
}
}
}
fn check_connection(
restart: Arc<tokio::sync::Notify>,
dispatcher: &Dispatcher,
mut writer: Writer,
) -> (
tokio::sync::mpsc::Sender<()>,
Arc<tokio::sync::Notify>,
impl Future,
) {
use tokio::sync::{mpsc, Notify};
let mut pong = dispatcher.subscribe_internal::<crate::events::Pong>(true);
let timeout_notify = Arc::new(Notify::new());
let (tx, mut rx) = mpsc::channel(1);
let timeout = timeout_notify.clone();
let task = async move {
loop {
tokio::select! {
_ = tokio::time::delay_for(PING_INACTIVITY) => {
log::debug!(target: "twitchchat::runner::timeout", "inactivity detected of {:?}, sending a ping", PING_INACTIVITY);
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time to not go backwards")
.as_secs();
if writer.ping(&format!("{}", ts)).await.is_err() {
timeout.notify();
log::error!(target: "twitchchat::runner::timeout", "cannot send ping");
break;
}
if tokio::time::timeout(PING_WINDOW, pong.next())
.await
.is_err()
{
timeout.notify();
log::error!(target: "twitchchat::runner::timeout", "did not get a ping after {:?}", PING_WINDOW);
break;
}
}
Some(..) = rx.next() => { }
_ = restart.notified() => { break }
else => { break }
}
}
};
(tx, timeout_notify, tokio::task::spawn(task))
}
impl std::fmt::Debug for Runner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runner").finish()
}
}