1use {
2 crate::{
3 config::{deserialize_rustls_server_config, deserialize_x_token_set},
4 shutdown::Shutdown,
5 transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 },
7 futures::{
8 future::{pending, FutureExt},
9 stream::StreamExt,
10 },
11 prost::Message,
12 quinn::{
13 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
14 Connection, Endpoint, Incoming, SendStream, VarInt,
15 },
16 richat_proto::richat::{
17 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18 QuicSubscribeResponseError,
19 },
20 serde::Deserialize,
21 std::{
22 borrow::Cow,
23 collections::{BTreeSet, HashSet, VecDeque},
24 future::Future,
25 io::{self, IoSlice},
26 net::{IpAddr, Ipv4Addr, SocketAddr},
27 sync::Arc,
28 },
29 thiserror::Error,
30 tokio::{
31 io::{AsyncReadExt, AsyncWriteExt},
32 task::{JoinError, JoinSet},
33 },
34 tracing::{error, info},
35};
36
37#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ConfigQuicServer {
40 #[serde(default = "ConfigQuicServer::default_endpoint")]
41 pub endpoint: SocketAddr,
42 #[serde(deserialize_with = "deserialize_rustls_server_config")]
43 pub tls_config: rustls::ServerConfig,
44 #[serde(default = "ConfigQuicServer::default_expected_rtt")]
46 pub expected_rtt: u32,
47 #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
49 pub max_stream_bandwidth: u32,
50 #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
52 pub max_idle_timeout: Option<u32>,
53 #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
55 pub max_recv_streams: u32,
56 #[serde(default = "ConfigQuicServer::default_max_request_size")]
58 pub max_request_size: usize,
59 #[serde(default, deserialize_with = "deserialize_x_token_set")]
60 pub x_tokens: HashSet<Vec<u8>>,
61}
62
63impl ConfigQuicServer {
64 pub const fn default_endpoint() -> SocketAddr {
65 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100)
66 }
67
68 const fn default_expected_rtt() -> u32 {
69 100
70 }
71
72 const fn default_max_stream_bandwidth() -> u32 {
73 12_500 * 1000
74 }
75
76 const fn default_max_idle_timeout() -> Option<u32> {
77 Some(30_000)
78 }
79
80 const fn default_max_recv_streams() -> u32 {
81 16
82 }
83
84 const fn default_max_request_size() -> usize {
85 1024
86 }
87
88 pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
89 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
90 QuicServerConfig::try_from(self.tls_config.clone())?,
91 ));
92
93 let transport_config = Arc::get_mut(&mut server_config.transport)
95 .ok_or(CreateEndpointError::TransportConfig)?;
96 transport_config.max_concurrent_bidi_streams(1u8.into());
97 transport_config.max_concurrent_uni_streams(0u8.into());
98
99 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
101 transport_config.stream_receive_window(stream_rwnd.into());
102 transport_config.send_window(8 * stream_rwnd as u64);
103 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
104
105 transport_config
107 .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
108
109 Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
110 error,
111 endpoint: self.endpoint,
112 })
113 }
114}
115
116#[derive(Debug, Error)]
117pub enum CreateEndpointError {
118 #[error("failed to crate QuicServerConfig")]
119 ServerConfig(#[from] NoInitialCipherSuite),
120 #[error("failed to modify TransportConfig")]
121 TransportConfig,
122 #[error("failed to bind {endpoint}: {error}")]
123 Bind {
124 error: io::Error,
125 endpoint: SocketAddr,
126 },
127}
128
129#[derive(Debug, Error)]
130enum ConnectionError {
131 #[error(transparent)]
132 QuinnConnection(#[from] quinn::ConnectionError),
133 #[error(transparent)]
134 QuinnReadExact(#[from] quinn::ReadExactError),
135 #[error(transparent)]
136 QuinnWrite(#[from] quinn::WriteError),
137 #[error(transparent)]
138 QuinnClosedStream(#[from] quinn::ClosedStream),
139 #[error(transparent)]
140 Io(#[from] io::Error),
141 #[error(transparent)]
142 Prost(#[from] prost::DecodeError),
143 #[error(transparent)]
144 Join(#[from] JoinError),
145 #[error("stream is not available")]
146 StreamNotAvailable,
147}
148
149#[derive(Debug)]
150pub struct QuicServer;
151
152impl QuicServer {
153 pub async fn spawn(
154 config: ConfigQuicServer,
155 messages: impl Subscribe + Clone + Send + 'static,
156 on_conn_new_cb: impl Fn() + Copy + Send + 'static,
157 on_conn_drop_cb: impl Fn() + Copy + Send + 'static,
158 shutdown: Shutdown,
159 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
160 let endpoint = config.create_endpoint()?;
161 info!("start server at {}", config.endpoint);
162
163 Ok(tokio::spawn(async move {
164 let max_recv_streams = config.max_recv_streams;
165 let max_request_size = config.max_request_size as u64;
166 let x_tokens = Arc::new(config.x_tokens);
167
168 let mut id = 0;
169 tokio::pin!(shutdown);
170 loop {
171 tokio::select! {
172 incoming = endpoint.accept() => {
173 let Some(incoming) = incoming else {
174 error!("quic connection closed");
175 break;
176 };
177
178 let messages = messages.clone();
179 let x_tokens = Arc::clone(&x_tokens);
180 tokio::spawn(async move {
181 on_conn_new_cb();
182 if let Err(error) = Self::handle_incoming(
183 id, incoming, messages, max_recv_streams, max_request_size, x_tokens
184 ).await {
185 error!("#{id}: connection failed: {error}");
186 } else {
187 info!("#{id}: connection closed");
188 }
189 on_conn_drop_cb();
190 });
191 id += 1;
192 }
193 () = &mut shutdown => {
194 endpoint.close(0u32.into(), b"shutdown");
195 info!("shutdown");
196 break
197 },
198 };
199 }
200 }))
201 }
202
203 async fn handle_incoming(
204 id: u64,
205 incoming: Incoming,
206 messages: impl Subscribe,
207 max_recv_streams: u32,
208 max_request_size: u64,
209 x_tokens: Arc<HashSet<Vec<u8>>>,
210 ) -> Result<(), ConnectionError> {
211 let conn = incoming.await?;
212 info!("#{id}: new connection from {:?}", conn.remote_address());
213
214 let (mut send, response, maybe_rx) = Self::handle_request(
216 id,
217 &conn,
218 messages,
219 max_recv_streams,
220 max_request_size,
221 x_tokens,
222 )
223 .await?;
224
225 let buf = response.encode_to_vec();
227 send.write_u64(buf.len() as u64).await?;
228 send.write_all(&buf).await?;
229 send.flush().await?;
230
231 let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
232 return Ok(());
233 };
234
235 let mut streams = VecDeque::with_capacity(recv_streams as usize);
237 while streams.len() < recv_streams as usize {
238 streams.push_back(conn.open_uni().await?);
239 }
240
241 let mut msg_id = 0;
243 let mut msg_ids = BTreeSet::new();
244 let mut next_message: Option<RecvItem> = None;
245 let mut set = JoinSet::new();
246 loop {
247 if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
248 if let Some(message) = next_message.take() {
249 if let Some(mut stream) = streams.pop_front() {
250 msg_ids.insert(msg_id);
251 set.spawn(async move {
252 WriteVectored::new(
253 &mut stream,
254 &mut [
255 IoSlice::new(&msg_id.to_be_bytes()),
256 IoSlice::new(&(message.len() as u64).to_be_bytes()),
257 IoSlice::new(&message),
258 ],
259 )
260 .await?;
261 Ok::<_, ConnectionError>((msg_id, stream))
262 });
263 msg_id += 1;
264 } else {
265 next_message = Some(message);
266 }
267 }
268 }
269
270 let rx_recv = if next_message.is_none() {
271 rx.next().boxed()
272 } else {
273 pending().boxed()
274 };
275 let set_join_next = if !set.is_empty() {
276 set.join_next().boxed()
277 } else {
278 pending().boxed()
279 };
280
281 tokio::select! {
282 message = rx_recv => {
283 match message {
284 Some(Ok(message)) => next_message = Some(message),
285 Some(Err(error)) => {
286 error!("#{id}: failed to get message: {error}");
287 if streams.is_empty() {
288 let (msg_id, stream) = set.join_next().await.expect("already verified")??;
289 msg_ids.remove(&msg_id);
290 streams.push_back(stream);
291 }
292 let Some(mut stream) = streams.pop_front() else {
293 return Err(ConnectionError::StreamNotAvailable);
294 };
295
296 let msg = QuicSubscribeClose {
297 error: match error {
298 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
299 RecvError::Closed => QuicSubscribeCloseError::Closed,
300 } as i32
301 };
302 let message = msg.encode_to_vec();
303
304 set.spawn(async move {
305 stream.write_u64(u64::MAX).await?;
306 stream.write_u64(message.len() as u64).await?;
307 stream.write_all(&message).await?;
308 Ok::<_, ConnectionError>((msg_id, stream))
309 });
310 },
311 None => break,
312 }
313 },
314 result = set_join_next => {
315 let (msg_id, stream) = result.expect("already verified")??;
316 msg_ids.remove(&msg_id);
317 streams.push_back(stream);
318 }
319 }
320 }
321
322 for (_, mut stream) in set.join_all().await.into_iter().flatten() {
323 stream.finish()?;
324 }
325 for mut stream in streams {
326 stream.finish()?;
327 }
328 drop(conn);
329
330 Ok(())
331 }
332
333 async fn handle_request(
334 id: u64,
335 conn: &Connection,
336 messages: impl Subscribe,
337 max_recv_streams: u32,
338 max_request_size: u64,
339 x_tokens: Arc<HashSet<Vec<u8>>>,
340 ) -> Result<
341 (
342 SendStream,
343 QuicSubscribeResponse,
344 Option<(u32, u64, RecvStream)>,
345 ),
346 ConnectionError,
347 > {
348 let (send, mut recv) = conn.accept_bi().await?;
349
350 let size = recv.read_u64().await?;
352 if size > max_request_size {
353 let msg = QuicSubscribeResponse {
354 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
355 ..Default::default()
356 };
357 return Ok((send, msg, None));
358 }
359 let mut buf = vec![0; size as usize]; recv.read_exact(buf.as_mut_slice()).await?;
361
362 let QuicSubscribeRequest {
364 x_token,
365 recv_streams,
366 max_backlog,
367 replay_from_slot,
368 filter,
369 } = Message::decode(buf.as_slice())?;
370
371 if !x_tokens.is_empty() {
373 if let Some(error) = match x_token {
374 Some(x_token) if !x_tokens.contains(&x_token) => {
375 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
376 }
377 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
378 _ => None,
379 } {
380 let msg = QuicSubscribeResponse {
381 error: Some(error),
382 ..Default::default()
383 };
384 return Ok((send, msg, None));
385 }
386 }
387
388 if recv_streams == 0 || recv_streams > max_recv_streams {
390 let code = if recv_streams == 0 {
391 QuicSubscribeResponseError::ZeroRecvStreams
392 } else {
393 QuicSubscribeResponseError::ExceedRecvStreams
394 };
395 let msg = QuicSubscribeResponse {
396 error: Some(code as i32),
397 max_recv_streams: Some(max_recv_streams),
398 ..Default::default()
399 };
400 return Ok((send, msg, None));
401 }
402
403 Ok(match messages.subscribe(replay_from_slot, filter) {
404 Ok(rx) => {
405 let pos = replay_from_slot
406 .map(|slot| format!("slot {slot}").into())
407 .unwrap_or(Cow::Borrowed("latest"));
408 info!("#{id}: subscribed from {pos}");
409 (
410 send,
411 QuicSubscribeResponse::default(),
412 Some((
413 recv_streams,
414 max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
415 rx,
416 )),
417 )
418 }
419 Err(SubscribeError::NotInitialized) => {
420 let msg = QuicSubscribeResponse {
421 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
422 ..Default::default()
423 };
424 (send, msg, None)
425 }
426 Err(SubscribeError::SlotNotAvailable { first_available }) => {
427 let msg = QuicSubscribeResponse {
428 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
429 first_available_slot: Some(first_available),
430 ..Default::default()
431 };
432 (send, msg, None)
433 }
434 })
435 }
436}