Skip to main content

ombrac_transport/quic/
mod.rs

1mod stream;
2
3pub mod client;
4pub mod error;
5pub mod server;
6
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{fs, io};
12
13use quinn::{IdleTimeout, VarInt};
14use rustls::pki_types::{CertificateDer, PrivateKeyDer};
15use serde::{Deserialize, Serialize};
16
17type Result<T> = std::result::Result<T, error::Error>;
18
19pub use quinn::Connection;
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22pub enum Congestion {
23    Bbr,
24    Cubic,
25    NewReno,
26}
27
28impl FromStr for Congestion {
29    type Err = error::Error;
30    fn from_str(value: &str) -> Result<Self> {
31        match value.to_lowercase().as_str() {
32            "bbr" => Ok(Congestion::Bbr),
33            "cubic" => Ok(Congestion::Cubic),
34            "newreno" => Ok(Congestion::NewReno),
35            _ => Err(Self::Err::InvalidCongestion),
36        }
37    }
38}
39
40#[derive(Debug, Default)]
41pub struct TransportConfig(pub(crate) quinn::TransportConfig);
42
43impl TransportConfig {
44    pub fn congestion(
45        &mut self,
46        congestion: Congestion,
47        initial_window: Option<u64>,
48    ) -> Result<&mut Self> {
49        use quinn::congestion;
50
51        let congestion: Arc<dyn congestion::ControllerFactory + Send + Sync + 'static> =
52            match congestion {
53                Congestion::Bbr => {
54                    let mut config = congestion::BbrConfig::default();
55                    if let Some(value) = initial_window {
56                        config.initial_window(value);
57                    }
58                    Arc::new(config)
59                }
60                Congestion::Cubic => {
61                    let mut config = congestion::CubicConfig::default();
62                    if let Some(value) = initial_window {
63                        config.initial_window(value);
64                    }
65                    Arc::new(config)
66                }
67                Congestion::NewReno => {
68                    let mut config = congestion::NewRenoConfig::default();
69                    if let Some(value) = initial_window {
70                        config.initial_window(value);
71                    }
72                    Arc::new(config)
73                }
74            };
75
76        self.0.congestion_controller_factory(congestion);
77        Ok(self)
78    }
79
80    pub fn max_idle_timeout(&mut self, value: Duration) -> Result<&mut Self> {
81        self.0.max_idle_timeout(Some(IdleTimeout::try_from(value)?));
82        Ok(self)
83    }
84
85    pub fn keep_alive_period(&mut self, value: Duration) -> Result<&mut Self> {
86        self.0.keep_alive_interval(Some(value));
87        Ok(self)
88    }
89
90    pub fn max_open_bidirectional_streams(&mut self, value: u64) -> Result<&mut Self> {
91        self.0.max_concurrent_bidi_streams(VarInt::try_from(value)?);
92        Ok(self)
93    }
94}
95
96fn load_certificates(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
97    let content = fs::read(path)?;
98    let certs = if path.extension().is_some_and(|ext| ext == "der") {
99        vec![CertificateDer::from(content)]
100    } else {
101        rustls_pemfile::certs(&mut &*content).collect::<io::Result<Vec<_>>>()?
102    };
103    Ok(certs)
104}
105
106fn load_private_key(path: &Path) -> io::Result<PrivateKeyDer<'static>> {
107    let content = fs::read(path)?;
108    let key = if path.extension().is_some_and(|ext| ext == "der") {
109        PrivateKeyDer::Pkcs8(content.into())
110    } else {
111        rustls_pemfile::private_key(&mut &*content)?.ok_or_else(|| {
112            io::Error::new(io::ErrorKind::NotFound, "no private key found in PEM file")
113        })?
114    };
115    Ok(key)
116}
117
118#[derive(Debug)]
119pub enum ConnectionError {
120    QuinnConnection(quinn::ConnectionError),
121    QuinnSendDatagram(quinn::SendDatagramError),
122}
123
124impl From<quinn::ConnectionError> for ConnectionError {
125    fn from(e: quinn::ConnectionError) -> Self {
126        ConnectionError::QuinnConnection(e)
127    }
128}
129
130impl From<quinn::SendDatagramError> for ConnectionError {
131    fn from(e: quinn::SendDatagramError) -> Self {
132        ConnectionError::QuinnSendDatagram(e)
133    }
134}
135
136impl From<ConnectionError> for io::Error {
137    fn from(e: ConnectionError) -> Self {
138        match e {
139            ConnectionError::QuinnConnection(error) => {
140                use quinn::ConnectionError::*;
141                let kind = match error {
142                    LocallyClosed | ConnectionClosed(_) | ApplicationClosed(_) | Reset => {
143                        io::ErrorKind::ConnectionReset
144                    }
145                    TimedOut => io::ErrorKind::TimedOut,
146                    _ => io::ErrorKind::Other,
147                };
148                io::Error::new(kind, error)
149            }
150            ConnectionError::QuinnSendDatagram(error) => {
151                use quinn::SendDatagramError::*;
152                let kind = match error {
153                    ConnectionLost(_) => io::ErrorKind::ConnectionReset,
154                    _ => io::ErrorKind::Other,
155                };
156                io::Error::new(kind, error)
157            }
158        }
159    }
160}
161
162impl crate::Connection for quinn::Connection {
163    type Stream = stream::Stream;
164
165    async fn accept_bidirectional(&self) -> io::Result<Self::Stream> {
166        let (send, recv) = quinn::Connection::accept_bi(self)
167            .await
168            .map_err(ConnectionError::from)?;
169        Ok(stream::Stream(send, recv))
170    }
171
172    async fn open_bidirectional(&self) -> io::Result<Self::Stream> {
173        let (send, recv) = quinn::Connection::open_bi(self)
174            .await
175            .map_err(ConnectionError::from)?;
176        Ok(stream::Stream(send, recv))
177    }
178
179    #[cfg(feature = "datagram")]
180    async fn read_datagram(&self) -> io::Result<bytes::Bytes> {
181        quinn::Connection::read_datagram(self)
182            .await
183            .map_err(|e| ConnectionError::from(e).into())
184    }
185
186    #[cfg(feature = "datagram")]
187    async fn send_datagram(&self, data: bytes::Bytes) -> io::Result<()> {
188        quinn::Connection::send_datagram_wait(self, data)
189            .await
190            .map_err(|e| ConnectionError::from(e).into())
191    }
192
193    fn remote_address(&self) -> io::Result<std::net::SocketAddr> {
194        Ok(quinn::Connection::remote_address(self))
195    }
196
197    #[cfg(feature = "datagram")]
198    fn max_datagram_size(&self) -> Option<usize> {
199        quinn::Connection::max_datagram_size(self)
200    }
201
202    fn id(&self) -> usize {
203        quinn::Connection::stable_id(self)
204    }
205
206    fn close(&self, error_code: u32, reason: &[u8]) {
207        self.close(error_code.into(), reason);
208    }
209}