aster/
action_required_manager.rs1use anyhow::Result;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex, RwLock};
7use tokio::time::timeout;
8use tracing::warn;
9use uuid::Uuid;
10
11use crate::conversation::message::{Message, MessageContent};
12
13struct PendingRequest {
14 response_tx: Option<tokio::sync::oneshot::Sender<Value>>,
15}
16
17pub struct ActionRequiredManager {
18 pending: Arc<RwLock<HashMap<String, Arc<Mutex<PendingRequest>>>>>,
19 request_tx: mpsc::UnboundedSender<Message>,
20 pub request_rx: Mutex<mpsc::UnboundedReceiver<Message>>,
21}
22
23impl ActionRequiredManager {
24 fn new() -> Self {
25 let (request_tx, request_rx) = mpsc::unbounded_channel();
26 Self {
27 pending: Arc::new(RwLock::new(HashMap::new())),
28 request_tx,
29 request_rx: Mutex::new(request_rx),
30 }
31 }
32
33 pub fn global() -> &'static Self {
34 static INSTANCE: once_cell::sync::Lazy<ActionRequiredManager> =
35 once_cell::sync::Lazy::new(ActionRequiredManager::new);
36 &INSTANCE
37 }
38
39 pub async fn request_and_wait(
40 &self,
41 message: String,
42 schema: Value,
43 timeout_duration: Duration,
44 ) -> Result<Value> {
45 let id = Uuid::new_v4().to_string();
46 let (tx, rx) = tokio::sync::oneshot::channel();
47 let pending_request = PendingRequest {
48 response_tx: Some(tx),
49 };
50
51 self.pending
52 .write()
53 .await
54 .insert(id.clone(), Arc::new(Mutex::new(pending_request)));
55
56 let action_required_message = Message::assistant().with_content(
57 MessageContent::action_required_elicitation(id.clone(), message, schema),
58 );
59
60 if let Err(e) = self.request_tx.send(action_required_message) {
61 warn!("Failed to send action required message: {}", e);
62 }
63
64 let result = match timeout(timeout_duration, rx).await {
65 Ok(Ok(user_data)) => Ok(user_data),
66 Ok(Err(_)) => {
67 warn!("Response channel closed for request: {}", id);
68 Err(anyhow::anyhow!("Response channel closed"))
69 }
70 Err(_) => {
71 warn!("Timeout waiting for response: {}", id);
72 Err(anyhow::anyhow!("Timeout waiting for user response"))
73 }
74 };
75
76 self.pending.write().await.remove(&id);
77
78 result
79 }
80
81 pub async fn submit_response(&self, request_id: String, user_data: Value) -> Result<()> {
82 let pending_arc = {
83 let pending = self.pending.read().await;
84 pending
85 .get(&request_id)
86 .cloned()
87 .ok_or_else(|| anyhow::anyhow!("Request not found: {}", request_id))?
88 };
89
90 let mut pending = pending_arc.lock().await;
91 if let Some(tx) = pending.response_tx.take() {
92 if tx.send(user_data).is_err() {
93 warn!("Failed to send response through oneshot channel");
94 }
95 }
96
97 Ok(())
98 }
99}