1use {
2 crate::{
3 config::{deserialize_rustls_server_config, deserialize_x_token_set},
4 shutdown::Shutdown,
5 transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 version::Version,
7 },
8 futures::{
9 future::{pending, FutureExt},
10 stream::StreamExt,
11 },
12 prost::Message,
13 quinn::{
14 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
15 Connection, Endpoint, Incoming, SendStream, VarInt,
16 },
17 richat_proto::richat::{
18 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
19 QuicSubscribeResponseError,
20 },
21 serde::Deserialize,
22 std::{
23 borrow::Cow,
24 collections::{BTreeSet, HashSet, VecDeque},
25 future::Future,
26 io::{self, IoSlice},
27 net::{IpAddr, Ipv4Addr, SocketAddr},
28 sync::Arc,
29 },
30 thiserror::Error,
31 tokio::{
32 io::{AsyncReadExt, AsyncWriteExt},
33 task::{JoinError, JoinSet},
34 },
35 tracing::{error, info},
36};
37
38#[derive(Debug, Clone, Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ConfigQuicServer {
41 #[serde(default = "ConfigQuicServer::default_endpoint")]
42 pub endpoint: SocketAddr,
43 #[serde(deserialize_with = "deserialize_rustls_server_config")]
44 pub tls_config: rustls::ServerConfig,
45 #[serde(default = "ConfigQuicServer::default_expected_rtt")]
47 pub expected_rtt: u32,
48 #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
50 pub max_stream_bandwidth: u32,
51 #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
53 pub max_idle_timeout: Option<u32>,
54 #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
56 pub max_recv_streams: u32,
57 #[serde(default = "ConfigQuicServer::default_max_request_size")]
59 pub max_request_size: usize,
60 #[serde(default, deserialize_with = "deserialize_x_token_set")]
61 pub x_tokens: HashSet<Vec<u8>>,
62}
63
64impl ConfigQuicServer {
65 pub const fn default_endpoint() -> SocketAddr {
66 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101)
67 }
68
69 const fn default_expected_rtt() -> u32 {
70 100
71 }
72
73 const fn default_max_stream_bandwidth() -> u32 {
74 12_500 * 1000
75 }
76
77 const fn default_max_idle_timeout() -> Option<u32> {
78 Some(30_000)
79 }
80
81 const fn default_max_recv_streams() -> u32 {
82 16
83 }
84
85 const fn default_max_request_size() -> usize {
86 1024
87 }
88
89 pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
90 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
91 QuicServerConfig::try_from(self.tls_config.clone())?,
92 ));
93
94 let transport_config = Arc::get_mut(&mut server_config.transport)
96 .ok_or(CreateEndpointError::TransportConfig)?;
97 transport_config.max_concurrent_bidi_streams(1u8.into());
98 transport_config.max_concurrent_uni_streams(0u8.into());
99
100 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
102 transport_config.stream_receive_window(stream_rwnd.into());
103 transport_config.send_window(8 * stream_rwnd as u64);
104 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
105
106 transport_config
108 .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
109
110 Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
111 error,
112 endpoint: self.endpoint,
113 })
114 }
115}
116
117#[derive(Debug, Error)]
118pub enum CreateEndpointError {
119 #[error("failed to crate QuicServerConfig")]
120 ServerConfig(#[from] NoInitialCipherSuite),
121 #[error("failed to modify TransportConfig")]
122 TransportConfig,
123 #[error("failed to bind {endpoint}: {error}")]
124 Bind {
125 error: io::Error,
126 endpoint: SocketAddr,
127 },
128}
129
130#[derive(Debug, Error)]
131enum ConnectionError {
132 #[error(transparent)]
133 QuinnConnection(#[from] quinn::ConnectionError),
134 #[error(transparent)]
135 QuinnReadExact(#[from] quinn::ReadExactError),
136 #[error(transparent)]
137 QuinnWrite(#[from] quinn::WriteError),
138 #[error(transparent)]
139 QuinnClosedStream(#[from] quinn::ClosedStream),
140 #[error(transparent)]
141 Io(#[from] io::Error),
142 #[error(transparent)]
143 Prost(#[from] prost::DecodeError),
144 #[error(transparent)]
145 Join(#[from] JoinError),
146 #[error("stream is not available")]
147 StreamNotAvailable,
148}
149
150#[derive(Debug)]
151pub struct QuicServer;
152
153impl QuicServer {
154 pub async fn spawn(
155 config: ConfigQuicServer,
156 messages: impl Subscribe + Clone + Send + 'static,
157 on_conn_new_cb: impl Fn() + Clone + Send + 'static,
158 on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
159 version: Version<'static>,
160 shutdown: Shutdown,
161 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
162 let endpoint = config.create_endpoint()?;
163 info!("start server at {}", config.endpoint);
164
165 Ok(tokio::spawn(async move {
166 let max_recv_streams = config.max_recv_streams;
167 let max_request_size = config.max_request_size as u64;
168 let x_tokens = Arc::new(config.x_tokens);
169
170 let mut id = 0;
171 tokio::pin!(shutdown);
172 loop {
173 tokio::select! {
174 incoming = endpoint.accept() => {
175 let Some(incoming) = incoming else {
176 error!("quic connection closed");
177 break;
178 };
179
180 let messages = messages.clone();
181 let on_conn_new_cb = on_conn_new_cb.clone();
182 let on_conn_drop_cb = on_conn_drop_cb.clone();
183 let x_tokens = Arc::clone(&x_tokens);
184 tokio::spawn(async move {
185 on_conn_new_cb();
186 if let Err(error) = Self::handle_incoming(
187 id,
188 incoming,
189 messages,
190 max_recv_streams,
191 max_request_size,
192 x_tokens,
193 version.create_grpc_version_info().json(),
194 ).await {
195 error!("#{id}: connection failed: {error}");
196 } else {
197 info!("#{id}: connection closed");
198 }
199 on_conn_drop_cb();
200 });
201 id += 1;
202 }
203 () = &mut shutdown => {
204 endpoint.close(0u32.into(), b"shutdown");
205 info!("shutdown");
206 break
207 },
208 };
209 }
210 }))
211 }
212
213 async fn handle_incoming(
214 id: u64,
215 incoming: Incoming,
216 messages: impl Subscribe,
217 max_recv_streams: u32,
218 max_request_size: u64,
219 x_tokens: Arc<HashSet<Vec<u8>>>,
220 version: String,
221 ) -> Result<(), ConnectionError> {
222 let conn = incoming.await?;
223 info!("#{id}: new connection from {:?}", conn.remote_address());
224
225 let (mut send, response, maybe_rx) = Self::handle_request(
227 id,
228 &conn,
229 messages,
230 max_recv_streams,
231 max_request_size,
232 x_tokens,
233 version,
234 )
235 .await?;
236
237 let buf = response.encode_to_vec();
239 send.write_u64(buf.len() as u64).await?;
240 send.write_all(&buf).await?;
241 send.flush().await?;
242
243 let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
244 return Ok(());
245 };
246
247 let mut streams = VecDeque::with_capacity(recv_streams as usize);
249 while streams.len() < recv_streams as usize {
250 streams.push_back(conn.open_uni().await?);
251 }
252
253 let mut msg_id = 0;
255 let mut msg_ids = BTreeSet::new();
256 let mut next_message: Option<RecvItem> = None;
257 let mut set = JoinSet::new();
258 loop {
259 if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
260 if let Some(message) = next_message.take() {
261 if let Some(mut stream) = streams.pop_front() {
262 msg_ids.insert(msg_id);
263 set.spawn(async move {
264 WriteVectored::new(
265 &mut stream,
266 &mut [
267 IoSlice::new(&msg_id.to_be_bytes()),
268 IoSlice::new(&(message.len() as u64).to_be_bytes()),
269 IoSlice::new(&message),
270 ],
271 )
272 .await?;
273 Ok::<_, ConnectionError>((msg_id, stream))
274 });
275 msg_id += 1;
276 } else {
277 next_message = Some(message);
278 }
279 }
280 }
281
282 let rx_recv = if next_message.is_none() {
283 rx.next().boxed()
284 } else {
285 pending().boxed()
286 };
287 let set_join_next = if !set.is_empty() {
288 set.join_next().boxed()
289 } else {
290 pending().boxed()
291 };
292
293 tokio::select! {
294 message = rx_recv => {
295 match message {
296 Some(Ok(message)) => next_message = Some(message),
297 Some(Err(error)) => {
298 error!("#{id}: failed to get message: {error}");
299 if streams.is_empty() {
300 let (msg_id, stream) = set.join_next().await.expect("already verified")??;
301 msg_ids.remove(&msg_id);
302 streams.push_back(stream);
303 }
304 let Some(mut stream) = streams.pop_front() else {
305 return Err(ConnectionError::StreamNotAvailable);
306 };
307
308 let msg = QuicSubscribeClose {
309 error: match error {
310 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
311 RecvError::Closed => QuicSubscribeCloseError::Closed,
312 } as i32
313 };
314 let message = msg.encode_to_vec();
315
316 set.spawn(async move {
317 stream.write_u64(u64::MAX).await?;
318 stream.write_u64(message.len() as u64).await?;
319 stream.write_all(&message).await?;
320 Ok::<_, ConnectionError>((msg_id, stream))
321 });
322 },
323 None => break,
324 }
325 },
326 result = set_join_next => {
327 let (msg_id, stream) = result.expect("already verified")??;
328 msg_ids.remove(&msg_id);
329 streams.push_back(stream);
330 }
331 }
332 }
333
334 for (_, mut stream) in set.join_all().await.into_iter().flatten() {
335 stream.finish()?;
336 }
337 for mut stream in streams {
338 stream.finish()?;
339 }
340 drop(conn);
341
342 Ok(())
343 }
344
345 async fn handle_request(
346 id: u64,
347 conn: &Connection,
348 messages: impl Subscribe,
349 max_recv_streams: u32,
350 max_request_size: u64,
351 x_tokens: Arc<HashSet<Vec<u8>>>,
352 version: String,
353 ) -> Result<
354 (
355 SendStream,
356 QuicSubscribeResponse,
357 Option<(u32, u64, RecvStream)>,
358 ),
359 ConnectionError,
360 > {
361 let (send, mut recv) = conn.accept_bi().await?;
362
363 let size = recv.read_u64().await?;
365 if size > max_request_size {
366 let msg = QuicSubscribeResponse {
367 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
368 version,
369 ..Default::default()
370 };
371 return Ok((send, msg, None));
372 }
373 let mut buf = vec![0; size as usize]; recv.read_exact(buf.as_mut_slice()).await?;
375
376 let QuicSubscribeRequest {
378 x_token,
379 recv_streams,
380 max_backlog,
381 replay_from_slot,
382 filter,
383 } = Message::decode(buf.as_slice())?;
384
385 if !x_tokens.is_empty() {
387 if let Some(error) = match x_token {
388 Some(x_token) if !x_tokens.contains(&x_token) => {
389 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
390 }
391 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
392 _ => None,
393 } {
394 let msg = QuicSubscribeResponse {
395 error: Some(error),
396 version,
397 ..Default::default()
398 };
399 return Ok((send, msg, None));
400 }
401 }
402
403 if recv_streams == 0 || recv_streams > max_recv_streams {
405 let code = if recv_streams == 0 {
406 QuicSubscribeResponseError::ZeroRecvStreams
407 } else {
408 QuicSubscribeResponseError::ExceedRecvStreams
409 };
410 let msg = QuicSubscribeResponse {
411 error: Some(code as i32),
412 max_recv_streams: Some(max_recv_streams),
413 version,
414 ..Default::default()
415 };
416 return Ok((send, msg, None));
417 }
418
419 Ok(match messages.subscribe(replay_from_slot, filter) {
420 Ok(rx) => {
421 let pos = replay_from_slot
422 .map(|slot| format!("slot {slot}").into())
423 .unwrap_or(Cow::Borrowed("latest"));
424 info!("#{id}: subscribed from {pos}");
425 (
426 send,
427 QuicSubscribeResponse {
428 version,
429 ..Default::default()
430 },
431 Some((
432 recv_streams,
433 max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
434 rx,
435 )),
436 )
437 }
438 Err(SubscribeError::NotInitialized) => {
439 let msg = QuicSubscribeResponse {
440 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
441 version,
442 ..Default::default()
443 };
444 (send, msg, None)
445 }
446 Err(SubscribeError::SlotNotAvailable { first_available }) => {
447 let msg = QuicSubscribeResponse {
448 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
449 first_available_slot: Some(first_available),
450 version,
451 ..Default::default()
452 };
453 (send, msg, None)
454 }
455 })
456 }
457}