use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OperationStatus {
Running,
Completed,
Failed,
Cancelled,
}
#[derive(Debug)]
pub struct OperationEntry {
pub request_id: String,
pub method: String,
pub status: OperationStatus,
pub cancel_token: CancellationToken,
pub started_at: std::time::Instant,
}
impl OperationEntry {
pub fn new(request_id: String, method: String) -> Self {
Self {
request_id,
method,
status: OperationStatus::Running,
cancel_token: CancellationToken::new(),
started_at: std::time::Instant::now(),
}
}
pub fn is_running(&self) -> bool {
self.status == OperationStatus::Running
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn elapsed(&self) -> std::time::Duration {
self.started_at.elapsed()
}
}
#[derive(Default)]
pub struct OperationRegistry {
operations: RwLock<HashMap<String, Arc<Mutex<OperationEntry>>>>,
max_concurrent: usize,
running: AtomicUsize,
}
impl OperationRegistry {
pub fn new() -> Self {
Self {
operations: RwLock::new(HashMap::new()),
max_concurrent: 0,
running: AtomicUsize::new(0),
}
}
pub fn with_max_concurrent(max: usize) -> Self {
Self {
operations: RwLock::new(HashMap::new()),
max_concurrent: max,
running: AtomicUsize::new(0),
}
}
fn generate_op_id() -> String {
uuid::Uuid::new_v4().to_string()
}
pub async fn register(
&self,
request_id: String,
method: String,
) -> Result<(String, CancellationToken), RegistryError> {
if self.max_concurrent > 0 {
let current = self.running.load(Ordering::Relaxed);
if current >= self.max_concurrent {
return Err(RegistryError::ConcurrencyLimitReached);
}
}
let mut ops = self.operations.write().await;
if self.max_concurrent > 0 {
let current = self.running.load(Ordering::Relaxed);
if current >= self.max_concurrent {
return Err(RegistryError::ConcurrencyLimitReached);
}
}
let op_id = Self::generate_op_id();
let entry = OperationEntry::new(request_id, method);
let cancel_token = entry.cancel_token.clone();
ops.insert(op_id.clone(), Arc::new(Mutex::new(entry)));
self.running.fetch_add(1, Ordering::Relaxed);
Ok((op_id, cancel_token))
}
pub async fn get(&self, op_id: &str) -> Option<Arc<Mutex<OperationEntry>>> {
let ops = self.operations.read().await;
ops.get(op_id).cloned()
}
pub async fn cancel(&self, op_id: &str) -> Result<(), RegistryError> {
let ops = self.operations.read().await;
match ops.get(op_id) {
Some(entry) => {
let mut entry = entry.lock().await;
if entry.status == OperationStatus::Running {
entry.status = OperationStatus::Cancelled;
entry.cancel_token.cancel();
self.running.fetch_sub(1, Ordering::Relaxed);
Ok(())
} else {
Err(RegistryError::OperationNotRunning)
}
}
None => Err(RegistryError::OperationNotFound),
}
}
pub async fn complete(&self, op_id: &str) -> Result<(), RegistryError> {
let ops = self.operations.read().await;
match ops.get(op_id) {
Some(entry) => {
let mut entry = entry.lock().await;
if entry.status == OperationStatus::Running {
entry.status = OperationStatus::Completed;
self.running.fetch_sub(1, Ordering::Relaxed);
}
Ok(())
}
None => Err(RegistryError::OperationNotFound),
}
}
pub async fn fail(&self, op_id: &str) -> Result<(), RegistryError> {
let ops = self.operations.read().await;
match ops.get(op_id) {
Some(entry) => {
let mut entry = entry.lock().await;
if entry.status == OperationStatus::Running {
entry.status = OperationStatus::Failed;
self.running.fetch_sub(1, Ordering::Relaxed);
}
Ok(())
}
None => Err(RegistryError::OperationNotFound),
}
}
pub async fn remove(&self, op_id: &str) -> Option<Arc<Mutex<OperationEntry>>> {
let mut ops = self.operations.write().await;
ops.remove(op_id)
}
pub async fn complete_and_remove(&self, op_id: &str) -> Result<(), RegistryError> {
self.complete(op_id).await?;
self.remove(op_id).await;
Ok(())
}
pub async fn fail_and_remove(&self, op_id: &str) -> Result<(), RegistryError> {
self.fail(op_id).await?;
self.remove(op_id).await;
Ok(())
}
pub async fn cancel_and_remove(&self, op_id: &str) -> Result<(), RegistryError> {
self.cancel(op_id).await?;
self.remove(op_id).await;
Ok(())
}
pub async fn running_count(&self) -> usize {
self.running.load(Ordering::Relaxed)
}
pub async fn op_ids(&self) -> Vec<String> {
let ops = self.operations.read().await;
ops.keys().cloned().collect()
}
pub async fn cleanup(&self, older_than: std::time::Duration) {
let mut ops = self.operations.write().await;
let now = std::time::Instant::now();
ops.retain(|_, entry| {
if let Ok(e) = entry.try_lock() {
e.is_running() || now.duration_since(e.started_at) < older_than
} else {
true }
});
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RegistryError {
OperationNotFound,
OperationNotRunning,
ConcurrencyLimitReached,
}
impl std::fmt::Display for RegistryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegistryError::OperationNotFound => write!(f, "Operation not found"),
RegistryError::OperationNotRunning => write!(f, "Operation is not running"),
RegistryError::ConcurrencyLimitReached => write!(f, "Concurrency limit reached"),
}
}
}
impl std::error::Error for RegistryError {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register_operation() {
let registry = OperationRegistry::new();
let (op_id, _cancel_token) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
assert!(!op_id.is_empty());
assert_eq!(registry.running_count().await, 1);
}
#[tokio::test]
async fn test_cancel_operation() {
let registry = OperationRegistry::new();
let (op_id, cancel_token) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
assert!(!cancel_token.is_cancelled());
registry.cancel(&op_id).await.unwrap();
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn test_cancel_unknown_operation() {
let registry = OperationRegistry::new();
let result = registry.cancel("unknown-op-id").await;
assert_eq!(result, Err(RegistryError::OperationNotFound));
}
#[tokio::test]
async fn test_concurrency_limit() {
let registry = OperationRegistry::with_max_concurrent(2);
let (op_id1, _) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
let (_op_id2, _) = registry
.register("req-2".to_string(), "parse.start".to_string())
.await
.unwrap();
let result = registry
.register("req-3".to_string(), "parse.start".to_string())
.await;
assert!(matches!(
result,
Err(RegistryError::ConcurrencyLimitReached)
));
registry.complete(&op_id1).await.unwrap();
let result = registry
.register("req-3".to_string(), "parse.start".to_string())
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_complete_operation() {
let registry = OperationRegistry::new();
let (op_id, _) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
assert_eq!(registry.running_count().await, 1);
registry.complete(&op_id).await.unwrap();
assert_eq!(registry.running_count().await, 0);
}
#[tokio::test]
async fn test_fail_operation() {
let registry = OperationRegistry::new();
let (op_id, _) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
registry.fail(&op_id).await.unwrap();
let entry = registry.get(&op_id).await.unwrap();
let entry = entry.lock().await;
assert_eq!(entry.status, OperationStatus::Failed);
}
#[tokio::test]
async fn test_remove_operation() {
let registry = OperationRegistry::new();
let (op_id, _) = registry
.register("req-1".to_string(), "parse.start".to_string())
.await
.unwrap();
assert!(registry.get(&op_id).await.is_some());
registry.remove(&op_id).await;
assert!(registry.get(&op_id).await.is_none());
}
}