use std::net::SocketAddr;
use async_trait::async_trait;
use futures::FutureExt;
use tokio::task::JoinHandle;
use super::*;
use crate::{
client::PsqClient,
PsqError,
server::Endpoint,
util::{
hdrs_to_strings,
send_quic_packets,
MAX_DATAGRAM_SIZE,
},
};
pub struct UdpTunnel {
stream_id: u64,
socket: Arc<UdpSocket>, clientaddr: Arc<Mutex<Option<SocketAddr>>>, taskhandle: Option<JoinHandle<Result<(), PsqError>>>,
}
impl UdpTunnel {
pub async fn connect<'a>(
pconn: &'a mut PsqClient,
urlstr: &str,
host: &str,
port: u16,
localaddr: SocketAddr,
) -> Result<&'a UdpTunnel, PsqError> {
let mut url = pconn.get_url().join(urlstr)?;
url.path_segments_mut()
.map_err(|_| PsqError::Custom(
"Base URL cannot have a non-empty fragment".into()
))?
.extend(&[host, &port.to_string()]);
let stream_id = start_connection(
pconn,
&url,
"connect-udp"
).await?;
let socket = UdpSocket::bind(localaddr).await?;
let ret = pconn.add_stream(
stream_id,
Box::new(UdpTunnel {
stream_id,
socket: Arc::new(socket),
clientaddr: Arc::new(Mutex::new(None)),
taskhandle: None,
})
).await;
match ret {
Ok(stream) => {
Ok(UdpTunnel::get_from_dyn(stream))
},
Err(e) => Err(e)
}
}
pub fn sockaddr(&self) -> Result<std::net::SocketAddr, PsqError> {
Ok(self.socket.local_addr()?)
}
fn new(
stream_id: u64,
socket: UdpSocket,
) -> Result<UdpTunnel, PsqError> {
Ok(UdpTunnel {
stream_id,
socket: Arc::new(socket),
clientaddr: Arc::new(Mutex::new(None)),
taskhandle: None,
})
}
fn get_from_dyn(stream: &Box<dyn PsqStream>) -> &UdpTunnel {
stream.as_any().downcast_ref::<UdpTunnel>().unwrap()
}
fn start_socket_listener(
&mut self,
qconn: &Arc<Mutex<quiche::Connection>>,
qsocket: &Arc::<UdpSocket>,
) {
let qconn = Arc::clone(qconn);
let qsocket = Arc::clone(qsocket);
let clientaddr = Arc::clone(&self.clientaddr);
let socket = self.socket.clone();
let stream_id = self.stream_id;
let handle = tokio::spawn(async move {
let mut buf = [0u8; MAX_DATAGRAM_SIZE];
loop {
let defined;
{
defined = clientaddr.lock().await.is_some();
}
let n = match defined {
true => socket.recv(&mut buf).await?,
false => {
let ret = socket.recv_from(&mut buf).await?;
debug!("hee");
*clientaddr.lock().await = Some(ret.1);
socket.connect(ret.1).await?;
ret.0
}
};
debug!("Sending {} bytes to HTTP/3 UDP tunnel", n);
send_h3_dgram(&mut *qconn.lock().await, stream_id, &buf[..n])?;
send_quic_packets(&qconn, &qsocket).await?;
};
});
self.taskhandle = Some(handle);
}
fn check_task_error(&mut self) -> Option<PsqError> {
if let Some(handle) = &mut self.taskhandle {
if let Some(result) = handle.now_or_never() {
match result {
Ok(Ok(())) => {
debug!("Background task completed successfully.");
self.taskhandle = None;
None
}
Ok(Err(e)) => {
error!("Background task returned error: {}", e);
self.taskhandle = None;
Some(e)
}
Err(join_err) => {
error!("Background task panicked: {}", join_err);
self.taskhandle = None;
Some(PsqError::Custom("Task panicked".to_string()))
}
}
} else {
None
}
} else {
None
}
}
}
#[async_trait]
impl PsqStream for UdpTunnel {
async fn process_datagram(&mut self, buf: &[u8]) -> Result<(), PsqError> {
if let Some(e) = self.check_task_error() {
error!("UDP reader task failed: {}", e);
return Err(e)
}
debug!("Received {} bytes from HTTP/3 UDP tunnel", buf.len());
if self.clientaddr.lock().await.is_none() {
return Err(PsqError::Custom(
"Received datagram from UDP tunnel, but no consuming socket known".into()))
}
self.socket.send(buf).await?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn is_ready(&self) -> bool {
self.taskhandle.is_some()
}
async fn process_h3_response(
&mut self,
h3_conn: &mut quiche::h3::Connection,
conn: &Arc<Mutex<quiche::Connection>>,
socket: &Arc<UdpSocket>,
event: quiche::h3::Event,
buf: &mut [u8],
) -> Result<(), PsqError> {
match event {
quiche::h3::Event::Headers { list, .. } => {
info!(
"got response headers {:?} on stream id {}",
hdrs_to_strings(&list),
self.stream_id
);
let status = get_h3_status(&list)?;
if status != 200 {
return Err(PsqError::HttpResponse(status, "CONNECT request unsuccesful".to_string()))
}
self.start_socket_listener(&conn, &socket);
},
quiche::h3::Event::Data => {
let c = &mut *conn.lock().await;
while let Ok(read) =
h3_conn.recv_body(c, self.stream_id, buf)
{
debug!(
"got {} bytes of response data on stream {}",
read, self.stream_id
);
}
},
quiche::h3::Event::Finished => {
info!(
"IpTunnel stream finished!"
);
},
quiche::h3::Event::Reset(e) => {
error!(
"request was reset by peer with {}, closing...",
e
);
let c = &mut *conn.lock().await;
c.close(true, 0x100, b"kthxbye").unwrap();
},
quiche::h3::Event::PriorityUpdate => unreachable!(),
quiche::h3::Event::GoAway => {
info!("GOAWAY");
},
}
Ok(())
}
}
pub struct UdpEndpoint {
}
impl UdpEndpoint {
pub fn new(
) -> Result<Box<dyn Endpoint>, PsqError> {
Ok(Box::new(UdpEndpoint {
}))
}
}
#[async_trait]
impl Endpoint for UdpEndpoint {
async fn process_request(
&mut self,
request: &[quiche::h3::Header],
qconn: &Arc<Mutex<quiche::Connection>>,
qsocket: &Arc<UdpSocket>,
stream_id: u64,
) -> Result<(Option<Box<dyn PsqStream + Send + Sync + 'static>>, Vec<u8>), PsqError> {
let mut desthost = "";
let mut destport: u16 = 0;
for hdr in request {
check_common_headers(hdr, "connect-udp")?;
if hdr.name() == b":path" {
let path = std::path::Path::new(
std::str::from_utf8(hdr.value()).unwrap()
);
let mut segments = path.iter();
segments.next();
segments.next();
let host = segments.next()
.ok_or_else(|| PsqError::Custom("Missing host in path".to_string()))?;
let port = segments.next()
.ok_or_else(|| PsqError::Custom("Missing port in path".to_string()))?;
desthost = host.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in host".to_string()))?;
let port_str = port.to_str().ok_or_else(|| PsqError::Custom("Invalid UTF-8 in port".to_string()))?;
destport = port_str.parse()
.map_err(|_| PsqError::Custom("Invalid port number".to_string()))?;
}
}
if destport == 0 {
return Err(PsqError::Custom(
"Could not parse destination address for the UDP tunnel".into()
))
}
debug!("Starting UDP tunnel to {}:{}", desthost, destport);
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(format!("{}:{}", desthost, destport)).await?;
let mut udptunnel = Box::new(UdpTunnel::new(
stream_id,
socket,
)?);
{
*udptunnel.clientaddr.lock().await = Some(udptunnel.socket.local_addr().unwrap());
}
udptunnel.start_socket_listener(&qconn, &qsocket);
let body = Vec::<u8>::new();
Ok((Some(udptunnel), body))
}
}