use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use bytes::Bytes;
use monoio::net::TcpStream;
use harrow_core::dispatch::{SharedState, dispatch};
use harrow_core::request::Body;
use harrow_core::response::Response;
use crate::o11y::ConnectionMetrics;
use crate::protocol::ProtocolError;
pub(crate) struct H2Config {
pub shared: Arc<SharedState>,
pub header_read_timeout: Option<Duration>,
pub body_read_timeout: Option<Duration>,
pub connection_timeout: Option<Duration>,
pub max_concurrent_streams: u32,
pub metrics: ConnectionMetrics,
}
pub(crate) struct H2Connection {
stream: TcpStream,
config: H2Config,
}
struct ActiveStreamGuard {
active_streams: Arc<AtomicUsize>,
}
impl ActiveStreamGuard {
fn new(active_streams: Arc<AtomicUsize>) -> Self {
active_streams.fetch_add(1, Ordering::AcqRel);
Self { active_streams }
}
}
impl Drop for ActiveStreamGuard {
fn drop(&mut self) {
self.active_streams.fetch_sub(1, Ordering::AcqRel);
}
}
impl H2Connection {
pub(crate) fn new(stream: TcpStream, config: H2Config) -> Self {
Self { stream, config }
}
pub(crate) async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
let H2Connection { stream, config } = self;
let H2Config {
shared,
header_read_timeout,
body_read_timeout,
connection_timeout,
max_concurrent_streams,
metrics,
} = config;
let metrics_id = metrics.id;
let connection_deadline =
connection_timeout.and_then(|timeout| Instant::now().checked_add(timeout));
let mut builder = monoio_http::h2::server::Builder::new();
builder.max_concurrent_streams(max_concurrent_streams);
let handshake_timeout = effective_timeout(header_read_timeout, connection_deadline);
let mut connection = if let Some(timeout) = handshake_timeout {
match monoio::select! {
result = builder.handshake(stream) => Ok(result),
_ = monoio::time::sleep(timeout) => Err(()),
} {
Ok(result) => result.map_err(|e| format!("h2 handshake failed: {}", e))?,
Err(()) => {
if deadline_expired(connection_deadline) {
tracing::debug!(
connection.id = metrics_id,
"h2 connection timeout during handshake"
);
} else {
tracing::debug!(
connection.id = metrics_id,
"h2 header read timeout during handshake"
);
}
let _duration = metrics.close();
return Ok(());
}
}
} else {
builder
.handshake(stream)
.await
.map_err(|e| format!("h2 handshake failed: {}", e))?
};
tracing::debug!(connection.id = metrics_id, "h2 connection established");
let active_streams = Arc::new(AtomicUsize::new(0));
loop {
let accept_timeout = effective_timeout(header_read_timeout, connection_deadline);
let accept_result = if let Some(timeout) = accept_timeout {
match monoio::select! {
result = connection.accept() => Ok(result),
_ = monoio::time::sleep(timeout) => Err(()),
} {
Ok(result) => result,
Err(()) => {
if deadline_expired(connection_deadline) {
tracing::debug!(connection.id = metrics_id, "h2 connection timeout");
break;
} else if active_streams.load(Ordering::Acquire) == 0 {
tracing::debug!(connection.id = metrics_id, "h2 header read timeout");
break;
}
continue;
}
}
} else {
connection.accept().await
};
match accept_result {
Some(Ok((request, respond))) => {
let shared = Arc::clone(&shared);
let max_body = shared.max_body_size;
let active_streams = Arc::clone(&active_streams);
let active_stream = ActiveStreamGuard::new(active_streams);
monoio::spawn(async move {
let _active_stream = active_stream;
if let Err(e) =
handle_stream(request, respond, shared, max_body, body_read_timeout)
.await
{
tracing::debug!(
connection.id = metrics_id,
error = %e,
"h2 stream error"
);
}
});
}
Some(Err(e)) => {
tracing::debug!(
connection.id = metrics_id,
error = %e,
"h2 accept error"
);
}
None => {
tracing::debug!(connection.id = metrics_id, "h2 connection closed by peer");
break;
}
}
}
let _duration = metrics.close();
Ok(())
}
}
fn effective_timeout(
timeout: Option<Duration>,
connection_deadline: Option<Instant>,
) -> Option<Duration> {
match (timeout, remaining_until(connection_deadline)) {
(Some(timeout), Some(remaining)) => Some(timeout.min(remaining)),
(Some(timeout), None) => Some(timeout),
(None, Some(remaining)) => Some(remaining),
(None, None) => None,
}
}
fn remaining_until(deadline: Option<Instant>) -> Option<Duration> {
deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()))
}
fn deadline_expired(deadline: Option<Instant>) -> bool {
deadline.is_some_and(|deadline| Instant::now() >= deadline)
}
async fn handle_stream(
mut request: http::Request<monoio_http::h2::RecvStream>,
mut respond: monoio_http::h2::server::SendResponse<bytes::Bytes>,
shared: Arc<SharedState>,
max_body: usize,
body_read_timeout: Option<Duration>,
) -> Result<(), Box<dyn std::error::Error>> {
let body_bytes = match read_h2_body(&mut request, max_body, body_read_timeout).await {
Ok(body) => body,
Err(ProtocolError::BodyTooLarge) => {
send_h2_response(
&mut respond,
Response::new(http::StatusCode::PAYLOAD_TOO_LARGE, "payload too large")
.into_inner(),
)
.await?;
return Ok(());
}
Err(ProtocolError::Timeout) => {
send_h2_response(
&mut respond,
Response::new(http::StatusCode::REQUEST_TIMEOUT, "request timeout").into_inner(),
)
.await?;
return Ok(());
}
Err(err) => return Err(Box::new(err)),
};
let harrow_request = convert_to_harrow_request(request, body_bytes)?;
let harrow_response = dispatch(shared, harrow_request).await;
send_h2_response(&mut respond, harrow_response).await?;
Ok(())
}
async fn read_h2_body(
request: &mut http::Request<monoio_http::h2::RecvStream>,
max_body: usize,
body_read_timeout: Option<Duration>,
) -> Result<Bytes, ProtocolError> {
let body = request.body_mut();
let mut chunks = Vec::new();
let mut total_len: usize = 0;
loop {
let data = if let Some(timeout) = body_read_timeout {
monoio::select! {
data = body.data() => data,
_ = monoio::time::sleep(timeout) => return Err(ProtocolError::Timeout),
}
} else {
body.data().await
};
let Some(data) = data else {
break;
};
let data = data.map_err(|e| ProtocolError::Parse(format!("h2 body error: {e}")))?;
let len = data.len();
if max_body > 0 && total_len + len > max_body {
return Err(ProtocolError::BodyTooLarge);
}
total_len += len;
chunks.push(data);
body.flow_control()
.release_capacity(len)
.map_err(|e| ProtocolError::ProtocolViolation(e.to_string()))?;
}
let mut result = bytes::BytesMut::with_capacity(total_len);
for chunk in chunks {
result.extend_from_slice(&chunk);
}
Ok(result.freeze())
}
fn convert_to_harrow_request(
request: http::Request<monoio_http::h2::RecvStream>,
body_bytes: Bytes,
) -> Result<http::Request<Body>, Box<dyn std::error::Error>> {
let (parts, _) = request.into_parts();
let body = crate::protocol::body_from_bytes(body_bytes);
Ok(http::Request::from_parts(parts, body))
}
async fn send_h2_response(
respond: &mut monoio_http::h2::server::SendResponse<bytes::Bytes>,
response: http::Response<harrow_core::response::ResponseBody>,
) -> Result<(), Box<dyn std::error::Error>> {
use http_body_util::BodyExt;
let (parts, mut body) = response.into_parts();
let mut builder = http::Response::builder()
.status(parts.status)
.version(http::Version::HTTP_2);
for (name, value) in &parts.headers {
builder = builder.header(name, value);
}
let response = builder.body(()).expect("valid response");
let mut body_data = Vec::new();
while let Some(frame) = body.frame().await {
let frame = frame.map_err(|e| format!("body frame error: {}", e))?;
if let Ok(data) = frame.into_data() {
body_data.push(data);
}
}
let mut stream = if body_data.is_empty() {
respond.send_response(response, true)?;
return Ok(());
} else {
respond.send_response(response, false)?
};
let total_chunks = body_data.len();
for (i, data) in body_data.into_iter().enumerate() {
let is_end = i == total_chunks - 1;
stream.send_data(data, is_end)?;
}
Ok(())
}
pub(crate) async fn handle_connection(
stream: TcpStream,
remote_addr: Option<SocketAddr>,
shared: Arc<SharedState>,
header_read_timeout: Option<Duration>,
body_read_timeout: Option<Duration>,
connection_timeout: Option<Duration>,
max_concurrent_streams: u32,
active_count: std::rc::Rc<std::cell::Cell<usize>>,
) {
use crate::o11y::connection_span;
use tracing::Instrument;
let metrics = ConnectionMetrics::new(active_count);
let span = connection_span(metrics.id, remote_addr);
let config = H2Config {
shared,
header_read_timeout,
body_read_timeout,
connection_timeout,
max_concurrent_streams,
metrics,
};
let conn = H2Connection::new(stream, config);
let connection_id = conn.config.metrics.id;
if let Err(e) = conn.run().instrument(span).await {
tracing::debug!(
connection.id = connection_id,
error = %e,
"h2 connection error"
);
}
}
#[cfg(test)]
mod tests {
}