use std::{
collections::HashMap,
sync::Arc,
};
use quiche::h3::NameValue;
use tokio::{
net::UdpSocket,
sync::{watch, Mutex},
time::Duration,
};
use super::*;
use crate::{
PsqError,
server::Endpoints,
stream::{
process_h3_datagram,
PsqStream,
},
util::hdrs_to_strings,
};
pub (crate) struct ClientSession {
socket: Arc<UdpSocket>,
conn: Arc<Mutex<quiche::Connection>>,
http3_conn: Option<quiche::h3::Connection>,
partial_responses: HashMap<u64, PartialResponse>,
timeout_tx: watch::Sender<Option<Duration>>,
streams: HashMap<u64, Box<dyn PsqStream>>,
endpoints: Arc<Mutex<Endpoints>>,
jwt_secret: Vec<u8>,
}
impl ClientSession {
pub (crate) fn new(
socket: &Arc<UdpSocket>,
conn: quiche::Connection,
timeout_tx: watch::Sender<Option<Duration>>,
endpoints: &Arc<Mutex<Endpoints>>,
jwt_secret: &Vec<u8>,
) -> ClientSession {
ClientSession {
socket: Arc::clone(socket),
conn: Arc::new(Mutex::new(conn)),
http3_conn: None,
partial_responses: HashMap::new(),
timeout_tx: timeout_tx,
streams: HashMap::new(),
endpoints: Arc::clone(endpoints),
jwt_secret: jwt_secret.to_vec(),
}
}
pub (crate) fn connection(&self) -> &Arc<Mutex<quiche::Connection>> {
&self.conn
}
pub (crate) fn h3_connection(&self) -> &Option<quiche::h3::Connection> {
&self.http3_conn
}
pub (crate) async fn process_data(&mut self, pkt_buf: &mut [u8], recv_info: quiche::RecvInfo) {
self.set_timeout().await;
let mut conn = self.conn.lock().await;
let _read = match conn.recv(pkt_buf, recv_info) {
Ok(v) => v,
Err(e) => {
error!("{} recv failed: {:?}", conn.trace_id(), e);
return;
},
};
if (conn.is_in_early_data() || conn.is_established()) &&
self.http3_conn.is_none()
{
debug!(
"{} QUIC handshake completed, now trying HTTP/3",
conn.trace_id()
);
let mut h3_config = quiche::h3::Config::new().unwrap();
h3_config.enable_extended_connect(true);
let h3_conn = match quiche::h3::Connection::with_transport(
&mut conn,
&h3_config,
) {
Ok(v) => v,
Err(e) => {
error!("failed to create HTTP/3 connection: {}", e);
return;
},
};
self.http3_conn = Some(h3_conn);
}
let mut buf = [0; 10000]; match conn.dgram_recv(&mut buf) {
Ok(n) => {
let (stream_id, offset) = match process_h3_datagram(&buf) {
Ok((stream, off)) => (stream, off),
Err(e) => {
error!("Error processing HTTP/3 capsule: {}", e);
return;
},
};
let stream = self.streams.get_mut(&stream_id);
if stream.is_none() {
warn!("Datagram received but no matching stream ID: {}", stream_id);
} else {
if let Err(e) = stream.unwrap().process_datagram(&buf[offset..n]).await {
warn!("Error with received datagram: {}", e);
}
}
},
Err(e) => {
if e != quiche::Error::Done {
error!("Error receiving datagram: {}", e);
}
},
}
}
pub (crate) async fn send_packets(&self) {
if let Err(e) = send_quic_packets(&self.conn, &self.socket).await {
error!("Error sending packets: {}", e);
}
}
pub (crate) async fn handle_h3_requests(&mut self) {
self.handle_writable().await;
loop {
match self.poll_helper().await {
Ok((
stream_id,
quiche::h3::Event::Headers { list, .. },
)) => {
self.handle_request(stream_id, &list).await;
},
Ok((stream_id, quiche::h3::Event::Data)) => {
info!(
"{} got data on stream id {}",
self.conn.lock().await.trace_id(),
stream_id
);
},
Ok((stream_id, quiche::h3::Event::Finished)) => {
info!("Stream {} closed", stream_id);
self.remove_stream(stream_id).await;
},
Ok((stream_id, quiche::h3::Event::Reset(e))) => {
error!("Stream {} was reset: {}", stream_id, e);
self.remove_stream(stream_id).await;
},
Ok((
_prioritized_element_id,
quiche::h3::Event::PriorityUpdate,
)) => (),
Ok((_goaway_id, quiche::h3::Event::GoAway)) => (),
Err(quiche::h3::Error::Done) => {
break;
},
Err(e) => {
error!(
"{} HTTP/3 error {:?}",
self.conn.lock().await.trace_id(),
e
);
break;
},
}
}
}
async fn handle_request(
&mut self, stream_id: u64, headers: &[quiche::h3::Header],
) {
info!(
"{} got request {:?} on stream id {}",
self.conn.lock().await.trace_id(),
hdrs_to_strings(headers),
stream_id
);
let (headers, body, fin) = self.build_response(stream_id, headers).await;
let conn = &mut self.conn.lock().await;
let http3_conn = &mut self.http3_conn.as_mut().unwrap();
match http3_conn.send_response(conn, stream_id, &headers, false) {
Ok(v) => v,
Err(quiche::h3::Error::StreamBlocked) => {
let response = PartialResponse {
headers: Some(headers),
body,
written: 0,
};
self.partial_responses.insert(stream_id, response);
return;
},
Err(e) => {
error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
},
}
let written = match http3_conn.send_body(conn, stream_id, &body, fin) {
Ok(v) => v,
Err(quiche::h3::Error::Done) => 0,
Err(e) => {
error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
},
};
if written < body.len() {
let response = PartialResponse {
headers: None,
body,
written,
};
self.partial_responses.insert(stream_id, response);
}
}
async fn poll_helper(&mut self) -> Result<(u64, quiche::h3::Event), quiche::h3::Error> {
let mut conn = &mut *self.conn.lock().await;
self.http3_conn.as_mut().unwrap().poll(&mut conn)
}
async fn set_timeout(&self) {
let new_duration = self.conn.lock().await.timeout();
let _ = self.timeout_tx.send(new_duration);
}
async fn handle_writable(&mut self) {
let conn = &mut self.conn.lock().await;
for stream_id in conn.writable() {
let http3_conn = &mut self.http3_conn.as_mut().unwrap();
if !self.partial_responses.contains_key(&stream_id) {
return;
}
let resp = self.partial_responses.get_mut(&stream_id).unwrap();
if let Some(ref headers) = resp.headers {
match http3_conn.send_response(conn, stream_id, headers, false) {
Ok(_) => (),
Err(quiche::h3::Error::StreamBlocked) => {
return;
},
Err(e) => {
error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
},
}
}
resp.headers = None;
let body = &resp.body[resp.written..];
let written = match http3_conn.send_body(conn, stream_id, body, true) {
Ok(v) => v,
Err(quiche::h3::Error::Done) => 0,
Err(e) => {
self.partial_responses.remove(&stream_id);
error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
},
};
resp.written += written;
if resp.written == resp.body.len() {
self.partial_responses.remove(&stream_id);
}
}
}
async fn build_response(
&mut self,
stream_id: u64,
request: &[quiche::h3::Header],
) -> (Vec<quiche::h3::Header>, Vec<u8>, bool) {
let mut path = std::path::Path::new("");
for hdr in request {
match hdr.name() {
b":path" => {
let s = std::str::from_utf8(hdr.value());
if s.is_err() {
warn!("Invalid path");
return build_h3_response(400, "Invalid path!")
}
path = std::path::Path::new(s.unwrap())
},
_ => (),
}
}
let ep = path.components().nth(1);
if ep.is_none() {
return build_h3_response(404, "Not Found (empty path)")
}
let string = ep.unwrap().as_os_str().to_string_lossy().to_string();
match self.endpoints.lock().await.get_mut(&string) {
Some(endpoint) => {
let (status, body, fin) = match endpoint.process_request(
request,
&self.conn,
&self.socket,
stream_id,
&self.jwt_secret,
).await {
Ok((stream, body)) => {
if stream.is_some() {
self.streams.insert(stream_id, stream.unwrap());
}
(200, body, false)
},
Err(PsqError::HttpResponse(status, body)) => {
warn!("Http Response with error {}: {}", status, body);
(status, body.as_bytes().to_vec(), true)
},
Err(e) => {
error!("Error processing request: {}", e);
(500, format!("Error processing request: {}", e).as_bytes().to_vec(), true)
},
};
(build_h3_resp_headers(status, &body), body, fin)
}
None => {
let body = format!("Not Found: {}", string).as_bytes().to_vec();
(build_h3_resp_headers(404, &body), body, true)
}
}
}
async fn remove_stream(&mut self, stream_id: u64) {
if let Err(e) = self.conn.lock().await.stream_shutdown(stream_id, quiche::Shutdown::Read, 0) {
warn!("Could not send shutdown message: {}", e);
}
self.streams.remove(&stream_id);
}
}
struct PartialResponse {
headers: Option<Vec<quiche::h3::Header>>,
body: Vec<u8>,
written: usize,
}