Skip to main content

plexus_core/plexus/bidirectional/
registry.rs

1//! Global pending response registry for bidirectional communication
2//!
3//! This module provides a global registry for pending bidirectional requests,
4//! enabling transports like MCP (which are fundamentally request-response) to
5//! route responses back to the correct BidirChannel.
6//!
7//! # Architecture
8//!
9//! 1. When a BidirChannel sends a request, it registers a callback in this registry
10//! 2. The transport (e.g., MCP) sends the request to the client as a notification
11//! 3. The client responds via a tool call (e.g., `_plexus_respond`)
12//! 4. The transport looks up the request in this registry and forwards the response
13//! 5. The registry callback deserializes and sends to the waiting BidirChannel
14//!
15//! # Thread Safety
16//!
17//! The registry uses a RwLock for concurrent read access with exclusive write access.
18//! Registrations and lookups are fast; response handling is done outside the lock.
19
20use serde_json::Value;
21use std::collections::HashMap;
22use std::sync::{LazyLock, RwLock};
23use tokio::sync::oneshot;
24
25use super::types::BidirError;
26
27/// Type alias for the response sender (type-erased to Value)
28type ResponseSender = oneshot::Sender<Value>;
29
30/// Global registry for pending bidirectional requests
31///
32/// This registry allows transports to correlate response messages with
33/// the BidirChannel waiting for them.
34static PENDING_RESPONSES: LazyLock<RwLock<HashMap<String, ResponseSender>>> =
35    LazyLock::new(|| RwLock::new(HashMap::new()));
36
37/// Register a pending request in the global registry
38///
39/// Called by BidirChannel when making a request over a transport that
40/// doesn't natively support bidirectional (like MCP).
41///
42/// # Arguments
43///
44/// * `request_id` - Unique identifier for the request
45/// * `sender` - Oneshot channel sender to forward the response
46///
47/// # Example
48///
49/// ```rust,ignore
50/// let (tx, rx) = oneshot::channel();
51/// register_pending_request("req-123", tx);
52/// // Transport sends request...
53/// // Later, _plexus_respond calls handle_pending_response("req-123", value)
54/// let response = rx.await?;
55/// ```
56pub fn register_pending_request(request_id: String, sender: ResponseSender) {
57    let mut registry = PENDING_RESPONSES.write().unwrap();
58    registry.insert(request_id, sender);
59}
60
61/// Remove a pending request from the registry (e.g., on timeout)
62///
63/// # Arguments
64///
65/// * `request_id` - The request ID to remove
66///
67/// # Returns
68///
69/// The removed sender if it existed, or None
70pub fn unregister_pending_request(request_id: &str) -> Option<ResponseSender> {
71    let mut registry = PENDING_RESPONSES.write().unwrap();
72    registry.remove(request_id)
73}
74
75/// Handle a response for a pending request
76///
77/// Called by transport tools like `_plexus_respond` when receiving a client response.
78///
79/// # Arguments
80///
81/// * `request_id` - The request ID from the client's response
82/// * `response_data` - The JSON response data
83///
84/// # Returns
85///
86/// * `Ok(())` if the response was successfully forwarded
87/// * `Err(BidirError::UnknownRequest)` if no pending request with that ID
88/// * `Err(BidirError::ChannelClosed)` if the receiver was dropped (timeout/cancelled)
89///
90/// # Example
91///
92/// ```rust,ignore
93/// // In _plexus_respond tool handler:
94/// let result = handle_pending_response(request_id, response_data)?;
95/// ```
96pub fn handle_pending_response(request_id: &str, response_data: Value) -> Result<(), BidirError> {
97    // Remove from registry (takes ownership of sender)
98    let sender = {
99        let mut registry = PENDING_RESPONSES.write().unwrap();
100        registry.remove(request_id)
101    };
102
103    match sender {
104        Some(tx) => {
105            // Send response through channel
106            tx.send(response_data).map_err(|_| BidirError::ChannelClosed)
107        }
108        None => Err(BidirError::UnknownRequest),
109    }
110}
111
112/// Check if a request is pending
113///
114/// # Arguments
115///
116/// * `request_id` - The request ID to check
117///
118/// # Returns
119///
120/// `true` if a request with this ID is pending, `false` otherwise
121pub fn is_request_pending(request_id: &str) -> bool {
122    let registry = PENDING_RESPONSES.read().unwrap();
123    registry.contains_key(request_id)
124}
125
126/// Get the count of pending requests (for monitoring/debugging)
127pub fn pending_count() -> usize {
128    let registry = PENDING_RESPONSES.read().unwrap();
129    registry.len()
130}
131
132/// Clear all pending requests (for testing)
133#[cfg(test)]
134#[allow(dead_code)]
135pub fn clear_all() {
136    let mut registry = PENDING_RESPONSES.write().unwrap();
137    registry.clear();
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    // Note: These tests run concurrently and share a global registry.
145    // Use unique request IDs and assert on per-ID presence (is_request_pending)
146    // rather than global pending_count(), which races with concurrent tests.
147
148    #[tokio::test]
149    async fn test_register_and_handle() {
150        let (tx, rx) = oneshot::channel();
151        let request_id = format!("test-reg-handle-{}", uuid::Uuid::new_v4());
152
153        // Register
154        register_pending_request(request_id.clone(), tx);
155        assert!(is_request_pending(&request_id));
156
157        // Handle response
158        let response = serde_json::json!({"confirmed": true});
159        handle_pending_response(&request_id, response.clone()).unwrap();
160
161        // Verify response received
162        let received = rx.await.unwrap();
163        assert_eq!(received, response);
164
165        // Request should be removed
166        assert!(!is_request_pending(&request_id));
167    }
168
169    #[tokio::test]
170    async fn test_unknown_request() {
171        let result = handle_pending_response(
172            &format!("nonexistent-{}", uuid::Uuid::new_v4()),
173            serde_json::json!({}),
174        );
175        assert!(matches!(result, Err(BidirError::UnknownRequest)));
176    }
177
178    #[tokio::test]
179    async fn test_unregister() {
180        let (tx, _rx) = oneshot::channel();
181        let request_id = format!("test-unreg-{}", uuid::Uuid::new_v4());
182
183        register_pending_request(request_id.clone(), tx);
184        assert!(is_request_pending(&request_id));
185
186        let removed = unregister_pending_request(&request_id);
187        assert!(removed.is_some());
188        assert!(!is_request_pending(&request_id));
189    }
190
191    #[tokio::test]
192    async fn test_channel_closed() {
193        let (tx, rx) = oneshot::channel();
194        let request_id = format!("test-closed-{}", uuid::Uuid::new_v4());
195
196        register_pending_request(request_id.clone(), tx);
197
198        // Drop the receiver
199        drop(rx);
200
201        // Handle should fail with ChannelClosed
202        let result = handle_pending_response(&request_id, serde_json::json!({}));
203        assert!(matches!(result, Err(BidirError::ChannelClosed)));
204    }
205}