plexus_core/plexus/bidirectional/registry.rs
1//! Global pending response registry for bidirectional communication
2//!
3//! This module provides a global registry for pending bidirectional requests,
4//! enabling transports like MCP (which are fundamentally request-response) to
5//! route responses back to the correct BidirChannel.
6//!
7//! # Architecture
8//!
9//! 1. When a BidirChannel sends a request, it registers a callback in this registry
10//! 2. The transport (e.g., MCP) sends the request to the client as a notification
11//! 3. The client responds via a tool call (e.g., `_plexus_respond`)
12//! 4. The transport looks up the request in this registry and forwards the response
13//! 5. The registry callback deserializes and sends to the waiting BidirChannel
14//!
15//! # Thread Safety
16//!
17//! The registry uses a RwLock for concurrent read access with exclusive write access.
18//! Registrations and lookups are fast; response handling is done outside the lock.
19
20use serde_json::Value;
21use std::collections::HashMap;
22use std::sync::{LazyLock, RwLock};
23use tokio::sync::oneshot;
24
25use super::types::BidirError;
26
27/// Type alias for the response sender (type-erased to Value)
28type ResponseSender = oneshot::Sender<Value>;
29
30/// Global registry for pending bidirectional requests
31///
32/// This registry allows transports to correlate response messages with
33/// the BidirChannel waiting for them.
34static PENDING_RESPONSES: LazyLock<RwLock<HashMap<String, ResponseSender>>> =
35 LazyLock::new(|| RwLock::new(HashMap::new()));
36
37/// Register a pending request in the global registry
38///
39/// Called by BidirChannel when making a request over a transport that
40/// doesn't natively support bidirectional (like MCP).
41///
42/// # Arguments
43///
44/// * `request_id` - Unique identifier for the request
45/// * `sender` - Oneshot channel sender to forward the response
46///
47/// # Example
48///
49/// ```rust,ignore
50/// let (tx, rx) = oneshot::channel();
51/// register_pending_request("req-123", tx);
52/// // Transport sends request...
53/// // Later, _plexus_respond calls handle_pending_response("req-123", value)
54/// let response = rx.await?;
55/// ```
56pub fn register_pending_request(request_id: String, sender: ResponseSender) {
57 let mut registry = PENDING_RESPONSES.write().unwrap();
58 registry.insert(request_id, sender);
59}
60
61/// Remove a pending request from the registry (e.g., on timeout)
62///
63/// # Arguments
64///
65/// * `request_id` - The request ID to remove
66///
67/// # Returns
68///
69/// The removed sender if it existed, or None
70pub fn unregister_pending_request(request_id: &str) -> Option<ResponseSender> {
71 let mut registry = PENDING_RESPONSES.write().unwrap();
72 registry.remove(request_id)
73}
74
75/// Handle a response for a pending request
76///
77/// Called by transport tools like `_plexus_respond` when receiving a client response.
78///
79/// # Arguments
80///
81/// * `request_id` - The request ID from the client's response
82/// * `response_data` - The JSON response data
83///
84/// # Returns
85///
86/// * `Ok(())` if the response was successfully forwarded
87/// * `Err(BidirError::UnknownRequest)` if no pending request with that ID
88/// * `Err(BidirError::ChannelClosed)` if the receiver was dropped (timeout/cancelled)
89///
90/// # Example
91///
92/// ```rust,ignore
93/// // In _plexus_respond tool handler:
94/// let result = handle_pending_response(request_id, response_data)?;
95/// ```
96pub fn handle_pending_response(request_id: &str, response_data: Value) -> Result<(), BidirError> {
97 // Remove from registry (takes ownership of sender)
98 let sender = {
99 let mut registry = PENDING_RESPONSES.write().unwrap();
100 registry.remove(request_id)
101 };
102
103 match sender {
104 Some(tx) => {
105 // Send response through channel
106 tx.send(response_data).map_err(|_| BidirError::ChannelClosed)
107 }
108 None => Err(BidirError::UnknownRequest),
109 }
110}
111
112/// Check if a request is pending
113///
114/// # Arguments
115///
116/// * `request_id` - The request ID to check
117///
118/// # Returns
119///
120/// `true` if a request with this ID is pending, `false` otherwise
121pub fn is_request_pending(request_id: &str) -> bool {
122 let registry = PENDING_RESPONSES.read().unwrap();
123 registry.contains_key(request_id)
124}
125
126/// Get the count of pending requests (for monitoring/debugging)
127pub fn pending_count() -> usize {
128 let registry = PENDING_RESPONSES.read().unwrap();
129 registry.len()
130}
131
132/// Clear all pending requests (for testing)
133#[cfg(test)]
134#[allow(dead_code)]
135pub fn clear_all() {
136 let mut registry = PENDING_RESPONSES.write().unwrap();
137 registry.clear();
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 // Note: These tests run concurrently and share a global registry.
145 // Use unique request IDs and assert on per-ID presence (is_request_pending)
146 // rather than global pending_count(), which races with concurrent tests.
147
148 #[tokio::test]
149 async fn test_register_and_handle() {
150 let (tx, rx) = oneshot::channel();
151 let request_id = format!("test-reg-handle-{}", uuid::Uuid::new_v4());
152
153 // Register
154 register_pending_request(request_id.clone(), tx);
155 assert!(is_request_pending(&request_id));
156
157 // Handle response
158 let response = serde_json::json!({"confirmed": true});
159 handle_pending_response(&request_id, response.clone()).unwrap();
160
161 // Verify response received
162 let received = rx.await.unwrap();
163 assert_eq!(received, response);
164
165 // Request should be removed
166 assert!(!is_request_pending(&request_id));
167 }
168
169 #[tokio::test]
170 async fn test_unknown_request() {
171 let result = handle_pending_response(
172 &format!("nonexistent-{}", uuid::Uuid::new_v4()),
173 serde_json::json!({}),
174 );
175 assert!(matches!(result, Err(BidirError::UnknownRequest)));
176 }
177
178 #[tokio::test]
179 async fn test_unregister() {
180 let (tx, _rx) = oneshot::channel();
181 let request_id = format!("test-unreg-{}", uuid::Uuid::new_v4());
182
183 register_pending_request(request_id.clone(), tx);
184 assert!(is_request_pending(&request_id));
185
186 let removed = unregister_pending_request(&request_id);
187 assert!(removed.is_some());
188 assert!(!is_request_pending(&request_id));
189 }
190
191 #[tokio::test]
192 async fn test_channel_closed() {
193 let (tx, rx) = oneshot::channel();
194 let request_id = format!("test-closed-{}", uuid::Uuid::new_v4());
195
196 register_pending_request(request_id.clone(), tx);
197
198 // Drop the receiver
199 drop(rx);
200
201 // Handle should fail with ChannelClosed
202 let result = handle_pending_response(&request_id, serde_json::json!({}));
203 assert!(matches!(result, Err(BidirError::ChannelClosed)));
204 }
205}