use crate::error::{Error, ErrorCode};
use crate::types::Message;
use std::future::Future;
use tokio_util::sync::CancellationToken;
#[cfg(feature = "http-server")]
pub use http::HttpServer;
#[cfg(feature = "server")]
pub(crate) use stdio::StdIoServer;
#[cfg(feature = "http-client")]
pub(crate) use http::HttpClient;
#[cfg(feature = "client")]
pub(crate) use stdio::StdIoClient;
#[cfg(any(feature = "http-server", feature = "http-client"))]
pub(crate) mod http;
pub(crate) mod stdio;
pub(crate) trait Sender {
fn send(&mut self, resp: Message) -> impl Future<Output = Result<(), Error>>;
}
pub(crate) trait Receiver {
fn recv(&mut self) -> impl Future<Output = Result<Message, Error>>;
}
pub(crate) trait Transport {
type Sender: Sender;
type Receiver: Receiver;
fn start(&mut self) -> CancellationToken;
fn split(self) -> (Self::Sender, Self::Receiver);
}
pub(crate) enum TransportProto {
None,
#[cfg(feature = "client")]
StdioClient(StdIoClient),
#[cfg(feature = "server")]
StdIoServer(StdIoServer),
#[cfg(feature = "http-server")]
HttpServer(Box<HttpServer>),
#[cfg(feature = "http-client")]
HttpClient(HttpClient),
}
#[derive(Clone)]
pub(crate) enum TransportProtoSender {
None,
Stdio(stdio::StdIoSender),
#[cfg(any(feature = "http-server", feature = "http-client"))]
Http(http::HttpSender),
#[cfg(feature = "server")]
BatchCollect {
real_sender: std::sync::Arc<tokio::sync::Mutex<TransportProtoSender>>,
responses: std::sync::Arc<std::sync::Mutex<Vec<crate::types::MessageEnvelope>>>,
},
}
pub(crate) enum TransportProtoReceiver {
None,
Stdio(stdio::StdIoReceiver),
#[cfg(any(feature = "http-server", feature = "http-client"))]
Http(http::HttpReceiver),
}
impl Default for TransportProto {
#[inline]
fn default() -> Self {
TransportProto::None
}
}
impl Sender for TransportProtoSender {
#[inline]
async fn send(&mut self, resp: Message) -> Result<(), Error> {
match self {
TransportProtoSender::Stdio(stdio) => stdio.send(resp).await,
#[cfg(any(feature = "http-server", feature = "http-client"))]
TransportProtoSender::Http(http) => http.send(resp).await,
TransportProtoSender::None => Err(Error::new(
ErrorCode::InternalError,
"Transport protocol must be specified",
)),
#[cfg(feature = "server")]
TransportProtoSender::BatchCollect {
real_sender,
responses,
} => match resp {
Message::Response(response) => {
if let Ok(mut guard) = responses.lock() {
guard.push(crate::types::MessageEnvelope::Response(response));
}
Ok(())
}
other => {
let mut guard = real_sender.lock().await;
Box::pin(guard.send(other)).await
}
},
}
}
}
impl Receiver for TransportProtoReceiver {
#[inline]
async fn recv(&mut self) -> Result<Message, Error> {
match self {
TransportProtoReceiver::Stdio(stdio) => stdio.recv().await,
#[cfg(any(feature = "http-server", feature = "http-client"))]
TransportProtoReceiver::Http(http) => http.recv().await,
TransportProtoReceiver::None => Err(Error::new(
ErrorCode::InternalError,
"Transport protocol must be specified",
)),
}
}
}
impl Transport for TransportProto {
type Sender = TransportProtoSender;
type Receiver = TransportProtoReceiver;
#[inline]
fn start(&mut self) -> CancellationToken {
match self {
#[cfg(feature = "server")]
TransportProto::StdIoServer(stdio) => stdio.start(),
#[cfg(feature = "client")]
TransportProto::StdioClient(stdio) => stdio.start(),
#[cfg(feature = "http-server")]
TransportProto::HttpServer(http) => http.start(),
#[cfg(feature = "http-client")]
TransportProto::HttpClient(http) => http.start(),
TransportProto::None => CancellationToken::new(),
}
}
fn split(self) -> (Self::Sender, Self::Receiver) {
match self {
#[cfg(feature = "server")]
TransportProto::StdIoServer(stdio) => {
let (tx, rx) = stdio.split();
(
TransportProtoSender::Stdio(tx),
TransportProtoReceiver::Stdio(rx),
)
}
#[cfg(feature = "http-server")]
TransportProto::HttpServer(http) => {
let (tx, rx) = http.split();
(
TransportProtoSender::Http(tx),
TransportProtoReceiver::Http(rx),
)
}
#[cfg(feature = "client")]
TransportProto::StdioClient(stdio) => {
let (tx, rx) = stdio.split();
(
TransportProtoSender::Stdio(tx),
TransportProtoReceiver::Stdio(rx),
)
}
#[cfg(feature = "http-client")]
TransportProto::HttpClient(http) => {
let (tx, rx) = http.split();
(
TransportProtoSender::Http(tx),
TransportProtoReceiver::Http(rx),
)
}
TransportProto::None => (TransportProtoSender::None, TransportProtoReceiver::None),
}
}
}