plexus_core/plexus/bidirectional/
helpers.rs1use super::channel::BidirChannel;
4use super::types::{BidirError, StandardRequest, StandardResponse};
5use crate::plexus::types::PlexusStreamItem;
6use serde::{de::DeserializeOwned, Serialize};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Copy)]
22pub struct TimeoutConfig {
23 pub confirm: Duration,
25 pub prompt: Duration,
27 pub select: Duration,
29 pub custom: Duration,
31}
32
33impl TimeoutConfig {
34 pub fn quick() -> Self {
43 Self {
44 confirm: Duration::from_secs(10),
45 prompt: Duration::from_secs(10),
46 select: Duration::from_secs(10),
47 custom: Duration::from_secs(10),
48 }
49 }
50
51 pub fn normal() -> Self {
60 Self {
61 confirm: Duration::from_secs(30),
62 prompt: Duration::from_secs(30),
63 select: Duration::from_secs(30),
64 custom: Duration::from_secs(30),
65 }
66 }
67
68 pub fn patient() -> Self {
77 Self {
78 confirm: Duration::from_secs(60),
79 prompt: Duration::from_secs(60),
80 select: Duration::from_secs(60),
81 custom: Duration::from_secs(60),
82 }
83 }
84
85 pub fn extended() -> Self {
94 Self {
95 confirm: Duration::from_secs(300),
96 prompt: Duration::from_secs(300),
97 select: Duration::from_secs(300),
98 custom: Duration::from_secs(300),
99 }
100 }
101}
102
103impl Default for TimeoutConfig {
104 fn default() -> Self {
105 Self::normal()
106 }
107}
108
109pub fn create_test_bidir_channel<Req, Resp>() -> (
132 Arc<BidirChannel<Req, Resp>>,
133 mpsc::Receiver<PlexusStreamItem>,
134)
135where
136 Req: Serialize + DeserializeOwned + Send + 'static,
137 Resp: Serialize + DeserializeOwned + Send + 'static,
138{
139 let (tx, rx) = mpsc::channel(32);
140 let channel = Arc::new(BidirChannel::new_direct(
141 tx,
142 true, vec!["test".into()],
144 "test-hash".into(),
145 ));
146 (channel, rx)
147}
148
149pub fn create_test_standard_channel() -> (
165 Arc<BidirChannel<StandardRequest, StandardResponse>>,
166 mpsc::Receiver<PlexusStreamItem>,
167) {
168 create_test_bidir_channel()
169}
170
171pub fn auto_respond_channel<Req, Resp>(
196 response_fn: impl Fn(&Req) -> Resp + Send + Sync + 'static,
197) -> Arc<BidirChannel<Req, Resp>>
198where
199 Req: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
200 Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
201{
202 let (tx, mut rx) = mpsc::channel::<PlexusStreamItem>(32);
203 let channel = Arc::new(BidirChannel::new_direct(
204 tx,
205 true, vec!["test".into()],
207 "test-hash".into(),
208 ));
209
210 let channel_clone = channel.clone();
212 tokio::spawn(async move {
213 while let Some(item) = rx.recv().await {
214 if let PlexusStreamItem::Request {
215 request_id,
216 request_data,
217 ..
218 } = item
219 {
220 if let Ok(req) = serde_json::from_value::<Req>(request_data) {
222 let resp = response_fn(&req);
224
225 if let Ok(resp_json) = serde_json::to_value(&resp) {
227 let _ = channel_clone.handle_response(request_id, resp_json);
228 }
229 }
230 }
231 }
232 });
233
234 channel
235}
236
237pub fn auto_confirm_channel(confirm_value: bool) -> Arc<BidirChannel<StandardRequest, StandardResponse>> {
254 auto_respond_channel(move |req: &StandardRequest| match req {
255 StandardRequest::Confirm { default, .. } => StandardResponse::Confirmed {
256 value: default.unwrap_or(confirm_value),
257 },
258 StandardRequest::Prompt { default, .. } => StandardResponse::Text {
259 value: default
260 .clone()
261 .unwrap_or(serde_json::Value::String(String::new())),
262 },
263 StandardRequest::Select { options, .. } => StandardResponse::Selected {
264 values: vec![options
265 .first()
266 .map(|o| o.value.clone())
267 .unwrap_or(serde_json::Value::String(String::new()))],
268 },
269 StandardRequest::Custom { data } => StandardResponse::Custom { data: data.clone() },
270 })
271}
272
273pub fn bidir_error_message(err: &BidirError) -> String {
286 match err {
287 BidirError::NotSupported => {
288 "Bidirectional communication not supported by this transport".to_string()
289 }
290 BidirError::Timeout(ms) => {
291 format!("Request timed out waiting for response (after {}ms)", ms)
292 }
293 BidirError::Cancelled => "Request was cancelled by user".to_string(),
294 BidirError::TypeMismatch { expected, got } => {
295 format!("Type mismatch: expected {}, got {}", expected, got)
296 }
297 BidirError::Serialization(e) => format!("Serialization error: {}", e),
298 BidirError::Transport(e) => format!("Transport error: {}", e),
299 BidirError::UnknownRequest => "Unknown request ID (may have already been handled)".to_string(),
300 BidirError::ChannelClosed => "Response channel closed before response received".to_string(),
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_timeout_config_quick() {
310 let config = TimeoutConfig::quick();
311 assert_eq!(config.confirm, Duration::from_secs(10));
312 assert_eq!(config.prompt, Duration::from_secs(10));
313 }
314
315 #[test]
316 fn test_timeout_config_normal() {
317 let config = TimeoutConfig::normal();
318 assert_eq!(config.confirm, Duration::from_secs(30));
319 }
320
321 #[test]
322 fn test_timeout_config_patient() {
323 let config = TimeoutConfig::patient();
324 assert_eq!(config.confirm, Duration::from_secs(60));
325 }
326
327 #[test]
328 fn test_timeout_config_extended() {
329 let config = TimeoutConfig::extended();
330 assert_eq!(config.confirm, Duration::from_secs(300));
331 }
332
333 #[test]
334 fn test_timeout_config_default() {
335 let config = TimeoutConfig::default();
336 assert_eq!(config.confirm, Duration::from_secs(30)); }
338
339 #[tokio::test]
340 async fn test_create_test_bidir_channel() {
341 let (channel, _rx) = create_test_bidir_channel::<StandardRequest, StandardResponse>();
342 assert!(channel.is_bidirectional());
343 }
344
345 #[tokio::test]
346 async fn test_create_test_standard_channel() {
347 let (channel, _rx) = create_test_standard_channel();
348 assert!(channel.is_bidirectional());
349 }
350
351 #[tokio::test]
352 async fn test_auto_respond_channel() {
353 let ctx = auto_respond_channel(|req: &StandardRequest| match req {
354 StandardRequest::Confirm { .. } => StandardResponse::Confirmed { value: true },
355 StandardRequest::Prompt { .. } => StandardResponse::Text {
356 value: serde_json::Value::String("hello".into()),
357 },
358 StandardRequest::Select { options, .. } => StandardResponse::Selected {
359 values: vec![options[0].value.clone()],
360 },
361 StandardRequest::Custom { data } => StandardResponse::Custom { data: data.clone() },
362 });
363
364 let result = ctx.confirm("Test?").await;
366 assert_eq!(result.unwrap(), true);
367
368 let result = ctx.prompt("Name?").await;
370 assert_eq!(result.unwrap(), "hello");
371 }
372
373 #[tokio::test]
374 async fn test_auto_confirm_channel() {
375 let ctx = auto_confirm_channel(true);
376 let result = ctx.confirm("Test?").await;
377 assert_eq!(result.unwrap(), true);
378
379 let ctx = auto_confirm_channel(false);
380 let result = ctx.confirm("Test?").await;
381 assert_eq!(result.unwrap(), false);
382 }
383
384 #[test]
385 fn test_bidir_error_message() {
386 assert_eq!(
387 bidir_error_message(&BidirError::NotSupported),
388 "Bidirectional communication not supported by this transport"
389 );
390
391 assert_eq!(
392 bidir_error_message(&BidirError::Timeout(30000)),
393 "Request timed out waiting for response (after 30000ms)"
394 );
395
396 assert_eq!(
397 bidir_error_message(&BidirError::Cancelled),
398 "Request was cancelled by user"
399 );
400
401 assert_eq!(
402 bidir_error_message(&BidirError::TypeMismatch {
403 expected: "String".into(),
404 got: "Integer".into()
405 }),
406 "Type mismatch: expected String, got Integer"
407 );
408 }
409}