use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Semaphore;
use rust_tg_bot_raw::types::update::Update;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum UpdateProcessorError {
#[error("`max_concurrent_updates` must be a positive integer")]
InvalidConcurrency,
#[error("Handler error: {0}")]
Handler(Box<dyn std::error::Error + Send + Sync>),
}
#[async_trait::async_trait]
pub trait UpdateProcessor: Send + Sync {
async fn do_process_update(
&self,
update: Arc<Update>,
coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
);
async fn initialize(&self) {}
async fn shutdown(&self) {}
}
pub struct BaseUpdateProcessor {
inner: Box<dyn UpdateProcessor>,
semaphore: Arc<Semaphore>,
max_concurrent_updates: usize,
active: AtomicUsize,
}
impl std::fmt::Debug for BaseUpdateProcessor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BaseUpdateProcessor")
.field("max_concurrent_updates", &self.max_concurrent_updates)
.field("active", &self.active.load(Ordering::Relaxed))
.finish()
}
}
impl BaseUpdateProcessor {
pub fn new(
inner: Box<dyn UpdateProcessor>,
max_concurrent_updates: usize,
) -> Result<Self, UpdateProcessorError> {
if max_concurrent_updates == 0 {
return Err(UpdateProcessorError::InvalidConcurrency);
}
Ok(Self {
inner,
semaphore: Arc::new(Semaphore::new(max_concurrent_updates)),
max_concurrent_updates,
active: AtomicUsize::new(0),
})
}
#[must_use]
pub fn max_concurrent_updates(&self) -> usize {
self.max_concurrent_updates
}
#[must_use]
pub fn current_concurrent_updates(&self) -> usize {
self.active.load(Ordering::Relaxed)
}
pub async fn process_update(
&self,
update: Arc<Update>,
coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
) {
let _permit = self
.semaphore
.acquire()
.await
.expect("semaphore should not be closed");
self.active.fetch_add(1, Ordering::Relaxed);
self.inner.do_process_update(update, coroutine).await;
self.active.fetch_sub(1, Ordering::Relaxed);
}
pub async fn initialize(&self) {
self.inner.initialize().await;
}
pub async fn shutdown(&self) {
self.inner.shutdown().await;
}
}
#[derive(Debug, Default)]
pub struct SimpleUpdateProcessor;
#[async_trait::async_trait]
impl UpdateProcessor for SimpleUpdateProcessor {
async fn do_process_update(
&self,
_update: Arc<Update>,
coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
) {
coroutine.await;
}
}
pub fn simple_processor(
max_concurrent_updates: usize,
) -> Result<BaseUpdateProcessor, UpdateProcessorError> {
BaseUpdateProcessor::new(Box::new(SimpleUpdateProcessor), max_concurrent_updates)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_update() -> Update {
serde_json::from_value(serde_json::json!({"update_id": 0})).unwrap()
}
#[tokio::test]
async fn simple_processor_runs_coroutine() {
let proc = simple_processor(1).unwrap();
proc.initialize().await;
let flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
let flag2 = flag.clone();
let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
flag2.store(true, Ordering::Relaxed);
});
proc.process_update(Arc::new(dummy_update()), fut).await;
assert!(flag.load(Ordering::Relaxed));
proc.shutdown().await;
}
#[test]
fn zero_concurrency_rejected() {
assert!(simple_processor(0).is_err());
}
#[tokio::test]
async fn concurrent_updates_tracking() {
let proc = simple_processor(4).unwrap();
assert_eq!(proc.max_concurrent_updates(), 4);
assert_eq!(proc.current_concurrent_updates(), 0);
}
#[tokio::test]
async fn concurrent_processing_bounded() {
let proc = Arc::new(simple_processor(2).unwrap());
let counter = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let p = proc.clone();
let c = counter.clone();
let m = max_seen.clone();
handles.push(tokio::spawn(async move {
let cc = c.clone();
let mm = m.clone();
let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
let current = cc.fetch_add(1, Ordering::SeqCst) + 1;
mm.fetch_max(current, Ordering::SeqCst);
tokio::task::yield_now().await;
cc.fetch_sub(1, Ordering::SeqCst);
});
p.process_update(Arc::new(dummy_update()), fut).await;
}));
}
for h in handles {
h.await.unwrap();
}
assert!(max_seen.load(Ordering::SeqCst) <= 2);
}
}