richat_shared/transports/
quic.rs

1use {
2    crate::{
3        config::{deserialize_rustls_server_config, deserialize_x_token_set},
4        shutdown::Shutdown,
5        transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6        version::Version,
7    },
8    futures::{
9        future::{pending, FutureExt},
10        stream::StreamExt,
11    },
12    prost::Message,
13    quinn::{
14        crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
15        Connection, Endpoint, Incoming, SendStream, VarInt,
16    },
17    richat_proto::richat::{
18        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
19        QuicSubscribeResponseError,
20    },
21    serde::Deserialize,
22    std::{
23        borrow::Cow,
24        collections::{BTreeSet, HashSet, VecDeque},
25        future::Future,
26        io::{self, IoSlice},
27        net::{IpAddr, Ipv4Addr, SocketAddr},
28        sync::Arc,
29    },
30    thiserror::Error,
31    tokio::{
32        io::{AsyncReadExt, AsyncWriteExt},
33        task::{JoinError, JoinSet},
34    },
35    tracing::{error, info},
36};
37
38#[derive(Debug, Clone, Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ConfigQuicServer {
41    #[serde(default = "ConfigQuicServer::default_endpoint")]
42    pub endpoint: SocketAddr,
43    #[serde(deserialize_with = "deserialize_rustls_server_config")]
44    pub tls_config: rustls::ServerConfig,
45    /// Value in ms
46    #[serde(default = "ConfigQuicServer::default_expected_rtt")]
47    pub expected_rtt: u32,
48    /// Value in bytes/s, default with expected rtt 100 is 100Mbps
49    #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
50    pub max_stream_bandwidth: u32,
51    /// Maximum duration of inactivity to accept before timing out the connection
52    #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
53    pub max_idle_timeout: Option<u32>,
54    /// Max number of outgoing streams
55    #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
56    pub max_recv_streams: u32,
57    /// Max request size in bytes
58    #[serde(default = "ConfigQuicServer::default_max_request_size")]
59    pub max_request_size: usize,
60    #[serde(default, deserialize_with = "deserialize_x_token_set")]
61    pub x_tokens: HashSet<Vec<u8>>,
62}
63
64impl ConfigQuicServer {
65    pub const fn default_endpoint() -> SocketAddr {
66        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101)
67    }
68
69    const fn default_expected_rtt() -> u32 {
70        100
71    }
72
73    const fn default_max_stream_bandwidth() -> u32 {
74        12_500 * 1000
75    }
76
77    const fn default_max_idle_timeout() -> Option<u32> {
78        Some(30_000)
79    }
80
81    const fn default_max_recv_streams() -> u32 {
82        16
83    }
84
85    const fn default_max_request_size() -> usize {
86        1024
87    }
88
89    pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
90        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
91            QuicServerConfig::try_from(self.tls_config.clone())?,
92        ));
93
94        // disallow incoming uni streams
95        let transport_config = Arc::get_mut(&mut server_config.transport)
96            .ok_or(CreateEndpointError::TransportConfig)?;
97        transport_config.max_concurrent_bidi_streams(1u8.into());
98        transport_config.max_concurrent_uni_streams(0u8.into());
99
100        // set window size
101        let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
102        transport_config.stream_receive_window(stream_rwnd.into());
103        transport_config.send_window(8 * stream_rwnd as u64);
104        transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
105
106        // set idle timeout
107        transport_config
108            .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
109
110        Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
111            error,
112            endpoint: self.endpoint,
113        })
114    }
115}
116
117#[derive(Debug, Error)]
118pub enum CreateEndpointError {
119    #[error("failed to crate QuicServerConfig")]
120    ServerConfig(#[from] NoInitialCipherSuite),
121    #[error("failed to modify TransportConfig")]
122    TransportConfig,
123    #[error("failed to bind {endpoint}: {error}")]
124    Bind {
125        error: io::Error,
126        endpoint: SocketAddr,
127    },
128}
129
130#[derive(Debug, Error)]
131enum ConnectionError {
132    #[error(transparent)]
133    QuinnConnection(#[from] quinn::ConnectionError),
134    #[error(transparent)]
135    QuinnReadExact(#[from] quinn::ReadExactError),
136    #[error(transparent)]
137    QuinnWrite(#[from] quinn::WriteError),
138    #[error(transparent)]
139    QuinnClosedStream(#[from] quinn::ClosedStream),
140    #[error(transparent)]
141    Io(#[from] io::Error),
142    #[error(transparent)]
143    Prost(#[from] prost::DecodeError),
144    #[error(transparent)]
145    Join(#[from] JoinError),
146    #[error("stream is not available")]
147    StreamNotAvailable,
148}
149
150#[derive(Debug)]
151pub struct QuicServer;
152
153impl QuicServer {
154    pub async fn spawn(
155        config: ConfigQuicServer,
156        messages: impl Subscribe + Clone + Send + 'static,
157        on_conn_new_cb: impl Fn() + Clone + Send + 'static,
158        on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
159        version: Version<'static>,
160        shutdown: Shutdown,
161    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
162        let endpoint = config.create_endpoint()?;
163        info!("start server at {}", config.endpoint);
164
165        Ok(tokio::spawn(async move {
166            let max_recv_streams = config.max_recv_streams;
167            let max_request_size = config.max_request_size as u64;
168            let x_tokens = Arc::new(config.x_tokens);
169
170            let mut id = 0;
171            tokio::pin!(shutdown);
172            loop {
173                tokio::select! {
174                    incoming = endpoint.accept() => {
175                        let Some(incoming) = incoming else {
176                            error!("quic connection closed");
177                            break;
178                        };
179
180                        let messages = messages.clone();
181                        let on_conn_new_cb = on_conn_new_cb.clone();
182                        let on_conn_drop_cb = on_conn_drop_cb.clone();
183                        let x_tokens = Arc::clone(&x_tokens);
184                        tokio::spawn(async move {
185                            on_conn_new_cb();
186                            if let Err(error) = Self::handle_incoming(
187                                id,
188                                incoming,
189                                messages,
190                                max_recv_streams,
191                                max_request_size,
192                                x_tokens,
193                                version.create_grpc_version_info().json(),
194                            ).await {
195                                error!("#{id}: connection failed: {error}");
196                            } else {
197                                info!("#{id}: connection closed");
198                            }
199                            on_conn_drop_cb();
200                        });
201                        id += 1;
202                    }
203                    () = &mut shutdown => {
204                        endpoint.close(0u32.into(), b"shutdown");
205                        info!("shutdown");
206                        break
207                    },
208                };
209            }
210        }))
211    }
212
213    async fn handle_incoming(
214        id: u64,
215        incoming: Incoming,
216        messages: impl Subscribe,
217        max_recv_streams: u32,
218        max_request_size: u64,
219        x_tokens: Arc<HashSet<Vec<u8>>>,
220        version: String,
221    ) -> Result<(), ConnectionError> {
222        let conn = incoming.await?;
223        info!("#{id}: new connection from {:?}", conn.remote_address());
224
225        // Read request and subscribe
226        let (mut send, response, maybe_rx) = Self::handle_request(
227            id,
228            &conn,
229            messages,
230            max_recv_streams,
231            max_request_size,
232            x_tokens,
233            version,
234        )
235        .await?;
236
237        // Send response
238        let buf = response.encode_to_vec();
239        send.write_u64(buf.len() as u64).await?;
240        send.write_all(&buf).await?;
241        send.flush().await?;
242
243        let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
244            return Ok(());
245        };
246
247        // Open connections
248        let mut streams = VecDeque::with_capacity(recv_streams as usize);
249        while streams.len() < recv_streams as usize {
250            streams.push_back(conn.open_uni().await?);
251        }
252
253        // Send loop
254        let mut msg_id = 0;
255        let mut msg_ids = BTreeSet::new();
256        let mut next_message: Option<RecvItem> = None;
257        let mut set = JoinSet::new();
258        loop {
259            if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
260                if let Some(message) = next_message.take() {
261                    if let Some(mut stream) = streams.pop_front() {
262                        msg_ids.insert(msg_id);
263                        set.spawn(async move {
264                            WriteVectored::new(
265                                &mut stream,
266                                &mut [
267                                    IoSlice::new(&msg_id.to_be_bytes()),
268                                    IoSlice::new(&(message.len() as u64).to_be_bytes()),
269                                    IoSlice::new(&message),
270                                ],
271                            )
272                            .await?;
273                            Ok::<_, ConnectionError>((msg_id, stream))
274                        });
275                        msg_id += 1;
276                    } else {
277                        next_message = Some(message);
278                    }
279                }
280            }
281
282            let rx_recv = if next_message.is_none() {
283                rx.next().boxed()
284            } else {
285                pending().boxed()
286            };
287            let set_join_next = if !set.is_empty() {
288                set.join_next().boxed()
289            } else {
290                pending().boxed()
291            };
292
293            tokio::select! {
294                message = rx_recv => {
295                    match message {
296                        Some(Ok(message)) => next_message = Some(message),
297                        Some(Err(error)) => {
298                            error!("#{id}: failed to get message: {error}");
299                            if streams.is_empty() {
300                                let (msg_id, stream) = set.join_next().await.expect("already verified")??;
301                                msg_ids.remove(&msg_id);
302                                streams.push_back(stream);
303                            }
304                            let Some(mut stream) = streams.pop_front() else {
305                                return Err(ConnectionError::StreamNotAvailable);
306                            };
307
308                            let msg = QuicSubscribeClose {
309                                error: match error {
310                                    RecvError::Lagged => QuicSubscribeCloseError::Lagged,
311                                    RecvError::Closed => QuicSubscribeCloseError::Closed,
312                                } as i32
313                            };
314                            let message = msg.encode_to_vec();
315
316                            set.spawn(async move {
317                                stream.write_u64(u64::MAX).await?;
318                                stream.write_u64(message.len() as u64).await?;
319                                stream.write_all(&message).await?;
320                                Ok::<_, ConnectionError>((msg_id, stream))
321                            });
322                        },
323                        None => break,
324                    }
325                },
326                result = set_join_next => {
327                    let (msg_id, stream) = result.expect("already verified")??;
328                    msg_ids.remove(&msg_id);
329                    streams.push_back(stream);
330                }
331            }
332        }
333
334        for (_, mut stream) in set.join_all().await.into_iter().flatten() {
335            stream.finish()?;
336        }
337        for mut stream in streams {
338            stream.finish()?;
339        }
340        drop(conn);
341
342        Ok(())
343    }
344
345    async fn handle_request(
346        id: u64,
347        conn: &Connection,
348        messages: impl Subscribe,
349        max_recv_streams: u32,
350        max_request_size: u64,
351        x_tokens: Arc<HashSet<Vec<u8>>>,
352        version: String,
353    ) -> Result<
354        (
355            SendStream,
356            QuicSubscribeResponse,
357            Option<(u32, u64, RecvStream)>,
358        ),
359        ConnectionError,
360    > {
361        let (send, mut recv) = conn.accept_bi().await?;
362
363        // Read request
364        let size = recv.read_u64().await?;
365        if size > max_request_size {
366            let msg = QuicSubscribeResponse {
367                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
368                version,
369                ..Default::default()
370            };
371            return Ok((send, msg, None));
372        }
373        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
374        recv.read_exact(buf.as_mut_slice()).await?;
375
376        // Decode request
377        let QuicSubscribeRequest {
378            x_token,
379            recv_streams,
380            max_backlog,
381            replay_from_slot,
382            filter,
383        } = Message::decode(buf.as_slice())?;
384
385        // verify access token
386        if !x_tokens.is_empty() {
387            if let Some(error) = match x_token {
388                Some(x_token) if !x_tokens.contains(&x_token) => {
389                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
390                }
391                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
392                _ => None,
393            } {
394                let msg = QuicSubscribeResponse {
395                    error: Some(error),
396                    version,
397                    ..Default::default()
398                };
399                return Ok((send, msg, None));
400            }
401        }
402
403        // validate number of streams
404        if recv_streams == 0 || recv_streams > max_recv_streams {
405            let code = if recv_streams == 0 {
406                QuicSubscribeResponseError::ZeroRecvStreams
407            } else {
408                QuicSubscribeResponseError::ExceedRecvStreams
409            };
410            let msg = QuicSubscribeResponse {
411                error: Some(code as i32),
412                max_recv_streams: Some(max_recv_streams),
413                version,
414                ..Default::default()
415            };
416            return Ok((send, msg, None));
417        }
418
419        Ok(match messages.subscribe(replay_from_slot, filter) {
420            Ok(rx) => {
421                let pos = replay_from_slot
422                    .map(|slot| format!("slot {slot}").into())
423                    .unwrap_or(Cow::Borrowed("latest"));
424                info!("#{id}: subscribed from {pos}");
425                (
426                    send,
427                    QuicSubscribeResponse {
428                        version,
429                        ..Default::default()
430                    },
431                    Some((
432                        recv_streams,
433                        max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
434                        rx,
435                    )),
436                )
437            }
438            Err(SubscribeError::NotInitialized) => {
439                let msg = QuicSubscribeResponse {
440                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
441                    version,
442                    ..Default::default()
443                };
444                (send, msg, None)
445            }
446            Err(SubscribeError::SlotNotAvailable { first_available }) => {
447                let msg = QuicSubscribeResponse {
448                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
449                    first_available_slot: Some(first_available),
450                    version,
451                    ..Default::default()
452                };
453                (send, msg, None)
454            }
455        })
456    }
457}