pollen-transport 0.1.0

QUIC-based transport layer for Pollen
Documentation
//! QUIC transport implementation.

use crate::{tls, Envelope, Transport, TransportConfig};
use async_trait::async_trait;
use dashmap::DashMap;
use pollen_types::{NodeId, Result, TransportError};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, RwLock};
use tracing::{debug, info, warn};

/// QUIC-based transport implementation.
pub struct QuicTransport {
    config: TransportConfig,
    endpoint: Endpoint,
    connections: Arc<DashMap<SocketAddr, Connection>>,
    incoming_tx: mpsc::Sender<Envelope>,
    incoming_rx: RwLock<Option<mpsc::Receiver<Envelope>>>,
    pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
    #[allow(dead_code)] // Used for request/response correlation
    request_counter: std::sync::atomic::AtomicU64,
    shutdown: tokio::sync::broadcast::Sender<()>,
}

impl QuicTransport {
    /// Create a new QUIC transport.
    pub async fn new(config: TransportConfig) -> Result<Arc<Self>> {
        let server_config = tls::server_config()?;
        let client_config = tls::client_config()?;

        let mut endpoint = Endpoint::server(server_config, config.bind_addr)
            .map_err(|e| TransportError::ConnectionFailed {
                addr: config.bind_addr.to_string(),
                reason: e.to_string(),
            })?;

        endpoint.set_default_client_config(client_config);

        let (incoming_tx, incoming_rx) = mpsc::channel(1000);
        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);

        let transport = Arc::new(Self {
            config,
            endpoint,
            connections: Arc::new(DashMap::new()),
            incoming_tx,
            incoming_rx: RwLock::new(Some(incoming_rx)),
            pending_requests: Arc::new(DashMap::new()),
            request_counter: std::sync::atomic::AtomicU64::new(0),
            shutdown: shutdown_tx,
        });

        // Start accepting connections
        let transport_clone = Arc::clone(&transport);
        tokio::spawn(async move {
            transport_clone.accept_loop().await;
        });

        info!("QUIC transport started on {}", transport.endpoint.local_addr().unwrap());

        Ok(transport)
    }

    /// Accept incoming connections.
    async fn accept_loop(self: Arc<Self>) {
        let mut shutdown_rx = self.shutdown.subscribe();

        loop {
            tokio::select! {
                _ = shutdown_rx.recv() => {
                    info!("Transport shutting down");
                    break;
                }
                Some(incoming) = self.endpoint.accept() => {
                    let transport = Arc::clone(&self);
                    tokio::spawn(async move {
                        match incoming.await {
                            Ok(conn) => {
                                let addr = conn.remote_address();
                                debug!("Accepted connection from {}", addr);
                                transport.connections.insert(addr, conn.clone());
                                transport.handle_connection(conn).await;
                            }
                            Err(e) => {
                                warn!("Failed to accept connection: {}", e);
                            }
                        }
                    });
                }
            }
        }
    }

    /// Handle incoming streams on a connection.
    async fn handle_connection(&self, conn: Connection) {
        let addr = conn.remote_address();
        let mut shutdown_rx = self.shutdown.subscribe();

        loop {
            tokio::select! {
                _ = shutdown_rx.recv() => {
                    break;
                }
                result = conn.accept_bi() => {
                    match result {
                        Ok((send, recv)) => {
                            let incoming_tx = self.incoming_tx.clone();
                            let pending = Arc::clone(&self.pending_requests);
                            tokio::spawn(async move {
                                if let Err(e) = handle_stream(send, recv, incoming_tx, pending).await {
                                    debug!("Stream error from {}: {}", addr, e);
                                }
                            });
                        }
                        Err(e) => {
                            debug!("Connection closed from {}: {}", addr, e);
                            break;
                        }
                    }
                }
            }
        }

        self.connections.remove(&addr);
    }

    /// Get or create a connection to a peer.
    async fn get_connection(&self, addr: SocketAddr) -> Result<Connection> {
        if let Some(conn) = self.connections.get(&addr) {
            if conn.close_reason().is_none() {
                return Ok(conn.clone());
            }
        }

        // Create new connection
        let conn = self
            .endpoint
            .connect(addr, "pollen")
            .map_err(|e| TransportError::ConnectionFailed {
                addr: addr.to_string(),
                reason: e.to_string(),
            })?
            .await
            .map_err(|e| TransportError::ConnectionFailed {
                addr: addr.to_string(),
                reason: e.to_string(),
            })?;

        self.connections.insert(addr, conn.clone());

        // Start handling incoming streams
        let transport = Arc::new(self.clone_inner());
        let conn_clone = conn.clone();
        tokio::spawn(async move {
            transport.handle_connection(conn_clone).await;
        });

        Ok(conn)
    }

    fn clone_inner(&self) -> QuicTransportInner {
        QuicTransportInner {
            connections: Arc::clone(&self.connections),
            incoming_tx: self.incoming_tx.clone(),
            pending_requests: Arc::clone(&self.pending_requests),
            shutdown: self.shutdown.clone(),
        }
    }
}

