agent_core/controller/tools/
user_interaction.rs1use std::collections::HashMap;
7
8use tokio::sync::{oneshot, Mutex, mpsc};
9
10use super::ask_user_questions::{AskUserQuestionsRequest, AskUserQuestionsResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum UserInteractionError {
16 NotFound,
18 AlreadyResponded,
20 SendFailed,
22 EventSendFailed,
24}
25
26impl std::fmt::Display for UserInteractionError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 UserInteractionError::NotFound => write!(f, "No pending interaction found"),
30 UserInteractionError::AlreadyResponded => write!(f, "Interaction already responded to"),
31 UserInteractionError::SendFailed => write!(f, "Failed to send response"),
32 UserInteractionError::EventSendFailed => write!(f, "Failed to send event notification"),
33 }
34 }
35}
36
37impl std::error::Error for UserInteractionError {}
38
39#[derive(Debug, Clone)]
41pub struct PendingQuestionInfo {
42 pub tool_use_id: String,
44 pub session_id: i64,
46 pub request: AskUserQuestionsRequest,
48 pub turn_id: Option<TurnId>,
50}
51
52struct PendingInteraction {
54 session_id: i64,
55 request: AskUserQuestionsRequest,
56 turn_id: Option<TurnId>,
57 responder: oneshot::Sender<AskUserQuestionsResponse>,
58}
59
60pub struct UserInteractionRegistry {
65 pending: Mutex<HashMap<String, PendingInteraction>>,
67 event_tx: mpsc::Sender<ControllerEvent>,
69}
70
71impl UserInteractionRegistry {
72 pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
77 Self {
78 pending: Mutex::new(HashMap::new()),
79 event_tx,
80 }
81 }
82
83 pub async fn register(
97 &self,
98 tool_use_id: String,
99 session_id: i64,
100 request: AskUserQuestionsRequest,
101 turn_id: Option<TurnId>,
102 ) -> Result<oneshot::Receiver<AskUserQuestionsResponse>, UserInteractionError> {
103 let (tx, rx) = oneshot::channel();
104
105 {
107 let mut pending = self.pending.lock().await;
108 pending.insert(
109 tool_use_id.clone(),
110 PendingInteraction {
111 session_id,
112 request: request.clone(),
113 turn_id: turn_id.clone(),
114 responder: tx,
115 },
116 );
117 }
118
119 self.event_tx
121 .send(ControllerEvent::UserInteractionRequired {
122 session_id,
123 tool_use_id,
124 request,
125 turn_id,
126 })
127 .await
128 .map_err(|_| UserInteractionError::EventSendFailed)?;
129
130 Ok(rx)
131 }
132
133 pub async fn respond(
144 &self,
145 tool_use_id: &str,
146 response: AskUserQuestionsResponse,
147 ) -> Result<(), UserInteractionError> {
148 let interaction = {
149 let mut pending = self.pending.lock().await;
150 pending
151 .remove(tool_use_id)
152 .ok_or(UserInteractionError::NotFound)?
153 };
154
155 interaction
156 .responder
157 .send(response)
158 .map_err(|_| UserInteractionError::SendFailed)
159 }
160
161 pub async fn cancel(&self, tool_use_id: &str) -> Result<(), UserInteractionError> {
172 let mut pending = self.pending.lock().await;
173 if pending.remove(tool_use_id).is_some() {
174 Ok(())
177 } else {
178 Err(UserInteractionError::NotFound)
179 }
180 }
181
182 pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingQuestionInfo> {
193 let pending = self.pending.lock().await;
194 pending
195 .iter()
196 .filter(|(_, interaction)| interaction.session_id == session_id)
197 .map(|(tool_use_id, interaction)| PendingQuestionInfo {
198 tool_use_id: tool_use_id.clone(),
199 session_id: interaction.session_id,
200 request: interaction.request.clone(),
201 turn_id: interaction.turn_id.clone(),
202 })
203 .collect()
204 }
205
206 pub async fn cancel_session(&self, session_id: i64) {
214 let mut pending = self.pending.lock().await;
215 pending.retain(|_, interaction| interaction.session_id != session_id);
216 }
218
219 pub async fn has_pending(&self, session_id: i64) -> bool {
227 let pending = self.pending.lock().await;
228 pending
229 .values()
230 .any(|interaction| interaction.session_id == session_id)
231 }
232
233 pub async fn pending_count(&self) -> usize {
235 let pending = self.pending.lock().await;
236 pending.len()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::controller::tools::ask_user_questions::{Answer, Question};
244
245 fn create_test_request() -> AskUserQuestionsRequest {
246 AskUserQuestionsRequest {
247 questions: vec![Question::SingleChoice {
248 text: "Which option?".to_string(),
249 choices: vec!["Option A".to_string(), "Option B".to_string()],
250 required: true,
251 }],
252 }
253 }
254
255 fn create_test_response() -> AskUserQuestionsResponse {
256 AskUserQuestionsResponse {
257 answers: vec![Answer {
258 question: "Which option?".to_string(),
259 answer: vec!["Option A".to_string()],
260 }],
261 }
262 }
263
264 #[tokio::test]
265 async fn test_register_and_respond() {
266 let (event_tx, mut event_rx) = mpsc::channel(10);
267 let registry = UserInteractionRegistry::new(event_tx);
268
269 let request = create_test_request();
270 let response = create_test_response();
271
272 let rx = registry
274 .register("tool_123".to_string(), 1, request.clone(), None)
275 .await
276 .unwrap();
277
278 let event = event_rx.recv().await.unwrap();
280 if let ControllerEvent::UserInteractionRequired {
281 session_id,
282 tool_use_id,
283 ..
284 } = event
285 {
286 assert_eq!(session_id, 1);
287 assert_eq!(tool_use_id, "tool_123");
288 } else {
289 panic!("Expected UserInteractionRequired event");
290 }
291
292 registry
294 .respond("tool_123", response.clone())
295 .await
296 .unwrap();
297
298 let received = rx.await.unwrap();
300 assert_eq!(received.answers.len(), 1);
301 }
302
303 #[tokio::test]
304 async fn test_respond_not_found() {
305 let (event_tx, _event_rx) = mpsc::channel(10);
306 let registry = UserInteractionRegistry::new(event_tx);
307
308 let response = create_test_response();
309 let result = registry.respond("nonexistent", response).await;
310
311 assert_eq!(result, Err(UserInteractionError::NotFound));
312 }
313
314 #[tokio::test]
315 async fn test_pending_for_session() {
316 let (event_tx, _event_rx) = mpsc::channel(10);
317 let registry = UserInteractionRegistry::new(event_tx);
318
319 let request = create_test_request();
320
321 let _ = registry
323 .register("tool_1".to_string(), 1, request.clone(), None)
324 .await;
325 let _ = registry
326 .register("tool_2".to_string(), 1, request.clone(), None)
327 .await;
328 let _ = registry
329 .register("tool_3".to_string(), 2, request.clone(), None)
330 .await;
331
332 let pending = registry.pending_for_session(1).await;
334 assert_eq!(pending.len(), 2);
335
336 let pending = registry.pending_for_session(2).await;
338 assert_eq!(pending.len(), 1);
339
340 let pending = registry.pending_for_session(999).await;
342 assert_eq!(pending.len(), 0);
343 }
344
345 #[tokio::test]
346 async fn test_cancel_session() {
347 let (event_tx, _event_rx) = mpsc::channel(10);
348 let registry = UserInteractionRegistry::new(event_tx);
349
350 let request = create_test_request();
351
352 let rx1 = registry
354 .register("tool_1".to_string(), 1, request.clone(), None)
355 .await
356 .unwrap();
357 let _ = registry
358 .register("tool_2".to_string(), 2, request.clone(), None)
359 .await;
360
361 registry.cancel_session(1).await;
363
364 assert!(!registry.has_pending(1).await);
366 assert!(registry.has_pending(2).await);
367
368 assert!(rx1.await.is_err());
370 }
371
372 #[tokio::test]
373 async fn test_has_pending() {
374 let (event_tx, _event_rx) = mpsc::channel(10);
375 let registry = UserInteractionRegistry::new(event_tx);
376
377 assert!(!registry.has_pending(1).await);
378
379 let request = create_test_request();
380 let _ = registry
381 .register("tool_1".to_string(), 1, request, None)
382 .await;
383
384 assert!(registry.has_pending(1).await);
385 assert!(!registry.has_pending(2).await);
386 }
387
388 #[tokio::test]
389 async fn test_pending_count() {
390 let (event_tx, _event_rx) = mpsc::channel(10);
391 let registry = UserInteractionRegistry::new(event_tx);
392
393 assert_eq!(registry.pending_count().await, 0);
394
395 let request = create_test_request();
396 let _ = registry
397 .register("tool_1".to_string(), 1, request.clone(), None)
398 .await;
399 assert_eq!(registry.pending_count().await, 1);
400
401 let _ = registry
402 .register("tool_2".to_string(), 1, request, None)
403 .await;
404 assert_eq!(registry.pending_count().await, 2);
405 }
406}