Skip to main content

agent_core_runtime/controller/tools/
user_interaction.rs

1//! User interaction registry for managing pending user questions.
2//!
3//! This module provides a registry for tools that need to wait for user input,
4//! such as the AskUserQuestions tool.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use tokio::sync::{oneshot, Mutex, mpsc};
10
11use super::ask_user_questions::{AskUserQuestionsRequest, AskUserQuestionsResponse};
12use crate::controller::types::{ControllerEvent, TurnId};
13
14/// Maximum number of pending interactions before triggering cleanup.
15const PENDING_CLEANUP_THRESHOLD: usize = 50;
16
17/// Maximum age for pending interactions before they're considered stale (5 minutes).
18const PENDING_MAX_AGE: Duration = Duration::from_secs(300);
19
20/// Error types for user interaction operations.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum UserInteractionError {
23    /// No pending interaction found for the given tool_use_id.
24    NotFound,
25    /// The interaction was already responded to.
26    AlreadyResponded,
27    /// Failed to send response (channel closed).
28    SendFailed,
29    /// Failed to send event notification.
30    EventSendFailed,
31}
32
33impl std::fmt::Display for UserInteractionError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            UserInteractionError::NotFound => write!(f, "No pending interaction found"),
37            UserInteractionError::AlreadyResponded => write!(f, "Interaction already responded to"),
38            UserInteractionError::SendFailed => write!(f, "Failed to send response"),
39            UserInteractionError::EventSendFailed => write!(f, "Failed to send event notification"),
40        }
41    }
42}
43
44impl std::error::Error for UserInteractionError {}
45
46/// Information about a pending interaction for UI display.
47#[derive(Debug, Clone)]
48pub struct PendingQuestionInfo {
49    /// Tool use ID for this interaction.
50    pub tool_use_id: String,
51    /// Session ID this interaction belongs to.
52    pub session_id: i64,
53    /// The questions being asked.
54    pub request: AskUserQuestionsRequest,
55    /// Turn ID for this interaction.
56    pub turn_id: Option<TurnId>,
57}
58
59/// Internal state for a pending interaction.
60struct PendingInteraction {
61    session_id: i64,
62    request: AskUserQuestionsRequest,
63    turn_id: Option<TurnId>,
64    responder: oneshot::Sender<AskUserQuestionsResponse>,
65    created_at: Instant,
66}
67
68/// Registry for managing pending user interactions.
69///
70/// This registry tracks tools that are blocked waiting for user input
71/// and provides methods for the UI to query and respond to these interactions.
72pub struct UserInteractionRegistry {
73    /// Pending interactions keyed by tool_use_id.
74    pending: Mutex<HashMap<String, PendingInteraction>>,
75    /// Channel to send events to the controller.
76    event_tx: mpsc::Sender<ControllerEvent>,
77}
78
79impl UserInteractionRegistry {
80    /// Create a new UserInteractionRegistry.
81    ///
82    /// # Arguments
83    /// * `event_tx` - Channel to send events when interactions are registered.
84    pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
85        Self {
86            pending: Mutex::new(HashMap::new()),
87            event_tx,
88        }
89    }
90
91    /// Register a pending interaction and get a receiver to await on.
92    ///
93    /// This is called by the AskUserQuestionsTool when it starts executing.
94    /// The tool will await on the returned receiver until the UI responds.
95    ///
96    /// # Arguments
97    /// * `tool_use_id` - Unique ID for this tool use request.
98    /// * `session_id` - Session that requested the interaction.
99    /// * `request` - The questions to ask the user.
100    /// * `turn_id` - Optional turn ID for this interaction.
101    ///
102    /// # Returns
103    /// A oneshot receiver that will receive the user's response.
104    pub async fn register(
105        &self,
106        tool_use_id: String,
107        session_id: i64,
108        request: AskUserQuestionsRequest,
109        turn_id: Option<TurnId>,
110    ) -> Result<oneshot::Receiver<AskUserQuestionsResponse>, UserInteractionError> {
111        let (tx, rx) = oneshot::channel();
112
113        // Store the pending interaction
114        {
115            let mut pending = self.pending.lock().await;
116
117            // Cleanup stale entries if map is getting large
118            if pending.len() >= PENDING_CLEANUP_THRESHOLD {
119                let now = Instant::now();
120                pending.retain(|id, interaction| {
121                    let keep = now.duration_since(interaction.created_at) < PENDING_MAX_AGE;
122                    if !keep {
123                        tracing::warn!(
124                            tool_use_id = %id,
125                            age_secs = now.duration_since(interaction.created_at).as_secs(),
126                            "Cleaning up stale pending user interaction"
127                        );
128                    }
129                    keep
130                });
131            }
132
133            pending.insert(
134                tool_use_id.clone(),
135                PendingInteraction {
136                    session_id,
137                    request: request.clone(),
138                    turn_id: turn_id.clone(),
139                    responder: tx,
140                    created_at: Instant::now(),
141                },
142            );
143        }
144
145        // Emit event to notify UI
146        self.event_tx
147            .send(ControllerEvent::UserInteractionRequired {
148                session_id,
149                tool_use_id,
150                request,
151                turn_id,
152            })
153            .await
154            .map_err(|_| UserInteractionError::EventSendFailed)?;
155
156        Ok(rx)
157    }
158
159    /// Respond to a pending interaction.
160    ///
161    /// This is called by the UI when the user has answered the questions.
162    ///
163    /// # Arguments
164    /// * `tool_use_id` - ID of the tool use to respond to.
165    /// * `response` - The user's answers.
166    ///
167    /// # Returns
168    /// Ok(()) if the response was sent successfully, or an error.
169    pub async fn respond(
170        &self,
171        tool_use_id: &str,
172        response: AskUserQuestionsResponse,
173    ) -> Result<(), UserInteractionError> {
174        let interaction = {
175            let mut pending = self.pending.lock().await;
176            pending
177                .remove(tool_use_id)
178                .ok_or(UserInteractionError::NotFound)?
179        };
180
181        interaction
182            .responder
183            .send(response)
184            .map_err(|_| UserInteractionError::SendFailed)
185    }
186
187    /// Cancel a pending interaction (user declined to answer).
188    ///
189    /// This is called by the UI when the user presses Escape to close
190    /// the question panel without answering.
191    ///
192    /// # Arguments
193    /// * `tool_use_id` - ID of the tool use to cancel.
194    ///
195    /// # Returns
196    /// Ok(()) if the interaction was found and cancelled, or NotFound error.
197    pub async fn cancel(&self, tool_use_id: &str) -> Result<(), UserInteractionError> {
198        let mut pending = self.pending.lock().await;
199        if pending.remove(tool_use_id).is_some() {
200            // Dropping the sender will cause the tool to receive a RecvError
201            // which will be converted to "User declined to answer"
202            Ok(())
203        } else {
204            Err(UserInteractionError::NotFound)
205        }
206    }
207
208    /// Get all pending interactions for a session.
209    ///
210    /// This is called by the UI when switching sessions to display
211    /// any pending questions for that session.
212    ///
213    /// # Arguments
214    /// * `session_id` - Session ID to query.
215    ///
216    /// # Returns
217    /// List of pending question information for the session.
218    pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingQuestionInfo> {
219        let pending = self.pending.lock().await;
220        pending
221            .iter()
222            .filter(|(_, interaction)| interaction.session_id == session_id)
223            .map(|(tool_use_id, interaction)| PendingQuestionInfo {
224                tool_use_id: tool_use_id.clone(),
225                session_id: interaction.session_id,
226                request: interaction.request.clone(),
227                turn_id: interaction.turn_id.clone(),
228            })
229            .collect()
230    }
231
232    /// Cancel all pending interactions for a session.
233    ///
234    /// This is called when a session is destroyed. It drops the senders,
235    /// which will cause the awaiting tools to receive a RecvError.
236    ///
237    /// # Arguments
238    /// * `session_id` - Session ID to cancel.
239    pub async fn cancel_session(&self, session_id: i64) {
240        let mut pending = self.pending.lock().await;
241        pending.retain(|_, interaction| interaction.session_id != session_id);
242        // Dropped senders will cause RecvError on the tool side
243    }
244
245    /// Check if there are any pending interactions for a session.
246    ///
247    /// # Arguments
248    /// * `session_id` - Session ID to check.
249    ///
250    /// # Returns
251    /// True if there are pending interactions.
252    pub async fn has_pending(&self, session_id: i64) -> bool {
253        let pending = self.pending.lock().await;
254        pending
255            .values()
256            .any(|interaction| interaction.session_id == session_id)
257    }
258
259    /// Get the count of pending interactions.
260    pub async fn pending_count(&self) -> usize {
261        let pending = self.pending.lock().await;
262        pending.len()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::controller::tools::ask_user_questions::{Answer, Question};
270
271    fn create_test_request() -> AskUserQuestionsRequest {
272        AskUserQuestionsRequest {
273            questions: vec![Question::SingleChoice {
274                text: "Which option?".to_string(),
275                choices: vec!["Option A".to_string(), "Option B".to_string()],
276                required: true,
277            }],
278        }
279    }
280
281    fn create_test_response() -> AskUserQuestionsResponse {
282        AskUserQuestionsResponse {
283            answers: vec![Answer {
284                question: "Which option?".to_string(),
285                answer: vec!["Option A".to_string()],
286            }],
287        }
288    }
289
290    #[tokio::test]
291    async fn test_register_and_respond() {
292        let (event_tx, mut event_rx) = mpsc::channel(10);
293        let registry = UserInteractionRegistry::new(event_tx);
294
295        let request = create_test_request();
296        let response = create_test_response();
297
298        // Register interaction
299        let rx = registry
300            .register("tool_123".to_string(), 1, request.clone(), None)
301            .await
302            .unwrap();
303
304        // Verify event was emitted
305        let event = event_rx.recv().await.unwrap();
306        if let ControllerEvent::UserInteractionRequired {
307            session_id,
308            tool_use_id,
309            ..
310        } = event
311        {
312            assert_eq!(session_id, 1);
313            assert_eq!(tool_use_id, "tool_123");
314        } else {
315            panic!("Expected UserInteractionRequired event");
316        }
317
318        // Respond to interaction
319        registry
320            .respond("tool_123", response.clone())
321            .await
322            .unwrap();
323
324        // Verify response was received
325        let received = rx.await.unwrap();
326        assert_eq!(received.answers.len(), 1);
327    }
328
329    #[tokio::test]
330    async fn test_respond_not_found() {
331        let (event_tx, _event_rx) = mpsc::channel(10);
332        let registry = UserInteractionRegistry::new(event_tx);
333
334        let response = create_test_response();
335        let result = registry.respond("nonexistent", response).await;
336
337        assert_eq!(result, Err(UserInteractionError::NotFound));
338    }
339
340    #[tokio::test]
341    async fn test_pending_for_session() {
342        let (event_tx, _event_rx) = mpsc::channel(10);
343        let registry = UserInteractionRegistry::new(event_tx);
344
345        let request = create_test_request();
346
347        // Register interactions for different sessions
348        let _ = registry
349            .register("tool_1".to_string(), 1, request.clone(), None)
350            .await;
351        let _ = registry
352            .register("tool_2".to_string(), 1, request.clone(), None)
353            .await;
354        let _ = registry
355            .register("tool_3".to_string(), 2, request.clone(), None)
356            .await;
357
358        // Query session 1
359        let pending = registry.pending_for_session(1).await;
360        assert_eq!(pending.len(), 2);
361
362        // Query session 2
363        let pending = registry.pending_for_session(2).await;
364        assert_eq!(pending.len(), 1);
365
366        // Query nonexistent session
367        let pending = registry.pending_for_session(999).await;
368        assert_eq!(pending.len(), 0);
369    }
370
371    #[tokio::test]
372    async fn test_cancel_session() {
373        let (event_tx, _event_rx) = mpsc::channel(10);
374        let registry = UserInteractionRegistry::new(event_tx);
375
376        let request = create_test_request();
377
378        // Register interactions
379        let rx1 = registry
380            .register("tool_1".to_string(), 1, request.clone(), None)
381            .await
382            .unwrap();
383        let _ = registry
384            .register("tool_2".to_string(), 2, request.clone(), None)
385            .await;
386
387        // Cancel session 1
388        registry.cancel_session(1).await;
389
390        // Session 1 interaction should be gone
391        assert!(!registry.has_pending(1).await);
392        assert!(registry.has_pending(2).await);
393
394        // Receiver should get error
395        assert!(rx1.await.is_err());
396    }
397
398    #[tokio::test]
399    async fn test_has_pending() {
400        let (event_tx, _event_rx) = mpsc::channel(10);
401        let registry = UserInteractionRegistry::new(event_tx);
402
403        assert!(!registry.has_pending(1).await);
404
405        let request = create_test_request();
406        let _ = registry
407            .register("tool_1".to_string(), 1, request, None)
408            .await;
409
410        assert!(registry.has_pending(1).await);
411        assert!(!registry.has_pending(2).await);
412    }
413
414    #[tokio::test]
415    async fn test_pending_count() {
416        let (event_tx, _event_rx) = mpsc::channel(10);
417        let registry = UserInteractionRegistry::new(event_tx);
418
419        assert_eq!(registry.pending_count().await, 0);
420
421        let request = create_test_request();
422        let _ = registry
423            .register("tool_1".to_string(), 1, request.clone(), None)
424            .await;
425        assert_eq!(registry.pending_count().await, 1);
426
427        let _ = registry
428            .register("tool_2".to_string(), 1, request, None)
429            .await;
430        assert_eq!(registry.pending_count().await, 2);
431    }
432}