struct QuicTransportInner {
    connections: Arc<DashMap<SocketAddr, Connection>>,
    incoming_tx: mpsc::Sender<Envelope>,
    pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
    shutdown: tokio::sync::broadcast::Sender<()>,
}

impl QuicTransportInner {
    async fn handle_connection(&self, conn: Connection) {
        let addr = conn.remote_address();
        let mut shutdown_rx = self.shutdown.subscribe();

        loop {
            tokio::select! {
                _ = shutdown_rx.recv() => {
                    break;
                }
                result = conn.accept_bi() => {
                    match result {
                        Ok((send, recv)) => {
                            let incoming_tx = self.incoming_tx.clone();
                            let pending = Arc::clone(&self.pending_requests);
                            tokio::spawn(async move {
                                if let Err(e) = handle_stream(send, recv, incoming_tx, pending).await {
                                    debug!("Stream error from {}: {}", addr, e);
                                }
                            });
                        }
                        Err(e) => {
                            debug!("Connection closed from {}: {}", addr, e);
                            break;
                        }
                    }
                }
            }
        }

        self.connections.remove(&addr);
    }
}

async fn handle_stream(
    _send: SendStream,
    mut recv: RecvStream,
    incoming_tx: mpsc::Sender<Envelope>,
    pending_requests: Arc<DashMap<u64, oneshot::Sender<Envelope>>>,
) -> Result<()> {
    // Read length prefix
    let mut len_buf = [0u8; 4];
    recv.read_exact(&mut len_buf)
        .await
        .map_err(|e| TransportError::ReceiveFailed(e.to_string()))?;
    let len = u32::from_be_bytes(len_buf) as usize;

    // Read message
    let mut buf = vec![0u8; len];
    recv.read_exact(&mut buf)
        .await
        .map_err(|e| TransportError::ReceiveFailed(e.to_string()))?;

    let envelope = Envelope::deserialize(&buf)?;

    // Check if this is a response to a pending request
    // For simplicity, we use the timestamp as a request ID
    let request_id = envelope.timestamp.as_u128() as u64;
    if let Some((_, tx)) = pending_requests.remove(&request_id) {
        let _ = tx.send(envelope);
    } else {
        // Send to incoming channel
        let _ = incoming_tx.send(envelope).await;
    }

    Ok(())
}

#[async_trait]
impl Transport for QuicTransport {
    async fn send(&self, to: SocketAddr, envelope: Envelope) -> Result<()> {
        let conn = self.get_connection(to).await?;

        let (mut send, _recv) = conn
            .open_bi()
            .await
            .map_err(|e| TransportError::SendFailed(e.to_string()))?;

        let data = envelope.serialize()?;
        let len = (data.len() as u32).to_be_bytes();

        send.write_all(&len)
            .await
            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
        send.write_all(&data)
            .await
            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
        send.finish()
            .map_err(|e| TransportError::SendFailed(e.to_string()))?;

        Ok(())
    }

    async fn send_recv(&self, to: SocketAddr, envelope: Envelope) -> Result<Envelope> {
        let request_id = envelope.timestamp.as_u128() as u64;
        let (tx, rx) = oneshot::channel();
        self.pending_requests.insert(request_id, tx);

        // Send the request
        self.send(to, envelope).await?;

        // Wait for response with timeout
        tokio::time::timeout(Duration::from_secs(10), rx)
            .await
            .map_err(|_| pollen_types::PollenError::Timeout)?
            .map_err(|_| pollen_types::PollenError::Cancelled)
    }

    fn incoming(&self) -> mpsc::Receiver<Envelope> {
        self.incoming_rx
            .try_write()
            .ok()
            .and_then(|mut guard| guard.take())
            .expect("incoming() can only be called once")
    }

    fn local_addr(&self) -> SocketAddr {
        self.endpoint.local_addr().unwrap()
    }

    fn node_id(&self) -> NodeId {
        self.config.node_id
    }

    async fn shutdown(&self) {
        let _ = self.shutdown.send(());
        self.endpoint.close(0u32.into(), b"shutdown");
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    

    #[tokio::test]
    async fn test_transport_creation() {
        let config = TransportConfig::new("127.0.0.1:0".parse().unwrap());
        let transport = QuicTransport::new(config).await.unwrap();

        assert!(transport.local_addr().port() > 0);
        transport.shutdown().await;
    }
}