pasque 0.3.0

UDP and IP over HTTP/3
Documentation
use std::{
    fmt,
    net::SocketAddr,
};

use async_trait::async_trait;
use futures::FutureExt;

use tokio::task::JoinHandle;

use super::*;
use crate::{
    client::PsqClient,
    PsqError,
    server::Endpoint,
    util::{
        send_quic_packets, 
        MAX_DATAGRAM_SIZE,
    },
};

/// UDP tunnel over HTTP/3 connection.
/// Tunnel is associated with an HTTP/3 stream established with
/// CONNECT request.
/// 
/// See [RFC 9298](https://datatracker.ietf.org/doc/html/rfc9298) for more information.
pub struct UdpTunnel {
    stream_id: u64,
    socket: Arc<UdpSocket>,  // local socket from which data relayed to tunnel
    clientaddr: Arc<Mutex<Option<SocketAddr>>>,  // Peer address of the user of client socket
    taskhandle: Option<JoinHandle<Result<(), PsqError>>>,
}

impl UdpTunnel {

    /// Connect a new UDP tunnel using given HTTP/3 connection, as indicated
    /// with `pconn`. 
    /// 
    /// Sends an HTTP/3 CONNECT request, and if successful, returns the created
    /// UdpTunnel object in response that can be used for further tunnel/proxy
    /// operations. Blocks until response to the CONNECT request is processed.
    /// 
    /// `urlstr` is URL path of the UDP proxy endpoint at server. It is appended
    /// to the base URL used when establishing connection. `urlstr` should not
    /// contain the target address parameters, but these are given in `host` and
    /// `port`.
    /// 
    /// The established tunnel is connected to UDP socket that will be bound at
    /// `localaddr`. Wildcard addresses are allowed. You can query the actual
    /// address using the [`UdpTunnel::sockaddr()`] function.
    pub async fn connect<'a>(
        pconn: &'a mut PsqClient,
        urlstr: &str,
        host: &str,
        port: u16,
        localaddr: SocketAddr,
    ) -> Result<&'a UdpTunnel, PsqError> {

        // Add host, port to URL
        let mut url = pconn.get_url().join(urlstr)?;
        url.path_segments_mut()
            .map_err(|_| PsqError::Custom(
                "Base URL cannot have a non-empty fragment".into()
            ))?
            .extend(&[host, &port.to_string()]);

        let stream_id = start_connection(
            pconn,
            &url,
            "connect-udp"
        ).await?;

        let socket = UdpSocket::bind(localaddr).await?;

        // Blocks until request is replied and tunnel is set up
        let ret = pconn.add_stream(
            stream_id,
            Box::new(UdpTunnel {
                stream_id,
                socket: Arc::new(socket),
                clientaddr: Arc::new(Mutex::new(None)),
                taskhandle: None,
            })
        ).await;
        match ret {
            Ok(stream) => {
                Ok(UdpTunnel::get_from_dyn(stream))
            },
            Err(e) => Err(e)
        }
    }


    /// Returns the address of UDP socket connected to the HTTP/3 tunnel.
    /// 
    /// All datagrams sent to this address are forwarded to the tunnel, and all
    /// datagrams coming from the tunnel can be read using this socket.
    pub fn sockaddr(&self) -> Result<std::net::SocketAddr, PsqError> {
        Ok(self.socket.local_addr()?)
    }


    fn new(
        stream_id: u64,
        socket: UdpSocket,
    ) -> Result<UdpTunnel, PsqError> {
        Ok(UdpTunnel {
            stream_id,
            socket: Arc::new(socket),
            clientaddr: Arc::new(Mutex::new(None)),
            taskhandle: None,
         })
    }


    fn get_from_dyn(stream: &Box<dyn PsqStream>) -> &UdpTunnel {
        stream.as_any().downcast_ref::<UdpTunnel>().unwrap()
    }


    fn start_socket_listener(
        &mut self,
        qconn: &Arc<Mutex<quiche::Connection>>,
        qsocket: &Arc::<UdpSocket>,
    ) {

        let qconn = Arc::clone(qconn);
        let qsocket = Arc::clone(qsocket);
        let clientaddr = Arc::clone(&self.clientaddr);
        let socket = self.socket.clone();

        let stream_id = self.stream_id;

        self.taskhandle = Some(tokio::spawn(async move {
            let mut buf = [0u8; MAX_DATAGRAM_SIZE];
            loop {
                let defined;
                {
                    defined = clientaddr.lock().await.is_some();
                }
                let n = match defined {
                    true => socket.recv(&mut buf).await?,
                    false => {
                        let ret = socket.recv_from(&mut buf).await?;
                        *clientaddr.lock().await = Some(ret.1);
                        socket.connect(ret.1).await?;
                        ret.0
                    }
                };
                debug!("Sending {} bytes to HTTP/3 UDP tunnel", n);
                send_h3_dgram(&mut *qconn.lock().await, stream_id, &buf[..n])?;
                send_quic_packets(&qconn, &qsocket).await?;
            };
        }));
    }


    fn check_task_error(&mut self) -> Option<PsqError> {
        if let Some(handle) = &mut self.taskhandle {
            if let Some(result) = handle.now_or_never() {
                match result {
                    Ok(Ok(())) => {
                        debug!("Background task completed successfully.");
                        self.taskhandle = None;
                        None
                    }
                    Ok(Err(e)) => {
                        error!("Background task returned error: {}", e);
                        self.taskhandle = None;
                        Some(e)
                    }
                    Err(join_err) => {
                        error!("Background task panicked: {}", join_err);
                        self.taskhandle = None;
                        Some(PsqError::Custom("Task panicked".to_string()))
                    }
                }
            } else {
                // Task still running
                None
            }
        } else {
            // No task running
            None
        }
    }
}

