ombrac_transport/quic/
mod.rs1mod stream;
2
3pub mod client;
4pub mod error;
5pub mod server;
6
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{fs, io};
12
13use quinn::{IdleTimeout, VarInt};
14use rustls::pki_types::{CertificateDer, PrivateKeyDer};
15use serde::{Deserialize, Serialize};
16
17type Result<T> = std::result::Result<T, error::Error>;
18
19pub use quinn::Connection;
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22pub enum Congestion {
23 Bbr,
24 Cubic,
25 NewReno,
26}
27
28impl FromStr for Congestion {
29 type Err = error::Error;
30 fn from_str(value: &str) -> Result<Self> {
31 match value.to_lowercase().as_str() {
32 "bbr" => Ok(Congestion::Bbr),
33 "cubic" => Ok(Congestion::Cubic),
34 "newreno" => Ok(Congestion::NewReno),
35 _ => Err(Self::Err::InvalidCongestion),
36 }
37 }
38}
39
40#[derive(Debug, Default)]
41pub struct TransportConfig(pub(crate) quinn::TransportConfig);
42
43impl TransportConfig {
44 pub fn congestion(
45 &mut self,
46 congestion: Congestion,
47 initial_window: Option<u64>,
48 ) -> Result<&mut Self> {
49 use quinn::congestion;
50
51 let congestion: Arc<dyn congestion::ControllerFactory + Send + Sync + 'static> =
52 match congestion {
53 Congestion::Bbr => {
54 let mut config = congestion::BbrConfig::default();
55 if let Some(value) = initial_window {
56 config.initial_window(value);
57 }
58 Arc::new(config)
59 }
60 Congestion::Cubic => {
61 let mut config = congestion::CubicConfig::default();
62 if let Some(value) = initial_window {
63 config.initial_window(value);
64 }
65 Arc::new(config)
66 }
67 Congestion::NewReno => {
68 let mut config = congestion::NewRenoConfig::default();
69 if let Some(value) = initial_window {
70 config.initial_window(value);
71 }
72 Arc::new(config)
73 }
74 };
75
76 self.0.congestion_controller_factory(congestion);
77 Ok(self)
78 }
79
80 pub fn max_idle_timeout(&mut self, value: Duration) -> Result<&mut Self> {
81 self.0.max_idle_timeout(Some(IdleTimeout::try_from(value)?));
82 Ok(self)
83 }
84
85 pub fn keep_alive_period(&mut self, value: Duration) -> Result<&mut Self> {
86 self.0.keep_alive_interval(Some(value));
87 Ok(self)
88 }
89
90 pub fn max_open_bidirectional_streams(&mut self, value: u64) -> Result<&mut Self> {
91 self.0.max_concurrent_bidi_streams(VarInt::try_from(value)?);
92 Ok(self)
93 }
94}
95
96fn load_certificates(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
97 let content = fs::read(path)?;
98 let certs = if path.extension().is_some_and(|ext| ext == "der") {
99 vec![CertificateDer::from(content)]
100 } else {
101 rustls_pemfile::certs(&mut &*content).collect::<io::Result<Vec<_>>>()?
102 };
103 Ok(certs)
104}
105
106fn load_private_key(path: &Path) -> io::Result<PrivateKeyDer<'static>> {
107 let content = fs::read(path)?;
108 let key = if path.extension().is_some_and(|ext| ext == "der") {
109 PrivateKeyDer::Pkcs8(content.into())
110 } else {
111 rustls_pemfile::private_key(&mut &*content)?.ok_or_else(|| {
112 io::Error::new(io::ErrorKind::NotFound, "no private key found in PEM file")
113 })?
114 };
115 Ok(key)
116}
117
118#[derive(Debug)]
119pub enum ConnectionError {
120 QuinnConnection(quinn::ConnectionError),
121 QuinnSendDatagram(quinn::SendDatagramError),
122}
123
124impl From<quinn::ConnectionError> for ConnectionError {
125 fn from(e: quinn::ConnectionError) -> Self {
126 ConnectionError::QuinnConnection(e)
127 }
128}
129
130impl From<quinn::SendDatagramError> for ConnectionError {
131 fn from(e: quinn::SendDatagramError) -> Self {
132 ConnectionError::QuinnSendDatagram(e)
133 }
134}
135
136impl From<ConnectionError> for io::Error {
137 fn from(e: ConnectionError) -> Self {
138 match e {
139 ConnectionError::QuinnConnection(error) => {
140 use quinn::ConnectionError::*;
141 let kind = match error {
142 LocallyClosed | ConnectionClosed(_) | ApplicationClosed(_) | Reset => {
143 io::ErrorKind::ConnectionReset
144 }
145 TimedOut => io::ErrorKind::TimedOut,
146 _ => io::ErrorKind::Other,
147 };
148 io::Error::new(kind, error)
149 }
150 ConnectionError::QuinnSendDatagram(error) => {
151 use quinn::SendDatagramError::*;
152 let kind = match error {
153 ConnectionLost(_) => io::ErrorKind::ConnectionReset,
154 _ => io::ErrorKind::Other,
155 };
156 io::Error::new(kind, error)
157 }
158 }
159 }
160}
161
162impl crate::Connection for quinn::Connection {
163 type Stream = stream::Stream;
164
165 async fn accept_bidirectional(&self) -> io::Result<Self::Stream> {
166 let (send, recv) = quinn::Connection::accept_bi(self)
167 .await
168 .map_err(ConnectionError::from)?;
169 Ok(stream::Stream(send, recv))
170 }
171
172 async fn open_bidirectional(&self) -> io::Result<Self::Stream> {
173 let (send, recv) = quinn::Connection::open_bi(self)
174 .await
175 .map_err(ConnectionError::from)?;
176 Ok(stream::Stream(send, recv))
177 }
178
179 #[cfg(feature = "datagram")]
180 async fn read_datagram(&self) -> io::Result<bytes::Bytes> {
181 quinn::Connection::read_datagram(self)
182 .await
183 .map_err(|e| ConnectionError::from(e).into())
184 }
185
186 #[cfg(feature = "datagram")]
187 async fn send_datagram(&self, data: bytes::Bytes) -> io::Result<()> {
188 quinn::Connection::send_datagram_wait(self, data)
189 .await
190 .map_err(|e| ConnectionError::from(e).into())
191 }
192
193 fn remote_address(&self) -> io::Result<std::net::SocketAddr> {
194 Ok(quinn::Connection::remote_address(self))
195 }
196
197 #[cfg(feature = "datagram")]
198 fn max_datagram_size(&self) -> Option<usize> {
199 quinn::Connection::max_datagram_size(self)
200 }
201
202 fn id(&self) -> usize {
203 quinn::Connection::stable_id(self)
204 }
205
206 fn close(&self, error_code: u32, reason: &[u8]) {
207 self.close(error_code.into(), reason);
208 }
209}