use std::convert::Infallible;
use std::process::ExitStatus;
use std::sync::Arc;
use anyhow::{Context, Result, anyhow};
use http_body_util::{BodyExt, Full, StreamBody, combinators::BoxBody};
use hyper::body::{Bytes, Frame, Incoming};
use hyper::header::CONTENT_TYPE;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode, Version};
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use crate::runtime::prepare::PreparedCommand;
use crate::runtime::process::{spawn_stream_child, terminate_child};
const H2_STREAM_CONTENT_TYPE: &str = "application/octet-stream";
const STREAM_BUFFER_CAPACITY: usize = 16;
const COPY_BUFFER_SIZE: usize = 8192;
type ResponseBody = BoxBody<Bytes, Infallible>;
type ResponseFrame = Result<Frame<Bytes>, Infallible>;
pub async fn serve_h2(
prepared: PreparedCommand,
subject: &str,
host: &str,
port: u16,
) -> Result<ExitStatus> {
let bound_listener = bind_listener(host, port).await?;
eprintln!(
"Running {} over http://{} (HTTP/2 stream transport)",
subject, bound_listener.display_address
);
let (socket, peer_address) = bound_listener.listener.accept().await.with_context(|| {
format!(
"failed to accept HTTP connection on {}",
bound_listener.display_address
)
})?;
eprintln!("Accepted connection from {}", peer_address);
serve_h2_connection(prepared, subject, socket).await
}
async fn serve_h2_connection(
prepared: PreparedCommand,
subject: &str,
socket: TcpStream,
) -> Result<ExitStatus> {
let _temp_dir = prepared.temp_dir;
let mut child = spawn_stream_child(&prepared.spec, subject)?;
let child_stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("child process missing piped stdin"))?;
let child_stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("child process missing piped stdout"))?;
let session = Arc::new(Mutex::new(Some(StreamSession {
child_stdin,
child_stdout,
})));
let service = service_fn(move |request| {
let session = session.clone();
async move { handle_h2_request(request, session).await }
});
let connection = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(socket), service);
tokio::pin!(connection);
tokio::select! {
status = child.wait() => {
let status = status.context("failed while waiting on child process")?;
connection.as_mut().graceful_shutdown();
connection
.await
.map_err(|source| anyhow!("HTTP/2 connection failed: {source}"))?;
Ok(status)
}
connection_result = &mut connection => {
match connection_result {
Ok(()) => terminate_child(&mut child).await,
Err(source) => {
let _ = terminate_child(&mut child).await;
Err(anyhow!("HTTP/2 connection failed: {source}"))
}
}
}
}
}
async fn handle_h2_request(
request: Request<Incoming>,
session: Arc<Mutex<Option<StreamSession>>>,
) -> Result<Response<ResponseBody>, Infallible> {
if request.version() != Version::HTTP_2 {
return Ok(text_response(
StatusCode::HTTP_VERSION_NOT_SUPPORTED,
"HTTP/2 is required for this transport.\n",
));
}
if request.method() != Method::POST {
return Ok(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"POST is required for this transport.\n",
));
}
if !content_type_is_stream(request.headers().get(CONTENT_TYPE)) {
return Ok(text_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
format!("Content-Type: {H2_STREAM_CONTENT_TYPE} is required.\n"),
));
}
let Some(session) = session.lock().await.take() else {
return Ok(text_response(
StatusCode::CONFLICT,
"This transport accepts only one active HTTP/2 stream.\n",
));
};
let (response_tx, response_rx) =
tokio::sync::mpsc::channel::<ResponseFrame>(STREAM_BUFFER_CAPACITY);
tokio::spawn(pump_request_body_to_child(
request.into_body(),
session.child_stdin,
));
tokio::spawn(pump_child_stdout_to_response(
session.child_stdout,
response_tx,
));
Ok(stream_response(response_rx))
}
async fn pump_request_body_to_child(mut body: Incoming, mut child_stdin: ChildStdin) {
while let Some(frame) = body.frame().await {
let Ok(frame) = frame else {
break;
};
let Ok(data) = frame.into_data() else {
continue;
};
if child_stdin.write_all(&data).await.is_err() {
break;
}
}
let _ = child_stdin.shutdown().await;
}
async fn pump_child_stdout_to_response(
mut child_stdout: ChildStdout,
response_tx: tokio::sync::mpsc::Sender<ResponseFrame>,
) {
let mut buffer = [0_u8; COPY_BUFFER_SIZE];
loop {
let read = match child_stdout.read(&mut buffer).await {
Ok(0) => break,
Ok(read) => read,
Err(_) => break,
};
if response_tx
.send(Ok(Frame::data(Bytes::copy_from_slice(&buffer[..read]))))
.await
.is_err()
{
break;
}
}
}
fn stream_response(
response_rx: tokio::sync::mpsc::Receiver<ResponseFrame>,
) -> Response<ResponseBody> {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, H2_STREAM_CONTENT_TYPE)
.body(StreamBody::new(ReceiverStream::new(response_rx)).boxed())
.expect("known-good response")
}
fn text_response(status: StatusCode, message: impl Into<Bytes>) -> Response<ResponseBody> {
Response::builder()
.status(status)
.header(CONTENT_TYPE, "text/plain; charset=utf-8")
.body(Full::new(message.into()).boxed())
.expect("known-good response")
}
fn content_type_is_stream(content_type: Option<&hyper::header::HeaderValue>) -> bool {
content_type
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.eq_ignore_ascii_case(H2_STREAM_CONTENT_TYPE))
}
async fn bind_listener(host: &str, port: u16) -> Result<BoundListener> {
let listener = TcpListener::bind((host, port)).await.with_context(|| {
format!(
"failed to bind HTTP listener on {}",
bind_target_display(host, port)
)
})?;
let display_address = listener
.local_addr()
.map(|address| address.to_string())
.unwrap_or_else(|_| bind_target_display(host, port));
Ok(BoundListener {
listener,
display_address,
})
}
fn bind_target_display(host: &str, port: u16) -> String {
format!("host={host}, port={port}")
}
struct StreamSession {
child_stdin: ChildStdin,
child_stdout: ChildStdout,
}
struct BoundListener {
listener: TcpListener,
display_address: String,
}
#[cfg(test)]
mod tests {
use std::ffi::OsString;
use std::net::SocketAddr;
use std::time::Duration;
use super::*;
use crate::runtime::prepare::{CommandSpec, PreparedCommand};
#[test]
fn accepts_stream_content_type() {
let header = hyper::header::HeaderValue::from_static(H2_STREAM_CONTENT_TYPE);
assert!(content_type_is_stream(Some(&header)));
}
#[test]
fn rejects_non_stream_content_type() {
let header = hyper::header::HeaderValue::from_static("application/json");
assert!(!content_type_is_stream(Some(&header)));
}
#[cfg(unix)]
#[tokio::test]
async fn http_transport_streams_full_duplex_over_h2() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
run_command_h2_with_listener(
prepared_command_with_program(
OsString::from("sh"),
vec![
OsString::from("-c"),
OsString::from("printf 'boot\\n'; cat"),
],
),
"demo-agent",
BoundListener {
listener,
display_address: address.to_string(),
},
)
.await
});
let (mut sender, connection) = h2_handshake(address).await;
tokio::spawn(async move {
let _ = connection.await;
});
let (request_tx, request_body) = request_body_channel();
let request: Request<ResponseBody> = Request::builder()
.method(Method::POST)
.uri(format!("http://{address}/"))
.version(Version::HTTP_2)
.header(CONTENT_TYPE, H2_STREAM_CONTENT_TYPE)
.body(request_body)
.unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.version(), Version::HTTP_2);
let mut response_body = response.into_body();
let first_chunk = tokio::time::timeout(
Duration::from_secs(2),
read_next_data_frame(&mut response_body),
)
.await
.unwrap()
.unwrap();
assert_eq!(first_chunk.as_ref(), b"boot\n");
request_tx
.send(Ok(Frame::data(Bytes::from_static(b"ping\n"))))
.await
.unwrap();
drop(request_tx);
let rest = response_body.collect().await.unwrap().to_bytes();
assert_eq!(rest.as_ref(), b"ping\n");
let status = tokio::time::timeout(Duration::from_secs(2), server)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(status.success());
}
#[cfg(unix)]
#[tokio::test]
async fn http_transport_rejects_second_stream() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
run_command_h2_with_listener(
prepared_command_with_program(
OsString::from("sh"),
vec![OsString::from("-c"), OsString::from("cat")],
),
"demo-agent",
BoundListener {
listener,
display_address: address.to_string(),
},
)
.await
});
let (mut sender, connection) = h2_handshake(address).await;
tokio::spawn(async move {
let _ = connection.await;
});
let (request_tx, request_body) = request_body_channel();
let first: Request<ResponseBody> = Request::builder()
.method(Method::POST)
.uri(format!("http://{address}/"))
.version(Version::HTTP_2)
.header(CONTENT_TYPE, H2_STREAM_CONTENT_TYPE)
.body(request_body)
.unwrap();
let first_response = sender.send_request(first).await.unwrap();
assert_eq!(first_response.status(), StatusCode::OK);
let second: Request<ResponseBody> = Request::builder()
.method(Method::POST)
.uri(format!("http://{address}/"))
.version(Version::HTTP_2)
.header(CONTENT_TYPE, H2_STREAM_CONTENT_TYPE)
.body(Full::new(Bytes::new()).boxed())
.unwrap();
let second_response = sender.send_request(second).await.unwrap();
assert_eq!(second_response.status(), StatusCode::CONFLICT);
request_tx
.send(Ok(Frame::data(Bytes::from_static(b"done"))))
.await
.unwrap();
drop(request_tx);
let _ = first_response.into_body().collect().await.unwrap();
let status = tokio::time::timeout(Duration::from_secs(2), server)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(status.success());
}
async fn h2_handshake(
address: SocketAddr,
) -> (
hyper::client::conn::http2::SendRequest<ResponseBody>,
hyper::client::conn::http2::Connection<TokioIo<TcpStream>, ResponseBody, TokioExecutor>,
) {
let stream = TcpStream::connect(address).await.unwrap();
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(TokioIo::new(stream))
.await
.unwrap()
}
fn request_body_channel() -> (tokio::sync::mpsc::Sender<ResponseFrame>, ResponseBody) {
let (tx, rx) = tokio::sync::mpsc::channel::<ResponseFrame>(STREAM_BUFFER_CAPACITY);
(tx, StreamBody::new(ReceiverStream::new(rx)).boxed())
}
async fn read_next_data_frame(body: &mut Incoming) -> Option<Bytes> {
while let Some(frame) = body.frame().await {
let frame = frame.unwrap();
if let Ok(data) = frame.into_data() {
return Some(data);
}
}
None
}
#[cfg(unix)]
async fn run_command_h2_with_listener(
prepared: PreparedCommand,
subject: &str,
bound_listener: BoundListener,
) -> Result<ExitStatus> {
eprintln!(
"Running {} over http://{} (HTTP/2 stream transport)",
subject, bound_listener.display_address
);
let (socket, _) = bound_listener.listener.accept().await.with_context(|| {
format!(
"failed to accept HTTP connection on {}",
bound_listener.display_address
)
})?;
serve_h2_connection(prepared, subject, socket).await
}
#[cfg(unix)]
fn prepared_command_with_program(program: OsString, args: Vec<OsString>) -> PreparedCommand {
PreparedCommand {
spec: CommandSpec {
program,
args,
env: Vec::new(),
current_dir: None,
},
temp_dir: None,
}
}
}