use crate::{RpcTransport, TransportError};
use async_trait::async_trait;
use capnweb_core::{decode_message, encode_message, Message};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, RwLock};
#[derive(Debug, Clone)]
pub struct Http3Config {
pub max_concurrent_streams: u32,
pub stream_idle_timeout: u64,
pub connection_idle_timeout: u64,
pub enable_multiplexing: bool,
pub enable_compression: bool,
}
impl Default for Http3Config {
fn default() -> Self {
Self {
max_concurrent_streams: 1000,
stream_idle_timeout: 30,
connection_idle_timeout: 300,
enable_multiplexing: true,
enable_compression: false, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Http3Headers {
pub method: String,
pub path: String,
pub authority: String,
pub scheme: String,
pub content_type: String,
pub user_agent: String,
pub custom_headers: HashMap<String, String>,
}
impl Default for Http3Headers {
fn default() -> Self {
Self {
method: "POST".to_string(),
path: "/rpc".to_string(),
authority: "localhost".to_string(),
scheme: "https".to_string(),
content_type: "application/json".to_string(),
user_agent: "CapnWeb-Rust/1.0".to_string(),
custom_headers: HashMap::new(),
}
}
}
#[derive(Debug)]
struct Http3Stream {
stream_id: u64,
send_stream: SendStream,
recv_stream: RecvStream,
headers: Http3Headers,
}
impl Http3Stream {
fn new(stream_id: u64, send_stream: SendStream, recv_stream: RecvStream) -> Self {
Self {
stream_id,
send_stream,
recv_stream,
headers: Http3Headers::default(),
}
}
async fn send_request(&mut self, message: &Message) -> Result<(), TransportError> {
let payload = encode_message(message).map_err(|e| TransportError::Codec(e.to_string()))?;
let headers_frame = self.create_headers_frame(&payload)?;
self.send_stream
.write_all(&headers_frame)
.await
.map_err(|e| TransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
let data_frame = self.create_data_frame(&payload)?;
self.send_stream
.write_all(&data_frame)
.await
.map_err(|e| TransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
self.send_stream
.finish()
.await
.map_err(|e| TransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
tracing::trace!(
stream_id = self.stream_id,
payload_size = payload.len(),
"HTTP/3 request sent"
);
Ok(())
}
async fn receive_response(&mut self) -> Result<Message, TransportError> {
let mut len_bytes = [0u8; 4];
self.recv_stream
.read_exact(&mut len_bytes)
.await
.map_err(|e| TransportError::Protocol(format!("Failed to read length: {}", e)))?;
let len = u32::from_be_bytes(len_bytes) as usize;
let mut payload = vec![0u8; len];
self.recv_stream
.read_exact(&mut payload)
.await
.map_err(|e| TransportError::Protocol(format!("Failed to read payload: {}", e)))?;
let message = decode_message(&payload).map_err(|e| TransportError::Codec(e.to_string()))?;
tracing::trace!(
stream_id = self.stream_id,
payload_size = payload.len(),
"HTTP/3 response received"
);
Ok(message)
}
fn create_headers_frame(&self, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
let mut frame = Vec::new();
frame.push(0x01);
frame.extend_from_slice(&(200u32.to_be_bytes()));
frame.extend_from_slice(format!(":method {}\n", self.headers.method).as_bytes());
frame.extend_from_slice(format!(":path {}\n", self.headers.path).as_bytes());
frame.extend_from_slice(format!(":authority {}\n", self.headers.authority).as_bytes());
frame.extend_from_slice(format!(":scheme {}\n", self.headers.scheme).as_bytes());
frame.extend_from_slice(format!("content-type {}\n", self.headers.content_type).as_bytes());
frame.extend_from_slice(format!("content-length {}\n", payload.len()).as_bytes());
frame.extend_from_slice(format!("user-agent {}\n", self.headers.user_agent).as_bytes());
for (name, value) in &self.headers.custom_headers {
frame.extend_from_slice(format!("{} {}\n", name, value).as_bytes());
}
Ok(frame)
}
fn create_data_frame(&self, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
Ok(frame)
}
}
pub struct Http3Transport {
connection: Arc<Connection>,
streams: Arc<RwLock<HashMap<u64, Http3Stream>>>,
next_stream_id: Arc<Mutex<u64>>,
config: Http3Config,
outgoing_queue: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
incoming_queue: Arc<Mutex<mpsc::UnboundedSender<Message>>>,
queue_tx: mpsc::UnboundedSender<Message>,
queue_rx: mpsc::UnboundedReceiver<Message>,
}
impl Http3Transport {
pub fn new(connection: Connection, config: Http3Config) -> Self {
let (queue_tx, queue_rx) = mpsc::unbounded_channel();
let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
Self {
connection: Arc::new(connection),
streams: Arc::new(RwLock::new(HashMap::new())),
next_stream_id: Arc::new(Mutex::new(0)),
config,
outgoing_queue: Arc::new(Mutex::new(queue_rx)),
incoming_queue: Arc::new(Mutex::new(incoming_tx)),
queue_tx,
queue_rx: incoming_rx,
}
}
async fn create_stream(&self) -> Result<u64, TransportError> {
let (send_stream, recv_stream) = self.connection.open_bi().await.map_err(|e| {
TransportError::Protocol(format!("Failed to open bidirectional stream: {}", e))
})?;
let mut stream_id_guard = self.next_stream_id.lock().await;
let stream_id = *stream_id_guard;
*stream_id_guard += 1;
let stream = Http3Stream::new(stream_id, send_stream, recv_stream);
self.streams.write().await.insert(stream_id, stream);
tracing::debug!(stream_id = stream_id, "HTTP/3 stream created");
Ok(stream_id)
}
async fn get_send_stream(&self) -> Result<u64, TransportError> {
if self.config.enable_multiplexing {
self.create_stream().await
} else {
let streams = self.streams.read().await;
if let Some(&stream_id) = streams.keys().next() {
Ok(stream_id)
} else {
drop(streams);
self.create_stream().await
}
}
}
pub async fn start_background_processing(&self) {
let streams = self.streams.clone();
let config = self.config.clone();
let connection = self.connection.clone();
tokio::spawn(async move {
tracing::info!("HTTP/3 background processing started");
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let mut streams_lock = streams.write().await;
let initial_count = streams_lock.len();
streams_lock.retain(|&stream_id, _stream| {
true });
if streams_lock.len() != initial_count {
tracing::debug!(
cleaned_streams = initial_count - streams_lock.len(),
remaining_streams = streams_lock.len(),
"Cleaned up idle HTTP/3 streams"
);
}
}
});
}
pub async fn get_stats(&self) -> Http3Stats {
let streams = self.streams.read().await;
let connection_stats = self.connection.stats();
Http3Stats {
active_streams: streams.len(),
max_concurrent_streams: self.config.max_concurrent_streams,
bytes_sent: connection_stats.udp_tx.bytes,
bytes_received: connection_stats.udp_rx.bytes,
packets_sent: connection_stats.udp_tx.datagrams,
packets_received: connection_stats.udp_rx.datagrams,
connection_errors: 0, stream_errors: 0, }
}
}
#[async_trait]
impl RpcTransport for Http3Transport {
async fn send(&mut self, message: Message) -> Result<(), TransportError> {
let stream_id = self.get_send_stream().await?;
let mut streams = self.streams.write().await;
let stream = streams
.get_mut(&stream_id)
.ok_or_else(|| TransportError::Protocol("Stream not found".to_string()))?;
stream.send_request(&message).await?;
tracing::trace!(stream_id = stream_id, "Message sent via HTTP/3");
Ok(())
}
async fn recv(&mut self) -> Result<Option<Message>, TransportError> {
match self.queue_rx.try_recv() {
Ok(message) => {
tracing::trace!("Message received via HTTP/3");
Ok(Some(message))
}
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => Err(TransportError::ConnectionClosed),
}
}
async fn close(&mut self) -> Result<(), TransportError> {
tracing::info!("Closing HTTP/3 transport");
let mut streams = self.streams.write().await;
streams.clear();
self.connection.close(0u32.into(), b"transport closed");
tracing::info!("HTTP/3 transport closed");
Ok(())
}
}
pub struct Http3Client {
endpoint: Endpoint,
config: Http3Config,
}
impl Http3Client {
pub fn new(endpoint: Endpoint, config: Http3Config) -> Self {
Self { endpoint, config }
}
pub async fn connect(&self, server_addr: &str) -> Result<Http3Transport, TransportError> {
tracing::info!(server_addr = server_addr, "Connecting to HTTP/3 server");
let connection = self
.endpoint
.connect(
server_addr.parse().map_err(|e| {
TransportError::Protocol(format!("Invalid server address: {}", e))
})?,
"localhost",
)
.map_err(|e| TransportError::Protocol(format!("Connection failed: {}", e)))?
.await
.map_err(|e| TransportError::Protocol(format!("Connection failed: {}", e)))?;
let transport = Http3Transport::new(connection, self.config.clone());
transport.start_background_processing().await;
tracing::info!(server_addr = server_addr, "HTTP/3 connection established");
Ok(transport)
}
}
#[derive(Debug, Clone)]
pub struct Http3Stats {
pub active_streams: usize,
pub max_concurrent_streams: u32,
pub bytes_sent: u64,
pub bytes_received: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub connection_errors: u64,
pub stream_errors: u64,
}
pub fn make_http3_client_endpoint(config: Http3Config) -> Result<Endpoint, TransportError> {
let mut client_cfg = configure_http3_client(config);
let endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())
.map_err(|e| TransportError::Protocol(format!("Failed to create endpoint: {}", e)))?;
Ok(endpoint)
}
fn configure_http3_client(config: Http3Config) -> quinn::ClientConfig {
let mut transport_config = quinn::TransportConfig::default();
transport_config.max_concurrent_bidi_streams(config.max_concurrent_streams.into());
transport_config.max_idle_timeout(Some(
std::time::Duration::from_secs(config.connection_idle_timeout)
.try_into()
.unwrap(),
));
let crypto = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth();
let mut client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
));
client_config.transport_config(Arc::new(transport_config));
client_config
}
#[derive(Debug)]
struct SkipServerVerification;
impl SkipServerVerification {
fn new() -> Arc<Self> {
Arc::new(Self)
}
}
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
pub mod advanced {
use super::*;
#[derive(Debug)]
pub struct Http3ConnectionPool {
connections: Arc<RwLock<HashMap<String, Arc<Http3Transport>>>>,
config: Http3Config,
max_connections_per_host: usize,
}
impl Http3ConnectionPool {
pub fn new(config: Http3Config) -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
config,
max_connections_per_host: 10,
}
}
pub async fn get_connection(
&self,
host: &str,
) -> Result<Arc<Http3Transport>, TransportError> {
let connections = self.connections.read().await;
if let Some(transport) = connections.get(host) {
return Ok(transport.clone());
}
drop(connections);
let endpoint = make_http3_client_endpoint(self.config.clone())?;
let client = Http3Client::new(endpoint, self.config.clone());
let transport = client.connect(host).await?;
let transport_arc = Arc::new(transport);
let mut connections = self.connections.write().await;
connections.insert(host.to_string(), transport_arc.clone());
tracing::info!(host = host, "New HTTP/3 connection added to pool");
Ok(transport_arc)
}
pub async fn get_pool_stats(&self) -> HashMap<String, Http3Stats> {
let connections = self.connections.read().await;
let mut stats = HashMap::new();
for (host, transport) in connections.iter() {
stats.insert(host.clone(), transport.get_stats().await);
}
stats
}
}
#[derive(Debug)]
pub struct Http3LoadBalancer {
servers: Vec<String>,
pool: Http3ConnectionPool,
current_index: Arc<Mutex<usize>>,
}
impl Http3LoadBalancer {
pub fn new(servers: Vec<String>, config: Http3Config) -> Self {
Self {
servers,
pool: Http3ConnectionPool::new(config),
current_index: Arc::new(Mutex::new(0)),
}
}
pub async fn get_next_connection(&self) -> Result<Arc<Http3Transport>, TransportError> {
if self.servers.is_empty() {
return Err(TransportError::Protocol(
"No servers configured".to_string(),
));
}
let mut index = self.current_index.lock().await;
let server = &self.servers[*index];
*index = (*index + 1) % self.servers.len();
drop(index);
self.pool.get_connection(server).await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http3_config_default() {
let config = Http3Config::default();
assert_eq!(config.max_concurrent_streams, 1000);
assert!(config.enable_multiplexing);
assert!(!config.enable_compression);
}
#[test]
fn test_http3_headers_default() {
let headers = Http3Headers::default();
assert_eq!(headers.method, "POST");
assert_eq!(headers.path, "/rpc");
assert_eq!(headers.content_type, "application/json");
}
#[tokio::test]
async fn test_http3_transport_creation() {
let config = Http3Config::default();
assert!(config.max_concurrent_streams > 0);
}
#[tokio::test]
async fn test_http3_connection_pool() {
let config = Http3Config::default();
let pool = advanced::Http3ConnectionPool::new(config);
assert_eq!(pool.max_connections_per_host, 10);
}
}