use super::channel::BidirChannel;
use super::types::{BidirError, StandardRequest, StandardResponse};
use crate::plexus::types::PlexusStreamItem;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy)]
pub struct TimeoutConfig {
pub confirm: Duration,
pub prompt: Duration,
pub select: Duration,
pub custom: Duration,
}
impl TimeoutConfig {
pub fn quick() -> Self {
Self {
confirm: Duration::from_secs(10),
prompt: Duration::from_secs(10),
select: Duration::from_secs(10),
custom: Duration::from_secs(10),
}
}
pub fn normal() -> Self {
Self {
confirm: Duration::from_secs(30),
prompt: Duration::from_secs(30),
select: Duration::from_secs(30),
custom: Duration::from_secs(30),
}
}
pub fn patient() -> Self {
Self {
confirm: Duration::from_secs(60),
prompt: Duration::from_secs(60),
select: Duration::from_secs(60),
custom: Duration::from_secs(60),
}
}
pub fn extended() -> Self {
Self {
confirm: Duration::from_secs(300),
prompt: Duration::from_secs(300),
select: Duration::from_secs(300),
custom: Duration::from_secs(300),
}
}
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self::normal()
}
}
pub fn create_test_bidir_channel<Req, Resp>() -> (
Arc<BidirChannel<Req, Resp>>,
mpsc::Receiver<PlexusStreamItem>,
)
where
Req: Serialize + DeserializeOwned + Send + 'static,
Resp: Serialize + DeserializeOwned + Send + 'static,
{
let (tx, rx) = mpsc::channel(32);
let channel = Arc::new(BidirChannel::new_direct(
tx,
true, vec!["test".into()],
"test-hash".into(),
));
(channel, rx)
}
pub fn create_test_standard_channel() -> (
Arc<BidirChannel<StandardRequest, StandardResponse>>,
mpsc::Receiver<PlexusStreamItem>,
) {
create_test_bidir_channel()
}
pub fn auto_respond_channel<Req, Resp>(
response_fn: impl Fn(&Req) -> Resp + Send + Sync + 'static,
) -> Arc<BidirChannel<Req, Resp>>
where
Req: Serialize + DeserializeOwned + Send + Sync + Clone + 'static,
Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
{
let (tx, mut rx) = mpsc::channel::<PlexusStreamItem>(32);
let channel = Arc::new(BidirChannel::new_direct(
tx,
true, vec!["test".into()],
"test-hash".into(),
));
let channel_clone = channel.clone();
tokio::spawn(async move {
while let Some(item) = rx.recv().await {
if let PlexusStreamItem::Request {
request_id,
request_data,
..
} = item
{
if let Ok(req) = serde_json::from_value::<Req>(request_data) {
let resp = response_fn(&req);
if let Ok(resp_json) = serde_json::to_value(&resp) {
let _ = channel_clone.handle_response(request_id, resp_json);
}
}
}
}
});
channel
}
pub fn auto_confirm_channel(confirm_value: bool) -> Arc<BidirChannel<StandardRequest, StandardResponse>> {
auto_respond_channel(move |req: &StandardRequest| match req {
StandardRequest::Confirm { default, .. } => StandardResponse::Confirmed {
value: default.unwrap_or(confirm_value),
},
StandardRequest::Prompt { default, .. } => StandardResponse::Text {
value: default
.clone()
.unwrap_or(serde_json::Value::String(String::new())),
},
StandardRequest::Select { options, .. } => StandardResponse::Selected {
values: vec![options
.first()
.map(|o| o.value.clone())
.unwrap_or(serde_json::Value::String(String::new()))],
},
StandardRequest::Custom { data } => StandardResponse::Custom { data: data.clone() },
})
}
pub fn bidir_error_message(err: &BidirError) -> String {
match err {
BidirError::NotSupported => {
"Bidirectional communication not supported by this transport".to_string()
}
BidirError::Timeout(ms) => {
format!("Request timed out waiting for response (after {}ms)", ms)
}
BidirError::Cancelled => "Request was cancelled by user".to_string(),
BidirError::TypeMismatch { expected, got } => {
format!("Type mismatch: expected {}, got {}", expected, got)
}
BidirError::Serialization(e) => format!("Serialization error: {}", e),
BidirError::Transport(e) => format!("Transport error: {}", e),
BidirError::UnknownRequest => "Unknown request ID (may have already been handled)".to_string(),
BidirError::ChannelClosed => "Response channel closed before response received".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timeout_config_quick() {
let config = TimeoutConfig::quick();
assert_eq!(config.confirm, Duration::from_secs(10));
assert_eq!(config.prompt, Duration::from_secs(10));
}
#[test]
fn test_timeout_config_normal() {
let config = TimeoutConfig::normal();
assert_eq!(config.confirm, Duration::from_secs(30));
}
#[test]
fn test_timeout_config_patient() {
let config = TimeoutConfig::patient();
assert_eq!(config.confirm, Duration::from_secs(60));
}
#[test]
fn test_timeout_config_extended() {
let config = TimeoutConfig::extended();
assert_eq!(config.confirm, Duration::from_secs(300));
}
#[test]
fn test_timeout_config_default() {
let config = TimeoutConfig::default();
assert_eq!(config.confirm, Duration::from_secs(30)); }
#[tokio::test]
async fn test_create_test_bidir_channel() {
let (channel, _rx) = create_test_bidir_channel::<StandardRequest, StandardResponse>();
assert!(channel.is_bidirectional());
}
#[tokio::test]
async fn test_create_test_standard_channel() {
let (channel, _rx) = create_test_standard_channel();
assert!(channel.is_bidirectional());
}
#[tokio::test]
async fn test_auto_respond_channel() {
let ctx = auto_respond_channel(|req: &StandardRequest| match req {
StandardRequest::Confirm { .. } => StandardResponse::Confirmed { value: true },
StandardRequest::Prompt { .. } => StandardResponse::Text {
value: serde_json::Value::String("hello".into()),
},
StandardRequest::Select { options, .. } => StandardResponse::Selected {
values: vec![options[0].value.clone()],
},
StandardRequest::Custom { data } => StandardResponse::Custom { data: data.clone() },
});
let result = ctx.confirm("Test?").await;
assert_eq!(result.unwrap(), true);
let result = ctx.prompt("Name?").await;
assert_eq!(result.unwrap(), "hello");
}
#[tokio::test]
async fn test_auto_confirm_channel() {
let ctx = auto_confirm_channel(true);
let result = ctx.confirm("Test?").await;
assert_eq!(result.unwrap(), true);
let ctx = auto_confirm_channel(false);
let result = ctx.confirm("Test?").await;
assert_eq!(result.unwrap(), false);
}
#[test]
fn test_bidir_error_message() {
assert_eq!(
bidir_error_message(&BidirError::NotSupported),
"Bidirectional communication not supported by this transport"
);
assert_eq!(
bidir_error_message(&BidirError::Timeout(30000)),
"Request timed out waiting for response (after 30000ms)"
);
assert_eq!(
bidir_error_message(&BidirError::Cancelled),
"Request was cancelled by user"
);
assert_eq!(
bidir_error_message(&BidirError::TypeMismatch {
expected: "String".into(),
got: "Integer".into()
}),
"Type mismatch: expected String, got Integer"
);
}
}