claude_code_acp/session/
prompt_manager.rs1use dashmap::DashMap;
8use tokio::task::JoinHandle;
9use std::time::Instant;
10
11pub type PromptId = String;
13
14#[derive(Debug)]
18pub struct PromptTask {
19 pub id: PromptId,
21 pub handle: JoinHandle<()>,
23 pub cancel_token: tokio_util::sync::CancellationToken,
25 pub created_at: Instant,
27 pub session_id: String,
29}
30
31#[derive(Debug)]
38pub struct PromptManager {
39 active_prompts: DashMap<String, PromptTask>,
42}
43
44impl Default for PromptManager {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl PromptManager {
51 pub fn new() -> Self {
53 Self {
54 active_prompts: DashMap::new(),
55 }
56 }
57
58 pub async fn cancel_session_prompt(&self, session_id: &str) -> bool {
68 use tokio::time::{timeout, Duration};
69
70 const CANCEL_TIMEOUT: Duration = Duration::from_secs(5);
71
72 if let Some((_, task)) = self.active_prompts.remove(session_id) {
74 tracing::info!(
75 session_id = %session_id,
76 prompt_id = %task.id,
77 "Cancelling previous prompt"
78 );
79
80 task.cancel_token.cancel();
82
83 let timeout_result = timeout(CANCEL_TIMEOUT, task.handle).await;
85
86 match timeout_result {
87 Ok(Ok(())) => {
88 tracing::info!("Previous prompt cancelled gracefully");
89 true
90 }
91 Ok(Err(e)) => {
92 tracing::warn!(error = ?e, "Previous prompt task failed");
93 true }
95 Err(_) => {
96 tracing::warn!(
97 "Previous prompt did not complete in {:?}, continuing anyway",
98 CANCEL_TIMEOUT
99 );
100 false }
102 }
103 } else {
104 false }
106 }
107
108 pub fn register_prompt(
113 &self,
114 session_id: String,
115 handle: JoinHandle<()>,
116 cancel_token: tokio_util::sync::CancellationToken,
117 ) -> PromptId {
118 let prompt_id = format!("{}-{}", session_id, uuid::Uuid::new_v4());
120
121 let task = PromptTask {
122 id: prompt_id.clone(),
123 handle,
124 cancel_token,
125 created_at: Instant::now(),
126 session_id: session_id.clone(),
127 };
128
129 self.active_prompts.insert(session_id.clone(), task);
131
132 tracing::info!(
133 session_id = %session_id,
134 prompt_id = %prompt_id,
135 "Registered new prompt task"
136 );
137
138 prompt_id
139 }
140
141 pub fn complete_prompt(&self, session_id: &str, prompt_id: &str) {
146 if let Some((_, task)) = self.active_prompts.remove(session_id) {
149 if task.id != prompt_id {
150 self.active_prompts.insert(session_id.to_string(), task);
152 return;
153 }
154 }
155
156 tracing::info!(
157 session_id = %session_id,
158 prompt_id = %prompt_id,
159 "Completed prompt task"
160 );
161 }
162
163 pub fn active_count(&self) -> usize {
165 self.active_prompts.len()
166 }
167
168 pub fn has_active_prompt(&self, session_id: &str) -> bool {
170 self.active_prompts.contains_key(session_id)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::time::Duration;
178 use tokio::time::sleep;
179
180 #[test]
181 fn test_prompt_manager_default() {
182 let manager = PromptManager::new();
183 assert_eq!(manager.active_count(), 0);
184 assert!(!manager.has_active_prompt("test-session"));
185 }
186
187 #[tokio::test]
188 async fn test_register_prompt() {
189 let manager = PromptManager::new();
190 let cancel_token = tokio_util::sync::CancellationToken::new();
191
192 let handle = tokio::spawn(async move {
194 });
196
197 let prompt_id = manager.register_prompt(
198 "test-session".to_string(),
199 handle,
200 cancel_token,
201 );
202
203 assert!(prompt_id.starts_with("test-session-"));
204 assert_eq!(manager.active_count(), 1);
205 assert!(manager.has_active_prompt("test-session"));
206
207 manager.complete_prompt("test-session", &prompt_id);
209 assert_eq!(manager.active_count(), 0);
210 }
211
212 #[tokio::test]
213 async fn test_cancel_session_prompt() {
214 let manager = PromptManager::new();
215 let cancel_token = tokio_util::sync::CancellationToken::new();
216 let cancel_token_clone = cancel_token.clone();
217
218 let handle = tokio::spawn(async move {
220 tokio::select! {
221 _ = cancel_token_clone.cancelled() => {
222 }
224 _ = sleep(Duration::from_secs(10)) => {
225 }
227 }
228 });
229
230 manager.register_prompt(
231 "test-session".to_string(),
232 handle,
233 cancel_token,
234 );
235
236 let cancelled = manager.cancel_session_prompt("test-session").await;
238 assert!(cancelled);
239 assert_eq!(manager.active_count(), 0);
240 }
241
242 #[tokio::test]
243 async fn test_cancel_nonexistent_prompt() {
244 let manager = PromptManager::new();
245 let cancelled = manager.cancel_session_prompt("nonexistent").await;
246 assert!(!cancelled);
247 }
248
249 #[tokio::test]
250 async fn test_complete_prompt_only_if_id_matches() {
251 let manager = PromptManager::new();
252 let cancel_token = tokio_util::sync::CancellationToken::new();
253
254 let handle = tokio::spawn(async move {
255 sleep(Duration::from_millis(100)).await;
256 });
257
258 let session_id = "test-session";
259 let prompt_id = manager.register_prompt(
260 session_id.to_string(),
261 handle,
262 cancel_token,
263 );
264
265 manager.complete_prompt(session_id, "wrong-id");
267 assert!(manager.has_active_prompt(session_id));
269
270 manager.complete_prompt(session_id, &prompt_id);
272 assert!(!manager.has_active_prompt(session_id));
274 }
275
276 #[tokio::test]
277 async fn test_new_prompt_replaces_old() {
278 let manager = PromptManager::new();
279
280 let cancel_token1 = tokio_util::sync::CancellationToken::new();
282 let handle1 = tokio::spawn(async move {
283 sleep(Duration::from_millis(100)).await;
284 });
285
286 let session_id = "test-session";
287 manager.register_prompt(
288 session_id.to_string(),
289 handle1,
290 cancel_token1,
291 );
292
293 assert_eq!(manager.active_count(), 1);
294
295 let cancel_token2 = tokio_util::sync::CancellationToken::new();
297 let handle2 = tokio::spawn(async move {
298 });
300
301 manager.register_prompt(
302 session_id.to_string(),
303 handle2,
304 cancel_token2,
305 );
306
307 assert_eq!(manager.active_count(), 1);
309 }
310}