#![allow(non_camel_case_types)]
#![cfg(feature = "http3")]
use crate::utils::refstr::Headers;
use crate::{HttpMethod, HttpRequest, HttpResponse, HttpResponseBody, SERVER_STR};
use anyhow::anyhow;
use bytes::Buf;
use h3_quinn::quinn;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_rustls::rustls;
#[derive(Debug)]
struct NoCertificateVerification;
impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
pub struct H3SessionImpl {
pub unique_host: (String, u16),
pub endpoint: quinn::Endpoint,
pub send_request: h3::client::SendRequest<h3_quinn::OpenStreams, bytes::Bytes>,
pub driver_handle: tokio::task::JoinHandle<()>,
pub use_encrypt: bool,
}
impl H3SessionImpl {
pub async fn new(host: String, port: u16) -> anyhow::Result<Self> {
let mut root_cert = tokio_rustls::rustls::RootCertStore::empty();
root_cert.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut tls_config = tokio_rustls::rustls::ClientConfig::builder()
.with_root_certificates(root_cert)
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let mut endpoint = quinn::Endpoint::client("[::]:0".parse()?)?;
let client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)?,
));
endpoint.set_default_client_config(client_config);
let quic_conn = endpoint
.connect(format!("{host}:{port}").parse()?, &host)?
.await
.map_err(|e| anyhow!("QUIC connection failed: {e}"))?;
let (mut driver, send_request) = h3::client::new(h3_quinn::Connection::new(quic_conn))
.await
.map_err(|e| anyhow!("HTTP/3 client initialization failed: {e}"))?;
let driver_handle = tokio::spawn(async move {
let _ = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
});
Ok(H3SessionImpl {
unique_host: (host, port),
endpoint,
send_request,
driver_handle,
use_encrypt: true,
})
}
pub async fn new_without_encrypt(host: String, port: u16) -> anyhow::Result<Self> {
let mut tls_config = tokio_rustls::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let mut endpoint = quinn::Endpoint::client("[::]:0".parse()?)?;
let client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)?,
));
endpoint.set_default_client_config(client_config);
let socket_addr: std::net::SocketAddr = if host == "localhost" {
format!("127.0.0.1:{port}")
.parse()
.map_err(|e| anyhow!("Invalid address {}: {}", format!("127.0.0.1:{port}"), e))?
} else {
format!("{host}:{port}")
.parse()
.map_err(|e| anyhow!("Invalid address {}: {}", format!("{host}:{port}"), e))?
};
let connecting = endpoint
.connect(socket_addr, &host)
.map_err(|e| anyhow!("QUIC connection failed: {e}"))?;
let quic_conn = tokio::time::timeout(std::time::Duration::from_secs(10), connecting)
.await
.map_err(|e| anyhow!("QUIC connection timeout: {e}"))?
.map_err(|e| anyhow!("QUIC connection failed: {e}"))?;
let (mut driver, send_request) = h3::client::new(h3_quinn::Connection::new(quic_conn))
.await
.map_err(|e| anyhow!("HTTP/3 client initialization failed: {e}"))?;
let driver_handle = tokio::spawn(async move {
let _ = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
});
Ok(H3SessionImpl {
unique_host: (host, port),
endpoint,
send_request,
driver_handle,
use_encrypt: false,
})
}
}
pub struct H3Session {
pub sess_impl: Option<H3SessionImpl>,
}
macro_rules! define_h3_session_method {
($fn_name:ident, $method:ident) => {
pub async fn $fn_name(
&mut self,
url: &str,
args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
let (mut req, _) = self.new_request(HttpMethod::$method, url).await?;
for arg in args.into_iter() {
req.apply_header(arg);
}
self.do_request(req).await
}
};
($fn_name:ident, $fn_name2:ident, $fn_name3:ident, $method:ident) => {
pub async fn $fn_name(
&mut self,
url: &str,
body: Vec<u8>,
args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
let (mut req, _) = self.new_request(HttpMethod::$method, url).await?;
req.body = body.into();
for arg in args.into_iter() {
req.apply_header(arg);
}
self.do_request(req).await
}
pub async fn $fn_name2(
&mut self,
url: &str,
body: serde_json::Value,
mut args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
args.push(Headers::Content_Type("application/json".into()));
self.$fn_name(url, serde_json::to_vec(&body)?, args).await
}
pub async fn $fn_name3(
&mut self,
url: &str,
body: String,
mut args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
args.push(Headers::Content_Type("application/json".into()));
self.$fn_name(url, body.into_bytes(), args).await
}
};
}
impl Default for H3Session {
fn default() -> Self {
Self::new()
}
}
impl H3Session {
pub fn new() -> Self {
Self { sess_impl: None }
}
fn is_encrypt_url(url: &str) -> bool {
url.starts_with("https://")
}
async fn new_request(
&mut self,
method: HttpMethod,
url: &str,
) -> anyhow::Result<(HttpRequest, &mut H3SessionImpl)> {
let (mut req, _use_ssl, port) = HttpRequest::from_url(url, method)?;
let use_encrypt = Self::is_encrypt_url(url);
let host = url
.parse::<http::Uri>()?
.host()
.unwrap_or("127.0.0.1")
.to_string();
let mut is_same_host = false;
if let Some(sess_impl) = &mut self.sess_impl {
let (host1, port1) = &sess_impl.unique_host;
if (host1, port1) == (&host, &port) && sess_impl.use_encrypt == use_encrypt {
is_same_host = true;
}
}
if !is_same_host {
if let Some(old_impl) = self.sess_impl.take() {
old_impl.driver_handle.abort();
old_impl.endpoint.wait_idle().await;
}
if use_encrypt {
self.sess_impl = Some(H3SessionImpl::new(host, port).await?);
} else {
self.sess_impl = Some(H3SessionImpl::new_without_encrypt(host, port).await?);
}
}
req.apply_header(Headers::User_Agent(SERVER_STR.clone()));
req.version = 30;
let sess_impl = self
.sess_impl
.as_mut()
.ok_or_else(|| anyhow!("session implementation not initialized"))?;
Ok((req, sess_impl))
}
async fn do_request(&mut self, req: HttpRequest) -> anyhow::Result<HttpResponse> {
let sess_impl = self
.sess_impl
.as_mut()
.ok_or_else(|| anyhow!("session implementation not initialized"))?;
let host_with_port = if sess_impl.unique_host.1 == 443 {
sess_impl.unique_host.0.clone()
} else {
format!("{}:{}", sess_impl.unique_host.0, sess_impl.unique_host.1)
};
let uri_str = format!("https://{}{}", host_with_port, req.url_path);
let uri: http::Uri = if !req.url_query.is_empty() {
let query: Vec<String> = req
.url_query
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect();
format!("{uri_str}?{}", query.join("&")).parse()?
} else {
uri_str.parse()?
};
let method_str = match req.method {
HttpMethod::GET => http::Method::GET,
HttpMethod::POST => http::Method::POST,
HttpMethod::PUT => http::Method::PUT,
HttpMethod::DELETE => http::Method::DELETE,
HttpMethod::HEAD => http::Method::HEAD,
HttpMethod::OPTIONS => http::Method::OPTIONS,
HttpMethod::PATCH => http::Method::PATCH,
HttpMethod::CONNECT => http::Method::CONNECT,
HttpMethod::TRACE => http::Method::TRACE,
_ => http::Method::GET,
};
let mut builder = http::Request::builder().method(method_str).uri(uri);
for (key, value) in req.headers.iter() {
if let (Ok(name), Ok(val)) = (
http::header::HeaderName::from_str(key.to_str()),
http::HeaderValue::from_str(value.as_ref()),
) {
builder = builder.header(name, val);
}
}
let has_body = !req.body.is_empty();
let request = builder.body(())?;
let mut stream = sess_impl
.send_request
.send_request(request)
.await
.map_err(|e| anyhow!("Failed to send request: {e}"))?;
if has_body {
stream
.send_data(bytes::Bytes::from(req.body.to_vec()))
.await
.map_err(|e| anyhow!("Failed to send request body: {e}"))?;
}
stream
.finish()
.await
.map_err(|e| anyhow!("Failed to finish request: {e}"))?;
let response = stream
.recv_response()
.await
.map_err(|e| anyhow!("Failed to receive response: {e}"))?;
let status = response.status().as_u16();
let response_headers: Vec<(String, String)> = response
.headers()
.iter()
.filter_map(|(name, value)| {
let name_str = name.to_string();
let value_str = value.to_str().ok()?.to_string();
Some((name_str, value_str))
})
.collect();
let is_sse = response_headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
.map(|(_, value)| {
value
.split(';')
.next()
.map(|v| v.trim().eq_ignore_ascii_case("text/event-stream"))
.unwrap_or(false)
})
.unwrap_or(false);
if is_sse {
let (tx, rx) = mpsc::channel(64);
tokio::spawn(async move {
loop {
match stream.recv_data().await {
Ok(Some(mut chunk)) => {
let data = chunk.copy_to_bytes(chunk.remaining()).to_vec();
if tx.send(data).await.is_err() {
break;
}
}
Ok(None) => break,
Err(_) => break,
}
}
});
let mut res = HttpResponse::new();
res.http_code = status;
for (name, value) in response_headers.iter() {
res.headers
.insert(name.clone().into(), value.clone().into());
}
res.body = HttpResponseBody::Stream(rx);
Ok(res)
} else {
let mut body_data = Vec::new();
loop {
match stream.recv_data().await {
Ok(Some(mut chunk)) => {
body_data.extend_from_slice(&chunk.copy_to_bytes(chunk.remaining()));
}
Ok(None) => break,
Err(e) => return Err(anyhow!("Failed to read response body: {e}")),
}
}
let mut res = HttpResponse::new();
res.http_code = status;
for (name, value) in response_headers.iter() {
res.headers
.insert(name.clone().into(), value.clone().into());
}
res.body = HttpResponseBody::Data(body_data);
Ok(res)
}
}
define_h3_session_method!(get, GET);
define_h3_session_method!(post, post_json, post_json_str, POST);
define_h3_session_method!(put, put_json, put_json_str, PUT);
define_h3_session_method!(delete, DELETE);
define_h3_session_method!(head, HEAD);
define_h3_session_method!(options, OPTIONS);
define_h3_session_method!(patch, PATCH);
define_h3_session_method!(connect, CONNECT);
define_h3_session_method!(trace, TRACE);
}
macro_rules! define_h3_client_method {
($fn_name:ident) => {
pub async fn $fn_name(url: &str, args: Vec<Headers>) -> anyhow::Result<HttpResponse> {
H3Session::new().$fn_name(url, args).await
}
};
($fn_name:ident, $fn_name2:ident, $fn_name3:ident) => {
pub async fn $fn_name(
url: &str,
body: Vec<u8>,
args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
H3Session::new().$fn_name(url, body, args).await
}
pub async fn $fn_name2(
url: &str,
body: serde_json::Value,
args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
H3Session::new().$fn_name2(url, body, args).await
}
pub async fn $fn_name3(
url: &str,
body: String,
args: Vec<Headers>,
) -> anyhow::Result<HttpResponse> {
H3Session::new().$fn_name3(url, body, args).await
}
};
}
define_h3_client_method!(get);
define_h3_client_method!(post, post_json, post_json_str);
define_h3_client_method!(put, put_json, put_json_str);
define_h3_client_method!(delete);
define_h3_client_method!(head);
define_h3_client_method!(options);
define_h3_client_method!(patch);
define_h3_client_method!(connect);
define_h3_client_method!(trace);
pub struct WebTransport {
connection: quinn::Connection,
driver_handle: tokio::task::JoinHandle<()>,
}
impl Drop for WebTransport {
fn drop(&mut self) {
self.driver_handle.abort();
}
}
impl WebTransport {
pub async fn connect(url: &str, _headers: Vec<Headers>) -> anyhow::Result<Self> {
let uri: http::Uri = url.parse()?;
let host = uri
.host()
.ok_or_else(|| anyhow!("Invalid URL: missing host"))?;
let port = uri.port_u16().unwrap_or(443);
let path = uri.path().to_string();
use tokio_rustls::rustls;
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut tls_config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let mut endpoint = quinn::Endpoint::client("[::]:0".parse()?)?;
let client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)?,
));
endpoint.set_default_client_config(client_config);
let connection = endpoint
.connect(format!("{host}:{port}").parse()?, host)?
.await
.map_err(|e| anyhow!("QUIC connection failed: {e}"))?;
let (mut driver, mut send_request) =
h3::client::new(h3_quinn::Connection::new(connection.clone()))
.await
.map_err(|e| anyhow!("HTTP/3 client initialization failed: {e}"))?;
let driver_handle = tokio::spawn(async move {
let _ = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
});
let req = http::Request::builder()
.method(http::Method::CONNECT)
.uri(&path)
.header(":protocol", "webtransport")
.header(":scheme", "https")
.header(":authority", format!("{host}:{port}"))
.body(())
.map_err(|e| anyhow!("Failed to build CONNECT request: {e}"))?;
let mut stream = send_request
.send_request(req)
.await
.map_err(|e| anyhow!("Failed to send CONNECT request: {e}"))?;
let response = stream
.recv_response()
.await
.map_err(|e| anyhow!("Failed to get response: {e}"))?;
if response.status() != 200 {
return Err(anyhow!(
"WebTransport connection failed with status: {}",
response.status()
));
}
Ok(Self {
connection,
driver_handle,
})
}
pub async fn open_bi(&self) -> anyhow::Result<crate::WebTransportStream> {
let (send, recv) = self.connection.open_bi().await?;
Ok(crate::WebTransportStream::new(send, recv))
}
pub async fn open_uni(&self) -> anyhow::Result<quinn::SendStream> {
let send = self.connection.open_uni().await?;
Ok(send)
}
pub async fn accept_uni(&self) -> anyhow::Result<Option<quinn::RecvStream>> {
match self.connection.accept_uni().await {
Ok(recv) => Ok(Some(recv)),
Err(quinn::ConnectionError::ApplicationClosed(_)) => Ok(None),
Err(quinn::ConnectionError::ConnectionClosed(_)) => Ok(None),
Err(e) => Err(anyhow::anyhow!(
"Failed to accept unidirectional stream: {}",
e
)),
}
}
pub async fn send_datagram(&self, data: &[u8]) -> anyhow::Result<()> {
self.connection.send_datagram(data.to_vec().into())?;
Ok(())
}
pub async fn recv_datagram(&self) -> anyhow::Result<Vec<u8>> {
let data = self.connection.read_datagram().await?;
Ok(data.to_vec())
}
}