#![allow(clippy::missing_panics_doc)]
use std::{fmt::Debug, net::SocketAddr, str::FromStr, sync::Arc};
use quinn::{
CertificateChain, ClientConfig, ClientConfigBuilder, ServerConfigBuilder, TransportConfig,
};
use rustls::RootCertStore;
use crate::{Certificate, Dangerous, Endpoint, Error, PrivateKey, Result};
pub struct Builder {
address: SocketAddr,
client: ClientConfigBuilder,
server: Option<ServerConfigBuilder>,
}
impl Debug for Builder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Builder")
.field("address", &self.address)
.field("client", &"ClientConfigBuilder")
.field(
"server",
&if self.server.is_some() {
"Some(ServerConfigBuilder)"
} else {
"None"
},
)
.finish()
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl Builder {
#[must_use]
pub fn new() -> Self {
let mut client = ClientConfig::default();
#[allow(clippy::expect_used)]
let crypto = Arc::get_mut(&mut client.crypto).expect("failed to build `ClientConfig`");
crypto.root_store = RootCertStore::empty();
crypto.ct_logs = None;
Self {
#[cfg(not(feature = "test"))]
address: ([0; 8], 0).into(),
#[cfg(feature = "test")]
address: ([0, 0, 0, 0, 0, 0, 0, 1], 0).into(),
client: ClientConfigBuilder::new(client),
server: None,
}
}
pub fn set_address(&mut self, address: SocketAddr) -> &mut Self {
self.address = address;
self
}
pub fn set_address_str(&mut self, address: &str) -> Result<&mut Self> {
self.address = FromStr::from_str(address).map_err(Error::ParseAddress)?;
Ok(self)
}
pub fn add_ca(&mut self, certificate: &Certificate) -> Result<&mut Self> {
let certificate =
quinn::Certificate::from_der(certificate.as_ref()).map_err(Error::Certificate)?;
let _ = self
.client
.add_certificate_authority(certificate)
.map_err(Error::InvalidCertificate)?;
Ok(self)
}
pub fn add_key_pair(
&mut self,
certificate: &Certificate,
private_key: &PrivateKey,
) -> Result<&mut Self> {
let certificate =
quinn::Certificate::from_der(certificate.as_ref()).map_err(Error::Certificate)?;
let chain = CertificateChain::from_certs(Some(certificate));
let private_key = quinn::PrivateKey::from_der(Dangerous::as_ref(private_key))
.map_err(Error::PrivateKey)?;
let _ = self
.server
.get_or_insert(ServerConfigBuilder::default())
.certificate(chain, private_key)
.map_err(Error::InvalidKeyPair)?;
Ok(self)
}
pub fn set_protocols(&mut self, protocols: &[&[u8]]) -> &mut Self {
let _ = self
.server
.get_or_insert(ServerConfigBuilder::default())
.protocols(protocols);
self
}
#[allow(clippy::unwrap_in_result)]
pub fn build(self) -> Result<Endpoint, (Error, Self)> {
let mut transport = TransportConfig::default();
#[allow(clippy::expect_used)]
let _ = transport
.allow_spin(false)
.datagram_receive_buffer_size(None)
.max_concurrent_uni_streams(0)
.expect("can't be bigger then `VarInt`");
let transport = Arc::new(transport);
match {
let mut client = self.client.clone().build();
client.transport = Arc::clone(&transport);
let server = self.server.as_ref().map(|server| {
let mut server = server.clone().build();
server.transport = transport;
server
});
Endpoint::new(self.address, client, server)
} {
Ok(endpoint) => Ok(endpoint),
Err(error) => Err((error, self)),
}
}
}
#[cfg(test)]
mod test {
use anyhow::Result;
use super::*;
#[tokio::test]
async fn default() -> Result<()> {
let _endpoint = Builder::default().build().map_err(|(error, _)| error)?;
Ok(())
}
#[tokio::test]
async fn new() -> Result<()> {
let _endpoint = Builder::new().build().map_err(|(error, _)| error)?;
Ok(())
}
#[tokio::test]
async fn address() -> Result<()> {
let mut builder = Builder::new();
let _ = builder.set_address(([0, 0, 0, 0, 0, 0, 0, 1], 5000).into());
let endpoint = builder.build().map_err(|(error, _)| error)?;
assert_eq!(
"[::1]:5000".parse::<SocketAddr>()?,
endpoint.local_address()?,
);
Ok(())
}
#[tokio::test]
async fn address_str() -> Result<()> {
let mut builder = Builder::new();
let _ = builder.set_address_str("[::1]:5001")?;
let endpoint = builder.build().map_err(|(error, _)| error)?;
assert_eq!(
"[::1]:5001".parse::<SocketAddr>()?,
endpoint.local_address()?
);
Ok(())
}
#[tokio::test]
async fn ca_key_pair() -> Result<()> {
use futures_util::StreamExt;
let (certificate, private_key) = crate::generate_self_signed("test");
let mut builder = Builder::new();
let _ = builder.add_ca(&certificate)?;
let client = builder.build().map_err(|(error, _)| error)?;
let mut builder = Builder::new();
let _ = builder.add_key_pair(&certificate, &private_key)?;
let mut server = builder.build().map_err(|(error, _)| error)?;
let _connection = client
.connect(server.local_address()?, "test")?
.accept::<()>()
.await?;
let _connection = server
.next()
.await
.expect("client dropped")
.accept::<()>()
.await?;
Ok(())
}
}