use bytes::{Buf, BufMut, Bytes};
use url::Url;
pub use web_transport_quinn as quinn;
pub use web_transport_quinn::CongestionControl;
#[derive(Default, Clone)]
pub struct ClientBuilder {
inner: quinn::ClientBuilder,
}
impl ClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_congestion_control(self, cc: CongestionControl) -> Self {
Self {
inner: self.inner.with_congestion_control(cc),
}
}
pub fn with_server_certificate_hashes(self, hashes: Vec<Vec<u8>>) -> Result<Client, Error> {
Ok(Client {
inner: self.inner.with_server_certificate_hashes(hashes)?,
})
}
pub fn with_system_roots(self) -> Result<Client, Error> {
Ok(Client {
inner: self.inner.with_system_roots()?,
})
}
}
#[derive(Clone, Debug)]
pub struct Client {
inner: quinn::Client,
}
impl Client {
pub async fn connect(&self, url: Url) -> Result<Session, Error> {
Ok(self.inner.connect(url).await?.into())
}
}
pub struct Server {
inner: quinn::Server,
}
impl From<quinn::Server> for Server {
fn from(server: quinn::Server) -> Self {
Self { inner: server }
}
}
impl Server {
pub async fn accept(&mut self) -> Result<Option<Session>, Error> {
match self.inner.accept().await {
Some(session) => Ok(Some(session.ok().await?.into())),
None => Ok(None),
}
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Session {
inner: quinn::Session,
}
impl Session {
pub async fn accept_uni(&self) -> Result<RecvStream, Error> {
let stream = self.inner.accept_uni().await?;
Ok(RecvStream::new(stream))
}
pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), Error> {
let (s, r) = self.inner.accept_bi().await?;
Ok((SendStream::new(s), RecvStream::new(r)))
}
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), Error> {
Ok(self
.inner
.open_bi()
.await
.map(|(s, r)| (SendStream::new(s), RecvStream::new(r)))?)
}
pub async fn open_uni(&self) -> Result<SendStream, Error> {
Ok(self.inner.open_uni().await.map(SendStream::new)?)
}
pub async fn send_datagram(&self, payload: Bytes) -> Result<(), Error> {
Ok(self.inner.send_datagram(payload)?)
}
pub async fn max_datagram_size(&self) -> usize {
self.inner.max_datagram_size()
}
pub async fn recv_datagram(&self) -> Result<Bytes, Error> {
Ok(self.inner.read_datagram().await?)
}
pub fn close(&self, code: u32, reason: &str) {
self.inner.close(code, reason.as_bytes())
}
pub async fn closed(&self) -> Error {
self.inner.closed().await.into()
}
pub fn url(&self) -> &Url {
&self.inner.request().url
}
pub fn protocol(&self) -> Option<&str> {
self.inner.response().protocol.as_deref()
}
}
impl From<quinn::Session> for Session {
fn from(session: quinn::Session) -> Self {
Session { inner: session }
}
}
pub struct SendStream {
inner: quinn::SendStream,
}
impl SendStream {
fn new(inner: quinn::SendStream) -> Self {
Self { inner }
}
#[must_use = "returns the number of bytes written"]
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
self.inner.write(buf).await.map_err(Into::into)
}
pub async fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Result<usize, Error> {
let size = buf.chunk().len();
let chunk = buf.copy_to_bytes(size);
self.inner.write_chunk(chunk).await?;
Ok(size)
}
pub fn set_priority(&mut self, order: i32) {
self.inner.set_priority(order).ok();
}
pub fn reset(&mut self, code: u32) {
self.inner.reset(code).ok();
}
pub fn finish(&mut self) -> Result<(), Error> {
self.inner
.finish()
.map_err(|_| Error::Write(quinn::WriteError::ClosedStream))?;
Ok(())
}
pub async fn closed(&mut self) -> Result<Option<u8>, Error> {
match self.inner.stopped().await {
Ok(None) => Ok(None),
Ok(Some(code)) => Ok(Some(code as u8)),
Err(e) => Err(Error::Session(e)),
}
}
}
pub struct RecvStream {
inner: quinn::RecvStream,
}
impl RecvStream {
fn new(inner: quinn::RecvStream) -> Self {
Self { inner }
}
pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
Ok(self
.inner
.read_chunk(max, true)
.await?
.map(|chunk| chunk.bytes))
}
pub async fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Result<Option<usize>, Error> {
let dst = buf.chunk_mut();
let dst = unsafe { &mut *(dst as *mut _ as *mut [u8]) };
let size = match self.inner.read(dst).await? {
Some(size) if size > 0 => size,
_ => return Ok(None),
};
unsafe { buf.advance_mut(size) };
Ok(Some(size))
}
pub fn stop(&mut self, code: u32) {
self.inner.stop(code).ok();
}
pub async fn closed(&mut self) -> Result<Option<u8>, Error> {
match self.inner.received_reset().await {
Ok(None) => Ok(None),
Ok(Some(code)) => Ok(Some(code as u8)),
Err(e) => Err(Error::Session(e)),
}
}
}
#[derive(Debug, thiserror::Error, Clone)]
pub enum Error {
#[error("session error: {0}")]
Session(#[from] quinn::SessionError),
#[error("server error: {0}")]
Server(#[from] quinn::ServerError),
#[error("client error: {0}")]
Client(#[from] quinn::ClientError),
#[error("write error: {0}")]
Write(quinn::WriteError),
#[error("read error: {0}")]
Read(quinn::ReadError),
}
impl From<quinn::WriteError> for Error {
fn from(e: quinn::WriteError) -> Self {
match e {
quinn::WriteError::SessionError(e) => Error::Session(e),
e => Error::Write(e),
}
}
}
impl From<quinn::ReadError> for Error {
fn from(e: quinn::ReadError) -> Self {
match e {
quinn::ReadError::SessionError(e) => Error::Session(e),
e => Error::Read(e),
}
}
}