agent_core/controller/tools/
user_interaction.rs

1//! User interaction registry for managing pending user questions.
2//!
3//! This module provides a registry for tools that need to wait for user input,
4//! such as the AskUserQuestions tool.
5
6use std::collections::HashMap;
7
8use tokio::sync::{oneshot, Mutex, mpsc};
9
10use super::ask_user_questions::{AskUserQuestionsRequest, AskUserQuestionsResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13/// Error types for user interaction operations.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum UserInteractionError {
16    /// No pending interaction found for the given tool_use_id.
17    NotFound,
18    /// The interaction was already responded to.
19    AlreadyResponded,
20    /// Failed to send response (channel closed).
21    SendFailed,
22    /// Failed to send event notification.
23    EventSendFailed,
24}
25
26impl std::fmt::Display for UserInteractionError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            UserInteractionError::NotFound => write!(f, "No pending interaction found"),
30            UserInteractionError::AlreadyResponded => write!(f, "Interaction already responded to"),
31            UserInteractionError::SendFailed => write!(f, "Failed to send response"),
32            UserInteractionError::EventSendFailed => write!(f, "Failed to send event notification"),
33        }
34    }
35}
36
37impl std::error::Error for UserInteractionError {}
38
39/// Information about a pending interaction for UI display.
40#[derive(Debug, Clone)]
41pub struct PendingQuestionInfo {
42    /// Tool use ID for this interaction.
43    pub tool_use_id: String,
44    /// Session ID this interaction belongs to.
45    pub session_id: i64,
46    /// The questions being asked.
47    pub request: AskUserQuestionsRequest,
48    /// Turn ID for this interaction.
49    pub turn_id: Option<TurnId>,
50}
51
52/// Internal state for a pending interaction.
53struct PendingInteraction {
54    session_id: i64,
55    request: AskUserQuestionsRequest,
56    turn_id: Option<TurnId>,
57    responder: oneshot::Sender<AskUserQuestionsResponse>,
58}
59
60/// Registry for managing pending user interactions.
61///
62/// This registry tracks tools that are blocked waiting for user input
63/// and provides methods for the UI to query and respond to these interactions.
64pub struct UserInteractionRegistry {
65    /// Pending interactions keyed by tool_use_id.
66    pending: Mutex<HashMap<String, PendingInteraction>>,
67    /// Channel to send events to the controller.
68    event_tx: mpsc::Sender<ControllerEvent>,
69}
70
71impl UserInteractionRegistry {
72    /// Create a new UserInteractionRegistry.
73    ///
74    /// # Arguments
75    /// * `event_tx` - Channel to send events when interactions are registered.
76    pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
77        Self {
78            pending: Mutex::new(HashMap::new()),
79            event_tx,
80        }
81    }
82
83    /// Register a pending interaction and get a receiver to await on.
84    ///
85    /// This is called by the AskUserQuestionsTool when it starts executing.
86    /// The tool will await on the returned receiver until the UI responds.
87    ///
88    /// # Arguments
89    /// * `tool_use_id` - Unique ID for this tool use request.
90    /// * `session_id` - Session that requested the interaction.
91    /// * `request` - The questions to ask the user.
92    /// * `turn_id` - Optional turn ID for this interaction.
93    ///
94    /// # Returns
95    /// A oneshot receiver that will receive the user's response.
96    pub async fn register(
97        &self,
98        tool_use_id: String,
99        session_id: i64,
100        request: AskUserQuestionsRequest,
101        turn_id: Option<TurnId>,
102    ) -> Result<oneshot::Receiver<AskUserQuestionsResponse>, UserInteractionError> {
103        let (tx, rx) = oneshot::channel();
104
105        // Store the pending interaction
106        {
107            let mut pending = self.pending.lock().await;
108            pending.insert(
109                tool_use_id.clone(),
110                PendingInteraction {
111                    session_id,
112                    request: request.clone(),
113                    turn_id: turn_id.clone(),
114                    responder: tx,
115                },
116            );
117        }
118
119        // Emit event to notify UI
120        self.event_tx
121            .send(ControllerEvent::UserInteractionRequired {
122                session_id,
123                tool_use_id,
124                request,
125                turn_id,
126            })
127            .await
128            .map_err(|_| UserInteractionError::EventSendFailed)?;
129
130        Ok(rx)
131    }
132
133    /// Respond to a pending interaction.
134    ///
135    /// This is called by the UI when the user has answered the questions.
136    ///
137    /// # Arguments
138    /// * `tool_use_id` - ID of the tool use to respond to.
139    /// * `response` - The user's answers.
140    ///
141    /// # Returns
142    /// Ok(()) if the response was sent successfully, or an error.
143    pub async fn respond(
144        &self,
145        tool_use_id: &str,
146        response: AskUserQuestionsResponse,
147    ) -> Result<(), UserInteractionError> {
148        let interaction = {
149            let mut pending = self.pending.lock().await;
150            pending
151                .remove(tool_use_id)
152                .ok_or(UserInteractionError::NotFound)?
153        };
154
155        interaction
156            .responder
157            .send(response)
158            .map_err(|_| UserInteractionError::SendFailed)
159    }
160
161    /// Cancel a pending interaction (user declined to answer).
162    ///
163    /// This is called by the UI when the user presses Escape to close
164    /// the question panel without answering.
165    ///
166    /// # Arguments
167    /// * `tool_use_id` - ID of the tool use to cancel.
168    ///
169    /// # Returns
170    /// Ok(()) if the interaction was found and cancelled, or NotFound error.
171    pub async fn cancel(&self, tool_use_id: &str) -> Result<(), UserInteractionError> {
172        let mut pending = self.pending.lock().await;
173        if pending.remove(tool_use_id).is_some() {
174            // Dropping the sender will cause the tool to receive a RecvError
175            // which will be converted to "User declined to answer"
176            Ok(())
177        } else {
178            Err(UserInteractionError::NotFound)
179        }
180    }
181
182    /// Get all pending interactions for a session.
183    ///
184    /// This is called by the UI when switching sessions to display
185    /// any pending questions for that session.
186    ///
187    /// # Arguments
188    /// * `session_id` - Session ID to query.
189    ///
190    /// # Returns
191    /// List of pending question information for the session.
192    pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingQuestionInfo> {
193        let pending = self.pending.lock().await;
194        pending
195            .iter()
196            .filter(|(_, interaction)| interaction.session_id == session_id)
197            .map(|(tool_use_id, interaction)| PendingQuestionInfo {
198                tool_use_id: tool_use_id.clone(),
199                session_id: interaction.session_id,
200                request: interaction.request.clone(),
201                turn_id: interaction.turn_id.clone(),
202            })
203            .collect()
204    }
205
206    /// Cancel all pending interactions for a session.
207    ///
208    /// This is called when a session is destroyed. It drops the senders,
209    /// which will cause the awaiting tools to receive a RecvError.
210    ///
211    /// # Arguments
212    /// * `session_id` - Session ID to cancel.
213    pub async fn cancel_session(&self, session_id: i64) {
214        let mut pending = self.pending.lock().await;
215        pending.retain(|_, interaction| interaction.session_id != session_id);
216        // Dropped senders will cause RecvError on the tool side
217    }
218
219    /// Check if there are any pending interactions for a session.
220    ///
221    /// # Arguments
222    /// * `session_id` - Session ID to check.
223    ///
224    /// # Returns
225    /// True if there are pending interactions.
226    pub async fn has_pending(&self, session_id: i64) -> bool {
227        let pending = self.pending.lock().await;
228        pending
229            .values()
230            .any(|interaction| interaction.session_id == session_id)
231    }
232
233    /// Get the count of pending interactions.
234    pub async fn pending_count(&self) -> usize {
235        let pending = self.pending.lock().await;
236        pending.len()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::controller::tools::ask_user_questions::{Answer, Question};
244
245    fn create_test_request() -> AskUserQuestionsRequest {
246        AskUserQuestionsRequest {
247            questions: vec![Question::SingleChoice {
248                text: "Which option?".to_string(),
249                choices: vec!["Option A".to_string(), "Option B".to_string()],
250                required: true,
251            }],
252        }
253    }
254
255    fn create_test_response() -> AskUserQuestionsResponse {
256        AskUserQuestionsResponse {
257            answers: vec![Answer {
258                question: "Which option?".to_string(),
259                answer: vec!["Option A".to_string()],
260            }],
261        }
262    }
263
264    #[tokio::test]
265    async fn test_register_and_respond() {
266        let (event_tx, mut event_rx) = mpsc::channel(10);
267        let registry = UserInteractionRegistry::new(event_tx);
268
269        let request = create_test_request();
270        let response = create_test_response();
271
272        // Register interaction
273        let rx = registry
274            .register("tool_123".to_string(), 1, request.clone(), None)
275            .await
276            .unwrap();
277
278        // Verify event was emitted
279        let event = event_rx.recv().await.unwrap();
280        if let ControllerEvent::UserInteractionRequired {
281            session_id,
282            tool_use_id,
283            ..
284        } = event
285        {
286            assert_eq!(session_id, 1);
287            assert_eq!(tool_use_id, "tool_123");
288        } else {
289            panic!("Expected UserInteractionRequired event");
290        }
291
292        // Respond to interaction
293        registry
294            .respond("tool_123", response.clone())
295            .await
296            .unwrap();
297
298        // Verify response was received
299        let received = rx.await.unwrap();
300        assert_eq!(received.answers.len(), 1);
301    }
302
303    #[tokio::test]
304    async fn test_respond_not_found() {
305        let (event_tx, _event_rx) = mpsc::channel(10);
306        let registry = UserInteractionRegistry::new(event_tx);
307
308        let response = create_test_response();
309        let result = registry.respond("nonexistent", response).await;
310
311        assert_eq!(result, Err(UserInteractionError::NotFound));
312    }
313
314    #[tokio::test]
315    async fn test_pending_for_session() {
316        let (event_tx, _event_rx) = mpsc::channel(10);
317        let registry = UserInteractionRegistry::new(event_tx);
318
319        let request = create_test_request();
320
321        // Register interactions for different sessions
322        let _ = registry
323            .register("tool_1".to_string(), 1, request.clone(), None)
324            .await;
325        let _ = registry
326            .register("tool_2".to_string(), 1, request.clone(), None)
327            .await;
328        let _ = registry
329            .register("tool_3".to_string(), 2, request.clone(), None)
330            .await;
331
332        // Query session 1
333        let pending = registry.pending_for_session(1).await;
334        assert_eq!(pending.len(), 2);
335
336        // Query session 2
337        let pending = registry.pending_for_session(2).await;
338        assert_eq!(pending.len(), 1);
339
340        // Query nonexistent session
341        let pending = registry.pending_for_session(999).await;
342        assert_eq!(pending.len(), 0);
343    }
344
345    #[tokio::test]
346    async fn test_cancel_session() {
347        let (event_tx, _event_rx) = mpsc::channel(10);
348        let registry = UserInteractionRegistry::new(event_tx);
349
350        let request = create_test_request();
351
352        // Register interactions
353        let rx1 = registry
354            .register("tool_1".to_string(), 1, request.clone(), None)
355            .await
356            .unwrap();
357        let _ = registry
358            .register("tool_2".to_string(), 2, request.clone(), None)
359            .await;
360
361        // Cancel session 1
362        registry.cancel_session(1).await;
363
364        // Session 1 interaction should be gone
365        assert!(!registry.has_pending(1).await);
366        assert!(registry.has_pending(2).await);
367
368        // Receiver should get error
369        assert!(rx1.await.is_err());
370    }
371
372    #[tokio::test]
373    async fn test_has_pending() {
374        let (event_tx, _event_rx) = mpsc::channel(10);
375        let registry = UserInteractionRegistry::new(event_tx);
376
377        assert!(!registry.has_pending(1).await);
378
379        let request = create_test_request();
380        let _ = registry
381            .register("tool_1".to_string(), 1, request, None)
382            .await;
383
384        assert!(registry.has_pending(1).await);
385        assert!(!registry.has_pending(2).await);
386    }
387
388    #[tokio::test]
389    async fn test_pending_count() {
390        let (event_tx, _event_rx) = mpsc::channel(10);
391        let registry = UserInteractionRegistry::new(event_tx);
392
393        assert_eq!(registry.pending_count().await, 0);
394
395        let request = create_test_request();
396        let _ = registry
397            .register("tool_1".to_string(), 1, request.clone(), None)
398            .await;
399        assert_eq!(registry.pending_count().await, 1);
400
401        let _ = registry
402            .register("tool_2".to_string(), 1, request, None)
403            .await;
404        assert_eq!(registry.pending_count().await, 2);
405    }
406}