1use std::net::SocketAddr;
6use std::sync::Arc;
7
8use quinn::{Endpoint, Incoming, RecvStream, SendStream, ServerConfig, TransportConfig};
9use thiserror::Error;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::frame::{Frame, FrameError, FramedStream, read_frame, write_frame};
13
14#[derive(Debug, Error)]
16pub enum ServerError {
17 #[error("bind error: {0}")]
18 Bind(#[from] std::io::Error),
19
20 #[error("connection error: {0}")]
21 Connection(#[from] quinn::ConnectionError),
22
23 #[error("frame error: {0}")]
24 Frame(#[from] FrameError),
25
26 #[error("TLS error: {0}")]
27 Tls(String),
28
29 #[error("server closed")]
30 Closed,
31}
32
33#[derive(Debug, Clone)]
35pub struct RuntaraServerConfig {
36 pub bind_addr: SocketAddr,
38 pub cert_pem: Vec<u8>,
40 pub key_pem: Vec<u8>,
42 pub max_connections: u32,
44 pub max_bi_streams: u32,
46 pub max_uni_streams: u32,
48 pub idle_timeout_ms: u64,
50}
51
52impl Default for RuntaraServerConfig {
53 fn default() -> Self {
54 Self {
55 bind_addr: "0.0.0.0:7001".parse().unwrap(),
56 cert_pem: Vec::new(),
57 key_pem: Vec::new(),
58 max_connections: 10_000,
59 max_bi_streams: 100,
60 max_uni_streams: 100,
61 idle_timeout_ms: 30_000,
62 }
63 }
64}
65
66pub struct RuntaraServer {
68 endpoint: Endpoint,
69}
70
71impl RuntaraServer {
72 pub fn new(config: RuntaraServerConfig) -> Result<Self, ServerError> {
74 let server_config = Self::build_server_config(&config)?;
75 let endpoint = Endpoint::server(server_config, config.bind_addr)?;
76
77 info!(addr = %config.bind_addr, "QUIC server bound");
78
79 Ok(Self { endpoint })
80 }
81
82 pub fn localhost(bind_addr: SocketAddr) -> Result<Self, ServerError> {
84 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
85 .map_err(|e| ServerError::Tls(e.to_string()))?;
86
87 let cert_pem = cert.cert.pem().into_bytes();
88 let key_pem = cert.key_pair.serialize_pem().into_bytes();
89
90 let config = RuntaraServerConfig {
91 bind_addr,
92 cert_pem,
93 key_pem,
94 ..Default::default()
95 };
96
97 Self::new(config)
98 }
99
100 fn build_server_config(config: &RuntaraServerConfig) -> Result<ServerConfig, ServerError> {
101 let certs = rustls_pemfile::certs(&mut config.cert_pem.as_slice())
102 .collect::<Result<Vec<_>, _>>()
103 .map_err(|e| ServerError::Tls(format!("failed to parse certificates: {}", e)))?;
104
105 let key = rustls_pemfile::private_key(&mut config.key_pem.as_slice())
106 .map_err(|e| ServerError::Tls(format!("failed to parse private key: {}", e)))?
107 .ok_or_else(|| ServerError::Tls("no private key found".to_string()))?;
108
109 let crypto = rustls::ServerConfig::builder()
110 .with_no_client_auth()
111 .with_single_cert(certs, key)
112 .map_err(|e| ServerError::Tls(e.to_string()))?;
113
114 let mut transport = TransportConfig::default();
115 transport.max_idle_timeout(Some(
116 std::time::Duration::from_millis(config.idle_timeout_ms)
117 .try_into()
118 .unwrap(),
119 ));
120 transport.max_concurrent_bidi_streams(config.max_bi_streams.into());
121 transport.max_concurrent_uni_streams(config.max_uni_streams.into());
122
123 let mut server_config = ServerConfig::with_crypto(Arc::new(
124 quinn::crypto::rustls::QuicServerConfig::try_from(crypto)
125 .map_err(|e| ServerError::Tls(e.to_string()))?,
126 ));
127 server_config.transport_config(Arc::new(transport));
128
129 Ok(server_config)
130 }
131
132 pub async fn accept(&self) -> Option<Incoming> {
134 self.endpoint.accept().await
135 }
136
137 pub fn local_addr(&self) -> Result<SocketAddr, ServerError> {
139 Ok(self.endpoint.local_addr()?)
140 }
141
142 pub fn close(&self) {
144 self.endpoint.close(0u32.into(), b"server closing");
145 }
146
147 #[instrument(skip(self, handler))]
149 pub async fn run<H, Fut>(&self, handler: H) -> Result<(), ServerError>
150 where
151 H: Fn(ConnectionHandler) -> Fut + Send + Sync + Clone + 'static,
152 Fut: std::future::Future<Output = ()> + Send + 'static,
153 {
154 info!("QUIC server running");
155
156 while let Some(incoming) = self.accept().await {
157 let handler = handler.clone();
158
159 tokio::spawn(async move {
160 match incoming.await {
161 Ok(connection) => {
162 let remote_addr = connection.remote_address();
163 debug!(%remote_addr, "accepted connection");
164
165 let conn_handler = ConnectionHandler::new(connection);
166 handler(conn_handler).await;
167 }
168 Err(e) => {
169 warn!("failed to accept connection: {}", e);
170 }
171 }
172 });
173 }
174
175 Ok(())
176 }
177}
178
179pub struct ConnectionHandler {
181 connection: quinn::Connection,
182}
183
184impl ConnectionHandler {
185 pub fn new(connection: quinn::Connection) -> Self {
186 Self { connection }
187 }
188
189 pub fn remote_address(&self) -> SocketAddr {
191 self.connection.remote_address()
192 }
193
194 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ServerError> {
196 Ok(self.connection.accept_bi().await?)
197 }
198
199 pub async fn accept_uni(&self) -> Result<RecvStream, ServerError> {
201 Ok(self.connection.accept_uni().await?)
202 }
203
204 pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ServerError> {
206 Ok(self.connection.open_bi().await?)
207 }
208
209 pub async fn open_uni(&self) -> Result<SendStream, ServerError> {
211 Ok(self.connection.open_uni().await?)
212 }
213
214 #[instrument(skip(self, handler), fields(remote = %self.remote_address()))]
216 pub async fn run<H, Fut>(&self, handler: H)
217 where
218 H: Fn(StreamHandler) -> Fut + Send + Sync + Clone + 'static,
219 Fut: std::future::Future<Output = ()> + Send + 'static,
220 {
221 loop {
222 tokio::select! {
223 result = self.accept_bi() => {
224 match result {
225 Ok((send, recv)) => {
226 let handler = handler.clone();
227 tokio::spawn(async move {
228 let stream_handler = StreamHandler::new(send, recv);
229 handler(stream_handler).await;
230 });
231 }
232 Err(e) => {
233 match &e {
234 ServerError::Connection(quinn::ConnectionError::ApplicationClosed(_)) |
235 ServerError::Connection(quinn::ConnectionError::LocallyClosed) => {
236 debug!("connection closed");
237 }
238 _ => {
239 error!("error accepting stream: {}", e);
240 }
241 }
242 break;
243 }
244 }
245 }
246 }
247 }
248 }
249
250 pub fn is_open(&self) -> bool {
252 self.connection.close_reason().is_none()
253 }
254
255 pub fn close(&self, code: u32, reason: &[u8]) {
257 self.connection.close(code.into(), reason);
258 }
259}
260
261pub struct StreamHandler {
263 send: SendStream,
264 recv: RecvStream,
265}
266
267impl StreamHandler {
268 pub fn new(send: SendStream, recv: RecvStream) -> Self {
269 Self { send, recv }
270 }
271
272 pub async fn read_frame(&mut self) -> Result<Frame, ServerError> {
274 Ok(read_frame(&mut self.recv).await?)
275 }
276
277 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ServerError> {
279 Ok(write_frame(&mut self.send, frame).await?)
280 }
281
282 pub async fn handle_request<Req, Resp, H, Fut>(&mut self, handler: H) -> Result<(), ServerError>
284 where
285 Req: prost::Message + Default,
286 Resp: prost::Message,
287 H: FnOnce(Req) -> Fut,
288 Fut: std::future::Future<Output = Result<Resp, ServerError>>,
289 {
290 let request_frame = self.read_frame().await?;
292 let request: Req = request_frame.decode()?;
293
294 match handler(request).await {
296 Ok(response) => {
297 let response_frame = Frame::response(&response)?;
298 self.write_frame(&response_frame).await?;
299 }
300 Err(e) => {
301 error!("request handler error: {}", e);
302 let error_frame = Frame {
305 message_type: crate::frame::MessageType::Error,
306 payload: bytes::Bytes::new(),
307 };
308 self.write_frame(&error_frame).await?;
309 }
310 }
311
312 Ok(())
313 }
314
315 pub fn into_framed(self) -> FramedStream<(SendStream, RecvStream)> {
317 FramedStream::new((self.send, self.recv))
318 }
319
320 pub fn finish(&mut self) -> Result<(), ServerError> {
322 self.send
323 .finish()
324 .map_err(|e| ServerError::Frame(FrameError::Io(std::io::Error::other(e))))?;
325 Ok(())
326 }
327
328 pub async fn read_bytes(&mut self, buf: &mut [u8]) -> Result<usize, ServerError> {
331 match self.recv.read(buf).await {
332 Ok(Some(n)) => Ok(n),
333 Ok(None) => Ok(0), Err(e) => Err(ServerError::Frame(FrameError::Io(std::io::Error::other(
335 e.to_string(),
336 )))),
337 }
338 }
339
340 pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ServerError> {
342 self.recv.read_exact(buf).await.map_err(|e| {
343 ServerError::Frame(FrameError::Io(std::io::Error::other(e.to_string())))
344 })?;
345 Ok(())
346 }
347
348 pub async fn read_to_end(&mut self, size_limit: usize) -> Result<Vec<u8>, ServerError> {
350 self.recv
351 .read_to_end(size_limit)
352 .await
353 .map_err(|e| ServerError::Frame(FrameError::Io(std::io::Error::other(e.to_string()))))
354 }
355
356 pub async fn stream_to_writer<W: tokio::io::AsyncWrite + Unpin>(
358 &mut self,
359 writer: &mut W,
360 expected_size: Option<u64>,
361 ) -> Result<u64, ServerError> {
362 use tokio::io::AsyncWriteExt;
363
364 let mut total = 0u64;
365 let mut buf = [0u8; 64 * 1024]; loop {
368 let n = match self.recv.read(&mut buf).await {
369 Ok(Some(n)) => n,
370 Ok(None) => 0, Err(e) => {
372 return Err(ServerError::Frame(FrameError::Io(std::io::Error::other(
373 e.to_string(),
374 ))));
375 }
376 };
377 if n == 0 {
378 break;
379 }
380 writer.write_all(&buf[..n]).await?;
381 total += n as u64;
382 }
383
384 if let Some(expected) = expected_size
385 && total != expected
386 {
387 return Err(ServerError::Frame(FrameError::Io(std::io::Error::new(
388 std::io::ErrorKind::UnexpectedEof,
389 format!("Expected {} bytes, got {}", expected, total),
390 ))));
391 }
392
393 Ok(total)
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_default_config() {
403 let config = RuntaraServerConfig::default();
404 assert_eq!(config.bind_addr, "0.0.0.0:7001".parse().unwrap());
405 assert_eq!(config.max_connections, 10_000);
406 }
407}