agent_core_runtime/controller/tools/
user_interaction.rs1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use tokio::sync::{oneshot, Mutex, mpsc};
10
11use super::ask_user_questions::{AskUserQuestionsRequest, AskUserQuestionsResponse};
12use crate::controller::types::{ControllerEvent, TurnId};
13
14const PENDING_CLEANUP_THRESHOLD: usize = 50;
16
17const PENDING_MAX_AGE: Duration = Duration::from_secs(300);
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum UserInteractionError {
23 NotFound,
25 AlreadyResponded,
27 SendFailed,
29 EventSendFailed,
31}
32
33impl std::fmt::Display for UserInteractionError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 UserInteractionError::NotFound => write!(f, "No pending interaction found"),
37 UserInteractionError::AlreadyResponded => write!(f, "Interaction already responded to"),
38 UserInteractionError::SendFailed => write!(f, "Failed to send response"),
39 UserInteractionError::EventSendFailed => write!(f, "Failed to send event notification"),
40 }
41 }
42}
43
44impl std::error::Error for UserInteractionError {}
45
46#[derive(Debug, Clone)]
48pub struct PendingQuestionInfo {
49 pub tool_use_id: String,
51 pub session_id: i64,
53 pub request: AskUserQuestionsRequest,
55 pub turn_id: Option<TurnId>,
57}
58
59struct PendingInteraction {
61 session_id: i64,
62 request: AskUserQuestionsRequest,
63 turn_id: Option<TurnId>,
64 responder: oneshot::Sender<AskUserQuestionsResponse>,
65 created_at: Instant,
66}
67
68pub struct UserInteractionRegistry {
73 pending: Mutex<HashMap<String, PendingInteraction>>,
75 event_tx: mpsc::Sender<ControllerEvent>,
77}
78
79impl UserInteractionRegistry {
80 pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
85 Self {
86 pending: Mutex::new(HashMap::new()),
87 event_tx,
88 }
89 }
90
91 pub async fn register(
105 &self,
106 tool_use_id: String,
107 session_id: i64,
108 request: AskUserQuestionsRequest,
109 turn_id: Option<TurnId>,
110 ) -> Result<oneshot::Receiver<AskUserQuestionsResponse>, UserInteractionError> {
111 let (tx, rx) = oneshot::channel();
112
113 {
115 let mut pending = self.pending.lock().await;
116
117 if pending.len() >= PENDING_CLEANUP_THRESHOLD {
119 let now = Instant::now();
120 pending.retain(|id, interaction| {
121 let keep = now.duration_since(interaction.created_at) < PENDING_MAX_AGE;
122 if !keep {
123 tracing::warn!(
124 tool_use_id = %id,
125 age_secs = now.duration_since(interaction.created_at).as_secs(),
126 "Cleaning up stale pending user interaction"
127 );
128 }
129 keep
130 });
131 }
132
133 pending.insert(
134 tool_use_id.clone(),
135 PendingInteraction {
136 session_id,
137 request: request.clone(),
138 turn_id: turn_id.clone(),
139 responder: tx,
140 created_at: Instant::now(),
141 },
142 );
143 }
144
145 self.event_tx
147 .send(ControllerEvent::UserInteractionRequired {
148 session_id,
149 tool_use_id,
150 request,
151 turn_id,
152 })
153 .await
154 .map_err(|_| UserInteractionError::EventSendFailed)?;
155
156 Ok(rx)
157 }
158
159 pub async fn respond(
170 &self,
171 tool_use_id: &str,
172 response: AskUserQuestionsResponse,
173 ) -> Result<(), UserInteractionError> {
174 let interaction = {
175 let mut pending = self.pending.lock().await;
176 pending
177 .remove(tool_use_id)
178 .ok_or(UserInteractionError::NotFound)?
179 };
180
181 interaction
182 .responder
183 .send(response)
184 .map_err(|_| UserInteractionError::SendFailed)
185 }
186
187 pub async fn cancel(&self, tool_use_id: &str) -> Result<(), UserInteractionError> {
198 let mut pending = self.pending.lock().await;
199 if pending.remove(tool_use_id).is_some() {
200 Ok(())
203 } else {
204 Err(UserInteractionError::NotFound)
205 }
206 }
207
208 pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingQuestionInfo> {
219 let pending = self.pending.lock().await;
220 pending
221 .iter()
222 .filter(|(_, interaction)| interaction.session_id == session_id)
223 .map(|(tool_use_id, interaction)| PendingQuestionInfo {
224 tool_use_id: tool_use_id.clone(),
225 session_id: interaction.session_id,
226 request: interaction.request.clone(),
227 turn_id: interaction.turn_id.clone(),
228 })
229 .collect()
230 }
231
232 pub async fn cancel_session(&self, session_id: i64) {
240 let mut pending = self.pending.lock().await;
241 pending.retain(|_, interaction| interaction.session_id != session_id);
242 }
244
245 pub async fn has_pending(&self, session_id: i64) -> bool {
253 let pending = self.pending.lock().await;
254 pending
255 .values()
256 .any(|interaction| interaction.session_id == session_id)
257 }
258
259 pub async fn pending_count(&self) -> usize {
261 let pending = self.pending.lock().await;
262 pending.len()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::controller::tools::ask_user_questions::{Answer, Question};
270
271 fn create_test_request() -> AskUserQuestionsRequest {
272 AskUserQuestionsRequest {
273 questions: vec![Question::SingleChoice {
274 text: "Which option?".to_string(),
275 choices: vec!["Option A".to_string(), "Option B".to_string()],
276 required: true,
277 }],
278 }
279 }
280
281 fn create_test_response() -> AskUserQuestionsResponse {
282 AskUserQuestionsResponse {
283 answers: vec![Answer {
284 question: "Which option?".to_string(),
285 answer: vec!["Option A".to_string()],
286 }],
287 }
288 }
289
290 #[tokio::test]
291 async fn test_register_and_respond() {
292 let (event_tx, mut event_rx) = mpsc::channel(10);
293 let registry = UserInteractionRegistry::new(event_tx);
294
295 let request = create_test_request();
296 let response = create_test_response();
297
298 let rx = registry
300 .register("tool_123".to_string(), 1, request.clone(), None)
301 .await
302 .unwrap();
303
304 let event = event_rx.recv().await.unwrap();
306 if let ControllerEvent::UserInteractionRequired {
307 session_id,
308 tool_use_id,
309 ..
310 } = event
311 {
312 assert_eq!(session_id, 1);
313 assert_eq!(tool_use_id, "tool_123");
314 } else {
315 panic!("Expected UserInteractionRequired event");
316 }
317
318 registry
320 .respond("tool_123", response.clone())
321 .await
322 .unwrap();
323
324 let received = rx.await.unwrap();
326 assert_eq!(received.answers.len(), 1);
327 }
328
329 #[tokio::test]
330 async fn test_respond_not_found() {
331 let (event_tx, _event_rx) = mpsc::channel(10);
332 let registry = UserInteractionRegistry::new(event_tx);
333
334 let response = create_test_response();
335 let result = registry.respond("nonexistent", response).await;
336
337 assert_eq!(result, Err(UserInteractionError::NotFound));
338 }
339
340 #[tokio::test]
341 async fn test_pending_for_session() {
342 let (event_tx, _event_rx) = mpsc::channel(10);
343 let registry = UserInteractionRegistry::new(event_tx);
344
345 let request = create_test_request();
346
347 let _ = registry
349 .register("tool_1".to_string(), 1, request.clone(), None)
350 .await;
351 let _ = registry
352 .register("tool_2".to_string(), 1, request.clone(), None)
353 .await;
354 let _ = registry
355 .register("tool_3".to_string(), 2, request.clone(), None)
356 .await;
357
358 let pending = registry.pending_for_session(1).await;
360 assert_eq!(pending.len(), 2);
361
362 let pending = registry.pending_for_session(2).await;
364 assert_eq!(pending.len(), 1);
365
366 let pending = registry.pending_for_session(999).await;
368 assert_eq!(pending.len(), 0);
369 }
370
371 #[tokio::test]
372 async fn test_cancel_session() {
373 let (event_tx, _event_rx) = mpsc::channel(10);
374 let registry = UserInteractionRegistry::new(event_tx);
375
376 let request = create_test_request();
377
378 let rx1 = registry
380 .register("tool_1".to_string(), 1, request.clone(), None)
381 .await
382 .unwrap();
383 let _ = registry
384 .register("tool_2".to_string(), 2, request.clone(), None)
385 .await;
386
387 registry.cancel_session(1).await;
389
390 assert!(!registry.has_pending(1).await);
392 assert!(registry.has_pending(2).await);
393
394 assert!(rx1.await.is_err());
396 }
397
398 #[tokio::test]
399 async fn test_has_pending() {
400 let (event_tx, _event_rx) = mpsc::channel(10);
401 let registry = UserInteractionRegistry::new(event_tx);
402
403 assert!(!registry.has_pending(1).await);
404
405 let request = create_test_request();
406 let _ = registry
407 .register("tool_1".to_string(), 1, request, None)
408 .await;
409
410 assert!(registry.has_pending(1).await);
411 assert!(!registry.has_pending(2).await);
412 }
413
414 #[tokio::test]
415 async fn test_pending_count() {
416 let (event_tx, _event_rx) = mpsc::channel(10);
417 let registry = UserInteractionRegistry::new(event_tx);
418
419 assert_eq!(registry.pending_count().await, 0);
420
421 let request = create_test_request();
422 let _ = registry
423 .register("tool_1".to_string(), 1, request.clone(), None)
424 .await;
425 assert_eq!(registry.pending_count().await, 1);
426
427 let _ = registry
428 .register("tool_2".to_string(), 1, request, None)
429 .await;
430 assert_eq!(registry.pending_count().await, 2);
431 }
432}