use crate::config::ClientConfig;
use crate::config::DnsResolver;
use crate::config::ServerConfig;
use crate::connection::Connection;
use crate::driver::streams::session::StreamSession;
use crate::driver::streams::ProtoReadError;
use crate::driver::streams::ProtoWriteError;
use crate::driver::utils::varint_w2q;
use crate::driver::Driver;
use crate::error::ConnectingError;
use crate::error::ConnectionError;
use crate::VarInt;
use quinn::TokioRuntime;
use std::collections::HashMap;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::SocketAddrV6;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use tracing::debug;
use url::Host;
use url::Url;
use wtransport_proto::error::ErrorCode;
use wtransport_proto::frame::FrameKind;
use wtransport_proto::headers::Headers;
use wtransport_proto::session::ReservedHeader;
use wtransport_proto::session::SessionRequest as SessionRequestProto;
use wtransport_proto::session::SessionResponse as SessionResponseProto;
pub mod endpoint_side {
use super::*;
pub struct Server {
pub(super) _marker: PhantomData<()>,
}
pub struct Client {
pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
}
}
pub struct Endpoint<Side> {
endpoint: quinn::Endpoint,
side: Side,
}
impl<Side> Endpoint<Side> {
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
self.endpoint.close(varint_w2q(error_code), reason);
}
pub async fn wait_idle(&self) {
self.endpoint.wait_idle().await;
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.endpoint.local_addr()
}
pub fn open_connections(&self) -> usize {
self.endpoint.open_connections()
}
}
impl Endpoint<endpoint_side::Server> {
pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
let endpoint_config = server_config.endpoint_config;
let quic_config = server_config.quic_config;
let socket = server_config.bind_address_config.bind_socket()?;
let runtime = Arc::new(TokioRuntime);
let endpoint = quinn::Endpoint::new(endpoint_config, Some(quic_config), socket, runtime)?;
Ok(Self {
endpoint,
side: endpoint_side::Server {
_marker: PhantomData,
},
})
}
pub async fn accept(&self) -> IncomingSession {
let quic_incoming = self
.endpoint
.accept()
.await
.expect("Endpoint cannot be closed");
debug!("New incoming QUIC connection");
IncomingSession(quic_incoming)
}
pub fn reload_config(&self, server_config: ServerConfig, rebind: bool) -> std::io::Result<()> {
if rebind {
let socket = server_config.bind_address_config.bind_socket()?;
self.endpoint.rebind(socket)?;
}
let quic_config = server_config.quic_config;
self.endpoint.set_server_config(Some(quic_config));
Ok(())
}
}
impl Endpoint<endpoint_side::Client> {
pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
let endpoint_config = client_config.endpoint_config;
let quic_config = client_config.quic_config;
let socket = client_config.bind_address_config.bind_socket()?;
let runtime = Arc::new(TokioRuntime);
let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket, runtime)?;
endpoint.set_default_client_config(quic_config);
Ok(Self {
endpoint,
side: endpoint_side::Client {
dns_resolver: client_config.dns_resolver,
},
})
}
pub async fn connect<O>(&self, options: O) -> Result<Connection, ConnectingError>
where
O: IntoConnectOptions,
{
let options = options.into_options();
let url = Url::parse(&options.url)
.map_err(|parse_error| ConnectingError::InvalidUrl(parse_error.to_string()))?;
if url.scheme() != "https" {
return Err(ConnectingError::InvalidUrl(
"WebTransport URL scheme must be 'https'".to_string(),
));
}
let host = url.host().expect("https scheme must have an host");
let port = url.port().unwrap_or(443);
let (socket_address, server_name) = match host {
Host::Domain(domain) => {
let socket_address = self
.side
.dns_resolver
.resolve(&format!("{domain}:{port}"))
.await
.map_err(ConnectingError::DnsLookup)?
.ok_or(ConnectingError::DnsNotFound)?;
(socket_address, domain.to_string())
}
Host::Ipv4(address) => {
let socket_address = SocketAddr::V4(SocketAddrV4::new(address, port));
(socket_address, address.to_string())
}
Host::Ipv6(address) => {
let socket_address = SocketAddr::V6(SocketAddrV6::new(address, port, 0, 0));
(socket_address, address.to_string())
}
};
let quic_connection = self
.endpoint
.connect(socket_address, &server_name)
.map_err(ConnectingError::with_connect_error)?
.await
.map_err(|connection_error| {
ConnectingError::ConnectionError(connection_error.into())
})?;
let driver = Driver::init(quic_connection.clone());
let _settings = driver.accept_settings().await.map_err(|driver_error| {
ConnectingError::ConnectionError(ConnectionError::with_driver_error(
driver_error,
&quic_connection,
))
})?;
let mut session_request_proto =
SessionRequestProto::new(url.as_ref()).expect("Url has been already validate");
for (k, v) in options.additional_headers {
session_request_proto
.insert(k.clone(), v)
.map_err(|ReservedHeader| ConnectingError::ReservedHeader(k))?;
}
let mut stream_session = match driver.open_session(session_request_proto).await {
Ok(stream_session) => stream_session,
Err(driver_error) => {
return Err(ConnectingError::ConnectionError(
ConnectionError::with_driver_error(driver_error, &quic_connection),
))
}
};
let session_id = stream_session.session_id();
match stream_session
.write_frame(stream_session.request().headers().generate_frame())
.await
{
Ok(()) => {}
Err(ProtoWriteError::Stopped) => {
return Err(ConnectingError::SessionRejected);
}
Err(ProtoWriteError::NotConnected) => {
return Err(ConnectingError::with_no_connection(&quic_connection));
}
}
let frame = loop {
let frame = match stream_session.read_frame().await {
Ok(frame) => frame,
Err(ProtoReadError::H3(error_code)) => {
quic_connection.close(varint_w2q(error_code.to_code()), b"");
return Err(ConnectingError::ConnectionError(
ConnectionError::local_h3_error(error_code),
));
}
Err(ProtoReadError::IO(_io_error)) => {
return Err(ConnectingError::with_no_connection(&quic_connection));
}
};
if let FrameKind::Exercise(_) = frame.kind() {
continue;
}
break frame;
};
if !matches!(frame.kind(), FrameKind::Headers) {
quic_connection.close(varint_w2q(ErrorCode::FrameUnexpected.to_code()), b"");
return Err(ConnectingError::ConnectionError(
ConnectionError::local_h3_error(ErrorCode::FrameUnexpected),
));
}
let headers = match Headers::with_frame(&frame) {
Ok(headers) => headers,
Err(error_code) => {
quic_connection.close(varint_w2q(error_code.to_code()), b"");
return Err(ConnectingError::ConnectionError(
ConnectionError::local_h3_error(error_code),
));
}
};
let session_response = match SessionResponseProto::try_from(headers) {
Ok(session_response) => session_response,
Err(_) => {
quic_connection.close(varint_w2q(ErrorCode::Message.to_code()), b"");
return Err(ConnectingError::ConnectionError(
ConnectionError::local_h3_error(ErrorCode::Message),
));
}
};
if session_response.code().is_successful() {
match driver.register_session(stream_session).await {
Ok(()) => {}
Err(driver_error) => {
return Err(ConnectingError::ConnectionError(
ConnectionError::with_driver_error(driver_error, &quic_connection),
))
}
}
} else {
return Err(ConnectingError::SessionRejected);
}
Ok(Connection::new(quic_connection, driver, session_id))
}
}
#[derive(Debug, Clone)]
pub struct ConnectOptions {
url: String,
additional_headers: HashMap<String, String>,
}
impl ConnectOptions {
pub fn builder<S>(url: S) -> ConnectRequestBuilder
where
S: ToString,
{
ConnectRequestBuilder {
url: url.to_string(),
additional_headers: Default::default(),
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn additional_headers(&self) -> &HashMap<String, String> {
&self.additional_headers
}
}
pub trait IntoConnectOptions {
fn into_options(self) -> ConnectOptions;
}
pub struct ConnectRequestBuilder {
url: String,
additional_headers: HashMap<String, String>,
}
impl ConnectRequestBuilder {
pub fn add_header<K, V>(mut self, key: K, value: V) -> Self
where
K: ToString,
V: ToString,
{
self.additional_headers
.insert(key.to_string(), value.to_string());
self
}
pub fn build(self) -> ConnectOptions {
ConnectOptions {
url: self.url,
additional_headers: self.additional_headers,
}
}
}
impl IntoConnectOptions for ConnectRequestBuilder {
fn into_options(self) -> ConnectOptions {
self.build()
}
}
impl IntoConnectOptions for ConnectOptions {
fn into_options(self) -> ConnectOptions {
self
}
}
impl<S> IntoConnectOptions for S
where
S: ToString,
{
fn into_options(self) -> ConnectOptions {
ConnectOptions::builder(self).build()
}
}
type DynFutureIncomingSession =
dyn Future<Output = Result<SessionRequest, ConnectionError>> + Send + Sync;
pub struct IncomingSession(quinn::Incoming);
impl IncomingSession {
pub fn remote_address(&self) -> SocketAddr {
self.0.remote_address()
}
pub fn remote_address_validated(&self) -> bool {
self.0.remote_address_validated()
}
pub fn retry(self) {
self.0.retry().expect("remote address already verified");
}
pub fn refuse(self) {
self.0.refuse();
}
pub fn ignore(self) {
self.0.ignore();
}
}
impl IntoFuture for IncomingSession {
type IntoFuture = IncomingSessionFuture;
type Output = Result<SessionRequest, ConnectionError>;
fn into_future(self) -> Self::IntoFuture {
IncomingSessionFuture::new(self.0)
}
}
pub struct IncomingSessionFuture(Pin<Box<DynFutureIncomingSession>>);
impl IncomingSessionFuture {
fn new(quic_incoming: quinn::Incoming) -> Self {
Self(Box::pin(Self::accept(quic_incoming)))
}
async fn accept(quic_incoming: quinn::Incoming) -> Result<SessionRequest, ConnectionError> {
let quic_connection = quic_incoming.await?;
let driver = Driver::init(quic_connection.clone());
let _settings = driver.accept_settings().await.map_err(|driver_error| {
ConnectionError::with_driver_error(driver_error, &quic_connection)
})?;
let stream_session = driver.accept_session().await.map_err(|driver_error| {
ConnectionError::with_driver_error(driver_error, &quic_connection)
})?;
Ok(SessionRequest::new(quic_connection, driver, stream_session))
}
}
impl Future for IncomingSessionFuture {
type Output = Result<SessionRequest, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Future::poll(self.0.as_mut(), cx)
}
}
pub struct SessionRequest {
quic_connection: quinn::Connection,
driver: Driver,
stream_session: StreamSession,
}
impl SessionRequest {
pub(crate) fn new(
quic_connection: quinn::Connection,
driver: Driver,
stream_session: StreamSession,
) -> Self {
Self {
quic_connection,
driver,
stream_session,
}
}
#[inline(always)]
pub fn remote_address(&self) -> SocketAddr {
self.quic_connection.remote_address()
}
pub fn authority(&self) -> &str {
self.stream_session.request().authority()
}
pub fn path(&self) -> &str {
self.stream_session.request().path()
}
pub fn origin(&self) -> Option<&str> {
self.stream_session.request().origin()
}
pub fn user_agent(&self) -> Option<&str> {
self.stream_session.request().user_agent()
}
pub fn headers(&self) -> &HashMap<String, String> {
self.stream_session.request().headers().as_ref()
}
pub async fn accept(mut self) -> Result<Connection, ConnectionError> {
let response = SessionResponseProto::ok();
self.send_response(response).await?;
let session_id = self.stream_session.session_id();
self.driver
.register_session(self.stream_session)
.await
.map_err(|driver_error| {
ConnectionError::with_driver_error(driver_error, &self.quic_connection)
})?;
Ok(Connection::new(
self.quic_connection,
self.driver,
session_id,
))
}
pub async fn forbidden(self) {
self.reject(SessionResponseProto::forbidden()).await;
}
pub async fn not_found(self) {
self.reject(SessionResponseProto::not_found()).await;
}
pub async fn too_many_requests(self) {
self.reject(SessionResponseProto::too_many_requests()).await;
}
async fn reject(mut self, response: SessionResponseProto) {
let _ = self.send_response(response).await;
self.stream_session.finish().await;
}
async fn send_response(
&mut self,
response: SessionResponseProto,
) -> Result<(), ConnectionError> {
let frame = response.headers().generate_frame();
match self.stream_session.write_frame(frame).await {
Ok(()) => Ok(()),
Err(ProtoWriteError::NotConnected) => {
Err(ConnectionError::no_connect(&self.quic_connection))
}
Err(ProtoWriteError::Stopped) => {
self.quic_connection
.close(varint_w2q(ErrorCode::ClosedCriticalStream.to_code()), b"");
Err(ConnectionError::local_h3_error(
ErrorCode::ClosedCriticalStream,
))
}
}
}
}