1use {
2 crate::{
3 config::{deserialize_num_str, deserialize_rustls_server_config, deserialize_x_token_set},
4 shutdown::Shutdown,
5 transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 version::Version,
7 },
8 futures::{
9 future::{pending, FutureExt},
10 stream::StreamExt,
11 },
12 prost::Message,
13 quinn::{
14 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
15 Connection, Endpoint, Incoming, SendStream, VarInt,
16 },
17 richat_proto::richat::{
18 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
19 QuicSubscribeResponseError,
20 },
21 serde::Deserialize,
22 std::{
23 borrow::Cow,
24 collections::{BTreeSet, HashSet, VecDeque},
25 future::Future,
26 io::{self, IoSlice},
27 net::{IpAddr, Ipv4Addr, SocketAddr},
28 sync::Arc,
29 },
30 thiserror::Error,
31 tokio::{
32 io::{AsyncReadExt, AsyncWriteExt},
33 task::{JoinError, JoinSet},
34 },
35 tracing::{error, info},
36};
37
38#[derive(Debug, Clone, Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ConfigQuicServer {
41 #[serde(default = "ConfigQuicServer::default_endpoint")]
42 pub endpoint: SocketAddr,
43 #[serde(deserialize_with = "deserialize_rustls_server_config")]
44 pub tls_config: rustls::ServerConfig,
45 #[serde(default = "ConfigQuicServer::default_expected_rtt")]
47 pub expected_rtt: u32,
48 #[serde(
50 default = "ConfigQuicServer::default_max_stream_bandwidth",
51 deserialize_with = "deserialize_num_str"
52 )]
53 pub max_stream_bandwidth: u32,
54 #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
56 pub max_idle_timeout: Option<u32>,
57 #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
59 pub max_recv_streams: u32,
60 #[serde(default = "ConfigQuicServer::default_max_request_size")]
62 pub max_request_size: usize,
63 #[serde(default, deserialize_with = "deserialize_x_token_set")]
64 pub x_tokens: HashSet<Vec<u8>>,
65}
66
67impl ConfigQuicServer {
68 pub const fn default_endpoint() -> SocketAddr {
69 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101)
70 }
71
72 const fn default_expected_rtt() -> u32 {
73 100
74 }
75
76 const fn default_max_stream_bandwidth() -> u32 {
77 12_500 * 1000
78 }
79
80 const fn default_max_idle_timeout() -> Option<u32> {
81 Some(30_000)
82 }
83
84 const fn default_max_recv_streams() -> u32 {
85 16
86 }
87
88 const fn default_max_request_size() -> usize {
89 1024
90 }
91
92 pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
93 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
94 QuicServerConfig::try_from(self.tls_config.clone())?,
95 ));
96
97 let transport_config = Arc::get_mut(&mut server_config.transport)
99 .ok_or(CreateEndpointError::TransportConfig)?;
100 transport_config.max_concurrent_bidi_streams(1u8.into());
101 transport_config.max_concurrent_uni_streams(0u8.into());
102
103 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
105 transport_config.stream_receive_window(stream_rwnd.into());
106 transport_config.send_window(8 * stream_rwnd as u64);
107 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
108
109 transport_config
111 .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
112
113 Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
114 error,
115 endpoint: self.endpoint,
116 })
117 }
118}
119
120#[derive(Debug, Error)]
121pub enum CreateEndpointError {
122 #[error("failed to crate QuicServerConfig")]
123 ServerConfig(#[from] NoInitialCipherSuite),
124 #[error("failed to modify TransportConfig")]
125 TransportConfig,
126 #[error("failed to bind {endpoint}: {error}")]
127 Bind {
128 error: io::Error,
129 endpoint: SocketAddr,
130 },
131}
132
133#[derive(Debug, Error)]
134enum ConnectionError {
135 #[error(transparent)]
136 QuinnConnection(#[from] quinn::ConnectionError),
137 #[error(transparent)]
138 QuinnReadExact(#[from] quinn::ReadExactError),
139 #[error(transparent)]
140 QuinnWrite(#[from] quinn::WriteError),
141 #[error(transparent)]
142 QuinnClosedStream(#[from] quinn::ClosedStream),
143 #[error(transparent)]
144 Io(#[from] io::Error),
145 #[error(transparent)]
146 Prost(#[from] prost::DecodeError),
147 #[error(transparent)]
148 Join(#[from] JoinError),
149 #[error("stream is not available")]
150 StreamNotAvailable,
151}
152
153#[derive(Debug)]
154pub struct QuicServer;
155
156impl QuicServer {
157 pub async fn spawn(
158 config: ConfigQuicServer,
159 messages: impl Subscribe + Clone + Send + 'static,
160 on_conn_new_cb: impl Fn() + Clone + Send + 'static,
161 on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
162 version: Version<'static>,
163 shutdown: Shutdown,
164 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
165 let endpoint = config.create_endpoint()?;
166 info!("start server at {}", config.endpoint);
167
168 Ok(tokio::spawn(async move {
169 let max_recv_streams = config.max_recv_streams;
170 let max_request_size = config.max_request_size as u64;
171 let x_tokens = Arc::new(config.x_tokens);
172
173 let mut id = 0;
174 tokio::pin!(shutdown);
175 loop {
176 tokio::select! {
177 incoming = endpoint.accept() => {
178 let Some(incoming) = incoming else {
179 error!("quic connection closed");
180 break;
181 };
182
183 let messages = messages.clone();
184 let on_conn_new_cb = on_conn_new_cb.clone();
185 let on_conn_drop_cb = on_conn_drop_cb.clone();
186 let x_tokens = Arc::clone(&x_tokens);
187 tokio::spawn(async move {
188 on_conn_new_cb();
189 if let Err(error) = Self::handle_incoming(
190 id,
191 incoming,
192 messages,
193 max_recv_streams,
194 max_request_size,
195 x_tokens,
196 version.create_grpc_version_info().json(),
197 ).await {
198 error!("#{id}: connection failed: {error}");
199 } else {
200 info!("#{id}: connection closed");
201 }
202 on_conn_drop_cb();
203 });
204 id += 1;
205 }
206 () = &mut shutdown => {
207 endpoint.close(0u32.into(), b"shutdown");
208 info!("shutdown");
209 break
210 },
211 };
212 }
213 }))
214 }
215
216 async fn handle_incoming(
217 id: u64,
218 incoming: Incoming,
219 messages: impl Subscribe,
220 max_recv_streams: u32,
221 max_request_size: u64,
222 x_tokens: Arc<HashSet<Vec<u8>>>,
223 version: String,
224 ) -> Result<(), ConnectionError> {
225 let conn = incoming.await?;
226 info!("#{id}: new connection from {:?}", conn.remote_address());
227
228 let (mut send, response, maybe_rx) = Self::handle_request(
230 id,
231 &conn,
232 messages,
233 max_recv_streams,
234 max_request_size,
235 x_tokens,
236 version,
237 )
238 .await?;
239
240 let buf = response.encode_to_vec();
242 send.write_u64(buf.len() as u64).await?;
243 send.write_all(&buf).await?;
244 send.flush().await?;
245
246 let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
247 return Ok(());
248 };
249
250 let mut streams = VecDeque::with_capacity(recv_streams as usize);
252 while streams.len() < recv_streams as usize {
253 streams.push_back(conn.open_uni().await?);
254 }
255
256 let mut msg_id = 0;
258 let mut msg_ids = BTreeSet::new();
259 let mut next_message: Option<RecvItem> = None;
260 let mut set = JoinSet::new();
261 loop {
262 if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
263 if let Some(message) = next_message.take() {
264 if let Some(mut stream) = streams.pop_front() {
265 msg_ids.insert(msg_id);
266 set.spawn(async move {
267 WriteVectored::new(
268 &mut stream,
269 &mut [
270 IoSlice::new(&msg_id.to_be_bytes()),
271 IoSlice::new(&(message.len() as u64).to_be_bytes()),
272 IoSlice::new(&message),
273 ],
274 )
275 .await?;
276 Ok::<_, ConnectionError>((msg_id, stream))
277 });
278 msg_id += 1;
279 } else {
280 next_message = Some(message);
281 }
282 }
283 }
284
285 let rx_recv = if next_message.is_none() {
286 rx.next().boxed()
287 } else {
288 pending().boxed()
289 };
290 let set_join_next = if !set.is_empty() {
291 set.join_next().boxed()
292 } else {
293 pending().boxed()
294 };
295
296 tokio::select! {
297 message = rx_recv => {
298 match message {
299 Some(Ok(message)) => next_message = Some(message),
300 Some(Err(error)) => {
301 error!("#{id}: failed to get message: {error}");
302 if streams.is_empty() {
303 let (msg_id, stream) = set.join_next().await.expect("already verified")??;
304 msg_ids.remove(&msg_id);
305 streams.push_back(stream);
306 }
307 let Some(mut stream) = streams.pop_front() else {
308 return Err(ConnectionError::StreamNotAvailable);
309 };
310
311 let msg = QuicSubscribeClose {
312 error: match error {
313 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
314 RecvError::Closed => QuicSubscribeCloseError::Closed,
315 } as i32
316 };
317 let message = msg.encode_to_vec();
318
319 set.spawn(async move {
320 stream.write_u64(u64::MAX).await?;
321 stream.write_u64(message.len() as u64).await?;
322 stream.write_all(&message).await?;
323 Ok::<_, ConnectionError>((msg_id, stream))
324 });
325 },
326 None => break,
327 }
328 },
329 result = set_join_next => {
330 let (msg_id, stream) = result.expect("already verified")??;
331 msg_ids.remove(&msg_id);
332 streams.push_back(stream);
333 }
334 }
335 }
336
337 for (_, mut stream) in set.join_all().await.into_iter().flatten() {
338 stream.finish()?;
339 }
340 for mut stream in streams {
341 stream.finish()?;
342 }
343 drop(conn);
344
345 Ok(())
346 }
347
348 async fn handle_request(
349 id: u64,
350 conn: &Connection,
351 messages: impl Subscribe,
352 max_recv_streams: u32,
353 max_request_size: u64,
354 x_tokens: Arc<HashSet<Vec<u8>>>,
355 version: String,
356 ) -> Result<
357 (
358 SendStream,
359 QuicSubscribeResponse,
360 Option<(u32, u64, RecvStream)>,
361 ),
362 ConnectionError,
363 > {
364 let (send, mut recv) = conn.accept_bi().await?;
365
366 let size = recv.read_u64().await?;
368 if size > max_request_size {
369 let msg = QuicSubscribeResponse {
370 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
371 version,
372 ..Default::default()
373 };
374 return Ok((send, msg, None));
375 }
376 let mut buf = vec![0; size as usize]; recv.read_exact(buf.as_mut_slice()).await?;
378
379 let QuicSubscribeRequest {
381 x_token,
382 recv_streams,
383 max_backlog,
384 replay_from_slot,
385 filter,
386 } = Message::decode(buf.as_slice())?;
387
388 if !x_tokens.is_empty() {
390 if let Some(error) = match x_token {
391 Some(x_token) if !x_tokens.contains(&x_token) => {
392 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
393 }
394 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
395 _ => None,
396 } {
397 let msg = QuicSubscribeResponse {
398 error: Some(error),
399 version,
400 ..Default::default()
401 };
402 return Ok((send, msg, None));
403 }
404 }
405
406 if recv_streams == 0 || recv_streams > max_recv_streams {
408 let code = if recv_streams == 0 {
409 QuicSubscribeResponseError::ZeroRecvStreams
410 } else {
411 QuicSubscribeResponseError::ExceedRecvStreams
412 };
413 let msg = QuicSubscribeResponse {
414 error: Some(code as i32),
415 max_recv_streams: Some(max_recv_streams),
416 version,
417 ..Default::default()
418 };
419 return Ok((send, msg, None));
420 }
421
422 Ok(match messages.subscribe(replay_from_slot, filter) {
423 Ok(rx) => {
424 let pos = replay_from_slot
425 .map(|slot| format!("slot {slot}").into())
426 .unwrap_or(Cow::Borrowed("latest"));
427 info!("#{id}: subscribed from {pos}");
428 (
429 send,
430 QuicSubscribeResponse {
431 version,
432 ..Default::default()
433 },
434 Some((
435 recv_streams,
436 max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
437 rx,
438 )),
439 )
440 }
441 Err(SubscribeError::NotInitialized) => {
442 let msg = QuicSubscribeResponse {
443 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
444 version,
445 ..Default::default()
446 };
447 (send, msg, None)
448 }
449 Err(SubscribeError::SlotNotAvailable { first_available }) => {
450 let msg = QuicSubscribeResponse {
451 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
452 first_available_slot: Some(first_available),
453 version,
454 ..Default::default()
455 };
456 (send, msg, None)
457 }
458 })
459 }
460}