use dashmap::DashMap;
use std::time::Instant;
use tokio::task::JoinHandle;
pub type PromptId = String;
#[derive(Debug)]
pub struct PromptTask {
pub id: PromptId,
pub handle: JoinHandle<()>,
pub cancel_token: tokio_util::sync::CancellationToken,
pub created_at: Instant,
pub session_id: String,
}
#[derive(Debug)]
pub struct PromptManager {
active_prompts: DashMap<String, PromptTask>,
}
impl Default for PromptManager {
fn default() -> Self {
Self::new()
}
}
impl PromptManager {
pub fn new() -> Self {
Self {
active_prompts: DashMap::new(),
}
}
pub async fn cancel_session_prompt(&self, session_id: &str) -> bool {
use tokio::time::{Duration, timeout};
const CANCEL_TIMEOUT: Duration = Duration::from_secs(5);
if let Some((_, task)) = self.active_prompts.remove(session_id) {
tracing::info!(
session_id = %session_id,
prompt_id = %task.id,
"Cancelling previous prompt"
);
task.cancel_token.cancel();
let timeout_result = timeout(CANCEL_TIMEOUT, task.handle).await;
match timeout_result {
Ok(Ok(())) => {
tracing::info!("Previous prompt cancelled gracefully");
true
}
Ok(Err(e)) => {
tracing::warn!(error = ?e, "Previous prompt task failed");
true }
Err(_) => {
tracing::warn!(
"Previous prompt did not complete in {:?}, continuing anyway",
CANCEL_TIMEOUT
);
false }
}
} else {
false }
}
pub fn register_prompt(
&self,
session_id: String,
handle: JoinHandle<()>,
cancel_token: tokio_util::sync::CancellationToken,
) -> PromptId {
let prompt_id = format!("{}-{}", session_id, uuid::Uuid::new_v4());
let task = PromptTask {
id: prompt_id.clone(),
handle,
cancel_token,
created_at: Instant::now(),
session_id: session_id.clone(),
};
self.active_prompts.insert(session_id.clone(), task);
tracing::info!(
session_id = %session_id,
prompt_id = %prompt_id,
"Registered new prompt task"
);
prompt_id
}
pub fn complete_prompt(&self, session_id: &str, prompt_id: &str) {
if let Some((_, task)) = self.active_prompts.remove(session_id) {
if task.id != prompt_id {
self.active_prompts.insert(session_id.to_string(), task);
return;
}
}
tracing::info!(
session_id = %session_id,
prompt_id = %prompt_id,
"Completed prompt task"
);
}
pub fn active_count(&self) -> usize {
self.active_prompts.len()
}
pub fn has_active_prompt(&self, session_id: &str) -> bool {
self.active_prompts.contains_key(session_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
#[test]
fn test_prompt_manager_default() {
let manager = PromptManager::new();
assert_eq!(manager.active_count(), 0);
assert!(!manager.has_active_prompt("test-session"));
}
#[tokio::test]
async fn test_register_prompt() {
let manager = PromptManager::new();
let cancel_token = tokio_util::sync::CancellationToken::new();
let handle = tokio::spawn(async move {
});
let prompt_id = manager.register_prompt("test-session".to_string(), handle, cancel_token);
assert!(prompt_id.starts_with("test-session-"));
assert_eq!(manager.active_count(), 1);
assert!(manager.has_active_prompt("test-session"));
manager.complete_prompt("test-session", &prompt_id);
assert_eq!(manager.active_count(), 0);
}
#[tokio::test]
async fn test_cancel_session_prompt() {
let manager = PromptManager::new();
let cancel_token = tokio_util::sync::CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
let handle = tokio::spawn(async move {
tokio::select! {
() = cancel_token_clone.cancelled() => {
}
() = sleep(Duration::from_secs(10)) => {
}
}
});
manager.register_prompt("test-session".to_string(), handle, cancel_token);
let cancelled = manager.cancel_session_prompt("test-session").await;
assert!(cancelled);
assert_eq!(manager.active_count(), 0);
}
#[tokio::test]
async fn test_cancel_nonexistent_prompt() {
let manager = PromptManager::new();
let cancelled = manager.cancel_session_prompt("nonexistent").await;
assert!(!cancelled);
}
#[tokio::test]
async fn test_complete_prompt_only_if_id_matches() {
let manager = PromptManager::new();
let cancel_token = tokio_util::sync::CancellationToken::new();
let handle = tokio::spawn(async move {
sleep(Duration::from_millis(100)).await;
});
let session_id = "test-session";
let prompt_id = manager.register_prompt(session_id.to_string(), handle, cancel_token);
manager.complete_prompt(session_id, "wrong-id");
assert!(manager.has_active_prompt(session_id));
manager.complete_prompt(session_id, &prompt_id);
assert!(!manager.has_active_prompt(session_id));
}
#[tokio::test]
async fn test_new_prompt_replaces_old() {
let manager = PromptManager::new();
let cancel_token1 = tokio_util::sync::CancellationToken::new();
let handle1 = tokio::spawn(async move {
sleep(Duration::from_millis(100)).await;
});
let session_id = "test-session";
manager.register_prompt(session_id.to_string(), handle1, cancel_token1);
assert_eq!(manager.active_count(), 1);
let cancel_token2 = tokio_util::sync::CancellationToken::new();
let handle2 = tokio::spawn(async move {
});
manager.register_prompt(session_id.to_string(), handle2, cancel_token2);
assert_eq!(manager.active_count(), 1);
}
}