use std::{
any::Any,
sync::Arc,
};
use async_trait::async_trait;
use quiche::h3::{Header, NameValue};
use tokio::{
net::UdpSocket,
sync::Mutex,
};
use crate::{
client::PsqClient,
jwt::Jwt,
PsqError,
VERSION_IDENTIFICATION,
util::MAX_DATAGRAM_SIZE,
};
pub (crate) enum Capsule {
AddressAssign = 0x01,
RouteAdvertisement = 0x03,
}
impl TryFrom<u64> for Capsule {
type Error = PsqError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(Capsule::AddressAssign),
0x03 => Ok(Capsule::RouteAdvertisement),
_ => Err(PsqError::H3Capsule(format!("Unknown capsule type: {}", value))),
}
}
}
#[async_trait]
pub trait PsqStream: Any + Send + Sync {
async fn process_datagram(&mut self, buf: &[u8]) -> Result<(), PsqError>;
fn is_ready(&self) -> bool;
fn as_any(&self) -> &dyn Any;
fn process_h3_headers(
&mut self,
conn: &Arc<Mutex<quiche::Connection>>,
socket: &Arc<UdpSocket>,
list: &Vec<Header>,
) -> Result<(), PsqError>;
async fn process_h3_data(
&mut self,
h3_conn: &mut quiche::h3::Connection,
conn: &Arc<Mutex<quiche::Connection>>,
socket: &Arc<UdpSocket>,
buf: &mut [u8],
) -> Result<(), PsqError>;
fn stream_id(&self) -> u64;
}
fn prepare_h3_request(
method: &str,
protocol: &str,
url: &url::Url,
token: &Option<String>,
) -> Vec<quiche::h3::Header> {
let mut path = String::from(url.path());
if let Some(query) = url.query() {
path.push('?');
path.push_str(query);
}
let mut headers = vec![
quiche::h3::Header::new(b":method", method.as_bytes()),
quiche::h3::Header::new(b":scheme", url.scheme().as_bytes()),
quiche::h3::Header::new(
b":authority",
url.host_str().unwrap().as_bytes(),
),
quiche::h3::Header::new(b":path", path.as_bytes()),
quiche::h3::Header::new(b"user-agent", format!("pasque/{}", VERSION_IDENTIFICATION).as_bytes()),
quiche::h3::Header::new(b"capsule-protocol", b"?1"),
];
if !protocol.is_empty() {
headers.push(quiche::h3::Header::new(b":protocol", protocol.as_bytes()));
}
if let Some(token) = token {
headers.push(quiche::h3::Header::new(
b"authorization",
format!("Bearer {}", token).as_bytes(),
));
}
headers
}
pub (crate) fn process_h3_datagram(buf: &[u8]) -> Result<(u64, usize), PsqError>{
let mut octets = octets::Octets::with_slice(buf);
let stream_id: u64 = octets.get_varint()? * 4;
let _context_id = octets.get_varint()?;
Ok((stream_id, octets.off()))
}
fn check_common_headers(
header: &quiche::h3::Header,
protocol: &str,
) -> Result<(), PsqError> {
match header.name() {
b":method" => {
if header.value() != b"CONNECT" {
return Err(PsqError::HttpResponse(
405,
"Only CONNECT method supported for this endpoint".to_string(),
))
}
},
b":protocol" => {
if header.value() != protocol.as_bytes() {
return Err(PsqError::HttpResponse(
406, format!("Only protocol '{}' supported at this endpoint", protocol),
))
}
}
b"capsule-protocol" => {
if header.value() != b"?1" {
return Err(PsqError::HttpResponse(
406, "Unsupported capsule protocol".to_string(),
))
}
}
_ => {},
}
Ok(())
}
fn check_authorized(
header: &quiche::h3::Header,
permission: &String,
jwt_secret: &Vec<u8>,
) -> Result<bool, PsqError> {
if header.name() == b"authorization" {
let value = String::from_utf8_lossy(header.value());
if let Some(token) = value.strip_prefix("Bearer ") {
match Jwt::verify_token(token, jwt_secret) {
Ok(token) => {
if token.claims.has_permission(permission) {
info!("Received valid token: {:?}", token.claims);
return Ok(true);
} else {
info!("Received token with insufficient permissions: {:?}",
token.claims
);
return Err(PsqError::HttpResponse(
403,
"Permission denied".to_string(),
));
}
}
Err(err) => {
warn!("Received invalid JWT token");
return Err(PsqError::HttpResponse(
401,
format!("Invalid token: {}", err),
));
}
}
}
}
Ok(false)
}
async fn start_connection<'a>(
pconn: &'a mut PsqClient,
url: &url::Url,
protocol: &str,
) -> Result<u64, PsqError> {
let req = prepare_h3_request(
"CONNECT",
protocol,
&url,
pconn.token(),
);
info!("sending HTTP request {:?}", req);
let a = pconn.connection();
let mut conn = a.lock().await;
let h3_conn = pconn.h3_connection().as_mut().unwrap();
let stream_id = h3_conn
.send_request(&mut *conn, &req, false)?;
Ok(stream_id)
}
fn send_h3_dgram(
conn: &mut quiche::Connection,
stream_id: u64,
buf: &[u8],
) -> Result<(), PsqError> {
let mut data: [u8; MAX_DATAGRAM_SIZE] = [0; MAX_DATAGRAM_SIZE];
let off = 3;
{
let mut octets = octets::OctetsMut::with_slice(data.as_mut_slice());
octets.put_varint_with_len(stream_id / 4, 2)?;
octets.put_varint_with_len(0, 1)?;
}
let end = off + buf.len();
data[off..end].copy_from_slice(buf);
conn.dgram_send(&data[..end])?;
Ok(())
}
pub mod iptunnel;
pub mod filestream;
pub mod udptunnel;