Skip to main content

aster/
action_required_manager.rs

1use anyhow::Result;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex, RwLock};
7use tokio::time::timeout;
8use tracing::warn;
9use uuid::Uuid;
10
11use crate::conversation::message::{Message, MessageContent};
12
13struct PendingRequest {
14    response_tx: Option<tokio::sync::oneshot::Sender<Value>>,
15}
16
17pub struct ActionRequiredManager {
18    pending: Arc<RwLock<HashMap<String, Arc<Mutex<PendingRequest>>>>>,
19    request_tx: mpsc::UnboundedSender<Message>,
20    pub request_rx: Mutex<mpsc::UnboundedReceiver<Message>>,
21}
22
23impl ActionRequiredManager {
24    fn new() -> Self {
25        let (request_tx, request_rx) = mpsc::unbounded_channel();
26        Self {
27            pending: Arc::new(RwLock::new(HashMap::new())),
28            request_tx,
29            request_rx: Mutex::new(request_rx),
30        }
31    }
32
33    pub fn global() -> &'static Self {
34        static INSTANCE: once_cell::sync::Lazy<ActionRequiredManager> =
35            once_cell::sync::Lazy::new(ActionRequiredManager::new);
36        &INSTANCE
37    }
38
39    pub async fn request_and_wait(
40        &self,
41        message: String,
42        schema: Value,
43        timeout_duration: Duration,
44    ) -> Result<Value> {
45        let id = Uuid::new_v4().to_string();
46        let (tx, rx) = tokio::sync::oneshot::channel();
47        let pending_request = PendingRequest {
48            response_tx: Some(tx),
49        };
50
51        self.pending
52            .write()
53            .await
54            .insert(id.clone(), Arc::new(Mutex::new(pending_request)));
55
56        let action_required_message = Message::assistant().with_content(
57            MessageContent::action_required_elicitation(id.clone(), message, schema),
58        );
59
60        if let Err(e) = self.request_tx.send(action_required_message) {
61            warn!("Failed to send action required message: {}", e);
62        }
63
64        let result = match timeout(timeout_duration, rx).await {
65            Ok(Ok(user_data)) => Ok(user_data),
66            Ok(Err(_)) => {
67                warn!("Response channel closed for request: {}", id);
68                Err(anyhow::anyhow!("Response channel closed"))
69            }
70            Err(_) => {
71                warn!("Timeout waiting for response: {}", id);
72                Err(anyhow::anyhow!("Timeout waiting for user response"))
73            }
74        };
75
76        self.pending.write().await.remove(&id);
77
78        result
79    }
80
81    pub async fn submit_response(&self, request_id: String, user_data: Value) -> Result<()> {
82        let pending_arc = {
83            let pending = self.pending.read().await;
84            pending
85                .get(&request_id)
86                .cloned()
87                .ok_or_else(|| anyhow::anyhow!("Request not found: {}", request_id))?
88        };
89
90        let mut pending = pending_arc.lock().await;
91        if let Some(tx) = pending.response_tx.take() {
92            if tx.send(user_data).is_err() {
93                warn!("Failed to send response through oneshot channel");
94            }
95        }
96
97        Ok(())
98    }
99}