use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use chrono::{DateTime, Utc};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
const DEDUP_WINDOW_SECS: i64 = 120;
#[derive(Clone, Debug)]
pub struct QueuedMessage {
pub text: String,
#[allow(dead_code)]
pub queued_at: DateTime<Utc>,
}
#[derive(Clone, Debug)]
pub enum TaskStatus {
Running,
Completed,
Failed(String),
Cancelled,
}
impl std::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskStatus::Running => write!(f, "Running"),
TaskStatus::Completed => write!(f, "Completed"),
TaskStatus::Failed(e) => write!(f, "Failed: {}", e),
TaskStatus::Cancelled => write!(f, "Cancelled"),
}
}
}
#[derive(Clone, Debug)]
pub struct TaskEntry {
pub id: u64,
pub session_id: String,
pub description: String,
pub status: TaskStatus,
pub started_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>,
}
struct TaskHandle {
entry: TaskEntry,
cancel_token: CancellationToken,
typing_cancel: Option<CancellationToken>,
}
pub struct TaskRegistry {
tasks: RwLock<HashMap<u64, TaskHandle>>,
next_id: AtomicU64,
max_completed: usize,
queues: RwLock<HashMap<String, VecDeque<QueuedMessage>>>,
recently_seen: RwLock<HashMap<(String, u64), DateTime<Utc>>>,
}
impl TaskRegistry {
pub fn new(max_completed: usize) -> Self {
Self {
tasks: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(1),
max_completed,
queues: RwLock::new(HashMap::new()),
recently_seen: RwLock::new(HashMap::new()),
}
}
fn text_hash(text: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
pub async fn register(&self, session_id: &str, description: &str) -> (u64, CancellationToken) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let cancel_token = CancellationToken::new();
let handle = TaskHandle {
entry: TaskEntry {
id,
session_id: session_id.to_string(),
description: description.to_string(),
status: TaskStatus::Running,
started_at: Utc::now(),
finished_at: None,
},
cancel_token: cancel_token.clone(),
typing_cancel: None,
};
let mut tasks = self.tasks.write().await;
tasks.insert(id, handle);
(id, cancel_token)
}
pub async fn complete(&self, task_id: u64) {
let mut tasks = self.tasks.write().await;
if let Some(handle) = tasks.get_mut(&task_id) {
handle.entry.status = TaskStatus::Completed;
handle.entry.finished_at = Some(Utc::now());
}
Self::cleanup_locked(&mut tasks, self.max_completed);
}
pub async fn fail(&self, task_id: u64, error: &str) {
let mut tasks = self.tasks.write().await;
if let Some(handle) = tasks.get_mut(&task_id) {
handle.entry.status = TaskStatus::Failed(error.to_string());
handle.entry.finished_at = Some(Utc::now());
}
Self::cleanup_locked(&mut tasks, self.max_completed);
}
pub async fn set_typing_cancel(&self, task_id: u64, token: CancellationToken) {
let mut tasks = self.tasks.write().await;
if let Some(handle) = tasks.get_mut(&task_id) {
handle.typing_cancel = Some(token);
}
}
pub async fn cancel(&self, task_id: u64) -> bool {
let mut tasks = self.tasks.write().await;
if let Some(handle) = tasks.get_mut(&task_id) {
if matches!(handle.entry.status, TaskStatus::Running) {
handle.cancel_token.cancel();
if let Some(ref typing) = handle.typing_cancel {
typing.cancel();
}
handle.entry.status = TaskStatus::Cancelled;
handle.entry.finished_at = Some(Utc::now());
return true;
}
}
false
}
pub async fn cancel_running_for_session(&self, session_id: &str) -> Vec<(u64, String)> {
let mut tasks = self.tasks.write().await;
let mut cancelled = Vec::new();
for (id, handle) in tasks.iter_mut() {
if handle.entry.session_id == session_id
&& matches!(handle.entry.status, TaskStatus::Running)
{
handle.cancel_token.cancel();
if let Some(ref typing) = handle.typing_cancel {
typing.cancel();
}
handle.entry.status = TaskStatus::Cancelled;
handle.entry.finished_at = Some(Utc::now());
cancelled.push((*id, handle.entry.description.clone()));
}
}
cancelled
}
pub async fn list_for_session(&self, session_id: &str) -> Vec<TaskEntry> {
let tasks = self.tasks.read().await;
let mut entries: Vec<TaskEntry> = tasks
.values()
.filter(|h| h.entry.session_id == session_id)
.map(|h| h.entry.clone())
.collect();
entries.sort_by_key(|e| e.id);
entries
}
fn cleanup_locked(tasks: &mut HashMap<u64, TaskHandle>, max_completed: usize) {
let mut finished: Vec<u64> = tasks
.iter()
.filter(|(_, h)| !matches!(h.entry.status, TaskStatus::Running))
.map(|(&id, _)| id)
.collect();
if finished.len() <= max_completed {
return;
}
finished.sort();
let to_remove = finished.len() - max_completed;
for &id in finished.iter().take(to_remove) {
tasks.remove(&id);
}
}
pub async fn has_running_task(&self, session_id: &str) -> bool {
let tasks = self.tasks.read().await;
tasks.values().any(|h| {
h.entry.session_id == session_id && matches!(h.entry.status, TaskStatus::Running)
})
}
pub async fn get_running_task_description(&self, session_id: &str) -> Option<String> {
let tasks = self.tasks.read().await;
tasks
.values()
.find(|h| {
h.entry.session_id == session_id && matches!(h.entry.status, TaskStatus::Running)
})
.map(|h| h.entry.description.clone())
}
pub async fn mark_seen(&self, session_id: &str, text: &str) -> bool {
let now = Utc::now();
let hash = Self::text_hash(text);
let key = (session_id.to_string(), hash);
let mut seen = self.recently_seen.write().await;
seen.retain(|_, ts| (now - *ts).num_seconds() < DEDUP_WINDOW_SECS);
use std::collections::hash_map::Entry;
match seen.entry(key) {
Entry::Occupied(_) => false,
Entry::Vacant(e) => {
e.insert(now);
true
}
}
}
pub async fn queue_message(&self, session_id: &str, text: &str) -> Option<usize> {
let now = Utc::now();
let hash = Self::text_hash(text);
let key = (session_id.to_string(), hash);
{
let mut seen = self.recently_seen.write().await;
seen.retain(|_, ts| (now - *ts).num_seconds() < DEDUP_WINDOW_SECS);
if let Some(first_seen) = seen.get(&key) {
if (now - *first_seen).num_seconds() < DEDUP_WINDOW_SECS {
return None; }
}
seen.insert(key, now);
}
let mut queues = self.queues.write().await;
let queue = queues.entry(session_id.to_string()).or_default();
queue.push_back(QueuedMessage {
text: text.to_string(),
queued_at: now,
});
Some(queue.len())
}
#[allow(dead_code)] pub async fn pop_queued_message(&self, session_id: &str) -> Option<QueuedMessage> {
let mut queues = self.queues.write().await;
queues.get_mut(session_id).and_then(|q| q.pop_front())
}
pub async fn pop_all_queued_messages(&self, session_id: &str) -> Option<QueuedMessage> {
let mut queues = self.queues.write().await;
let queue = queues.get_mut(session_id)?;
if queue.is_empty() {
return None;
}
let first = queue.pop_front().unwrap();
if queue.is_empty() {
return Some(first);
}
let mut combined = first.text;
while let Some(msg) = queue.pop_front() {
combined.push('\n');
combined.push_str(&msg.text);
}
Some(QueuedMessage {
text: combined,
queued_at: first.queued_at,
})
}
pub async fn queue_len(&self, session_id: &str) -> usize {
let queues = self.queues.read().await;
queues.get(session_id).map(|q| q.len()).unwrap_or(0)
}
pub async fn clear_queue(&self, session_id: &str) {
let mut queues = self.queues.write().await;
if let Some(queue) = queues.get_mut(session_id) {
queue.clear();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_queue_deduplicates_identical_messages() {
let registry = TaskRegistry::new(10);
let session = "test-session";
let result = registry.queue_message(session, "hello world").await;
assert_eq!(result, Some(1));
let result = registry.queue_message(session, "hello world").await;
assert_eq!(result, None);
let result = registry.queue_message(session, "different message").await;
assert_eq!(result, Some(2));
assert_eq!(registry.queue_len(session).await, 2);
}
#[tokio::test]
async fn test_queue_allows_same_message_different_sessions() {
let registry = TaskRegistry::new(10);
let result = registry.queue_message("session-a", "hello").await;
assert_eq!(result, Some(1));
let result = registry.queue_message("session-b", "hello").await;
assert_eq!(result, Some(1));
}
#[tokio::test]
async fn test_queue_deduplicates_after_pop() {
let registry = TaskRegistry::new(10);
let session = "test-session";
let result = registry.queue_message(session, "hello").await;
assert_eq!(result, Some(1));
let popped = registry.pop_queued_message(session).await;
assert!(popped.is_some());
assert_eq!(popped.unwrap().text, "hello");
let result = registry.queue_message(session, "hello").await;
assert_eq!(result, None);
}
#[tokio::test]
async fn test_mark_seen_prevents_duplicates() {
let registry = TaskRegistry::new(10);
let session = "test-session";
assert!(registry.mark_seen(session, "hello world").await);
assert!(!registry.mark_seen(session, "hello world").await);
assert!(registry.mark_seen(session, "different text").await);
assert!(registry.mark_seen("other-session", "hello world").await);
}
#[tokio::test]
async fn test_mark_seen_blocks_subsequent_queue() {
let registry = TaskRegistry::new(10);
let session = "test-session";
assert!(registry.mark_seen(session, "hello").await);
let result = registry.queue_message(session, "hello").await;
assert_eq!(result, None);
}
#[tokio::test]
async fn test_pop_all_coalesces_fragments() {
let registry = TaskRegistry::new(10);
let session = "test-session";
registry.queue_message(session, "Part 1: Hello").await;
registry.queue_message(session, "Part 2: World").await;
registry.queue_message(session, "Part 3: How are").await;
registry.queue_message(session, "Part 4: you?").await;
assert_eq!(registry.queue_len(session).await, 4);
let coalesced = registry.pop_all_queued_messages(session).await;
assert!(coalesced.is_some());
let msg = coalesced.unwrap();
assert_eq!(
msg.text,
"Part 1: Hello\nPart 2: World\nPart 3: How are\nPart 4: you?"
);
assert_eq!(registry.queue_len(session).await, 0);
assert!(registry.pop_all_queued_messages(session).await.is_none());
}
#[tokio::test]
async fn test_pop_all_single_message() {
let registry = TaskRegistry::new(10);
let session = "test-session";
registry.queue_message(session, "only one").await;
let result = registry.pop_all_queued_messages(session).await;
assert!(result.is_some());
assert_eq!(result.unwrap().text, "only one");
}
}