use super::router::{PropagateEvent, Response};
use crate::{
client::{Bot, Reqwest, Session},
context::Context,
either::Either,
enums::UpdateType,
errors::{EventErrorKind, HandlerError},
methods::GetUpdates,
types::Update,
Extensions, Request, RouterConfigured,
};
use backoff::{exponential::ExponentialBackoff, future::retry, SystemClock};
use futures_util::future::BoxFuture;
use std::{
future::{Future, IntoFuture},
sync::Arc,
};
use tokio::{
select,
sync::{mpsc, watch},
};
use tracing::{event, field, instrument, Level, Span};
const GET_UPDATES_SIZE: u8 = 100;
const CHANNEL_UPDATES_SIZE: usize = 100;
pub const DEFAULT_POLLING_TIMEOUT: i64 = 30;
#[derive(Clone)]
pub struct Dispatcher<
Client = Reqwest,
Propagator = RouterConfigured,
Backoff = ExponentialBackoff<SystemClock>,
> {
propagator: Propagator,
bots: Vec<Bot<Client>>,
extensions: Extensions,
context: Context,
polling_timeout: Option<i64>,
backoff: Backoff,
allowed_updates: Vec<Box<str>>,
}
impl<Client, Propagator> Dispatcher<Client, Propagator>
where
Propagator: Default,
{
#[inline]
#[must_use]
pub fn builder() -> Builder<Client, Propagator> {
Builder::default()
}
}
pub struct Builder<Client, Propagator, BackoffType = ExponentialBackoff<SystemClock>> {
propagator: Propagator,
bots: Vec<Bot<Client>>,
context: Context,
extensions: Extensions,
polling_timeout: Option<i64>,
backoff: BackoffType,
allowed_updates: Vec<Box<str>>,
}
impl<Client, Propagator> Default for Builder<Client, Propagator>
where
Propagator: Default,
{
fn default() -> Self {
Self {
propagator: Propagator::default(),
bots: vec![],
context: Context::new(),
extensions: Extensions::new(),
polling_timeout: Some(DEFAULT_POLLING_TIMEOUT),
backoff: ExponentialBackoff::default(),
allowed_updates: vec![],
}
}
}
impl<Client, Propagator, BackoffType> Builder<Client, Propagator, BackoffType> {
#[inline]
#[must_use]
pub fn with_backoff(mut self, backoff: BackoffType) -> Self {
self.backoff = backoff;
self
}
}
impl<Client, Propagator, BackoffType> Builder<Client, Propagator, BackoffType> {
#[inline]
#[must_use]
pub fn main_router(self, val: Propagator) -> Self
where
Propagator: PropagateEvent<Client>,
{
Self {
propagator: val,
..self
}
}
#[inline]
#[must_use]
pub fn router(self, val: Propagator) -> Self
where
Propagator: PropagateEvent<Client>,
{
self.main_router(val)
}
#[must_use]
pub fn bot(self, val: Bot<Client>) -> Self {
Self {
bots: self.bots.into_iter().chain(Some(val)).collect(),
..self
}
}
#[must_use]
pub fn bots(self, val: impl IntoIterator<Item = Bot<Client>>) -> Self {
Self {
bots: self.bots.into_iter().chain(val).collect(),
..self
}
}
#[must_use]
pub fn context<T>(mut self, key: &'static str, val: T) -> Self
where
T: Clone + Send + Sync + 'static,
{
self.context.insert(key, val);
self
}
#[must_use]
pub fn context_extend(mut self, val: Context) -> Self {
self.context.extend(val);
self
}
#[must_use]
pub fn extension<T>(mut self, val: T) -> Self
where
T: Clone + Send + Sync + 'static,
{
self.extensions.insert(val);
self
}
#[must_use]
pub fn extensions_extend(mut self, val: Extensions) -> Self {
self.extensions.extend(val);
self
}
#[inline]
#[must_use]
pub fn polling_timeout(self, val: i64) -> Self {
Self {
polling_timeout: Some(val),
..self
}
}
#[inline]
#[must_use]
pub fn backoff(self, val: BackoffType) -> Self {
Self {
backoff: val,
..self
}
}
#[must_use]
pub fn allowed_update(self, val: impl Into<Box<str>>) -> Self {
Self {
allowed_updates: self
.allowed_updates
.into_iter()
.chain(Some(val.into()))
.collect(),
..self
}
}
#[must_use]
pub fn allowed_updates<T, I>(self, val: I) -> Self
where
T: Into<Box<str>>,
I: IntoIterator<Item = T>,
{
Self {
allowed_updates: self
.allowed_updates
.into_iter()
.chain(val.into_iter().map(Into::into))
.collect(),
..self
}
}
#[inline]
#[must_use]
pub fn build(self) -> Dispatcher<Client, Propagator, BackoffType> {
Dispatcher {
propagator: self.propagator,
bots: self.bots,
extensions: self.extensions,
context: self.context,
polling_timeout: self.polling_timeout,
backoff: self.backoff,
allowed_updates: self.allowed_updates,
}
}
}
impl<Client, Propagator, Backoff> Dispatcher<Client, Propagator, Backoff> {
#[instrument(skip_all, fields(update_id = update.update_id(), update_type))]
pub async fn feed_update(
&mut self,
bot: Bot<Client>,
update: Arc<Update>,
) -> Result<Response<Client>, EventErrorKind>
where
Client: Send + Sync + Clone + 'static,
Propagator: PropagateEvent<Client>,
{
let update_type = UpdateType::from(update.as_ref());
Span::current().record("update_type", field::display(&update_type));
self.propagator
.propagate_event(
update_type,
Request {
bot,
update,
context: self.context.clone(),
extensions: self.extensions.clone(),
},
)
.await
}
#[instrument(skip_all)]
async fn listen_updates(
bot: Bot<Client>,
polling_timeout: Option<i64>,
allowed_updates: Vec<Box<str>>,
update_tx: mpsc::Sender<Update>,
backoff: Backoff,
) -> mpsc::error::SendError<Update>
where
Client: Session,
Backoff: backoff::backoff::Backoff + Clone,
{
event!(Level::TRACE, "Start listening updates");
let mut method = GetUpdates::new()
.limit(GET_UPDATES_SIZE)
.timeout_option(polling_timeout)
.allowed_updates(allowed_updates.clone());
loop {
let updates = retry(backoff.clone(), || {
let bot = ⊥
let method = method.clone();
async move {
match bot.send(method).await {
Ok(updates) => Ok(updates),
Err(err) => {
event!(Level::ERROR, %err, "Failed to fetch updates");
Err(backoff::Error::transient(err))
}
}
}
})
.await
.expect("Retry gave up due to permanent error");
let id = match updates.last() {
Some(Either::Left(update)) => update.update_id(),
Some(Either::Right(update)) => update.update_id,
None => {
event!(Level::TRACE, "No updates received");
continue;
}
};
method.offset = Some(id + 1);
for update in updates {
match update {
Either::Left(update) => {
if let Err(err) = update_tx.send(update).await {
return err;
}
}
Either::Right(update) => {
event!(
Level::ERROR,
update_id = update.update_id,
update = ?update.extra,
"Failed to parse update",
);
}
}
}
}
}
#[instrument(skip_all, fields(bot_id = bot.id))]
fn polling(&self, bot: Bot<Client>) -> impl Drop
where
Client: Session + Clone + 'static,
Propagator: PropagateEvent<Client> + Clone,
Backoff: backoff::backoff::Backoff + Send + Sync + Clone + 'static,
{
let (signal_tx, signal_rx) = watch::channel(());
let (update_tx, mut update_rx) = mpsc::channel(CHANNEL_UPDATES_SIZE);
let hidden_token = bot.hidden_token.clone();
tokio::spawn({
let fut = Self::listen_updates(
bot.clone(),
self.polling_timeout,
self.allowed_updates.clone(),
update_tx,
self.backoff.clone(),
);
async move {
select! {
() = signal_tx.closed() => event!(Level::TRACE, "Select signal branch"),
_ = fut => event!(Level::TRACE, "Select future branch"),
};
event!(Level::WARN, "Graceful shutdown signal received");
}
});
tokio::spawn({
let dispatcher = self.clone();
async move {
while let Some(update) = update_rx.recv().await {
event!(
Level::TRACE,
update_id = update.update_id(),
"Received update from the listener"
);
let update = Arc::new(update);
let bot = bot.clone();
let mut dispatcher = dispatcher.clone();
tokio::spawn(async move { dispatcher.feed_update(bot, update).await });
}
}
});
event!(Level::INFO, token = hidden_token, "Started");
signal_rx
}
#[inline]
#[must_use]
pub fn run_polling(self) -> ServePolling<Client, Propagator, Backoff>
where
Client: Session + Clone + 'static,
Propagator: PropagateEvent<Client> + 'static,
Backoff: backoff::backoff::Backoff + Send + Sync + Clone + 'static,
{
assert!(
!self.bots.is_empty(),
"You must add at least one bot to the dispatcher",
);
ServePolling::new(self)
}
#[inline]
#[must_use]
pub fn run_no_polling(self) -> Serve<Client, Propagator, Backoff>
where
Propagator: PropagateEvent<Client> + 'static,
{
Serve::new(self)
}
}
pub struct ServePolling<Client, Propagator, BackoffType> {
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
}
impl<Client, Propagator, BackoffType> ServePolling<Client, Propagator, BackoffType> {
#[inline]
#[must_use]
pub const fn new(dispatcher: Dispatcher<Client, Propagator, BackoffType>) -> Self {
Self {
dispatcher,
}
}
#[inline]
#[must_use]
pub fn with_graceful_shutdown<Signal>(
self,
signal: Signal,
) -> ServePollingWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
where
Signal: Future + Send + 'static,
Signal::Output: Send,
{
ServePollingWithGracefulShutdown::new(self.dispatcher, signal)
}
}
impl<Client, Propagator, BackoffType> IntoFuture for ServePolling<Client, Propagator, BackoffType>
where
Client: Session + Clone + 'static,
Propagator: PropagateEvent<Client>,
BackoffType: backoff::backoff::Backoff + Send + Sync + Clone + 'static,
{
type IntoFuture = BoxFuture<'static, Self::Output>;
type Output = Result<(), HandlerError>;
#[cfg(feature = "default_signal")]
fn into_future(self) -> Self::IntoFuture {
use crate::utils::shutdown_signal;
self.with_graceful_shutdown(shutdown_signal()).into_future()
}
#[cfg(not(feature = "default_signal"))]
fn into_future(self) -> Self::IntoFuture {
if self.dispatcher.propagator.shutdown_handlers_len() != 0 {
event!(
target: "telers:dispatcher:into_future",
Level::WARN,
"Shutdown observer can't be called without graceful shutdow. \
You can off this log by `telers:dispatcher:into_future=off`.",
);
}
self.with_graceful_shutdown(std::future::pending::<Self::Output>())
.into_future()
}
}
pub struct ServePollingWithGracefulShutdown<Client, Propagator, BackoffType, Signal> {
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
signal: Signal,
}
impl<Client, Propagator, BackoffType, Signal>
ServePollingWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
{
#[inline]
#[must_use]
pub const fn new(
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
signal: Signal,
) -> Self {
Self {
dispatcher,
signal,
}
}
}
impl<Client, Propagator, BackoffType, Signal> IntoFuture
for ServePollingWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
where
Client: Session + Clone + 'static,
Signal: Future + Send + 'static,
Signal::Output: Send,
Propagator: PropagateEvent<Client>,
BackoffType: backoff::backoff::Backoff + Send + Sync + Clone + 'static,
{
type IntoFuture = BoxFuture<'static, Self::Output>;
type Output = Result<(), HandlerError>;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
self.dispatcher.propagator.emit_startup().await?;
let mut pollings = Vec::with_capacity(self.dispatcher.bots.len());
for bot in self.dispatcher.bots.clone() {
pollings.push(self.dispatcher.polling(bot));
}
self.signal.await;
self.dispatcher.propagator.emit_shutdown().await?;
Ok(())
})
}
}
pub struct Serve<Client, Propagator, BackoffType> {
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
}
impl<Client, Propagator, BackoffType> Serve<Client, Propagator, BackoffType> {
#[inline]
#[must_use]
pub const fn new(dispatcher: Dispatcher<Client, Propagator, BackoffType>) -> Self {
Self {
dispatcher,
}
}
#[inline]
#[must_use]
pub fn with_graceful_shutdown<Signal>(
self,
signal: Signal,
) -> ServeWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
where
Signal: Future + Send + 'static,
Signal::Output: Send,
{
ServeWithGracefulShutdown::new(self.dispatcher, signal)
}
}
impl<Client, Propagator, BackoffType> IntoFuture for Serve<Client, Propagator, BackoffType>
where
Propagator: PropagateEvent<Client>,
{
type IntoFuture = BoxFuture<'static, Self::Output>;
type Output = Result<(), HandlerError>;
#[cfg(feature = "default_signal")]
fn into_future(self) -> Self::IntoFuture {
use crate::utils::shutdown_signal;
self.with_graceful_shutdown(shutdown_signal()).into_future()
}
#[cfg(not(feature = "default_signal"))]
fn into_future(self) -> Self::IntoFuture {
if self.dispatcher.propagator.shutdown_handlers_len() != 0 {
event!(
target: "telers:dispatcher:into_future",
Level::WARN,
"Shutdown observer can't be called without graceful shutdow. \
You can off this log by `telers:dispatcher:into_future=off`.",
);
}
self.with_graceful_shutdown(std::future::pending::<Self::Output>())
.into_future()
}
}
pub struct ServeWithGracefulShutdown<Client, Propagator, BackoffType, Signal> {
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
signal: Signal,
}
impl<Client, Propagator, BackoffType, Signal>
ServeWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
{
#[inline]
#[must_use]
pub const fn new(
dispatcher: Dispatcher<Client, Propagator, BackoffType>,
signal: Signal,
) -> Self {
Self {
dispatcher,
signal,
}
}
}
impl<Client, Propagator, BackoffType, Signal> IntoFuture
for ServeWithGracefulShutdown<Client, Propagator, BackoffType, Signal>
where
Signal: Future + Send + 'static,
Signal::Output: Send,
Propagator: PropagateEvent<Client>,
{
type IntoFuture = BoxFuture<'static, Self::Output>;
type Output = Result<(), HandlerError>;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
self.dispatcher.propagator.emit_startup().await?;
self.signal.await;
self.dispatcher.propagator.emit_shutdown().await?;
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::Reqwest,
event::{
bases::{EventReturn, PropagateEventResult},
telegram::Handler,
},
router::Router,
types::{ChatPrivate, MessageText, UpdateMessage},
};
use std::convert::Infallible;
use tokio;
#[tokio::test]
async fn test_feed_update() {
let bot = Bot::<Reqwest>::default();
let update = Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), ""),
)));
let router = Router::new("main");
let mut dispatcher = Dispatcher::builder()
.main_router(router.configure_default())
.build();
let response = dispatcher
.feed_update(bot.clone(), update.clone())
.await
.unwrap();
match response.propagate_result {
PropagateEventResult::Unhandled => {}
_ => panic!("Unexpected result"),
}
let router = Router::new("main").on_message(|observer| {
observer.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Finish)
}))
});
let mut dispatcher = Dispatcher::builder()
.main_router(router.configure_default())
.build();
let response = dispatcher.feed_update(bot.clone(), update).await.unwrap();
match response.propagate_result {
PropagateEventResult::Handled(_) => {}
_ => panic!("Unexpected result"),
}
}
#[derive(Clone)]
struct Test1;
#[derive(Clone)]
struct Test2;
#[derive(Clone)]
struct Test3;
#[test]
fn test_builder() {
let bot = Bot::<Reqwest>::default();
let dispatcher = Dispatcher::builder()
.main_router(Router::new("main").configure_default())
.bot(bot.clone())
.bots([bot])
.extension(Test1)
.extension(Test2)
.extensions_extend({
let mut extensions = Extensions::new();
extensions.insert(Test3);
extensions
})
.context("1", Test1)
.context("2", Test2)
.context_extend({
let mut context = Context::new();
context.insert("3", Test3);
context
})
.polling_timeout(123)
.allowed_update(UpdateType::Message)
.allowed_updates([UpdateType::InlineQuery, UpdateType::ChosenInlineResult])
.build();
assert_eq!(dispatcher.bots.len(), 2);
assert_eq!(dispatcher.extensions.len(), 3);
assert_eq!(dispatcher.context.len(), 3);
assert_eq!(dispatcher.polling_timeout, Some(123));
assert_eq!(dispatcher.allowed_updates.len(), 3);
}
}