claude_code_acp/session/
prompt_manager.rs

1//! Prompt task management
2//!
3//! This module provides the PromptManager for tracking and cancelling
4//! active prompts per session. It ensures that only one prompt runs at
5//! a time per session, and automatically cancels old prompts when new ones arrive.
6
7use dashmap::DashMap;
8use tokio::task::JoinHandle;
9use std::time::Instant;
10
11/// Prompt task identifier
12pub type PromptId = String;
13
14/// Prompt task wrapper
15///
16/// Contains all the information needed to track and cancel a prompt task.
17#[derive(Debug)]
18pub struct PromptTask {
19    /// Unique identifier for this prompt
20    pub id: PromptId,
21    /// JoinHandle for the task (used to wait for completion)
22    pub handle: JoinHandle<()>,
23    /// Cancellation token (used to signal cancellation)
24    pub cancel_token: tokio_util::sync::CancellationToken,
25    /// When this prompt was created
26    pub created_at: Instant,
27    /// Which session this prompt belongs to
28    pub session_id: String,
29}
30
31/// Prompt manager
32///
33/// Tracks active prompts per session and ensures serialization:
34/// - Only one prompt can run at a time per session
35/// - New prompts automatically cancel old prompts
36/// - Provides timeout protection for cancellation
37#[derive(Debug)]
38pub struct PromptManager {
39    /// Map of session_id -> PromptTask
40    /// Using DashMap for concurrent access without blocking
41    active_prompts: DashMap<String, PromptTask>,
42}
43
44impl Default for PromptManager {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl PromptManager {
51    /// Create a new prompt manager
52    pub fn new() -> Self {
53        Self {
54            active_prompts: DashMap::new(),
55        }
56    }
57
58    /// Cancel any active prompt for the given session
59    ///
60    /// This will:
61    /// 1. Send a cancellation signal via the token
62    /// 2. Wait for the task to complete (with 5 second timeout)
63    /// 3. Remove the task from tracking
64    ///
65    /// Returns `true` if an old prompt was cancelled, `false` if there was
66    /// no active prompt for this session.
67    pub async fn cancel_session_prompt(&self, session_id: &str) -> bool {
68        use tokio::time::{timeout, Duration};
69
70        const CANCEL_TIMEOUT: Duration = Duration::from_secs(5);
71
72        // Remove the old prompt task from the map
73        if let Some((_, task)) = self.active_prompts.remove(session_id) {
74            tracing::info!(
75                session_id = %session_id,
76                prompt_id = %task.id,
77                "Cancelling previous prompt"
78            );
79
80            // Send cancellation signal
81            task.cancel_token.cancel();
82
83            // Wait for task to complete (with timeout)
84            let timeout_result = timeout(CANCEL_TIMEOUT, task.handle).await;
85
86            match timeout_result {
87                Ok(Ok(())) => {
88                    tracing::info!("Previous prompt cancelled gracefully");
89                    true
90                }
91                Ok(Err(e)) => {
92                    tracing::warn!(error = ?e, "Previous prompt task failed");
93                    true // Task is done, even if it failed
94                }
95                Err(_) => {
96                    tracing::warn!(
97                        "Previous prompt did not complete in {:?}, continuing anyway",
98                        CANCEL_TIMEOUT
99                    );
100                    false // Task didn't complete in time
101                }
102            }
103        } else {
104            false // No active prompt for this session
105        }
106    }
107
108    /// Register a new prompt task
109    ///
110    /// This should be called after spawning a prompt task.
111    /// The prompt will be tracked and can be cancelled later.
112    pub fn register_prompt(
113        &self,
114        session_id: String,
115        handle: JoinHandle<()>,
116        cancel_token: tokio_util::sync::CancellationToken,
117    ) -> PromptId {
118        // Generate a unique prompt ID
119        let prompt_id = format!("{}-{}", session_id, uuid::Uuid::new_v4());
120
121        let task = PromptTask {
122            id: prompt_id.clone(),
123            handle,
124            cancel_token,
125            created_at: Instant::now(),
126            session_id: session_id.clone(),
127        };
128
129        // Insert into the map (this will replace any existing prompt for this session)
130        self.active_prompts.insert(session_id.clone(), task);
131
132        tracing::info!(
133            session_id = %session_id,
134            prompt_id = %prompt_id,
135            "Registered new prompt task"
136        );
137
138        prompt_id
139    }
140
141    /// Mark a prompt as completed
142    ///
143    /// This should be called when a prompt finishes normally (not cancelled).
144    /// It removes the prompt from tracking if the prompt_id matches.
145    pub fn complete_prompt(&self, session_id: &str, prompt_id: &str) {
146        // Only remove if the prompt ID matches
147        // Use DashMap's try_remove to check and remove atomically
148        if let Some((_, task)) = self.active_prompts.remove(session_id) {
149            if task.id != prompt_id {
150                // ID doesn't match, put it back
151                self.active_prompts.insert(session_id.to_string(), task);
152                return;
153            }
154        }
155
156        tracing::info!(
157            session_id = %session_id,
158            prompt_id = %prompt_id,
159            "Completed prompt task"
160        );
161    }
162
163    /// Get the number of active prompts
164    pub fn active_count(&self) -> usize {
165        self.active_prompts.len()
166    }
167
168    /// Check if a session has an active prompt
169    pub fn has_active_prompt(&self, session_id: &str) -> bool {
170        self.active_prompts.contains_key(session_id)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::time::Duration;
178    use tokio::time::sleep;
179
180    #[test]
181    fn test_prompt_manager_default() {
182        let manager = PromptManager::new();
183        assert_eq!(manager.active_count(), 0);
184        assert!(!manager.has_active_prompt("test-session"));
185    }
186
187    #[tokio::test]
188    async fn test_register_prompt() {
189        let manager = PromptManager::new();
190        let cancel_token = tokio_util::sync::CancellationToken::new();
191
192        // Create a simple task that completes immediately
193        let handle = tokio::spawn(async move {
194            // Task that does nothing
195        });
196
197        let prompt_id = manager.register_prompt(
198            "test-session".to_string(),
199            handle,
200            cancel_token,
201        );
202
203        assert!(prompt_id.starts_with("test-session-"));
204        assert_eq!(manager.active_count(), 1);
205        assert!(manager.has_active_prompt("test-session"));
206
207        // Clean up
208        manager.complete_prompt("test-session", &prompt_id);
209        assert_eq!(manager.active_count(), 0);
210    }
211
212    #[tokio::test]
213    async fn test_cancel_session_prompt() {
214        let manager = PromptManager::new();
215        let cancel_token = tokio_util::sync::CancellationToken::new();
216        let cancel_token_clone = cancel_token.clone();
217
218        // Create a task that waits for cancellation
219        let handle = tokio::spawn(async move {
220            tokio::select! {
221                _ = cancel_token_clone.cancelled() => {
222                    // Cancelled
223                }
224                _ = sleep(Duration::from_secs(10)) => {
225                    // Would timeout, but should be cancelled first
226                }
227            }
228        });
229
230        manager.register_prompt(
231            "test-session".to_string(),
232            handle,
233            cancel_token,
234        );
235
236        // Cancel the prompt
237        let cancelled = manager.cancel_session_prompt("test-session").await;
238        assert!(cancelled);
239        assert_eq!(manager.active_count(), 0);
240    }
241
242    #[tokio::test]
243    async fn test_cancel_nonexistent_prompt() {
244        let manager = PromptManager::new();
245        let cancelled = manager.cancel_session_prompt("nonexistent").await;
246        assert!(!cancelled);
247    }
248
249    #[tokio::test]
250    async fn test_complete_prompt_only_if_id_matches() {
251        let manager = PromptManager::new();
252        let cancel_token = tokio_util::sync::CancellationToken::new();
253
254        let handle = tokio::spawn(async move {
255            sleep(Duration::from_millis(100)).await;
256        });
257
258        let session_id = "test-session";
259        let prompt_id = manager.register_prompt(
260            session_id.to_string(),
261            handle,
262            cancel_token,
263        );
264
265        // Try to complete with wrong ID
266        manager.complete_prompt(session_id, "wrong-id");
267        // Should still be active
268        assert!(manager.has_active_prompt(session_id));
269
270        // Complete with correct ID
271        manager.complete_prompt(session_id, &prompt_id);
272        // Should be removed
273        assert!(!manager.has_active_prompt(session_id));
274    }
275
276    #[tokio::test]
277    async fn test_new_prompt_replaces_old() {
278        let manager = PromptManager::new();
279
280        // Register first prompt
281        let cancel_token1 = tokio_util::sync::CancellationToken::new();
282        let handle1 = tokio::spawn(async move {
283            sleep(Duration::from_millis(100)).await;
284        });
285
286        let session_id = "test-session";
287        manager.register_prompt(
288            session_id.to_string(),
289            handle1,
290            cancel_token1,
291        );
292
293        assert_eq!(manager.active_count(), 1);
294
295        // Register second prompt (replaces first)
296        let cancel_token2 = tokio_util::sync::CancellationToken::new();
297        let handle2 = tokio::spawn(async move {
298            // Immediate completion
299        });
300
301        manager.register_prompt(
302            session_id.to_string(),
303            handle2,
304            cancel_token2,
305        );
306
307        // Still only one active prompt
308        assert_eq!(manager.active_count(), 1);
309    }
310}