impl Drop for UdpTunnel {
    fn drop(&mut self) {
        debug!("Dropping IpTunnel");
        if let Some(task) = &self.taskhandle {
            task.abort();
        }
    }
}

#[async_trait]
impl PsqStream for UdpTunnel {
    async fn process_datagram(&mut self, buf: &[u8]) -> Result<(), PsqError> {

        // check if Tokio reader task is still running
        if let Some(e) = self.check_task_error() {
            error!("UDP reader task failed: {}", e);
            return Err(e)
        }

        debug!("Received {} bytes from HTTP/3 UDP tunnel", buf.len());

        if self.clientaddr.lock().await.is_none() {
            return Err(PsqError::Custom(
                "Received datagram from UDP tunnel, but no consuming socket known".into()))
        }

        self.socket.send(buf).await?;

        Ok(())
    }

    fn as_any(&self) -> &dyn Any {
        self
    }


    fn is_ready(&self) -> bool {
        self.taskhandle.is_some()
    }


    fn process_h3_headers(
        &mut self,
        conn: &Arc<Mutex<quiche::Connection>>,
        socket: &Arc<UdpSocket>,
        _list: &Vec<Header>,
    ) -> Result<(), PsqError> {
        self.start_socket_listener(&conn, &socket);
        Ok(())
    }


    async fn process_h3_data(
        &mut self,
        h3_conn: &mut quiche::h3::Connection,
        conn: &Arc<Mutex<quiche::Connection>>,
        _socket: &Arc<UdpSocket>,
        buf: &mut [u8],
    ) -> Result<(), PsqError> {
        let c = &mut *conn.lock().await;
        while let Ok(read) =
            h3_conn.recv_body(c, self.stream_id, buf)
        {
            debug!(
                "got {} bytes of response data on stream {}",
                read, self.stream_id
            );
        }
        Ok(())
    }


    fn stream_id(&self) -> u64 {
        self.stream_id
    }
}


/// Server endpoint for UDP tunnel over HTTP/3
/// (see [RFC 9298](https://datatracker.ietf.org/doc/html/rfc9298)).
pub struct UdpEndpoint {
    /// Permission label required to be present in incoming JWT token.
    permission: Option<String>,
}

impl UdpEndpoint {

    /// Create a new UDP tunnel endpoint that relays UDP datagrams to a given
    /// destination.
    /// 
    /// The endpoint path is appended with DNS name or IP address and port, for
    /// example if the endpoint prefix is "udp":
    /// 
    /// `https://someaddress.org/udp/192.0.2.6/443/`
    pub fn new(
    ) -> UdpEndpoint {

        UdpEndpoint {
            permission: None,
        }
    }

    /// Require permission label required in incoming JWT token.
    /// 
    /// If incoming request does not have JWT token, or the token does not
    /// include this permission label in its claims, the request is rejected as
    /// unauthorized.
    pub fn require_permission(
        &mut self,
        permission: &String,
    ) {
        self.permission = Some(permission.to_string());
    }
}

#[async_trait]
impl Endpoint for UdpEndpoint {
    async fn process_request(
        &mut self,
        request: &[quiche::h3::Header],
        qconn: &Arc<Mutex<quiche::Connection>>,
        qsocket: &Arc<UdpSocket>,
        stream_id: u64,
        jwt_secret: &Vec<u8>,
    ) -> Result<(Option<Box<dyn PsqStream + Send + Sync + 'static>>, Vec<u8>), PsqError> {

        let mut desthost = "";
        let mut destport: u16 = 0;

        let mut authorized = self.permission.is_none();
        for hdr in request {
            check_common_headers(hdr, "connect-udp")?;
            authorized = authorized ||
                check_authorized(hdr, self.permission.as_ref().unwrap(), jwt_secret)?;

            if hdr.name() == b":path" {
                let path = std::path::Path::new(
                    // UTF8 validity was already checked earlier
                    std::str::from_utf8(hdr.value()).unwrap()
                );

                let mut segments = path.iter();

                // Skip the first segment (like "udp")
                // TODO: implement properly
                segments.next();
                segments.next();

                let host = segments.next()
                    .ok_or_else(|| PsqError::Custom("Missing host in path".to_string()))?;
                let port = segments.next()
                    .ok_or_else(|| PsqError::Custom("Missing port in path".to_string()))?;

                desthost = host.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in host".to_string()))?;
                let port_str = port.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in port".to_string()))?;
                destport = port_str.parse()
                    .map_err(|_| PsqError::Custom("Invalid port number".to_string()))?;

            }
        }

        if !authorized {
            return Err(
                PsqError::HttpResponse(401, "Authorization required".to_string())
            )
        }

        if destport == 0 {
            return Err(PsqError::Custom(
                "Could not parse destination address for the UDP tunnel".into()
            ))
        }

        debug!("Starting UDP tunnel to {}:{}", desthost, destport);

        // Open UDP socket to given address
        let socket = UdpSocket::bind("0.0.0.0:0").await?;
        socket.connect(format!("{}:{}", desthost, destport)).await?;

        let mut udptunnel = Box::new(UdpTunnel::new(
            stream_id,
            socket,
        )?);
        {
            *udptunnel.clientaddr.lock().await = Some(udptunnel.socket.local_addr().unwrap());
        }
        udptunnel.start_socket_listener(&qconn, &qsocket);

        let body = Vec::<u8>::new();
        Ok((Some(udptunnel), body))
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

impl fmt::Debug for UdpEndpoint {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "UdpEndpoint()")
    }
}