pub mod handler;
pub mod pool;
use std::net::SocketAddr;
use std::sync::Arc;
use futures::stream::StreamExt;
use futures::SinkExt;
use quiche::h3::NameValue;
use tokio_quiche::buf_factory::BufFactory;
use tokio_quiche::http3::driver::{
H3Event, InboundFrame, IncomingH3Headers, OutboundFrame, ServerH3Event,
};
use tokio_quiche::http3::settings::Http3Settings;
use tokio_quiche::metrics::DefaultMetrics;
use tokio_quiche::quic::SimpleConnectionIdGenerator;
use tokio_quiche::settings::TlsCertificatePaths;
use tokio_quiche::{ConnectionParams, ServerH3Driver};
use crate::codec::{Codec, DefaultCodec};
use crate::error::KnafehError;
use crate::rpc::message::Metadata;
use crate::rpc::middleware::{Interceptor, MiddlewareStack};
use crate::rpc::router::MethodRouter;
use crate::rpc::service::Service;
use crate::transport::connection::{
build_response_headers, extract_method_path, validate_metadata_key, RPC_HEADER_PREFIX,
RPC_METHOD_KIND_HEADER,
};
use crate::transport::quic_wire::MAX_MESSAGE_SIZE;
use crate::transport::tls::TlsConfig;
use self::handler::RequestHandler;
use self::pool::ConnectionPool;
pub struct Server {
bind_addr: SocketAddr,
tls_config: TlsConfig,
handler: Arc<RequestHandler>,
connection_pool: Arc<ConnectionPool>,
}
impl Server {
pub fn builder() -> ServerBuilder {
ServerBuilder::new()
}
pub fn bind_addr(&self) -> SocketAddr {
self.bind_addr
}
pub async fn serve(&self) -> Result<(), KnafehError> {
tracing::info!(addr = %self.bind_addr, "starting Knafeh RPC server");
let socket = tokio::net::UdpSocket::bind(self.bind_addr)
.await
.map_err(|e| KnafehError::Transport(format!("failed to bind UDP socket: {e}")))?;
tracing::info!(
addr = %socket.local_addr().map_err(KnafehError::Io)?,
"server listening"
);
self.serve_quic(socket).await
}
pub async fn serve_with_ready_signal(
&self,
ready: tokio::sync::oneshot::Sender<SocketAddr>,
) -> Result<(), KnafehError> {
let socket = tokio::net::UdpSocket::bind(self.bind_addr)
.await
.map_err(|e| KnafehError::Transport(format!("failed to bind UDP socket: {e}")))?;
let local_addr = socket.local_addr().map_err(KnafehError::Io)?;
let _ = ready.send(local_addr);
self.serve_quic(socket).await
}
async fn serve_quic(&self, socket: tokio::net::UdpSocket) -> Result<(), KnafehError> {
let (cert, private_key) = validate_server_tls(&self.tls_config)?;
let tls_cert = TlsCertificatePaths {
cert,
private_key,
kind: Default::default(),
};
let conn_params =
ConnectionParams::new_server(Default::default(), tls_cert, Default::default());
let handler = Arc::clone(&self.handler);
let pool = Arc::clone(&self.connection_pool);
let mut listeners = tokio_quiche::listen(
[socket],
conn_params,
SimpleConnectionIdGenerator,
DefaultMetrics,
)
.map_err(|e| KnafehError::Transport(format!("failed to start QUIC listener: {e}")))?;
let mut accept_stream = listeners.remove(0);
while let Some(conn_result) = accept_stream.next().await {
let initial_conn = match conn_result {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "error accepting QUIC connection");
continue;
}
};
let guard = match pool::ConnectionGuard::new(Arc::clone(&pool)) {
Some(g) => g,
None => {
tracing::warn!("connection pool exhausted, rejecting connection");
continue;
}
};
let handler = Arc::clone(&handler);
tokio::spawn(async move {
let _guard = guard;
if let Err(e) = Self::handle_connection(initial_conn, handler).await {
tracing::error!(error = %e, "connection handler error");
}
});
}
Ok(())
}
async fn handle_connection<M: tokio_quiche::metrics::Metrics>(
initial_conn: tokio_quiche::InitialQuicConnection<tokio::net::UdpSocket, M>,
handler: Arc<RequestHandler>,
) -> Result<(), KnafehError> {
let (driver, mut controller) = ServerH3Driver::new(Http3Settings::default());
let _quic_conn = initial_conn.start(driver);
tracing::debug!("new HTTP/3 connection established");
while let Some(event) = controller.event_receiver_mut().recv().await {
match event {
ServerH3Event::Core(core_event) => {
Self::handle_h3_event(core_event, &handler).await;
}
}
}
tracing::debug!("HTTP/3 connection closed");
Ok(())
}
async fn handle_h3_event(event: H3Event, handler: &RequestHandler) {
match event {
H3Event::IncomingHeaders(incoming) => {
let handler = handler.clone();
tokio::spawn(async move {
Self::handle_request(incoming, &handler).await;
});
}
H3Event::BodyBytesReceived {
stream_id,
num_bytes,
fin,
} => {
tracing::trace!(stream_id, num_bytes, fin, "body bytes received");
}
H3Event::ResetStream { stream_id } => {
tracing::debug!(stream_id, "stream reset");
}
H3Event::ConnectionError(e) => {
tracing::error!(error = ?e, "HTTP/3 connection error");
}
H3Event::ConnectionShutdown(reason) => {
tracing::info!(reason = ?reason, "HTTP/3 connection shutdown");
}
_ => {
tracing::trace!("unhandled H3 event");
}
}
}
async fn handle_request(incoming: IncomingH3Headers, handler: &RequestHandler) {
let IncomingH3Headers {
stream_id,
headers,
mut send,
mut recv,
read_fin,
..
} = incoming;
let mut method_path = String::new();
let mut method_kind_is_streaming = false;
let mut metadata = Metadata::new();
let prefix_bytes = RPC_HEADER_PREFIX.as_bytes();
let method_kind_bytes = RPC_METHOD_KIND_HEADER.as_bytes();
for header in &headers {
let name = header.name();
let value = header.value();
if name == b":path" {
method_path = match extract_method_path(value) {
Ok(p) => p,
Err(e) => {
tracing::warn!(stream_id, error = %e, "invalid RPC path");
Self::send_error_response(
stream_id,
handler,
crate::error::RpcStatusCode::InvalidArgument,
e.to_string(),
&mut send,
)
.await;
return;
}
};
} else if name == method_kind_bytes {
method_kind_is_streaming = value == b"server_streaming";
} else if name.starts_with(prefix_bytes) {
let key = String::from_utf8_lossy(&name[prefix_bytes.len()..]).into_owned();
let val = String::from_utf8_lossy(value).into_owned();
metadata.insert(key, val);
}
}
if method_path.is_empty() {
tracing::warn!(stream_id, "missing :path header");
Self::send_error_response(
stream_id,
handler,
crate::error::RpcStatusCode::InvalidArgument,
"missing :path header".to_string(),
&mut send,
)
.await;
return;
}
for key in metadata.keys() {
if let Err(e) = validate_metadata_key(key) {
tracing::warn!(stream_id, error = %e, "invalid RPC metadata");
Self::send_error_response(
stream_id,
handler,
crate::error::RpcStatusCode::InvalidArgument,
e.to_string(),
&mut send,
)
.await;
return;
}
}
let mut body = Vec::new();
if !read_fin {
while let Some(frame) = recv.recv().await {
if let InboundFrame::Body(buf, fin) = frame {
if body.len().saturating_add(buf.len()) > MAX_MESSAGE_SIZE {
Self::send_error_response(
stream_id,
handler,
crate::error::RpcStatusCode::ResourceExhausted,
format!("request body exceeds maximum {MAX_MESSAGE_SIZE} bytes"),
&mut send,
)
.await;
return;
}
body.extend_from_slice(&buf);
if fin {
break;
}
}
}
}
if method_kind_is_streaming {
Self::send_streaming_response(
stream_id,
handler,
method_path,
body,
metadata,
&mut send,
)
.await;
} else {
Self::send_unary_response(stream_id, handler, method_path, body, metadata, &mut send)
.await;
}
}
async fn send_error_response(
stream_id: u64,
handler: &RequestHandler,
status_code: crate::error::RpcStatusCode,
status_message: String,
send: &mut tokio_quiche::http3::driver::OutboundFrameSender,
) {
let response_headers = build_response_headers(
status_code as u8,
&status_message,
handler.codec.content_type(),
&Metadata::new(),
)
.expect("empty response metadata is valid");
let h3_headers: Vec<quiche::h3::Header> = response_headers
.into_iter()
.map(|(k, v)| quiche::h3::Header::new(&k, &v))
.collect();
if send.send(OutboundFrame::Headers(h3_headers)).await.is_err() {
tracing::warn!(stream_id, "failed to send error response headers");
return;
}
if send
.send(OutboundFrame::body(BufFactory::buf_from_slice(&[]), true))
.await
.is_err()
{
tracing::warn!(stream_id, "failed to send error response body");
}
}
async fn send_unary_response(
stream_id: u64,
handler: &RequestHandler,
method_path: String,
body: Vec<u8>,
metadata: Metadata,
send: &mut tokio_quiche::http3::driver::OutboundFrameSender,
) {
let (response_body, status_code, status_message, response_metadata) =
handler.handle_unary(method_path, body, metadata).await;
let (response_headers, response_body) = match build_response_headers(
status_code as u8,
&status_message,
handler.codec.content_type(),
&response_metadata,
) {
Ok(headers) => (headers, response_body),
Err(e) => {
tracing::warn!(stream_id, error = %e, "invalid response metadata");
(
build_response_headers(
crate::error::RpcStatusCode::Internal as u8,
"invalid response metadata",
handler.codec.content_type(),
&Metadata::new(),
)
.expect("empty response metadata is valid"),
Vec::new(),
)
}
};
let h3_headers: Vec<quiche::h3::Header> = response_headers
.into_iter()
.map(|(k, v)| quiche::h3::Header::new(&k, &v))
.collect();
if send.send(OutboundFrame::Headers(h3_headers)).await.is_err() {
tracing::warn!(stream_id, "failed to send response headers");
return;
}
let buf = BufFactory::buf_from_slice(&response_body);
if send.send(OutboundFrame::body(buf, true)).await.is_err() {
tracing::warn!(stream_id, "failed to send response body");
}
}
async fn send_streaming_response(
stream_id: u64,
handler: &RequestHandler,
method_path: String,
body: Vec<u8>,
metadata: Metadata,
send: &mut tokio_quiche::http3::driver::OutboundFrameSender,
) {
let (status_code, status_message, response_metadata, receiver) = handler
.handle_server_stream(method_path, body, metadata)
.await;
let response_headers = match build_response_headers(
status_code as u8,
&status_message,
handler.codec.content_type(),
&response_metadata,
) {
Ok(headers) => headers,
Err(e) => {
tracing::warn!(stream_id, error = %e, "invalid streaming response metadata");
let headers = build_response_headers(
crate::error::RpcStatusCode::Internal as u8,
"invalid response metadata",
handler.codec.content_type(),
&Metadata::new(),
)
.expect("empty response metadata is valid");
let h3_headers: Vec<quiche::h3::Header> = headers
.into_iter()
.map(|(k, v)| quiche::h3::Header::new(&k, &v))
.collect();
if send.send(OutboundFrame::Headers(h3_headers)).await.is_err() {
tracing::warn!(stream_id, "failed to send streaming response headers");
} else {
let _ = send
.send(OutboundFrame::body(BufFactory::buf_from_slice(&[]), true))
.await;
}
return;
}
};
let h3_headers: Vec<quiche::h3::Header> = response_headers
.into_iter()
.map(|(k, v)| quiche::h3::Header::new(&k, &v))
.collect();
if send.send(OutboundFrame::Headers(h3_headers)).await.is_err() {
tracing::warn!(stream_id, "failed to send streaming response headers");
return;
}
if let Some(mut receiver) = receiver {
let mut frame_buf = Vec::with_capacity(4096);
while let Some(chunk_result) = receiver.next().await {
match chunk_result {
Ok(chunk) => {
let encoded = match handler.codec.encode(&chunk) {
Ok(e) => e,
Err(e) => {
tracing::warn!(stream_id, error = %e, "codec error in stream");
break;
}
};
let len_u32 = match u32::try_from(encoded.len()) {
Ok(len) => len,
Err(_) => {
tracing::warn!(
stream_id,
encoded_len = encoded.len(),
"chunk too large for length-prefix framing"
);
break;
}
};
frame_buf.clear();
frame_buf.extend_from_slice(&len_u32.to_be_bytes());
frame_buf.extend_from_slice(&encoded);
let buf = BufFactory::buf_from_slice(&frame_buf);
if send.send(OutboundFrame::body(buf, false)).await.is_err() {
tracing::debug!(stream_id, "client disconnected during stream");
break;
}
}
Err(e) => {
tracing::warn!(stream_id, error = %e, "stream error");
break;
}
}
}
}
let _ = send
.send(OutboundFrame::body(BufFactory::buf_from_slice(&[]), true))
.await;
}
}
impl Clone for RequestHandler {
fn clone(&self) -> Self {
Self {
router: Arc::clone(&self.router),
middleware: Arc::clone(&self.middleware),
codec: Arc::clone(&self.codec),
}
}
}
pub struct ServerBuilder {
bind_addr: SocketAddr,
tls_config: Option<TlsConfig>,
codec: Option<Arc<dyn Codec>>,
router: MethodRouter,
middleware: MiddlewareStack,
max_connections: usize,
}
impl ServerBuilder {
fn new() -> Self {
Self {
bind_addr: "0.0.0.0:4433".parse().unwrap(),
tls_config: None,
codec: None,
router: MethodRouter::new(),
middleware: MiddlewareStack::new(),
max_connections: 10_000,
}
}
pub fn bind(mut self, addr: impl Into<SocketAddr>) -> Self {
self.bind_addr = addr.into();
self
}
pub fn bind_str(mut self, addr: &str) -> Result<Self, KnafehError> {
self.bind_addr = addr
.parse()
.map_err(|e| KnafehError::Transport(format!("invalid bind address: {e}")))?;
Ok(self)
}
pub fn tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub fn codec(mut self, codec: impl Codec) -> Self {
self.codec = Some(Arc::new(codec));
self
}
pub fn add_service(mut self, service: impl Service) -> Self {
self.router.add_service(Arc::new(service));
self
}
pub fn add_interceptor(mut self, interceptor: impl Interceptor) -> Self {
self.middleware.add(Arc::new(interceptor));
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub fn build(self) -> Result<Server, KnafehError> {
let tls_config = self
.tls_config
.ok_or_else(|| KnafehError::Tls("TLS configuration is required".to_string()))?;
validate_server_tls(&tls_config)?;
let codec = self.codec.unwrap_or_else(|| Arc::new(DefaultCodec::new()));
let handler = Arc::new(RequestHandler::new(
Arc::new(self.router),
Arc::new(self.middleware),
codec,
));
let connection_pool = Arc::new(ConnectionPool::new(self.max_connections));
Ok(Server {
bind_addr: self.bind_addr,
tls_config,
handler,
connection_pool,
})
}
}
fn validate_server_tls(config: &TlsConfig) -> Result<(&str, &str), KnafehError> {
let key_path = config
.key_path
.as_ref()
.ok_or_else(|| KnafehError::Tls("server TLS requires cert + key".to_string()))?;
let cert = config.cert_path.to_str().ok_or_else(|| {
KnafehError::Tls("server certificate path is not valid UTF-8".to_string())
})?;
let private_key = key_path.to_str().ok_or_else(|| {
KnafehError::Tls("server private key path is not valid UTF-8".to_string())
})?;
Ok((cert, private_key))
}