use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use uuid::Uuid;
use super::registry::{register_pending_request, unregister_pending_request};
use super::types::{BidirError, SelectOption, StandardRequest, StandardResponse};
use crate::plexus::types::PlexusStreamItem;
pub struct BidirChannel<Req, Resp>
where
Req: Serialize + DeserializeOwned + Send + 'static,
Resp: Serialize + DeserializeOwned + Send + 'static,
{
stream_tx: mpsc::Sender<PlexusStreamItem>,
pending: Arc<Mutex<HashMap<String, oneshot::Sender<Resp>>>>,
bidirectional_supported: bool,
use_global_registry: bool,
provenance: Vec<String>,
plexus_hash: String,
_phantom_req: PhantomData<Req>,
}
pub type StandardBidirChannel = BidirChannel<StandardRequest, StandardResponse>;
impl<Req, Resp> BidirChannel<Req, Resp>
where
Req: Serialize + DeserializeOwned + Send + 'static,
Resp: Serialize + DeserializeOwned + Send + 'static,
{
pub fn new(
stream_tx: mpsc::Sender<PlexusStreamItem>,
bidirectional_supported: bool,
provenance: Vec<String>,
plexus_hash: String,
) -> Self {
Self {
stream_tx,
pending: Arc::new(Mutex::new(HashMap::new())),
bidirectional_supported,
use_global_registry: true, provenance,
plexus_hash,
_phantom_req: PhantomData,
}
}
pub fn new_direct(
stream_tx: mpsc::Sender<PlexusStreamItem>,
bidirectional_supported: bool,
provenance: Vec<String>,
plexus_hash: String,
) -> Self {
Self {
stream_tx,
pending: Arc::new(Mutex::new(HashMap::new())),
bidirectional_supported,
use_global_registry: false,
provenance,
plexus_hash,
_phantom_req: PhantomData,
}
}
pub fn is_bidirectional(&self) -> bool {
self.bidirectional_supported
}
pub async fn request(&self, req: Req) -> Result<Resp, BidirError> {
self.request_with_timeout(req, Duration::from_secs(30))
.await
}
pub async fn request_with_timeout(
&self,
req: Req,
timeout_duration: Duration,
) -> Result<Resp, BidirError> {
if !self.bidirectional_supported {
return Err(BidirError::NotSupported);
}
let request_id = Uuid::new_v4().to_string();
let request_data = serde_json::to_value(&req)
.map_err(|e| BidirError::Serialization(e.to_string()))?;
let timeout_ms = timeout_duration.as_millis() as u64;
if self.use_global_registry {
self.request_via_registry(request_id, request_data, timeout_duration, timeout_ms)
.await
} else {
self.request_direct(request_id, request_data, timeout_duration, timeout_ms)
.await
}
}
async fn request_direct(
&self,
request_id: String,
request_data: Value,
timeout_duration: Duration,
timeout_ms: u64,
) -> Result<Resp, BidirError> {
let (tx, rx) = oneshot::channel();
self.pending.lock().unwrap().insert(request_id.clone(), tx);
self.stream_tx
.send(PlexusStreamItem::request(
request_id.clone(),
request_data,
timeout_ms,
))
.await
.map_err(|e| BidirError::Transport(format!("Failed to send request: {}", e)))?;
match timeout(timeout_duration, rx).await {
Ok(Ok(resp)) => Ok(resp),
Ok(Err(_)) => {
self.pending.lock().unwrap().remove(&request_id);
Err(BidirError::ChannelClosed)
}
Err(_) => {
self.pending.lock().unwrap().remove(&request_id);
Err(BidirError::Timeout(timeout_ms))
}
}
}
async fn request_via_registry(
&self,
request_id: String,
request_data: Value,
timeout_duration: Duration,
timeout_ms: u64,
) -> Result<Resp, BidirError> {
let (tx, rx) = oneshot::channel::<Value>();
register_pending_request(request_id.clone(), tx);
if let Err(e) = self
.stream_tx
.send(PlexusStreamItem::request(
request_id.clone(),
request_data,
timeout_ms,
))
.await
{
unregister_pending_request(&request_id);
return Err(BidirError::Transport(format!("Failed to send request: {}", e)));
}
match timeout(timeout_duration, rx).await {
Ok(Ok(value)) => {
serde_json::from_value(value).map_err(|e| BidirError::TypeMismatch {
expected: std::any::type_name::<Resp>().to_string(),
got: e.to_string(),
})
}
Ok(Err(_)) => {
unregister_pending_request(&request_id);
Err(BidirError::ChannelClosed)
}
Err(_) => {
unregister_pending_request(&request_id);
Err(BidirError::Timeout(timeout_ms))
}
}
}
pub fn handle_response(
&self,
request_id: String,
response_data: Value,
) -> Result<(), BidirError> {
let tx = self
.pending
.lock()
.unwrap()
.remove(&request_id)
.ok_or(BidirError::UnknownRequest)?;
let resp: Resp = serde_json::from_value(response_data).map_err(|e| {
BidirError::TypeMismatch {
expected: std::any::type_name::<Resp>().to_string(),
got: e.to_string(),
}
})?;
tx.send(resp).map_err(|_| BidirError::ChannelClosed)?;
Ok(())
}
pub fn provenance(&self) -> &[String] {
&self.provenance
}
pub fn plexus_hash(&self) -> &str {
&self.plexus_hash
}
}
impl BidirChannel<StandardRequest, StandardResponse> {
pub async fn confirm(&self, message: &str) -> Result<bool, BidirError> {
let resp = self
.request(StandardRequest::Confirm {
message: message.to_string(),
default: None,
})
.await?;
match resp {
StandardResponse::Confirmed { value } => Ok(value),
StandardResponse::Cancelled => Err(BidirError::Cancelled),
_ => Err(BidirError::TypeMismatch {
expected: "Confirmed".into(),
got: format!("{:?}", resp),
}),
}
}
pub async fn prompt(&self, message: &str) -> Result<String, BidirError> {
let resp = self
.request(StandardRequest::Prompt {
message: message.to_string(),
default: None,
placeholder: None,
})
.await?;
match resp {
StandardResponse::Text { value } => {
match value {
serde_json::Value::String(s) => Ok(s),
other => Ok(other.to_string()),
}
}
StandardResponse::Cancelled => Err(BidirError::Cancelled),
_ => Err(BidirError::TypeMismatch {
expected: "Text".into(),
got: format!("{:?}", resp),
}),
}
}
pub async fn select(
&self,
message: &str,
options: Vec<SelectOption>,
) -> Result<Vec<String>, BidirError> {
let resp = self
.request(StandardRequest::Select {
message: message.to_string(),
options,
multi_select: false,
})
.await?;
match resp {
StandardResponse::Selected { values } => {
let strings = values
.into_iter()
.map(|v| match v {
serde_json::Value::String(s) => s,
other => other.to_string(),
})
.collect();
Ok(strings)
}
StandardResponse::Cancelled => Err(BidirError::Cancelled),
_ => Err(BidirError::TypeMismatch {
expected: "Selected".into(),
got: format!("{:?}", resp),
}),
}
}
}
pub struct BidirWithFallback<Req, Resp>
where
Req: Serialize + DeserializeOwned + Send + 'static,
Resp: Serialize + DeserializeOwned + Send + 'static,
{
channel: Arc<BidirChannel<Req, Resp>>,
fallback_fn: Box<dyn Fn(&Req) -> Resp + Send + Sync>,
}
impl<Req, Resp> BidirWithFallback<Req, Resp>
where
Req: Serialize + DeserializeOwned + Send + 'static,
Resp: Serialize + DeserializeOwned + Send + 'static,
{
pub fn new(
channel: Arc<BidirChannel<Req, Resp>>,
fallback: impl Fn(&Req) -> Resp + Send + Sync + 'static,
) -> Self {
Self {
channel,
fallback_fn: Box::new(fallback),
}
}
pub async fn request(&self, req: Req) -> Resp
where
Req: Clone,
{
match self.channel.request(req.clone()).await {
Ok(resp) => resp,
Err(BidirError::NotSupported) | Err(BidirError::Timeout(_)) => {
(self.fallback_fn)(&req)
}
Err(_) => (self.fallback_fn)(&req),
}
}
}
impl BidirWithFallback<StandardRequest, StandardResponse> {
pub fn auto_confirm(
channel: Arc<BidirChannel<StandardRequest, StandardResponse>>,
) -> Self {
Self::new(channel, |req| match req {
StandardRequest::Confirm { default, .. } => StandardResponse::Confirmed {
value: default.unwrap_or(true),
},
StandardRequest::Prompt { default, .. } => StandardResponse::Text {
value: default.clone().unwrap_or(serde_json::Value::String(String::new())),
},
StandardRequest::Select { options, .. } => StandardResponse::Selected {
values: vec![options
.first()
.map(|o| o.value.clone())
.unwrap_or(serde_json::Value::String(String::new()))],
},
StandardRequest::Custom { data } => StandardResponse::Custom { data: data.clone() },
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bidir_channel_not_supported() {
let (tx, _rx) = mpsc::channel(32);
let channel: BidirChannel<StandardRequest, StandardResponse> =
BidirChannel::new_direct(tx, false, vec!["test".into()], "hash".into());
let result = channel.confirm("Test?").await;
assert!(matches!(result, Err(BidirError::NotSupported)));
}
#[tokio::test]
async fn test_bidir_request_response() {
let (tx, mut rx) = mpsc::channel(32);
let channel: Arc<BidirChannel<StandardRequest, StandardResponse>> = Arc::new(BidirChannel::new_direct(
tx,
true,
vec!["test".into()],
"hash".into(),
));
let channel_clone = channel.clone();
let handle = tokio::spawn(async move {
channel_clone
.request(StandardRequest::Confirm {
message: "Test?".into(),
default: None,
})
.await
});
if let Some(PlexusStreamItem::Request {
request_id,
request_data,
..
}) = rx.recv().await
{
let req: StandardRequest = serde_json::from_value(request_data).unwrap();
assert!(matches!(req, StandardRequest::Confirm { .. }));
channel
.handle_response(
request_id,
serde_json::to_value(&StandardResponse::<serde_json::Value>::Confirmed {
value: true,
})
.unwrap(),
)
.unwrap();
} else {
panic!("Expected Request item");
}
let result: StandardResponse = handle.await.unwrap().unwrap();
assert_eq!(result, StandardResponse::Confirmed { value: true });
}
#[tokio::test]
async fn test_convenience_methods() {
let (tx, mut rx) = mpsc::channel(32);
let channel: Arc<StandardBidirChannel> = Arc::new(BidirChannel::new_direct(
tx,
true,
vec!["test".into()],
"hash".into(),
));
let channel_clone = channel.clone();
let handle = tokio::spawn(async move { channel_clone.confirm("Delete?").await });
if let Some(PlexusStreamItem::Request { request_id, .. }) = rx.recv().await {
channel
.handle_response(
request_id,
serde_json::to_value(&StandardResponse::<serde_json::Value>::Confirmed {
value: true,
})
.unwrap(),
)
.unwrap();
}
assert_eq!(handle.await.unwrap().unwrap(), true);
}
#[tokio::test]
async fn test_timeout() {
let (tx, _rx) = mpsc::channel(32);
let channel: BidirChannel<StandardRequest, StandardResponse> =
BidirChannel::new_direct(tx, true, vec!["test".into()], "hash".into());
let result = channel
.request_with_timeout(
StandardRequest::Confirm {
message: "Test?".into(),
default: None,
},
Duration::from_millis(100),
)
.await;
assert!(matches!(result, Err(BidirError::Timeout(100))));
}
#[tokio::test]
async fn test_fallback() {
let (tx, _rx) = mpsc::channel(32);
let channel = Arc::new(BidirChannel::new_direct(
tx,
false, vec!["test".into()],
"hash".into(),
));
let fallback = BidirWithFallback::auto_confirm(channel);
let resp = fallback
.request(StandardRequest::Confirm {
message: "Test?".into(),
default: Some(false),
})
.await;
assert_eq!(resp, StandardResponse::Confirmed { value: false });
}
}