use std::{
fmt::{self, Debug},
pin::Pin,
str,
task::{Context, Poll},
};
#[cfg(feature = "http")]
use std::str::FromStr;
use bytes::{Bytes, BytesMut};
use futures_core::stream::Stream;
use tracing::debug;
use crate::{
ClientError, ClientResult,
io::AsyncRead,
meta::{EndRequestRec, HEADER_LEN, Header, RequestType},
};
#[cfg(feature = "http")]
use crate::{HttpConversionError, HttpConversionResult};
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct Response {
pub stdout: Option<Vec<u8>>,
pub stderr: Option<Vec<u8>>,
}
impl Debug for Response {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("Response")
.field("stdout", &self.stdout.as_deref().map(str::from_utf8))
.field("stderr", &self.stderr.as_deref().map(str::from_utf8))
.finish()
}
}
#[cfg(feature = "http")]
impl<B> TryFrom<::http::Response<B>> for Response
where
B: AsRef<[u8]>,
{
type Error = HttpConversionError;
fn try_from(response: ::http::Response<B>) -> Result<Self, Self::Error> {
let (parts, body) = response.into_parts();
let mut stdout = Vec::new();
if parts.status != ::http::StatusCode::OK {
stdout.extend_from_slice(format!("Status: {}\r\n", parts.status).as_bytes());
}
for (name, value) in &parts.headers {
stdout.extend_from_slice(name.as_str().as_bytes());
stdout.extend_from_slice(b": ");
stdout.extend_from_slice(value.as_bytes());
stdout.extend_from_slice(b"\r\n");
}
stdout.extend_from_slice(b"\r\n");
stdout.extend_from_slice(body.as_ref());
Ok(Response {
stdout: Some(stdout),
stderr: None,
})
}
}
#[cfg(feature = "http")]
impl TryFrom<Response> for ::http::Response<Vec<u8>> {
type Error = HttpConversionError;
fn try_from(response: Response) -> Result<Self, Self::Error> {
let stdout = response.stdout.unwrap_or_default();
let (header_bytes, body_bytes) = split_header_body(&stdout)?;
let mut status = ::http::StatusCode::OK;
let mut builder = ::http::Response::builder();
{
let headers = builder
.headers_mut()
.expect("response builder should provide headers");
for line in header_bytes.split(|byte| *byte == b'\n') {
let line = trim_cr(line);
if line.is_empty() {
continue;
}
let Some((name, value)) = line.split_first_colon() else {
return Err(HttpConversionError::MalformedHttpResponse {
message: "response header line is missing ':'",
});
};
if name.eq_ignore_ascii_case(b"Status") {
status = parse_status_header(value)?;
continue;
}
headers.append(
::http::header::HeaderName::from_bytes(name)?,
::http::header::HeaderValue::from_bytes(trim_start_ascii(value))?,
);
}
}
builder = builder.status(status);
Ok(builder.body(body_bytes.to_vec())?)
}
}
pub enum Content {
Stdout(Bytes),
Stderr(Bytes),
}
pub struct ResponseStream<S: AsyncRead + Unpin> {
stream: S,
id: u16,
eof: bool,
header: Option<Header>,
buf: BytesMut,
}
impl<S: AsyncRead + Unpin> ResponseStream<S> {
#[inline]
pub(crate) fn new(stream: S, id: u16) -> Self {
Self {
stream,
id,
eof: false,
header: None,
buf: BytesMut::new(),
}
}
#[inline]
fn read_header(&mut self) -> Option<Header> {
if self.buf.len() < HEADER_LEN {
return None;
}
let buf = self.buf.split_to(HEADER_LEN);
let header = (&buf as &[u8]).try_into().expect("failed to read header");
Some(Header::new_from_buf(header))
}
#[inline]
fn read_content(&mut self) -> Option<Bytes> {
let header = self.header.as_ref().unwrap();
let block_length = header.content_length as usize + header.padding_length as usize;
if self.buf.len() < block_length {
return None;
}
let content = self.buf.split_to(header.content_length as usize);
let _ = self.buf.split_to(header.padding_length as usize);
self.header = None;
Some(content.freeze())
}
fn process_message(&mut self) -> Result<Option<Content>, ClientError> {
if self.buf.is_empty() {
return Ok(None);
}
if self.header.is_none() {
match self.read_header() {
Some(header) => self.header = Some(header),
None => return Ok(None),
}
}
let header = self.header.as_ref().unwrap();
match header.r#type.clone() {
RequestType::Stdout => {
if let Some(data) = self.read_content() {
return Ok(Some(Content::Stdout(data)));
}
}
RequestType::Stderr => {
if let Some(data) = self.read_content() {
return Ok(Some(Content::Stderr(data)));
}
}
RequestType::EndRequest => {
let header = header.clone();
let Some(data) = self.read_content() else {
return Ok(None);
};
let end = EndRequestRec::new_from_buf(header, &data);
debug!(id = self.id, ?end, "Receive from stream.");
self.eof = true;
end.end_request
.protocol_status
.convert_to_client_result(end.end_request.app_status)?;
return Ok(None);
}
r#type => {
self.eof = true;
return Err(ClientError::UnknownRequestType {
request_type: r#type,
});
}
}
Ok(None)
}
fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll<ClientResult<Option<()>>> {
let mut chunk = [0; 8192];
match Pin::new(&mut self.stream).poll_read(cx, &mut chunk) {
Poll::Ready(Ok(0)) => Poll::Ready(Ok(None)),
Poll::Ready(Ok(read)) => {
self.buf.extend_from_slice(&chunk[..read]);
Poll::Ready(Ok(Some(())))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())),
Poll::Pending => Poll::Pending,
}
}
}
impl<S> Stream for ResponseStream<S>
where
S: AsyncRead + Unpin,
{
type Item = ClientResult<Content>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut pending = false;
loop {
match self.poll_fill_buf(cx) {
Poll::Ready(Ok(Some(()))) => match self.process_message() {
Ok(Some(data)) => return Poll::Ready(Some(Ok(data))),
Ok(None) if self.eof => return Poll::Ready(None),
Ok(None) => continue,
Err(err) => return Poll::Ready(Some(Err(err))),
},
Poll::Ready(Ok(None)) => break,
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
Poll::Pending => {
pending = true;
break;
}
}
}
match self.process_message() {
Ok(Some(data)) => Poll::Ready(Some(Ok(data))),
Ok(None) if !self.eof && pending => Poll::Pending,
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err))),
}
}
}
#[cfg(feature = "http")]
fn split_header_body(stdout: &[u8]) -> HttpConversionResult<(&[u8], &[u8])> {
if let Some(offset) = stdout.windows(4).position(|window| window == b"\r\n\r\n") {
return Ok((&stdout[..offset], &stdout[offset + 4..]));
}
if let Some(offset) = stdout.windows(2).position(|window| window == b"\n\n") {
return Ok((&stdout[..offset], &stdout[offset + 2..]));
}
Err(HttpConversionError::MalformedHttpResponse {
message: "response does not contain a header/body separator",
})
}
#[cfg(feature = "http")]
fn parse_status_header(value: &[u8]) -> HttpConversionResult<::http::StatusCode> {
let value = str::from_utf8(trim_start_ascii(value)).map_err(|_| {
HttpConversionError::MalformedHttpResponse {
message: "status header is not valid UTF-8",
}
})?;
let Some(code) = value.split_whitespace().next() else {
return Err(HttpConversionError::MalformedHttpResponse {
message: "status header is empty",
});
};
Ok(::http::StatusCode::from_str(code)?)
}
#[cfg(feature = "http")]
fn trim_cr(line: &[u8]) -> &[u8] {
line.strip_suffix(b"\r").unwrap_or(line)
}
#[cfg(feature = "http")]
fn trim_start_ascii(bytes: &[u8]) -> &[u8] {
let index = bytes
.iter()
.position(|byte| !byte.is_ascii_whitespace())
.unwrap_or(bytes.len());
&bytes[index..]
}
#[cfg(feature = "http")]
trait SplitFirstColon {
fn split_first_colon(&self) -> Option<(&[u8], &[u8])>;
}
#[cfg(feature = "http")]
impl SplitFirstColon for [u8] {
fn split_first_colon(&self) -> Option<(&[u8], &[u8])> {
let offset = self.iter().position(|byte| *byte == b':')?;
Some((&self[..offset], &self[offset + 1..]))
}
}
#[cfg(all(test, feature = "http"))]
mod http_tests {
use crate::Response;
#[test]
fn response_into_http_defaults_status_to_ok() {
let response = Response {
stdout: Some(b"Content-type: text/plain\r\nX-Test: 1\r\n\r\nhello".to_vec()),
stderr: Some(b"notice".to_vec()),
};
let response: ::http::Response<Vec<u8>> = response.try_into().unwrap();
assert_eq!(response.status(), ::http::StatusCode::OK);
assert_eq!(response.headers()["content-type"], "text/plain");
assert_eq!(response.body(), b"hello");
}
#[test]
fn response_from_http_serializes_status_and_headers() {
let response = ::http::Response::builder()
.status(::http::StatusCode::CREATED)
.header(::http::header::CONTENT_TYPE, "text/plain")
.body(b"hello".to_vec())
.unwrap();
let response = Response::try_from(response).unwrap();
let stdout = String::from_utf8(response.stdout.unwrap()).unwrap();
assert!(stdout.starts_with("Status: 201 Created\r\n"));
assert!(stdout.contains("content-type: text/plain\r\n\r\nhello"));
}
}