1use {
2 crate::{
3 config::{deserialize_rustls_server_config, deserialize_x_token_set},
4 shutdown::Shutdown,
5 transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 },
7 futures::{
8 future::{pending, FutureExt},
9 stream::StreamExt,
10 },
11 prost::Message,
12 quinn::{
13 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
14 Connection, Endpoint, Incoming, SendStream, VarInt,
15 },
16 richat_proto::richat::{
17 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18 QuicSubscribeResponseError,
19 },
20 serde::Deserialize,
21 std::{
22 borrow::Cow,
23 collections::{BTreeSet, HashSet, VecDeque},
24 future::Future,
25 io::{self, IoSlice},
26 net::{IpAddr, Ipv4Addr, SocketAddr},
27 sync::Arc,
28 },
29 thiserror::Error,
30 tokio::{
31 io::{AsyncReadExt, AsyncWriteExt},
32 task::{JoinError, JoinSet},
33 },
34 tracing::{error, info},
35};
36
37#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ConfigQuicServer {
40 #[serde(default = "ConfigQuicServer::default_endpoint")]
41 pub endpoint: SocketAddr,
42 #[serde(deserialize_with = "deserialize_rustls_server_config")]
43 pub tls_config: rustls::ServerConfig,
44 #[serde(default = "ConfigQuicServer::default_expected_rtt")]
46 pub expected_rtt: u32,
47 #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
49 pub max_stream_bandwidth: u32,
50 #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
52 pub max_idle_timeout: Option<u32>,
53 #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
55 pub max_recv_streams: u32,
56 #[serde(default = "ConfigQuicServer::default_max_request_size")]
58 pub max_request_size: usize,
59 #[serde(default, deserialize_with = "deserialize_x_token_set")]
60 pub x_tokens: HashSet<Vec<u8>>,
61}
62
63impl ConfigQuicServer {
64 pub const fn default_endpoint() -> SocketAddr {
65 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100)
66 }
67
68 const fn default_expected_rtt() -> u32 {
69 100
70 }
71
72 const fn default_max_stream_bandwidth() -> u32 {
73 12_500 * 1000
74 }
75
76 const fn default_max_idle_timeout() -> Option<u32> {
77 Some(30_000)
78 }
79
80 const fn default_max_recv_streams() -> u32 {
81 16
82 }
83
84 const fn default_max_request_size() -> usize {
85 1024
86 }
87
88 pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
89 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
90 QuicServerConfig::try_from(self.tls_config.clone())?,
91 ));
92
93 let transport_config = Arc::get_mut(&mut server_config.transport)
95 .ok_or(CreateEndpointError::TransportConfig)?;
96 transport_config.max_concurrent_bidi_streams(1u8.into());
97 transport_config.max_concurrent_uni_streams(0u8.into());
98
99 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
101 transport_config.stream_receive_window(stream_rwnd.into());
102 transport_config.send_window(8 * stream_rwnd as u64);
103 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
104
105 transport_config
107 .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
108
109 Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
110 error,
111 endpoint: self.endpoint,
112 })
113 }
114}
115
116#[derive(Debug, Error)]
117pub enum CreateEndpointError {
118 #[error("failed to crate QuicServerConfig")]
119 ServerConfig(#[from] NoInitialCipherSuite),
120 #[error("failed to modify TransportConfig")]
121 TransportConfig,
122 #[error("failed to bind {endpoint}: {error}")]
123 Bind {
124 error: io::Error,
125 endpoint: SocketAddr,
126 },
127}
128
129#[derive(Debug, Error)]
130enum ConnectionError {
131 #[error(transparent)]
132 QuinnConnection(#[from] quinn::ConnectionError),
133 #[error(transparent)]
134 QuinnReadExact(#[from] quinn::ReadExactError),
135 #[error(transparent)]
136 QuinnWrite(#[from] quinn::WriteError),
137 #[error(transparent)]
138 QuinnClosedStream(#[from] quinn::ClosedStream),
139 #[error(transparent)]
140 Io(#[from] io::Error),
141 #[error(transparent)]
142 Prost(#[from] prost::DecodeError),
143 #[error(transparent)]
144 Join(#[from] JoinError),
145 #[error("stream is not available")]
146 StreamNotAvailable,
147}
148
149#[derive(Debug)]
150pub struct QuicServer;
151
152impl QuicServer {
153 pub async fn spawn(
154 config: ConfigQuicServer,
155 messages: impl Subscribe + Clone + Send + 'static,
156 on_conn_new_cb: impl Fn() + Clone + Send + 'static,
157 on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
158 shutdown: Shutdown,
159 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
160 let endpoint = config.create_endpoint()?;
161 info!("start server at {}", config.endpoint);
162
163 Ok(tokio::spawn(async move {
164 let max_recv_streams = config.max_recv_streams;
165 let max_request_size = config.max_request_size as u64;
166 let x_tokens = Arc::new(config.x_tokens);
167
168 let mut id = 0;
169 tokio::pin!(shutdown);
170 loop {
171 tokio::select! {
172 incoming = endpoint.accept() => {
173 let Some(incoming) = incoming else {
174 error!("quic connection closed");
175 break;
176 };
177
178 let messages = messages.clone();
179 let on_conn_new_cb = on_conn_new_cb.clone();
180 let on_conn_drop_cb = on_conn_drop_cb.clone();
181 let x_tokens = Arc::clone(&x_tokens);
182 tokio::spawn(async move {
183 on_conn_new_cb();
184 if let Err(error) = Self::handle_incoming(
185 id, incoming, messages, max_recv_streams, max_request_size, x_tokens
186 ).await {
187 error!("#{id}: connection failed: {error}");
188 } else {
189 info!("#{id}: connection closed");
190 }
191 on_conn_drop_cb();
192 });
193 id += 1;
194 }
195 () = &mut shutdown => {
196 endpoint.close(0u32.into(), b"shutdown");
197 info!("shutdown");
198 break
199 },
200 };
201 }
202 }))
203 }
204
205 async fn handle_incoming(
206 id: u64,
207 incoming: Incoming,
208 messages: impl Subscribe,
209 max_recv_streams: u32,
210 max_request_size: u64,
211 x_tokens: Arc<HashSet<Vec<u8>>>,
212 ) -> Result<(), ConnectionError> {
213 let conn = incoming.await?;
214 info!("#{id}: new connection from {:?}", conn.remote_address());
215
216 let (mut send, response, maybe_rx) = Self::handle_request(
218 id,
219 &conn,
220 messages,
221 max_recv_streams,
222 max_request_size,
223 x_tokens,
224 )
225 .await?;
226
227 let buf = response.encode_to_vec();
229 send.write_u64(buf.len() as u64).await?;
230 send.write_all(&buf).await?;
231 send.flush().await?;
232
233 let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
234 return Ok(());
235 };
236
237 let mut streams = VecDeque::with_capacity(recv_streams as usize);
239 while streams.len() < recv_streams as usize {
240 streams.push_back(conn.open_uni().await?);
241 }
242
243 let mut msg_id = 0;
245 let mut msg_ids = BTreeSet::new();
246 let mut next_message: Option<RecvItem> = None;
247 let mut set = JoinSet::new();
248 loop {
249 if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
250 if let Some(message) = next_message.take() {
251 if let Some(mut stream) = streams.pop_front() {
252 msg_ids.insert(msg_id);
253 set.spawn(async move {
254 WriteVectored::new(
255 &mut stream,
256 &mut [
257 IoSlice::new(&msg_id.to_be_bytes()),
258 IoSlice::new(&(message.len() as u64).to_be_bytes()),
259 IoSlice::new(&message),
260 ],
261 )
262 .await?;
263 Ok::<_, ConnectionError>((msg_id, stream))
264 });
265 msg_id += 1;
266 } else {
267 next_message = Some(message);
268 }
269 }
270 }
271
272 let rx_recv = if next_message.is_none() {
273 rx.next().boxed()
274 } else {
275 pending().boxed()
276 };
277 let set_join_next = if !set.is_empty() {
278 set.join_next().boxed()
279 } else {
280 pending().boxed()
281 };
282
283 tokio::select! {
284 message = rx_recv => {
285 match message {
286 Some(Ok(message)) => next_message = Some(message),
287 Some(Err(error)) => {
288 error!("#{id}: failed to get message: {error}");
289 if streams.is_empty() {
290 let (msg_id, stream) = set.join_next().await.expect("already verified")??;
291 msg_ids.remove(&msg_id);
292 streams.push_back(stream);
293 }
294 let Some(mut stream) = streams.pop_front() else {
295 return Err(ConnectionError::StreamNotAvailable);
296 };
297
298 let msg = QuicSubscribeClose {
299 error: match error {
300 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
301 RecvError::Closed => QuicSubscribeCloseError::Closed,
302 } as i32
303 };
304 let message = msg.encode_to_vec();
305
306 set.spawn(async move {
307 stream.write_u64(u64::MAX).await?;
308 stream.write_u64(message.len() as u64).await?;
309 stream.write_all(&message).await?;
310 Ok::<_, ConnectionError>((msg_id, stream))
311 });
312 },
313 None => break,
314 }
315 },
316 result = set_join_next => {
317 let (msg_id, stream) = result.expect("already verified")??;
318 msg_ids.remove(&msg_id);
319 streams.push_back(stream);
320 }
321 }
322 }
323
324 for (_, mut stream) in set.join_all().await.into_iter().flatten() {
325 stream.finish()?;
326 }
327 for mut stream in streams {
328 stream.finish()?;
329 }
330 drop(conn);
331
332 Ok(())
333 }
334
335 async fn handle_request(
336 id: u64,
337 conn: &Connection,
338 messages: impl Subscribe,
339 max_recv_streams: u32,
340 max_request_size: u64,
341 x_tokens: Arc<HashSet<Vec<u8>>>,
342 ) -> Result<
343 (
344 SendStream,
345 QuicSubscribeResponse,
346 Option<(u32, u64, RecvStream)>,
347 ),
348 ConnectionError,
349 > {
350 let (send, mut recv) = conn.accept_bi().await?;
351
352 let size = recv.read_u64().await?;
354 if size > max_request_size {
355 let msg = QuicSubscribeResponse {
356 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
357 ..Default::default()
358 };
359 return Ok((send, msg, None));
360 }
361 let mut buf = vec![0; size as usize]; recv.read_exact(buf.as_mut_slice()).await?;
363
364 let QuicSubscribeRequest {
366 x_token,
367 recv_streams,
368 max_backlog,
369 replay_from_slot,
370 filter,
371 } = Message::decode(buf.as_slice())?;
372
373 if !x_tokens.is_empty() {
375 if let Some(error) = match x_token {
376 Some(x_token) if !x_tokens.contains(&x_token) => {
377 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
378 }
379 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
380 _ => None,
381 } {
382 let msg = QuicSubscribeResponse {
383 error: Some(error),
384 ..Default::default()
385 };
386 return Ok((send, msg, None));
387 }
388 }
389
390 if recv_streams == 0 || recv_streams > max_recv_streams {
392 let code = if recv_streams == 0 {
393 QuicSubscribeResponseError::ZeroRecvStreams
394 } else {
395 QuicSubscribeResponseError::ExceedRecvStreams
396 };
397 let msg = QuicSubscribeResponse {
398 error: Some(code as i32),
399 max_recv_streams: Some(max_recv_streams),
400 ..Default::default()
401 };
402 return Ok((send, msg, None));
403 }
404
405 Ok(match messages.subscribe(replay_from_slot, filter) {
406 Ok(rx) => {
407 let pos = replay_from_slot
408 .map(|slot| format!("slot {slot}").into())
409 .unwrap_or(Cow::Borrowed("latest"));
410 info!("#{id}: subscribed from {pos}");
411 (
412 send,
413 QuicSubscribeResponse::default(),
414 Some((
415 recv_streams,
416 max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
417 rx,
418 )),
419 )
420 }
421 Err(SubscribeError::NotInitialized) => {
422 let msg = QuicSubscribeResponse {
423 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
424 ..Default::default()
425 };
426 (send, msg, None)
427 }
428 Err(SubscribeError::SlotNotAvailable { first_available }) => {
429 let msg = QuicSubscribeResponse {
430 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
431 first_available_slot: Some(first_available),
432 ..Default::default()
433 };
434 (send, msg, None)
435 }
436 })
437 }
438}