use crate::types::Update;
use crate::{Bot, BotError};
use futures::FutureExt as _;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{error, warn};
pub type UpdateHandler =
Box<dyn Fn(Bot, Update) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
pub struct Poller {
poll_bot: Bot,
api_bot: Option<Bot>,
handler: Arc<UpdateHandler>,
timeout: i64,
limit: i64,
allowed_updates: Vec<String>,
max_concurrent: usize,
}
impl Poller {
pub fn new(bot: Bot, handler: UpdateHandler) -> Self {
Poller {
poll_bot: bot,
api_bot: None,
handler: Arc::new(handler),
timeout: 30,
limit: 100,
allowed_updates: vec![],
max_concurrent: 512,
}
}
pub fn api_bot(mut self, bot: Bot) -> Self {
self.api_bot = Some(bot);
self
}
pub fn concurrency(mut self, max: usize) -> Self {
self.max_concurrent = max.max(1);
self
}
pub fn timeout(mut self, t: i64) -> Self {
self.timeout = t;
self
}
pub fn limit(mut self, l: i64) -> Self {
self.limit = l;
self
}
pub fn allowed_updates(mut self, updates: Vec<String>) -> Self {
self.allowed_updates = updates;
self
}
pub async fn start(self) -> Result<(), BotError> {
let mut offset: i64 = 0;
let allowed_updates = if self.allowed_updates.is_empty() {
None
} else {
Some(self.allowed_updates.clone())
};
let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
let api_bot = self.api_bot.unwrap_or_else(|| self.poll_bot.clone());
let handler = Arc::clone(&self.handler);
tracing::debug!(
max_concurrent = self.max_concurrent,
timeout = self.timeout,
"polling started"
);
loop {
let mut req = self
.poll_bot
.get_updates()
.offset(offset)
.timeout(self.timeout)
.limit(self.limit);
if let Some(ref au) = allowed_updates {
req = req.allowed_updates(au.clone());
}
let updates = match req.await {
Ok(u) => u,
Err(e) => {
let sleep_secs = match &e {
BotError::Api {
retry_after: Some(secs),
..
} => {
warn!(retry_after = secs, "flood-wait on getUpdates");
*secs as u64
}
_ => {
error!(error = %e, "getUpdates error, retrying in 3 s");
3
}
};
tokio::time::sleep(Duration::from_secs(sleep_secs)).await;
continue;
}
};
for update in updates {
offset = update.update_id + 1;
let permit = semaphore
.clone()
.acquire_owned()
.await
.expect("semaphore should not be closed");
let bot = api_bot.clone();
let handler = Arc::clone(&handler);
tokio::spawn(async move {
let _permit = permit;
let result = std::panic::AssertUnwindSafe((handler)(bot, update))
.catch_unwind()
.await;
if result.is_err() {
error!("handler panicked on update - caught, polling continues");
}
});
}
}
}
}