use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::domain::error::{GraphError, Result, StygianError};
use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
struct WorkItem {
service: Arc<dyn ScrapingService>,
input: ServiceInput,
reply: tokio::sync::oneshot::Sender<Result<ServiceOutput>>,
}
pub struct WorkerPool {
tx: mpsc::Sender<WorkItem>,
cancel: CancellationToken,
workers: Arc<Mutex<JoinSet<()>>>,
}
impl WorkerPool {
#[allow(clippy::significant_drop_tightening)]
pub fn new(concurrency: usize, queue_depth: usize) -> Self {
let (tx, rx) = mpsc::channel::<WorkItem>(queue_depth);
let rx = Arc::new(Mutex::new(rx));
let cancel = CancellationToken::new();
let mut join_set = JoinSet::new();
for _ in 0..concurrency {
let rx_clone = Arc::clone(&rx);
let cancel_clone = cancel.clone();
join_set.spawn(async move {
loop {
if cancel_clone.is_cancelled() {
break;
}
let item = {
#[allow(clippy::significant_drop_tightening)]
let mut guard = rx_clone.lock().await;
tokio::select! {
biased;
() = cancel_clone.cancelled() => break,
item = guard.recv() => {
match item {
Some(item) => item,
None => break, }
}
}
};
let result = item.service.execute(item.input).await;
let _ = item.reply.send(result);
}
});
}
Self {
tx,
cancel,
workers: Arc::new(Mutex::new(join_set)),
}
}
pub async fn submit(
&self,
service: Arc<dyn ScrapingService>,
input: ServiceInput,
) -> Result<ServiceOutput> {
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
self.tx
.send(WorkItem {
service,
input,
reply: reply_tx,
})
.await
.map_err(|_| {
StygianError::Graph(GraphError::ExecutionFailed(
"Worker pool is shut down".into(),
))
})?;
reply_rx.await.map_err(|_| {
StygianError::Graph(GraphError::ExecutionFailed(
"Worker task dropped reply channel".into(),
))
})?
}
pub async fn shutdown(self) {
self.cancel.cancel();
drop(self.tx);
let mut workers = self.workers.lock().await;
while workers.join_next().await.is_some() {}
}
#[must_use]
pub fn is_saturated(&self) -> bool {
self.tx.capacity() == 0
}
pub fn available_capacity(&self) -> usize {
self.tx.capacity()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapters::noop::NoopService;
#[tokio::test]
async fn test_worker_pool_basic_execution() {
let pool = WorkerPool::new(2, 10);
let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
let input = ServiceInput {
url: "https://example.com".to_string(),
params: serde_json::json!({}),
};
let result = pool.submit(svc, input).await;
assert!(result.is_ok());
pool.shutdown().await;
}
#[tokio::test]
async fn test_worker_pool_concurrent_tasks()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let pool = Arc::new(WorkerPool::new(4, 20));
let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
let mut handles = Vec::new();
for i in 0..10 {
let pool_clone = Arc::clone(&pool);
let svc_clone = Arc::clone(&svc);
handles.push(tokio::spawn(async move {
let url = format!("https://example.com/{i}");
let input = ServiceInput {
url,
params: serde_json::json!({}),
};
pool_clone.submit(svc_clone, input).await
}));
}
for handle in handles {
let result = handle.await?;
assert!(result.is_ok(), "Task failed: {result:?}");
}
if let Some(p) = Arc::into_inner(pool) {
p.shutdown().await;
}
Ok(())
}
#[tokio::test]
async fn test_worker_pool_backpressure() {
let pool = WorkerPool::new(1, 1);
assert_eq!(pool.available_capacity(), 1);
let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
let input = ServiceInput {
url: "https://example.com".to_string(),
params: serde_json::json!({}),
};
let result = pool.submit(svc, input).await;
assert!(result.is_ok());
pool.shutdown().await;
}
#[tokio::test]
async fn test_worker_pool_graceful_shutdown() {
let pool = WorkerPool::new(2, 10);
pool.shutdown().await;
}
}