use crate::{
ClientError, ClientResult, Response,
conn::{KeepAlive, Mode, ShortConn},
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
meta::{BeginRequestRec, EndRequestRec, Header, ParamPairs, RequestType, Role},
params::Params,
request::Request,
response::ResponseStream,
};
use std::marker::PhantomData;
use tracing::debug;
#[cfg(feature = "runtime-tokio")]
use crate::io::{TokioAsyncReadCompatExt, TokioCompat};
const REQUEST_ID: u16 = 1;
pub struct Client<S, M> {
stream: S,
_mode: PhantomData<M>,
}
impl<S, M> Client<S, M> {
fn from_stream(stream: S) -> Self {
Self {
stream,
_mode: PhantomData,
}
}
}
#[cfg(feature = "runtime-tokio")]
impl<S> Client<TokioCompat<S>, ShortConn>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
pub fn new_tokio(stream: S) -> Self {
Self::from_stream(stream.compat())
}
}
#[cfg(feature = "runtime-smol")]
impl<S> Client<S, ShortConn>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new_smol(stream: S) -> Self {
Self::from_stream(stream)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<S, ShortConn> {
pub async fn execute_once<I: AsyncRead + Unpin>(
mut self, request: Request<'_, I>,
) -> ClientResult<Response> {
self.inner_execute(request).await
}
pub async fn execute_once_stream<I: AsyncRead + Unpin>(
mut self, request: Request<'_, I>,
) -> ClientResult<ResponseStream<S>> {
Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
Ok(ResponseStream::new(self.stream, REQUEST_ID))
}
}
#[cfg(feature = "runtime-tokio")]
impl<S> Client<TokioCompat<S>, KeepAlive>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
pub fn new_keep_alive_tokio(stream: S) -> Self {
Self::from_stream(stream.compat())
}
}
#[cfg(feature = "runtime-smol")]
impl<S> Client<S, KeepAlive>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new_keep_alive_smol(stream: S) -> Self {
Self::from_stream(stream)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<S, KeepAlive> {
pub async fn execute<I: AsyncRead + Unpin>(
&mut self, request: Request<'_, I>,
) -> ClientResult<Response> {
self.inner_execute(request).await
}
pub async fn execute_stream<I: AsyncRead + Unpin>(
&mut self, request: Request<'_, I>,
) -> ClientResult<ResponseStream<&mut S>> {
Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
Ok(ResponseStream::new(&mut self.stream, REQUEST_ID))
}
}
impl<S: AsyncRead + AsyncWrite + Unpin, M: Mode> Client<S, M> {
async fn inner_execute<I: AsyncRead + Unpin>(
&mut self, request: Request<'_, I>,
) -> ClientResult<Response> {
Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
Self::handle_response(&mut self.stream, REQUEST_ID).await
}
async fn handle_request<I: AsyncRead + Unpin>(
stream: &mut S, id: u16, params: Params<'_>, mut body: I,
) -> ClientResult<()> {
Self::handle_request_start(stream, id).await?;
Self::handle_request_params(stream, id, params).await?;
Self::handle_request_body(stream, id, &mut body).await?;
Self::handle_request_flush(stream).await?;
Ok(())
}
async fn handle_request_start(stream: &mut S, id: u16) -> ClientResult<()> {
debug!(id, "Start handle request");
let begin_request_rec =
BeginRequestRec::new(id, Role::Responder, <M>::is_keep_alive()).await?;
debug!(id, ?begin_request_rec, "Send to stream.");
begin_request_rec.write_to_stream(stream).await?;
Ok(())
}
async fn handle_request_params(
stream: &mut S, id: u16, params: Params<'_>,
) -> ClientResult<()> {
let param_pairs = ParamPairs::new(params);
debug!(id, ?param_pairs, "Params will be sent.");
Header::write_to_stream_batches(
RequestType::Params,
id,
stream,
&mut ¶m_pairs.to_content().await?[..],
Some(|header| {
debug!(id, ?header, "Send to stream for Params.");
header
}),
)
.await?;
Header::write_to_stream_batches(
RequestType::Params,
id,
stream,
&mut io::empty(),
Some(|header| {
debug!(id, ?header, "Send to stream for Params.");
header
}),
)
.await?;
Ok(())
}
async fn handle_request_body<I: AsyncRead + Unpin>(
stream: &mut S, id: u16, body: &mut I,
) -> ClientResult<()> {
Header::write_to_stream_batches(
RequestType::Stdin,
id,
stream,
body,
Some(|header| {
debug!(id, ?header, "Send to stream for Stdin.");
header
}),
)
.await?;
Header::write_to_stream_batches(
RequestType::Stdin,
id,
stream,
&mut io::empty(),
Some(|header| {
debug!(id, ?header, "Send to stream for Stdin.");
header
}),
)
.await?;
Ok(())
}
async fn handle_request_flush(stream: &mut S) -> ClientResult<()> {
stream.flush().await?;
Ok(())
}
async fn handle_response(stream: &mut S, id: u16) -> ClientResult<Response> {
let mut response = Response::default();
let mut stderr = Vec::new();
let mut stdout = Vec::new();
loop {
let header = Header::new_from_stream(stream).await?;
if header.request_id != id {
return Err(ClientError::ResponseNotFound { id });
}
debug!(id, ?header, "Receive from stream.");
match header.r#type {
RequestType::Stdout => {
stdout.extend(header.read_content_from_stream(stream).await?);
}
RequestType::Stderr => {
stderr.extend(header.read_content_from_stream(stream).await?);
}
RequestType::EndRequest => {
let end_request_rec = EndRequestRec::from_header(&header, stream).await?;
debug!(id, ?end_request_rec, "Receive from stream.");
end_request_rec
.end_request
.protocol_status
.convert_to_client_result(end_request_rec.end_request.app_status)?;
response.stdout = if stdout.is_empty() {
None
} else {
Some(stdout)
};
response.stderr = if stderr.is_empty() {
None
} else {
Some(stderr)
};
return Ok(response);
}
r#type => {
return Err(ClientError::UnknownRequestType {
request_type: r#type,
});
}
}
}
}
}