use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use boring::ssl::{SslContextBuilder, SslMethod, SslVerifyMode};
use tokio::sync::{oneshot, Mutex as TokioMutex};
use tokio_quiche::http3::driver::{
ClientH3Driver, ClientH3Event, ClientRequestSender, H3Event, InboundFrame, IncomingH3Headers,
};
use tokio_quiche::http3::settings::Http3Settings;
use tokio_quiche::settings::{CertificateKind, ConnectionParams, Hooks, TlsCertificatePaths};
use tokio_quiche::ClientH3Controller;
use tokio_quiche::QuicConnection;
use crate::error::KnafehError;
use crate::transport::quic_wire::MAX_MESSAGE_SIZE;
use crate::transport::tls::TlsConfig;
pub(crate) enum H3Response {
Complete {
headers: Vec<quiche::h3::Header>,
body: Vec<u8>,
},
Streaming {
headers: Vec<quiche::h3::Header>,
recv: tokio::sync::mpsc::Receiver<InboundFrame>,
read_fin: bool,
},
}
struct ConnectionState {
pending: HashMap<u64, oneshot::Sender<Result<H3Response, KnafehError>>>,
pending_stream: HashMap<u64, oneshot::Sender<Result<H3Response, KnafehError>>>,
stream_to_request: HashMap<u64, u64>,
}
pub(crate) struct ConnectionInner {
pub(crate) request_sender: ClientRequestSender,
state: Arc<TokioMutex<ConnectionState>>,
pub(crate) next_request_id: AtomicU64,
}
impl ConnectionInner {
pub(crate) async fn register_pending(
&self,
request_id: u64,
tx: oneshot::Sender<Result<H3Response, KnafehError>>,
) {
self.state.lock().await.pending.insert(request_id, tx);
}
pub(crate) async fn register_pending_stream(
&self,
request_id: u64,
tx: oneshot::Sender<Result<H3Response, KnafehError>>,
) {
self.state
.lock()
.await
.pending_stream
.insert(request_id, tx);
}
pub(crate) async fn remove_pending(&self, request_id: u64) {
self.state.lock().await.pending.remove(&request_id);
}
pub(crate) async fn remove_pending_stream(&self, request_id: u64) {
self.state.lock().await.pending_stream.remove(&request_id);
}
}
struct PoolEntry {
id: u64,
active_streams: usize,
max_streams: usize,
inner: Arc<ConnectionInner>,
}
pub struct ConnectionHandle {
pub id: u64,
pub(crate) inner: Arc<ConnectionInner>,
}
pub(crate) struct ConnectionGuard {
pool: Arc<ClientConnectionPool>,
conn_id: u64,
}
impl ConnectionGuard {
pub(crate) fn new(pool: Arc<ClientConnectionPool>, conn_id: u64) -> Self {
Self { pool, conn_id }
}
pub(crate) fn detach(self) -> (Arc<ClientConnectionPool>, u64) {
let pool = Arc::clone(&self.pool);
let id = self.conn_id;
std::mem::forget(self);
(pool, id)
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.pool.release(self.conn_id);
}
}
pub struct ClientConnectionPool {
max_size: usize,
endpoint: String,
hostname: String,
tls_config: TlsConfig,
connections: Mutex<Vec<PoolEntry>>,
}
impl ClientConnectionPool {
pub fn new(endpoint: String, max_size: usize, tls_config: TlsConfig) -> Self {
let hostname = endpoint
.rsplit_once(':')
.map(|(h, _)| h)
.unwrap_or(&endpoint)
.trim_start_matches('[')
.trim_end_matches(']')
.to_string();
Self {
max_size,
endpoint,
hostname,
tls_config,
connections: Mutex::new(Vec::new()),
}
}
pub async fn acquire(&self) -> Result<ConnectionHandle, KnafehError> {
let should_create = {
let mut conns = self.connections.lock().unwrap();
if let Some(entry) = conns.iter_mut().find(|e| e.active_streams < e.max_streams) {
entry.active_streams += 1;
return Ok(ConnectionHandle {
id: entry.id,
inner: Arc::clone(&entry.inner),
});
}
conns.len() < self.max_size
};
if !should_create {
return Err(KnafehError::Transport(
"connection pool exhausted".to_string(),
));
}
let (handle, entry, quic_conn, controller, state) = self.create_connection().await?;
let mut conns = self.connections.lock().unwrap();
if conns.len() >= self.max_size {
if let Some(existing) = conns.iter_mut().find(|e| e.active_streams < e.max_streams) {
existing.active_streams += 1;
return Ok(ConnectionHandle {
id: existing.id,
inner: Arc::clone(&existing.inner),
});
}
return Err(KnafehError::Transport(
"connection pool exhausted".to_string(),
));
}
conns.push(entry);
tokio::spawn(async move {
connection_event_loop(quic_conn, controller, state).await;
});
Ok(handle)
}
pub fn release(&self, conn_id: u64) {
let mut conns = self.connections.lock().unwrap();
if let Some(entry) = conns.iter_mut().find(|e| e.id == conn_id) {
entry.active_streams = entry.active_streams.saturating_sub(1);
}
}
async fn create_connection(
&self,
) -> Result<
(
ConnectionHandle,
PoolEntry,
QuicConnection,
ClientH3Controller,
Arc<TokioMutex<ConnectionState>>,
),
KnafehError,
> {
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
let id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
tracing::info!(endpoint = %self.endpoint, id, "establishing new QUIC connection");
let addr: SocketAddr = match self.endpoint.parse() {
Ok(addr) => addr,
Err(_) => tokio::net::lookup_host(&self.endpoint)
.await
.map_err(|e| {
KnafehError::Transport(format!(
"failed to resolve endpoint '{}': {e}",
self.endpoint
))
})?
.next()
.ok_or_else(|| {
KnafehError::Transport(format!("no addresses found for '{}'", self.endpoint))
})?,
};
let bind_addr: SocketAddr = if addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
}
.parse()
.unwrap();
let socket = tokio::net::UdpSocket::bind(bind_addr)
.await
.map_err(|e| KnafehError::Transport(format!("failed to bind UDP socket: {e}")))?;
socket
.connect(addr)
.await
.map_err(|e| KnafehError::Transport(format!("failed to connect UDP socket: {e}")))?;
let (h3_driver, controller) = ClientH3Driver::new(Http3Settings::default());
let conn_params = self.connection_params()?;
let quic_conn = tokio_quiche::quic::connect_with_config(
socket,
Some(&self.hostname),
&conn_params,
h3_driver,
)
.await
.map_err(|e| KnafehError::Transport(format!("QUIC connect failed: {e}")))?;
let request_sender = controller.request_sender();
let state = Arc::new(TokioMutex::new(ConnectionState {
pending: HashMap::new(),
pending_stream: HashMap::new(),
stream_to_request: HashMap::new(),
}));
let inner = Arc::new(ConnectionInner {
request_sender,
state: Arc::clone(&state),
next_request_id: AtomicU64::new(0),
});
let handle = ConnectionHandle {
id,
inner: Arc::clone(&inner),
};
let entry = PoolEntry {
id,
active_streams: 1,
max_streams: 100,
inner,
};
Ok((handle, entry, quic_conn, controller, state))
}
fn connection_params(&self) -> Result<ConnectionParams<'static>, KnafehError> {
let _ = build_client_tls_context(&self.tls_config)?;
let settings = tokio_quiche::settings::QuicSettings {
max_idle_timeout: Some(std::time::Duration::from_secs(30)),
alpn: self.tls_config.alpn.clone(),
..Default::default()
};
let hooks = Hooks {
connection_hook: Some(Arc::new(ClientTlsHook {
tls_config: self.tls_config.clone(),
})),
};
let tls_cert = Some(TlsCertificatePaths {
cert: "",
private_key: "",
kind: CertificateKind::X509,
});
Ok(ConnectionParams::new_client(settings, tls_cert, hooks))
}
pub fn size(&self) -> usize {
self.connections.lock().unwrap().len()
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn hostname(&self) -> &str {
&self.hostname
}
}
struct ClientTlsHook {
tls_config: TlsConfig,
}
impl tokio_quiche::quic::ConnectionHook for ClientTlsHook {
fn create_custom_ssl_context_builder(
&self,
_settings: TlsCertificatePaths<'_>,
) -> Option<SslContextBuilder> {
match build_client_tls_context(&self.tls_config) {
Ok(builder) => Some(builder),
Err(e) => {
tracing::error!(error = %e, "failed to build client TLS context");
None
}
}
}
}
fn build_client_tls_context(tls_config: &TlsConfig) -> Result<SslContextBuilder, KnafehError> {
let mut builder =
SslContextBuilder::new(SslMethod::tls()).map_err(|e| KnafehError::Tls(e.to_string()))?;
if tls_config.verify_peer {
builder.set_verify(SslVerifyMode::PEER);
if let Some(ca_path) = &tls_config.ca_path {
builder
.set_ca_file(ca_path)
.map_err(|e| KnafehError::Tls(e.to_string()))?;
} else {
builder
.set_default_verify_paths()
.map_err(|e| KnafehError::Tls(e.to_string()))?;
}
} else {
builder.set_verify(SslVerifyMode::NONE);
}
Ok(builder)
}
async fn connection_event_loop(
_quic_conn: QuicConnection,
mut controller: ClientH3Controller,
state: Arc<TokioMutex<ConnectionState>>,
) {
while let Some(event) = controller.event_receiver_mut().recv().await {
match event {
ClientH3Event::NewOutboundRequest {
stream_id,
request_id,
} => {
state
.lock()
.await
.stream_to_request
.insert(stream_id, request_id);
}
ClientH3Event::Core(H3Event::IncomingHeaders(incoming)) => {
let state = Arc::clone(&state);
tokio::spawn(async move {
dispatch_incoming_response(incoming, state).await;
});
}
ClientH3Event::Core(H3Event::ConnectionError(_) | H3Event::ConnectionShutdown(_)) => {
let mut s = state.lock().await;
for (_, tx) in s.pending.drain() {
let _ = tx.send(Err(KnafehError::ConnectionClosed));
}
for (_, tx) in s.pending_stream.drain() {
let _ = tx.send(Err(KnafehError::ConnectionClosed));
}
break;
}
_ => {}
}
}
}
async fn dispatch_incoming_response(
incoming: IncomingH3Headers,
state: Arc<TokioMutex<ConnectionState>>,
) {
let stream_id = incoming.stream_id;
let request_id = {
let s = state.lock().await;
s.stream_to_request.get(&stream_id).copied()
};
let Some(request_id) = request_id else {
return;
};
{
let mut s = state.lock().await;
if let Some(tx) = s.pending_stream.remove(&request_id) {
s.stream_to_request.remove(&stream_id);
let IncomingH3Headers {
headers,
recv,
read_fin,
..
} = incoming;
let _ = tx.send(Ok(H3Response::Streaming {
headers,
recv,
read_fin,
}));
return;
}
}
read_response_and_resolve(incoming, state).await;
}
async fn read_response_and_resolve(
incoming: IncomingH3Headers,
state: Arc<TokioMutex<ConnectionState>>,
) {
let IncomingH3Headers {
stream_id,
headers,
mut recv,
read_fin,
..
} = incoming;
let mut body = Vec::new();
if !read_fin {
while let Some(frame) = recv.recv().await {
if let InboundFrame::Body(buf, fin) = frame {
if body.len().saturating_add(buf.len()) > MAX_MESSAGE_SIZE {
let mut s = state.lock().await;
if let Some(&request_id) = s.stream_to_request.get(&stream_id) {
if let Some(tx) = s.pending.remove(&request_id) {
let _ = tx.send(Err(KnafehError::InvalidMessage(format!(
"response body exceeds maximum {MAX_MESSAGE_SIZE} bytes"
))));
}
s.stream_to_request.remove(&stream_id);
}
return;
}
body.extend_from_slice(&buf);
if fin {
break;
}
}
}
}
let mut s = state.lock().await;
if let Some(&request_id) = s.stream_to_request.get(&stream_id) {
if let Some(tx) = s.pending.remove(&request_id) {
let _ = tx.send(Ok(H3Response::Complete { headers, body }));
}
s.stream_to_request.remove(&stream_id);
}
}