use std::{
collections::{hash_map::Entry, HashMap, VecDeque},
pin::pin,
time::{Duration, Instant},
};
use either::Either;
use futures::{future, FutureExt as _};
use tokio::sync::{mpsc, mpsc::error::TryRecvError, oneshot::Sender};
use vecrem::VecExt;
use crate::{
adaptors::throttle::{request_lock::RequestLock, ChatIdHash, Limits, Settings},
errors::AsResponseParameters,
requests::Requester,
};
const MINUTE: Duration = Duration::from_secs(60);
const SECOND: Duration = Duration::from_secs(1);
const DELAY: Duration = Duration::from_millis(250);
const QUEUE_FULL_DELAY: Duration = Duration::from_secs(4);
#[derive(Debug)]
pub(super) enum InfoMessage {
GetLimits { response: Sender<Limits> },
SetLimits { new: Limits, response: Sender<()> },
}
type RequestsSent = u32;
#[derive(Default)]
struct RequestsSentToChats {
per_min: HashMap<ChatIdHash, RequestsSent>,
per_sec: HashMap<ChatIdHash, RequestsSent>,
}
pub(super) struct FreezeUntil {
pub(super) until: Instant,
pub(super) after: Duration,
pub(super) chat: ChatIdHash,
}
pub(super) async fn worker<B>(
Settings { mut limits, mut on_queue_full, retry, check_slow_mode }: Settings,
mut rx: mpsc::Receiver<(ChatIdHash, RequestLock)>,
mut info_rx: mpsc::Receiver<InfoMessage>,
bot: B,
) where
B: Requester,
B::Err: AsResponseParameters,
{
let mut queue: Vec<(ChatIdHash, RequestLock)> =
Vec::with_capacity(limits.messages_per_sec_overall as usize);
let mut history: VecDeque<(ChatIdHash, Instant)> = VecDeque::new();
let mut requests_sent = RequestsSentToChats::default();
let mut slow_mode: Option<HashMap<ChatIdHash, (Duration, Instant)>> =
check_slow_mode.then(HashMap::new);
let mut rx_is_closed = false;
let mut last_queue_full =
Instant::now().checked_sub(QUEUE_FULL_DELAY).unwrap_or_else(Instant::now);
let (freeze_tx, mut freeze_rx) = mpsc::channel::<FreezeUntil>(1);
while !rx_is_closed || !queue.is_empty() {
answer_info(&mut info_rx, &mut limits);
loop {
let res = future::select(
pin!(freeze_rx.recv()),
pin!(read_from_rx(&mut rx, &mut queue, &mut rx_is_closed)),
)
.map(either)
.await
.map_either(|l| l.0, |r| r.0);
match res {
Either::Left(freeze_until) => {
freeze(&mut freeze_rx, slow_mode.as_mut(), &bot, freeze_until).await;
}
Either::Right(()) => break,
}
}
if queue.len() == queue.capacity() && last_queue_full.elapsed() > QUEUE_FULL_DELAY {
last_queue_full = Instant::now();
tokio::spawn(on_queue_full(queue.len()));
}
let now = Instant::now();
let min_back = now.checked_sub(MINUTE).unwrap_or(now);
let sec_back = now.checked_sub(SECOND).unwrap_or(now);
while let Some((_, time)) = history.front() {
if time >= &min_back {
break;
}
if let Some((chat, _)) = history.pop_front() {
let entry = requests_sent.per_min.entry(chat).and_modify(|count| {
*count -= 1;
});
if let Entry::Occupied(entry) = entry {
if *entry.get() == 0 {
entry.remove_entry();
}
}
}
}
let used = history.iter().rev().take_while(|(_, time)| time > &sec_back).count() as u32;
let mut allowed = limits.messages_per_sec_overall.saturating_sub(used);
if allowed == 0 {
requests_sent.per_sec.clear();
tokio::time::sleep(DELAY).await;
continue;
}
for (chat, _) in history.iter().rev().take_while(|(_, time)| time > &sec_back) {
*requests_sent.per_sec.entry(*chat).or_insert(0) += 1;
}
let mut queue_removing = queue.removing();
while let Some(entry) = queue_removing.next() {
let chat = &entry.value().0;
let slow_mode = slow_mode.as_mut().and_then(|sm| sm.get_mut(chat));
if let Some(&mut (delay, last)) = slow_mode {
if last + delay > Instant::now() {
continue;
}
}
let requests_sent_per_sec_count = requests_sent.per_sec.get(chat).copied().unwrap_or(0);
let requests_sent_per_min_count = requests_sent.per_min.get(chat).copied().unwrap_or(0);
let messages_per_min_limit = if chat.is_channel_or_supergroup() {
limits.messages_per_min_channel_or_supergroup
} else {
limits.messages_per_min_chat
};
let limits_not_exceeded = requests_sent_per_sec_count < limits.messages_per_sec_chat
&& requests_sent_per_min_count < messages_per_min_limit;
if limits_not_exceeded {
let chat = *chat;
let (_, lock) = entry.remove();
if lock.unlock(retry, freeze_tx.clone()).is_ok() {
*requests_sent.per_sec.entry(chat).or_insert(0) += 1;
*requests_sent.per_min.entry(chat).or_insert(0) += 1;
history.push_back((chat, Instant::now()));
if let Some((_, last)) = slow_mode {
*last = Instant::now();
}
allowed -= 1;
if allowed == 0 {
break;
}
}
}
}
requests_sent.per_sec.clear();
tokio::time::sleep(DELAY).await;
}
}
fn answer_info(rx: &mut mpsc::Receiver<InfoMessage>, limits: &mut Limits) {
while let Ok(req) = rx.try_recv() {
match req {
InfoMessage::GetLimits { response } => response.send(*limits).ok(),
InfoMessage::SetLimits { new, response } => {
*limits = new;
response.send(()).ok()
}
};
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
async fn freeze(
rx: &mut mpsc::Receiver<FreezeUntil>,
mut slow_mode: Option<&mut HashMap<ChatIdHash, (Duration, Instant)>>,
bot: &impl Requester,
mut imm: Option<FreezeUntil>,
) {
while let Some(freeze_until) = imm.take().or_else(|| rx.try_recv().ok()) {
let FreezeUntil { until, after, chat } = freeze_until;
#[allow(clippy::needless_option_as_deref)]
if let Some(slow_mode) = slow_mode.as_deref_mut() {
if let hash @ ChatIdHash::Id(id) = chat {
if let Ok(chat) = bot.get_chat(id).await {
match chat.slow_mode_delay() {
Some(delay) => {
let now = Instant::now();
let new_delay = delay.duration();
slow_mode.insert(hash, (new_delay, now));
}
None => {
slow_mode.remove(&hash);
}
};
}
}
}
let slow_mode_enabled_and_likely_the_cause = slow_mode
.as_ref()
.and_then(|m| m.get(&chat).map(|(delay, _)| delay <= &after))
.unwrap_or(false);
if !slow_mode_enabled_and_likely_the_cause {
log::warn!(
"freezing the bot for approximately {after:?} due to `RetryAfter` error from \
telegram"
);
tokio::time::sleep_until(until.into()).await;
log::warn!("unfreezing the bot");
}
}
}
async fn read_from_rx<T>(rx: &mut mpsc::Receiver<T>, queue: &mut Vec<T>, rx_is_closed: &mut bool) {
if queue.is_empty() {
log::debug!("blocking on queue");
match rx.recv().await {
Some(req) => queue.push(req),
None => *rx_is_closed = true,
}
}
while queue.len() < queue.capacity() {
match rx.try_recv() {
Ok(req) => queue.push(req),
Err(TryRecvError::Disconnected) => {
*rx_is_closed = true;
break;
}
Err(TryRecvError::Empty) => break,
}
}
}
fn either<L, R>(x: future::Either<L, R>) -> Either<L, R> {
match x {
future::Either::Left(l) => Either::Left(l),
future::Either::Right(r) => Either::Right(r),
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn issue_535() {
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
drop(tx);
super::read_from_rx::<()>(&mut rx, &mut Vec::new(), &mut false).await;
}
}