#[cfg(feature = "quinn")]
use crate::network::transport::{
Transport, TransportAddr, TransportConnection, TransportListener, TransportType,
};
#[cfg(feature = "quinn")]
use anyhow::Result;
#[cfg(feature = "quinn")]
use std::net::SocketAddr;
#[cfg(feature = "quinn")]
use tracing::{debug, info};
#[cfg(feature = "quinn")]
#[derive(Debug)]
pub struct QuinnTransport {
endpoint: quinn::Endpoint,
max_message_length: usize,
}
#[cfg(feature = "quinn")]
impl QuinnTransport {
pub fn new() -> Result<Self> {
Self::with_max_message_length(crate::network::protocol::MAX_PROTOCOL_MESSAGE_LENGTH)
}
pub fn with_max_message_length(max_message_length: usize) -> Result<Self> {
let endpoint = quinn::Endpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?;
info!("Quinn transport initialized (client mode)");
Ok(Self {
endpoint,
max_message_length,
})
}
}
#[cfg(feature = "quinn")]
#[async_trait::async_trait]
impl Transport for QuinnTransport {
type Connection = QuinnConnection;
type Listener = QuinnListener;
fn transport_type(&self) -> TransportType {
TransportType::Quinn
}
async fn listen(&self, addr: SocketAddr) -> Result<Self::Listener> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
.map_err(|e| anyhow::anyhow!("Failed to generate certificate: {}", e))?;
let cert_der = cert.serialize_der()?;
let key_der = cert.serialize_private_key_der();
use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer};
let certs = vec![CertificateDer::from(cert_der)];
let key = PrivateKeyDer::Pkcs8(key_der.into());
let server_config = quinn::ServerConfig::with_single_cert(certs, key)?;
let endpoint = quinn::Endpoint::server(server_config, addr)?;
Ok(QuinnListener {
endpoint,
local_addr: addr,
max_message_length: self.max_message_length,
})
}
async fn connect(&self, addr: TransportAddr) -> Result<Self::Connection> {
let socket_addr = match addr {
TransportAddr::Quinn(socket_addr) => socket_addr,
_ => {
return Err(anyhow::anyhow!(
"Quinn transport can only connect to Quinn addresses"
))
}
};
let endpoint = quinn::Endpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?;
let server_name = socket_addr.ip().to_string();
let conn = endpoint.connect(socket_addr, &server_name)?.await?;
Ok(QuinnConnection {
conn,
peer_addr: TransportAddr::Quinn(socket_addr),
connected: true,
max_message_length: self.max_message_length,
})
}
}
#[cfg(feature = "quinn")]
pub struct QuinnListener {
endpoint: quinn::Endpoint,
local_addr: SocketAddr,
max_message_length: usize,
}
#[cfg(feature = "quinn")]
#[async_trait::async_trait]
impl TransportListener for QuinnListener {
type Connection = QuinnConnection;
async fn accept(&mut self) -> Result<(Self::Connection, TransportAddr)> {
let conn = self
.endpoint
.accept()
.await
.ok_or_else(|| anyhow::anyhow!("Endpoint closed"))?;
let conn = conn.await?;
let peer_addr = conn.remote_address();
let transport_addr = TransportAddr::Quinn(peer_addr);
debug!("Accepted Quinn connection from {}", peer_addr);
Ok((
QuinnConnection {
conn,
peer_addr: transport_addr.clone(),
connected: true,
max_message_length: self.max_message_length,
},
transport_addr,
))
}
fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.local_addr)
}
}
#[cfg(feature = "quinn")]
pub struct QuinnConnection {
conn: quinn::Connection,
peer_addr: TransportAddr,
connected: bool,
max_message_length: usize,
}
#[cfg(feature = "quinn")]
#[async_trait::async_trait]
impl TransportConnection for QuinnConnection {
async fn send(&mut self, data: &[u8]) -> Result<()> {
if !self.connected {
return Err(anyhow::anyhow!("Connection closed"));
}
let mut stream = self.conn.open_uni().await?;
let len = data.len() as u32;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(data).await?;
stream.finish()?;
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
if !self.connected {
return Ok(Vec::new()); }
let mut stream = match self.conn.accept_uni().await {
Ok(stream) => stream,
Err(e) => {
self.connected = false;
return Err(anyhow::anyhow!("Failed to accept stream: {}", e));
}
};
let mut len_bytes = [0u8; 4];
stream.read_exact(&mut len_bytes).await?;
let len = u32::from_be_bytes(len_bytes) as usize;
if len == 0 {
self.connected = false;
return Ok(Vec::new());
}
if len > self.max_message_length {
return Err(anyhow::anyhow!(
"Message too large: {} bytes (max: {} bytes)",
len,
self.max_message_length
));
}
let mut buffer = vec![0u8; len];
stream.read_exact(&mut buffer).await?;
Ok(buffer)
}
fn peer_addr(&self) -> TransportAddr {
self.peer_addr.clone()
}
fn is_connected(&self) -> bool {
self.connected && self.conn.close_reason().is_none()
}
async fn close(&mut self) -> Result<()> {
if self.connected {
self.conn.close(0u32.into(), b"Connection closed");
self.connected = false;
}
Ok(())
}
}
#[cfg(not(feature = "quinn"))]
pub struct QuinnTransport;
#[cfg(not(feature = "quinn"))]
impl QuinnTransport {
pub async fn new() -> Result<Self> {
Err(anyhow::anyhow!("Quinn transport requires 'quinn' feature"))
}
}