use std::{
fmt,
net::SocketAddr,
};
use async_trait::async_trait;
use futures::FutureExt;
use tokio::task::JoinHandle;
use super::*;
use crate::{
client::PsqClient,
PsqError,
server::Endpoint,
util::{
send_quic_packets,
MAX_DATAGRAM_SIZE,
},
};
pub struct UdpTunnel {
stream_id: u64,
socket: Arc<UdpSocket>, clientaddr: Arc<Mutex<Option<SocketAddr>>>, taskhandle: Option<JoinHandle<Result<(), PsqError>>>,
}
impl UdpTunnel {
pub async fn connect<'a>(
pconn: &'a mut PsqClient,
urlstr: &str,
host: &str,
port: u16,
localaddr: SocketAddr,
) -> Result<&'a UdpTunnel, PsqError> {
let mut url = pconn.get_url().join(urlstr)?;
url.path_segments_mut()
.map_err(|_| PsqError::Custom(
"Base URL cannot have a non-empty fragment".into()
))?
.extend(&[host, &port.to_string()]);
let stream_id = start_connection(
pconn,
&url,
"connect-udp"
).await?;
let socket = UdpSocket::bind(localaddr).await?;
let ret = pconn.add_stream(
stream_id,
Box::new(UdpTunnel {
stream_id,
socket: Arc::new(socket),
clientaddr: Arc::new(Mutex::new(None)),
taskhandle: None,
})
).await;
match ret {
Ok(stream) => {
Ok(UdpTunnel::get_from_dyn(stream))
},
Err(e) => Err(e)
}
}
pub fn sockaddr(&self) -> Result<std::net::SocketAddr, PsqError> {
Ok(self.socket.local_addr()?)
}
fn new(
stream_id: u64,
socket: UdpSocket,
) -> Result<UdpTunnel, PsqError> {
Ok(UdpTunnel {
stream_id,
socket: Arc::new(socket),
clientaddr: Arc::new(Mutex::new(None)),
taskhandle: None,
})
}
fn get_from_dyn(stream: &Box<dyn PsqStream>) -> &UdpTunnel {
stream.as_any().downcast_ref::<UdpTunnel>().unwrap()
}
fn start_socket_listener(
&mut self,
qconn: &Arc<Mutex<quiche::Connection>>,
qsocket: &Arc::<UdpSocket>,
) {
let qconn = Arc::clone(qconn);
let qsocket = Arc::clone(qsocket);
let clientaddr = Arc::clone(&self.clientaddr);
let socket = self.socket.clone();
let stream_id = self.stream_id;
self.taskhandle = Some(tokio::spawn(async move {
let mut buf = [0u8; MAX_DATAGRAM_SIZE];
loop {
let defined;
{
defined = clientaddr.lock().await.is_some();
}
let n = match defined {
true => socket.recv(&mut buf).await?,
false => {
let ret = socket.recv_from(&mut buf).await?;
*clientaddr.lock().await = Some(ret.1);
socket.connect(ret.1).await?;
ret.0
}
};
debug!("Sending {} bytes to HTTP/3 UDP tunnel", n);
send_h3_dgram(&mut *qconn.lock().await, stream_id, &buf[..n])?;
send_quic_packets(&qconn, &qsocket).await?;
};
}));
}
fn check_task_error(&mut self) -> Option<PsqError> {
if let Some(handle) = &mut self.taskhandle {
if let Some(result) = handle.now_or_never() {
match result {
Ok(Ok(())) => {
debug!("Background task completed successfully.");
self.taskhandle = None;
None
}
Ok(Err(e)) => {
error!("Background task returned error: {}", e);
self.taskhandle = None;
Some(e)
}
Err(join_err) => {
error!("Background task panicked: {}", join_err);
self.taskhandle = None;
Some(PsqError::Custom("Task panicked".to_string()))
}
}
} else {
None
}
} else {
None
}
}
}
impl Drop for UdpTunnel {
fn drop(&mut self) {
debug!("Dropping IpTunnel");
if let Some(task) = &self.taskhandle {
task.abort();
}
}
}
#[async_trait]
impl PsqStream for UdpTunnel {
async fn process_datagram(&mut self, buf: &[u8]) -> Result<(), PsqError> {
if let Some(e) = self.check_task_error() {
error!("UDP reader task failed: {}", e);
return Err(e)
}
debug!("Received {} bytes from HTTP/3 UDP tunnel", buf.len());
if self.clientaddr.lock().await.is_none() {
return Err(PsqError::Custom(
"Received datagram from UDP tunnel, but no consuming socket known".into()))
}
self.socket.send(buf).await?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn is_ready(&self) -> bool {
self.taskhandle.is_some()
}
fn process_h3_headers(
&mut self,
conn: &Arc<Mutex<quiche::Connection>>,
socket: &Arc<UdpSocket>,
_list: &Vec<Header>,
) -> Result<(), PsqError> {
self.start_socket_listener(&conn, &socket);
Ok(())
}
async fn process_h3_data(
&mut self,
h3_conn: &mut quiche::h3::Connection,
conn: &Arc<Mutex<quiche::Connection>>,
_socket: &Arc<UdpSocket>,
buf: &mut [u8],
) -> Result<(), PsqError> {
let c = &mut *conn.lock().await;
while let Ok(read) =
h3_conn.recv_body(c, self.stream_id, buf)
{
debug!(
"got {} bytes of response data on stream {}",
read, self.stream_id
);
}
Ok(())
}
fn stream_id(&self) -> u64 {
self.stream_id
}
}
pub struct UdpEndpoint {
permission: Option<String>,
}
impl UdpEndpoint {
pub fn new(
) -> UdpEndpoint {
UdpEndpoint {
permission: None,
}
}
pub fn require_permission(
&mut self,
permission: &String,
) {
self.permission = Some(permission.to_string());
}
}
#[async_trait]
impl Endpoint for UdpEndpoint {
async fn process_request(
&mut self,
request: &[quiche::h3::Header],
qconn: &Arc<Mutex<quiche::Connection>>,
qsocket: &Arc<UdpSocket>,
stream_id: u64,
jwt_secret: &Vec<u8>,
) -> Result<(Option<Box<dyn PsqStream + Send + Sync + 'static>>, Vec<u8>), PsqError> {
let mut desthost = "";
let mut destport: u16 = 0;
let mut authorized = self.permission.is_none();
for hdr in request {
check_common_headers(hdr, "connect-udp")?;
authorized = authorized ||
check_authorized(hdr, self.permission.as_ref().unwrap(), jwt_secret)?;
if hdr.name() == b":path" {
let path = std::path::Path::new(
std::str::from_utf8(hdr.value()).unwrap()
);
let mut segments = path.iter();
segments.next();
segments.next();
let host = segments.next()
.ok_or_else(|| PsqError::Custom("Missing host in path".to_string()))?;
let port = segments.next()
.ok_or_else(|| PsqError::Custom("Missing port in path".to_string()))?;
desthost = host.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in host".to_string()))?;
let port_str = port.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in port".to_string()))?;
destport = port_str.parse()
.map_err(|_| PsqError::Custom("Invalid port number".to_string()))?;
}
}
if !authorized {
return Err(
PsqError::HttpResponse(401, "Authorization required".to_string())
)
}
if destport == 0 {
return Err(PsqError::Custom(
"Could not parse destination address for the UDP tunnel".into()
))
}
debug!("Starting UDP tunnel to {}:{}", desthost, destport);
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(format!("{}:{}", desthost, destport)).await?;
let mut udptunnel = Box::new(UdpTunnel::new(
stream_id,
socket,
)?);
{
*udptunnel.clientaddr.lock().await = Some(udptunnel.socket.local_addr().unwrap());
}
udptunnel.start_socket_listener(&qconn, &qsocket);
let body = Vec::<u8>::new();
Ok((Some(udptunnel), body))
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl fmt::Debug for UdpEndpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UdpEndpoint()")
}
}