use crate::{tls, Envelope, Transport, TransportConfig};
use async_trait::async_trait;
use dashmap::DashMap;
use pollen_types::{NodeId, Result, TransportError};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, RwLock};
use tracing::{debug, info, warn};
pub struct QuicTransport {
config: TransportConfig,
endpoint: Endpoint,
connections: Arc<DashMap<SocketAddr, Connection>>,
incoming_tx: mpsc::Sender<Envelope>,
incoming_rx: RwLock<Option<mpsc::Receiver<Envelope>>>,
pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
#[allow(dead_code)] request_counter: std::sync::atomic::AtomicU64,
shutdown: tokio::sync::broadcast::Sender<()>,
}
impl QuicTransport {
pub async fn new(config: TransportConfig) -> Result<Arc<Self>> {
let server_config = tls::server_config()?;
let client_config = tls::client_config()?;
let mut endpoint = Endpoint::server(server_config, config.bind_addr)
.map_err(|e| TransportError::ConnectionFailed {
addr: config.bind_addr.to_string(),
reason: e.to_string(),
})?;
endpoint.set_default_client_config(client_config);
let (incoming_tx, incoming_rx) = mpsc::channel(1000);
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
let transport = Arc::new(Self {
config,
endpoint,
connections: Arc::new(DashMap::new()),
incoming_tx,
incoming_rx: RwLock::new(Some(incoming_rx)),
pending_requests: Arc::new(DashMap::new()),
request_counter: std::sync::atomic::AtomicU64::new(0),
shutdown: shutdown_tx,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
transport_clone.accept_loop().await;
});
info!("QUIC transport started on {}", transport.endpoint.local_addr().unwrap());
Ok(transport)
}
async fn accept_loop(self: Arc<Self>) {
let mut shutdown_rx = self.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("Transport shutting down");
break;
}
Some(incoming) = self.endpoint.accept() => {
let transport = Arc::clone(&self);
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
let addr = conn.remote_address();
debug!("Accepted connection from {}", addr);
transport.connections.insert(addr, conn.clone());
transport.handle_connection(conn).await;
}
Err(e) => {
warn!("Failed to accept connection: {}", e);
}
}
});
}
}
}
}
async fn handle_connection(&self, conn: Connection) {
let addr = conn.remote_address();
let mut shutdown_rx = self.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
break;
}
result = conn.accept_bi() => {
match result {
Ok((send, recv)) => {
let incoming_tx = self.incoming_tx.clone();
let pending = Arc::clone(&self.pending_requests);
tokio::spawn(async move {
if let Err(e) = handle_stream(send, recv, incoming_tx, pending).await {
debug!("Stream error from {}: {}", addr, e);
}
});
}
Err(e) => {
debug!("Connection closed from {}: {}", addr, e);
break;
}
}
}
}
}
self.connections.remove(&addr);
}
async fn get_connection(&self, addr: SocketAddr) -> Result<Connection> {
if let Some(conn) = self.connections.get(&addr) {
if conn.close_reason().is_none() {
return Ok(conn.clone());
}
}
let conn = self
.endpoint
.connect(addr, "pollen")
.map_err(|e| TransportError::ConnectionFailed {
addr: addr.to_string(),
reason: e.to_string(),
})?
.await
.map_err(|e| TransportError::ConnectionFailed {
addr: addr.to_string(),
reason: e.to_string(),
})?;
self.connections.insert(addr, conn.clone());
let transport = Arc::new(self.clone_inner());
let conn_clone = conn.clone();
tokio::spawn(async move {
transport.handle_connection(conn_clone).await;
});
Ok(conn)
}
fn clone_inner(&self) -> QuicTransportInner {
QuicTransportInner {
connections: Arc::clone(&self.connections),
incoming_tx: self.incoming_tx.clone(),
pending_requests: Arc::clone(&self.pending_requests),
shutdown: self.shutdown.clone(),
}
}
}
struct QuicTransportInner {
connections: Arc<DashMap<SocketAddr, Connection>>,
incoming_tx: mpsc::Sender<Envelope>,
pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
shutdown: tokio::sync::broadcast::Sender<()>,
}
impl QuicTransportInner {
async fn handle_connection(&self, conn: Connection) {
let addr = conn.remote_address();
let mut shutdown_rx = self.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
break;
}
result = conn.accept_bi() => {
match result {
Ok((send, recv)) => {
let incoming_tx = self.incoming_tx.clone();
let pending = Arc::clone(&self.pending_requests);
tokio::spawn(async move {
if let Err(e) = handle_stream(send, recv, incoming_tx, pending).await {
debug!("Stream error from {}: {}", addr, e);
}
});
}
Err(e) => {
debug!("Connection closed from {}: {}", addr, e);
break;
}
}
}
}
}
self.connections.remove(&addr);
}
}
async fn handle_stream(
_send: SendStream,
mut recv: RecvStream,
incoming_tx: mpsc::Sender<Envelope>,
pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
) -> Result<()> {
let mut len_buf = [0u8; 4];
recv.read_exact(&mut len_buf)
.await
.map_err(|e| TransportError::ReceiveFailed(e.to_string()))?;
let len = u32::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
recv.read_exact(&mut buf)
.await
.map_err(|e| TransportError::ReceiveFailed(e.to_string()))?;
let envelope = Envelope::deserialize(&buf)?;
let request_id = envelope.timestamp.as_u128() as u64;
if let Some((_, tx)) = pending_requests.remove(&request_id) {
let _ = tx.send(envelope);
} else {
let _ = incoming_tx.send(envelope).await;
}
Ok(())
}
#[async_trait]
impl Transport for QuicTransport {
async fn send(&self, to: SocketAddr, envelope: Envelope) -> Result<()> {
let conn = self.get_connection(to).await?;
let (mut send, _recv) = conn
.open_bi()
.await
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
let data = envelope.serialize()?;
let len = (data.len() as u32).to_be_bytes();
send.write_all(&len)
.await
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
send.write_all(&data)
.await
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
send.finish()
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
Ok(())
}
async fn send_recv(&self, to: SocketAddr, envelope: Envelope) -> Result<Envelope> {
let request_id = envelope.timestamp.as_u128() as u64;
let (tx, rx) = oneshot::channel();
self.pending_requests.insert(request_id, tx);
self.send(to, envelope).await?;
tokio::time::timeout(Duration::from_secs(10), rx)
.await
.map_err(|_| pollen_types::PollenError::Timeout)?
.map_err(|_| pollen_types::PollenError::Cancelled)
}
fn incoming(&self) -> mpsc::Receiver<Envelope> {
self.incoming_rx
.try_write()
.ok()
.and_then(|mut guard| guard.take())
.expect("incoming() can only be called once")
}
fn local_addr(&self) -> SocketAddr {
self.endpoint.local_addr().unwrap()
}
fn node_id(&self) -> NodeId {
self.config.node_id
}
async fn shutdown(&self) {
let _ = self.shutdown.send(());
self.endpoint.close(0u32.into(), b"shutdown");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_transport_creation() {
let config = TransportConfig::new("127.0.0.1:0".parse().unwrap());
let transport = QuicTransport::new(config).await.unwrap();
assert!(transport.local_addr().port() > 0);
transport.shutdown().await;
}
}