use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::error::HandlerResult;
use crate::types::*;
#[async_trait]
pub trait Middleware: Send + Sync + 'static {
async fn before(&self, chat_id: ChatId, user: &UserInfo, update: &IncomingUpdate) -> bool {
let _ = (chat_id, user, update);
true
}
async fn after(
&self,
chat_id: ChatId,
user: &UserInfo,
update: &IncomingUpdate,
result: &HandlerResult,
) {
let _ = (chat_id, user, update, result);
}
}
pub struct LoggingMiddleware;
#[async_trait]
impl Middleware for LoggingMiddleware {
async fn before(&self, chat_id: ChatId, user: &UserInfo, update: &IncomingUpdate) -> bool {
tracing::info!(
chat_id = chat_id.0,
user_id = user.id.0,
update_type = update.type_name(),
"incoming update"
);
true
}
async fn after(
&self,
chat_id: ChatId,
_user: &UserInfo,
update: &IncomingUpdate,
result: &HandlerResult,
) {
match result {
Ok(()) => tracing::debug!(chat_id = chat_id.0, update_type = update.type_name(), "ok"),
Err(e) => tracing::error!(chat_id = chat_id.0, error = %e, "handler error"),
}
}
}
pub struct AuthMiddleware {
allowed_ids: HashSet<u64>,
}
impl AuthMiddleware {
pub fn new(ids: impl IntoIterator<Item = u64>) -> Self {
Self {
allowed_ids: ids.into_iter().collect(),
}
}
}
#[async_trait]
impl Middleware for AuthMiddleware {
async fn before(&self, _chat_id: ChatId, user: &UserInfo, _update: &IncomingUpdate) -> bool {
if self.allowed_ids.contains(&user.id.0) {
true
} else {
tracing::warn!(user_id = user.id.0, "unauthorized access blocked");
false
}
}
}
pub struct ThrottleMiddleware {
max_per_second: u64,
counter: dashmap::DashMap<ChatId, (std::time::Instant, u64)>,
}
impl ThrottleMiddleware {
pub fn new(max_per_second: u64) -> Self {
Self {
max_per_second,
counter: dashmap::DashMap::new(),
}
}
}
#[async_trait]
impl Middleware for ThrottleMiddleware {
async fn before(&self, chat_id: ChatId, _user: &UserInfo, _update: &IncomingUpdate) -> bool {
let now = std::time::Instant::now();
let mut entry = self.counter.entry(chat_id).or_insert((now, 0));
if now.duration_since(entry.0).as_secs() >= 1 {
*entry = (now, 1);
true
} else {
entry.1 += 1;
if entry.1 > self.max_per_second {
tracing::warn!(chat_id = chat_id.0, "throttled");
false
} else {
true
}
}
}
async fn after(
&self,
_chat_id: ChatId,
_user: &UserInfo,
_update: &IncomingUpdate,
_result: &HandlerResult,
) {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let n = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n % 100 == 0 {
let now = std::time::Instant::now();
self.counter
.retain(|_, (ts, _)| now.duration_since(*ts).as_secs() < 60);
}
}
}
pub struct AnalyticsMiddleware {
pub total_updates: AtomicU64,
pub total_messages: AtomicU64,
pub total_callbacks: AtomicU64,
pub unique_users: dashmap::DashMap<UserId, ()>,
}
impl AnalyticsMiddleware {
pub fn new() -> Arc<Self> {
Arc::new(Self {
total_updates: AtomicU64::new(0),
total_messages: AtomicU64::new(0),
total_callbacks: AtomicU64::new(0),
unique_users: dashmap::DashMap::new(),
})
}
pub fn stats(&self) -> (u64, u64, u64, usize) {
(
self.total_updates.load(Ordering::Relaxed),
self.total_messages.load(Ordering::Relaxed),
self.total_callbacks.load(Ordering::Relaxed),
self.unique_users.len(),
)
}
}
impl Default for AnalyticsMiddleware {
fn default() -> Self {
Self {
total_updates: AtomicU64::new(0),
total_messages: AtomicU64::new(0),
total_callbacks: AtomicU64::new(0),
unique_users: dashmap::DashMap::new(),
}
}
}
#[async_trait]
impl Middleware for Arc<AnalyticsMiddleware> {
async fn before(&self, _chat_id: ChatId, user: &UserInfo, update: &IncomingUpdate) -> bool {
self.total_updates.fetch_add(1, Ordering::Relaxed);
self.unique_users.insert(user.id, ());
match &update.kind {
UpdateKind::Message { .. } | UpdateKind::Photo { .. } | UpdateKind::Document { .. } => {
self.total_messages.fetch_add(1, Ordering::Relaxed);
}
UpdateKind::CallbackQuery { .. } => {
self.total_callbacks.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_user(id: u64) -> UserInfo {
UserInfo {
id: UserId(id),
first_name: "Test".into(),
last_name: None,
username: None,
language_code: None,
}
}
fn test_update(chat_id: i64) -> IncomingUpdate {
IncomingUpdate {
chat_id: ChatId(chat_id),
user: test_user(chat_id as u64),
message_id: None,
kind: UpdateKind::Message {
text: Some("hello".into()),
},
}
}
#[tokio::test]
async fn auth_allows_authorized_user() {
let auth = AuthMiddleware::new(vec![123]);
let user = test_user(123);
let update = test_update(123);
assert!(auth.before(ChatId(123), &user, &update).await);
}
#[tokio::test]
async fn auth_blocks_unauthorized_user() {
let auth = AuthMiddleware::new(vec![999]);
let user = test_user(123);
let update = test_update(123);
assert!(!auth.before(ChatId(123), &user, &update).await);
}
#[tokio::test]
async fn throttle_allows_first_request() {
let throttle = ThrottleMiddleware::new(10);
let user = test_user(123);
let update = test_update(123);
assert!(throttle.before(ChatId(123), &user, &update).await);
}
#[tokio::test]
async fn throttle_blocks_excess_requests() {
let throttle = ThrottleMiddleware::new(2);
let user = test_user(123);
let update = test_update(123);
assert!(throttle.before(ChatId(123), &user, &update).await);
assert!(throttle.before(ChatId(123), &user, &update).await);
assert!(!throttle.before(ChatId(123), &user, &update).await);
}
#[tokio::test]
async fn analytics_counts() {
let analytics = AnalyticsMiddleware::new();
let user = test_user(123);
let update = test_update(123);
analytics.before(ChatId(123), &user, &update).await;
analytics
.before(ChatId(456), &test_user(456), &test_update(456))
.await;
let stats = analytics.stats();
assert_eq!(stats.0, 2); assert_eq!(stats.1, 2); }
}