use std::sync::Arc;
use async_trait::async_trait;
use crate::{types::Update, Bot};
#[async_trait]
pub trait Middleware: Send + Sync + 'static {
async fn before(&self, bot: &Bot, update: &Update) -> bool {
let _ = (bot, update);
true
}
async fn after(&self, bot: &Bot, update: &Update) {
let _ = (bot, update);
}
}
pub type ArcMiddleware = Arc<dyn Middleware>;
pub async fn run_before(chain: &[ArcMiddleware], bot: &Bot, update: &Update) -> bool {
for mw in chain {
if !mw.before(bot, update).await {
return false;
}
}
true
}
pub async fn run_after(chain: &[ArcMiddleware], bot: &Bot, update: &Update) {
for mw in chain {
mw.after(bot, update).await;
}
}
pub struct LoggingMiddleware;
#[async_trait]
impl Middleware for LoggingMiddleware {
async fn before(&self, _bot: &Bot, update: &Update) -> bool {
let kind = update_kind(update);
tracing::info!(update_id = update.update_id, kind, "update received");
true
}
async fn after(&self, _bot: &Bot, update: &Update) {
tracing::debug!(update_id = update.update_id, "update handled");
}
}
fn update_kind(u: &Update) -> &'static str {
if u.message.is_some() {
"message"
} else if u.edited_message.is_some() {
"edited_message"
} else if u.callback_query.is_some() {
"callback_query"
} else if u.inline_query.is_some() {
"inline_query"
} else if u.channel_post.is_some() {
"channel_post"
} else if u.chat_member.is_some() {
"chat_member"
} else {
"other"
}
}
pub struct RateLimiter {
max_per_second: u32,
counters: Arc<dashmap_rl::DashMap>,
}
mod dashmap_rl {
use std::{collections::HashMap, sync::Mutex, time::Instant};
pub struct DashMap(Mutex<HashMap<i64, (Instant, u32)>>);
impl DashMap {
pub fn new() -> Self {
Self(Mutex::new(HashMap::new()))
}
pub fn check_and_increment(&self, id: i64, max: u32) -> bool {
let mut m = self.0.lock().unwrap_or_else(|e| e.into_inner());
let now = Instant::now();
let entry = m.entry(id).or_insert((now, 0));
if now.duration_since(entry.0).as_secs() >= 1 {
*entry = (now, 1);
true
} else if entry.1 < max {
entry.1 += 1;
true
} else {
false
}
}
}
}
impl RateLimiter {
pub fn new(max_per_second: u32) -> Self {
Self {
max_per_second,
counters: Arc::new(dashmap_rl::DashMap::new()),
}
}
fn chat_id(update: &Update) -> Option<i64> {
update
.message
.as_ref()
.map(|m| m.chat.id)
.or_else(|| update.edited_message.as_ref().map(|m| m.chat.id))
.or_else(|| update.channel_post.as_ref().map(|m| m.chat.id))
}
}
#[async_trait]
impl Middleware for RateLimiter {
async fn before(&self, _bot: &Bot, update: &Update) -> bool {
let id = match Self::chat_id(update) {
Some(id) => id,
None => return true, };
let allowed = self.counters.check_and_increment(id, self.max_per_second);
if !allowed {
tracing::debug!(chat_id = id, "rate limit: update dropped");
}
allowed
}
}