1use {
2 crate::{
3 config::{deserialize_num_str, deserialize_rustls_server_config, deserialize_x_tokens_set},
4 transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
5 version::Version,
6 },
7 futures::{
8 future::{FutureExt, pending},
9 stream::StreamExt,
10 },
11 prost::Message,
12 quinn::{
13 Connection, Endpoint, Incoming, SendStream, VarInt,
14 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
15 },
16 richat_proto::richat::{
17 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18 QuicSubscribeResponseError,
19 },
20 serde::Deserialize,
21 std::{
22 borrow::Cow,
23 collections::{BTreeSet, HashSet, VecDeque},
24 future::Future,
25 io::{self, IoSlice},
26 net::{IpAddr, Ipv4Addr, SocketAddr},
27 sync::Arc,
28 },
29 thiserror::Error,
30 tokio::{
31 io::{AsyncReadExt, AsyncWriteExt},
32 task::{JoinError, JoinSet},
33 },
34 tokio_util::sync::CancellationToken,
35 tracing::{error, info},
36};
37
38#[derive(Debug, Clone, Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ConfigQuicServer {
41 #[serde(default = "ConfigQuicServer::default_endpoint")]
42 pub endpoint: SocketAddr,
43 #[serde(deserialize_with = "deserialize_rustls_server_config")]
44 pub tls_config: rustls::ServerConfig,
45 #[serde(default = "ConfigQuicServer::default_expected_rtt")]
47 pub expected_rtt: u32,
48 #[serde(
50 default = "ConfigQuicServer::default_max_stream_bandwidth",
51 deserialize_with = "deserialize_num_str"
52 )]
53 pub max_stream_bandwidth: u32,
54 #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
56 pub max_idle_timeout: Option<u32>,
57 #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
59 pub max_recv_streams: u32,
60 #[serde(default = "ConfigQuicServer::default_max_request_size")]
62 pub max_request_size: usize,
63 #[serde(default, deserialize_with = "deserialize_x_tokens_set")]
64 pub x_tokens: HashSet<Vec<u8>>,
65}
66
67impl ConfigQuicServer {
68 pub const fn default_endpoint() -> SocketAddr {
69 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101)
70 }
71
72 const fn default_expected_rtt() -> u32 {
73 100
74 }
75
76 const fn default_max_stream_bandwidth() -> u32 {
77 12_500 * 1000
78 }
79
80 const fn default_max_idle_timeout() -> Option<u32> {
81 Some(30_000)
82 }
83
84 const fn default_max_recv_streams() -> u32 {
85 16
86 }
87
88 const fn default_max_request_size() -> usize {
89 1024
90 }
91
92 pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
93 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
94 QuicServerConfig::try_from(self.tls_config.clone())?,
95 ));
96
97 let transport_config = Arc::get_mut(&mut server_config.transport)
99 .ok_or(CreateEndpointError::TransportConfig)?;
100 transport_config.max_concurrent_bidi_streams(1u8.into());
101 transport_config.max_concurrent_uni_streams(0u8.into());
102
103 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
105 transport_config.stream_receive_window(stream_rwnd.into());
106 transport_config.send_window(8 * stream_rwnd as u64);
107 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
108
109 transport_config
111 .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
112
113 Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
114 error,
115 endpoint: self.endpoint,
116 })
117 }
118}
119
120#[derive(Debug, Error)]
121pub enum CreateEndpointError {
122 #[error("failed to crate QuicServerConfig")]
123 ServerConfig(#[from] NoInitialCipherSuite),
124 #[error("failed to modify TransportConfig")]
125 TransportConfig,
126 #[error("failed to bind {endpoint}: {error}")]
127 Bind {
128 error: io::Error,
129 endpoint: SocketAddr,
130 },
131}
132
133#[derive(Debug, Error)]
134enum ConnectionError {
135 #[error(transparent)]
136 QuinnConnection(#[from] quinn::ConnectionError),
137 #[error(transparent)]
138 QuinnReadExact(#[from] quinn::ReadExactError),
139 #[error(transparent)]
140 QuinnWrite(#[from] quinn::WriteError),
141 #[error(transparent)]
142 QuinnClosedStream(#[from] quinn::ClosedStream),
143 #[error(transparent)]
144 Io(#[from] io::Error),
145 #[error(transparent)]
146 Prost(#[from] prost::DecodeError),
147 #[error(transparent)]
148 Join(#[from] JoinError),
149 #[error("stream is not available")]
150 StreamNotAvailable,
151}
152
153#[derive(Debug)]
154pub struct QuicServer;
155
156impl QuicServer {
157 pub async fn spawn(
158 config: ConfigQuicServer,
159 messages: impl Subscribe + Clone + Send + 'static,
160 on_conn_new_cb: impl Fn() + Clone + Send + 'static,
161 on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
162 version: Version<'static>,
163 shutdown: CancellationToken,
164 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
165 let endpoint = config.create_endpoint()?;
166 info!("start server at {}", config.endpoint);
167
168 Ok(tokio::spawn(async move {
169 let max_recv_streams = config.max_recv_streams;
170 let max_request_size = config.max_request_size as u64;
171 let x_tokens = Arc::new(config.x_tokens);
172
173 let mut id = 0;
174 loop {
175 tokio::select! {
176 incoming = endpoint.accept() => {
177 let Some(incoming) = incoming else {
178 error!("quic connection closed");
179 break;
180 };
181
182 let messages = messages.clone();
183 let on_conn_new_cb = on_conn_new_cb.clone();
184 let on_conn_drop_cb = on_conn_drop_cb.clone();
185 let x_tokens = Arc::clone(&x_tokens);
186 tokio::spawn(async move {
187 on_conn_new_cb();
188 if let Err(error) = Self::handle_incoming(
189 id,
190 incoming,
191 messages,
192 max_recv_streams,
193 max_request_size,
194 x_tokens,
195 version.create_grpc_version_info().json(),
196 ).await {
197 error!("#{id}: connection failed: {error}");
198 } else {
199 info!("#{id}: connection closed");
200 }
201 on_conn_drop_cb();
202 });
203 id += 1;
204 }
205 () = shutdown.cancelled() => {
206 endpoint.close(0u32.into(), b"shutdown");
207 info!("shutdown");
208 break
209 },
210 };
211 }
212 }))
213 }
214
215 async fn handle_incoming(
216 id: u64,
217 incoming: Incoming,
218 messages: impl Subscribe,
219 max_recv_streams: u32,
220 max_request_size: u64,
221 x_tokens: Arc<HashSet<Vec<u8>>>,
222 version: String,
223 ) -> Result<(), ConnectionError> {
224 let conn = incoming.await?;
225 info!("#{id}: new connection from {:?}", conn.remote_address());
226
227 let (mut send, response, maybe_rx) = Self::handle_request(
229 id,
230 &conn,
231 messages,
232 max_recv_streams,
233 max_request_size,
234 x_tokens,
235 version,
236 )
237 .await?;
238
239 let buf = response.encode_to_vec();
241 send.write_u64(buf.len() as u64).await?;
242 send.write_all(&buf).await?;
243 send.flush().await?;
244
245 let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
246 return Ok(());
247 };
248
249 let mut streams = VecDeque::with_capacity(recv_streams as usize);
251 while streams.len() < recv_streams as usize {
252 streams.push_back(conn.open_uni().await?);
253 }
254
255 let mut msg_id = 0;
257 let mut msg_ids = BTreeSet::new();
258 let mut next_message: Option<RecvItem> = None;
259 let mut set = JoinSet::new();
260 loop {
261 if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
262 if let Some(message) = next_message.take() {
263 if let Some(mut stream) = streams.pop_front() {
264 msg_ids.insert(msg_id);
265 set.spawn(async move {
266 WriteVectored::new(
267 &mut stream,
268 &mut [
269 IoSlice::new(&msg_id.to_be_bytes()),
270 IoSlice::new(&(message.len() as u64).to_be_bytes()),
271 IoSlice::new(&message),
272 ],
273 )
274 .await?;
275 Ok::<_, ConnectionError>((msg_id, stream))
276 });
277 msg_id += 1;
278 } else {
279 next_message = Some(message);
280 }
281 }
282 }
283
284 let rx_recv = if next_message.is_none() {
285 rx.next().boxed()
286 } else {
287 pending().boxed()
288 };
289 let set_join_next = if !set.is_empty() {
290 set.join_next().boxed()
291 } else {
292 pending().boxed()
293 };
294
295 tokio::select! {
296 message = rx_recv => {
297 match message {
298 Some(Ok(message)) => next_message = Some(message),
299 Some(Err(error)) => {
300 error!("#{id}: failed to get message: {error}");
301 if streams.is_empty() {
302 let (msg_id, stream) = set.join_next().await.expect("already verified")??;
303 msg_ids.remove(&msg_id);
304 streams.push_back(stream);
305 }
306 let Some(mut stream) = streams.pop_front() else {
307 return Err(ConnectionError::StreamNotAvailable);
308 };
309
310 let msg = QuicSubscribeClose {
311 error: match error {
312 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
313 RecvError::Closed => QuicSubscribeCloseError::Closed,
314 } as i32
315 };
316 let message = msg.encode_to_vec();
317
318 set.spawn(async move {
319 stream.write_u64(u64::MAX).await?;
320 stream.write_u64(message.len() as u64).await?;
321 stream.write_all(&message).await?;
322 Ok::<_, ConnectionError>((msg_id, stream))
323 });
324 },
325 None => break,
326 }
327 },
328 result = set_join_next => {
329 let (msg_id, stream) = result.expect("already verified")??;
330 msg_ids.remove(&msg_id);
331 streams.push_back(stream);
332 }
333 }
334 }
335
336 for (_, mut stream) in set.join_all().await.into_iter().flatten() {
337 stream.finish()?;
338 }
339 for mut stream in streams {
340 stream.finish()?;
341 }
342 drop(conn);
343
344 Ok(())
345 }
346
347 async fn handle_request(
348 id: u64,
349 conn: &Connection,
350 messages: impl Subscribe,
351 max_recv_streams: u32,
352 max_request_size: u64,
353 x_tokens: Arc<HashSet<Vec<u8>>>,
354 version: String,
355 ) -> Result<
356 (
357 SendStream,
358 QuicSubscribeResponse,
359 Option<(u32, u64, RecvStream)>,
360 ),
361 ConnectionError,
362 > {
363 let (send, mut recv) = conn.accept_bi().await?;
364
365 let size = recv.read_u64().await?;
367 if size > max_request_size {
368 let msg = QuicSubscribeResponse {
369 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
370 version,
371 ..Default::default()
372 };
373 return Ok((send, msg, None));
374 }
375 let mut buf = vec![0; size as usize]; recv.read_exact(buf.as_mut_slice()).await?;
377
378 let QuicSubscribeRequest {
380 x_token,
381 recv_streams,
382 max_backlog,
383 replay_from_slot,
384 filter,
385 } = Message::decode(buf.as_slice())?;
386
387 if !x_tokens.is_empty() {
389 if let Some(error) = match x_token {
390 Some(x_token) if !x_tokens.contains(&x_token) => {
391 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
392 }
393 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
394 _ => None,
395 } {
396 let msg = QuicSubscribeResponse {
397 error: Some(error),
398 version,
399 ..Default::default()
400 };
401 return Ok((send, msg, None));
402 }
403 }
404
405 if recv_streams == 0 || recv_streams > max_recv_streams {
407 let code = if recv_streams == 0 {
408 QuicSubscribeResponseError::ZeroRecvStreams
409 } else {
410 QuicSubscribeResponseError::ExceedRecvStreams
411 };
412 let msg = QuicSubscribeResponse {
413 error: Some(code as i32),
414 max_recv_streams: Some(max_recv_streams),
415 version,
416 ..Default::default()
417 };
418 return Ok((send, msg, None));
419 }
420
421 Ok(match messages.subscribe(replay_from_slot, filter) {
422 Ok(rx) => {
423 let pos = replay_from_slot
424 .map(|slot| format!("slot {slot}").into())
425 .unwrap_or(Cow::Borrowed("latest"));
426 info!("#{id}: subscribed from {pos}");
427 (
428 send,
429 QuicSubscribeResponse {
430 version,
431 ..Default::default()
432 },
433 Some((
434 recv_streams,
435 max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
436 rx,
437 )),
438 )
439 }
440 Err(SubscribeError::NotInitialized) => {
441 let msg = QuicSubscribeResponse {
442 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
443 version,
444 ..Default::default()
445 };
446 (send, msg, None)
447 }
448 Err(SubscribeError::SlotNotAvailable { first_available }) => {
449 let msg = QuicSubscribeResponse {
450 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
451 first_available_slot: Some(first_available),
452 version,
453 ..Default::default()
454 };
455 (send, msg, None)
456 }
457 })
458 }
459}