use super::{Connection, Transport, TransportError};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
pub struct QuicTransportConfig {
pub connect_timeout: Duration,
pub bind_addr: SocketAddr,
pub server_name: String,
pub client_config: quinn::ClientConfig,
}
pub struct QuicTransport {
pub connect_timeout: Duration,
endpoint: quinn::Endpoint,
runtime: tokio::runtime::Handle,
server_name: String,
}
impl QuicTransport {
pub fn with_config(config: QuicTransportConfig) -> Result<Self, TransportError> {
let runtime = tokio::runtime::Handle::try_current()
.map_err(|e| TransportError::ConnectionFailed(format!("no tokio runtime: {}", e)))?;
let mut endpoint = quinn::Endpoint::client(config.bind_addr)
.map_err(|e| TransportError::ConnectionFailed(format!("bind endpoint: {}", e)))?;
endpoint.set_default_client_config(config.client_config);
Ok(Self {
connect_timeout: config.connect_timeout,
endpoint,
runtime,
server_name: config.server_name,
})
}
pub fn with_client_config(
client_config: quinn::ClientConfig,
server_name: impl Into<String>,
) -> Result<Self, TransportError> {
let config = QuicTransportConfig {
connect_timeout: Duration::from_secs(10),
bind_addr: "0.0.0.0:0"
.parse()
.map_err(|e| TransportError::ConnectionFailed(format!("bind addr parse: {}", e)))?,
server_name: server_name.into(),
client_config,
};
Self::with_config(config)
}
pub fn with_trust_anchors(
root_certs_der: Vec<Vec<u8>>,
server_name: impl Into<String>,
) -> Result<Self, TransportError> {
let mut roots = rustls::RootCertStore::empty();
for der in root_certs_der {
roots
.add(rustls::pki_types::CertificateDer::from(der))
.map_err(|e| TransportError::ConnectionFailed(format!("add root cert: {}", e)))?;
}
let mut client_crypto = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
client_crypto.alpn_protocols = vec![b"shape/1".to_vec()];
let client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto).map_err(|e| {
TransportError::ConnectionFailed(format!("QUIC client config: {}", e))
})?,
));
Self::with_client_config(client_config, server_name)
}
pub fn new_self_signed() -> Result<Self, TransportError> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.map_err(|e| TransportError::ConnectionFailed(format!("cert generation: {}", e)))?;
let cert_der = cert.cert.der().to_vec();
Self::with_trust_anchors(vec![cert_der], "localhost")
}
}
impl Transport for QuicTransport {
fn send(&self, destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
let dest_addr = destination
.parse()
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
let endpoint = self.endpoint.clone();
let payload = payload.to_vec();
let timeout = self.connect_timeout;
let server_name = self.server_name.clone();
self.runtime.block_on(async move {
let connecting = endpoint
.connect(dest_addr, &server_name)
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
let connection = tokio::time::timeout(timeout, connecting)
.await
.map_err(|_| TransportError::Timeout)?
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
let (mut send_stream, mut recv_stream) = connection
.open_bi()
.await
.map_err(|e| TransportError::SendFailed(format!("open stream: {}", e)))?;
let framed = super::framing::encode_framed(&payload);
let len = (framed.len() as u32).to_be_bytes();
send_stream
.write_all(&len)
.await
.map_err(|e| TransportError::SendFailed(format!("{}", e)))?;
send_stream
.write_all(&framed)
.await
.map_err(|e| TransportError::SendFailed(format!("{}", e)))?;
send_stream
.finish()
.map_err(|e| TransportError::SendFailed(format!("finish: {}", e)))?;
let mut len_buf = [0u8; 4];
recv_stream
.read_exact(&mut len_buf)
.await
.map_err(|e| TransportError::ReceiveFailed(format!("{}", e)))?;
let resp_len = u32::from_be_bytes(len_buf) as usize;
if resp_len > super::tcp::MAX_PAYLOAD_SIZE {
return Err(TransportError::PayloadTooLarge {
size: resp_len,
max: super::tcp::MAX_PAYLOAD_SIZE,
});
}
let mut buf = vec![0u8; resp_len];
recv_stream
.read_exact(&mut buf)
.await
.map_err(|e| TransportError::ReceiveFailed(format!("{}", e)))?;
super::framing::decode_framed(&buf)
})
}
fn connect(&self, destination: &str) -> Result<Box<dyn Connection>, TransportError> {
let dest_addr = destination
.parse()
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
let endpoint = self.endpoint.clone();
let timeout = self.connect_timeout;
let runtime = self.runtime.clone();
let server_name = self.server_name.clone();
let connection = self.runtime.block_on(async move {
let connecting = endpoint
.connect(dest_addr, &server_name)
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
let conn = tokio::time::timeout(timeout, connecting)
.await
.map_err(|_| TransportError::Timeout)?
.map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?;
Ok::<_, TransportError>(conn)
})?;
Ok(Box::new(QuicConnection {
connection,
runtime,
}))
}
}
pub struct QuicConnection {
connection: quinn::Connection,
runtime: tokio::runtime::Handle,
}
impl Connection for QuicConnection {
fn send(&mut self, payload: &[u8]) -> Result<(), TransportError> {
let framed = super::framing::encode_framed(payload);
self.runtime.block_on(async {
let mut send_stream = self
.connection
.open_uni()
.await
.map_err(|e| TransportError::SendFailed(format!("open stream: {}", e)))?;
let len = (framed.len() as u32).to_be_bytes();
send_stream
.write_all(&len)
.await
.map_err(|e| TransportError::SendFailed(format!("{}", e)))?;
send_stream
.write_all(&framed)
.await
.map_err(|e| TransportError::SendFailed(format!("{}", e)))?;
send_stream
.finish()
.map_err(|e| TransportError::SendFailed(format!("finish: {}", e)))?;
Ok(())
})
}
fn recv(&mut self, timeout: Option<Duration>) -> Result<Vec<u8>, TransportError> {
self.runtime.block_on(async {
let accept_fut = self.connection.accept_uni();
let mut recv_stream = if let Some(t) = timeout {
tokio::time::timeout(t, accept_fut)
.await
.map_err(|_| TransportError::Timeout)?
} else {
accept_fut.await
}
.map_err(|e| TransportError::ReceiveFailed(format!("accept stream: {}", e)))?;
let mut len_buf = [0u8; 4];
recv_stream
.read_exact(&mut len_buf)
.await
.map_err(|e| TransportError::ReceiveFailed(format!("{}", e)))?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > super::tcp::MAX_PAYLOAD_SIZE {
return Err(TransportError::PayloadTooLarge {
size: len,
max: super::tcp::MAX_PAYLOAD_SIZE,
});
}
let mut buf = vec![0u8; len];
recv_stream
.read_exact(&mut buf)
.await
.map_err(|e| TransportError::ReceiveFailed(format!("{}", e)))?;
super::framing::decode_framed(&buf)
})
}
fn close(&mut self) -> Result<(), TransportError> {
self.connection.close(0u32.into(), b"closed");
Ok(())
}
fn supports_sidecars(&self) -> bool {
true
}
fn send_sidecar(&mut self, payload: &[u8]) -> Result<(), TransportError> {
self.send(payload)
}
fn recv_any(&mut self, timeout: Option<Duration>) -> Result<Vec<u8>, TransportError> {
self.recv(timeout)
}
}