richat_shared/transports/
quic.rs

1use {
2    crate::{
3        config::{deserialize_rustls_server_config, deserialize_x_token_set},
4        shutdown::Shutdown,
5        transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6    },
7    futures::{
8        future::{pending, FutureExt},
9        stream::StreamExt,
10    },
11    prost::Message,
12    quinn::{
13        crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
14        Connection, Endpoint, Incoming, SendStream, VarInt,
15    },
16    richat_proto::richat::{
17        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18        QuicSubscribeResponseError,
19    },
20    serde::Deserialize,
21    std::{
22        borrow::Cow,
23        collections::{BTreeSet, HashSet, VecDeque},
24        future::Future,
25        io::{self, IoSlice},
26        net::{IpAddr, Ipv4Addr, SocketAddr},
27        sync::Arc,
28    },
29    thiserror::Error,
30    tokio::{
31        io::{AsyncReadExt, AsyncWriteExt},
32        task::{JoinError, JoinSet},
33    },
34    tracing::{error, info},
35};
36
37#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ConfigQuicServer {
40    #[serde(default = "ConfigQuicServer::default_endpoint")]
41    pub endpoint: SocketAddr,
42    #[serde(deserialize_with = "deserialize_rustls_server_config")]
43    pub tls_config: rustls::ServerConfig,
44    /// Value in ms
45    #[serde(default = "ConfigQuicServer::default_expected_rtt")]
46    pub expected_rtt: u32,
47    /// Value in bytes/s, default with expected rtt 100 is 100Mbps
48    #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
49    pub max_stream_bandwidth: u32,
50    /// Maximum duration of inactivity to accept before timing out the connection
51    #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
52    pub max_idle_timeout: Option<u32>,
53    /// Max number of outgoing streams
54    #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
55    pub max_recv_streams: u32,
56    /// Max request size in bytes
57    #[serde(default = "ConfigQuicServer::default_max_request_size")]
58    pub max_request_size: usize,
59    #[serde(default, deserialize_with = "deserialize_x_token_set")]
60    pub x_tokens: HashSet<Vec<u8>>,
61}
62
63impl ConfigQuicServer {
64    pub const fn default_endpoint() -> SocketAddr {
65        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100)
66    }
67
68    const fn default_expected_rtt() -> u32 {
69        100
70    }
71
72    const fn default_max_stream_bandwidth() -> u32 {
73        12_500 * 1000
74    }
75
76    const fn default_max_idle_timeout() -> Option<u32> {
77        Some(30_000)
78    }
79
80    const fn default_max_recv_streams() -> u32 {
81        16
82    }
83
84    const fn default_max_request_size() -> usize {
85        1024
86    }
87
88    pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
89        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
90            QuicServerConfig::try_from(self.tls_config.clone())?,
91        ));
92
93        // disallow incoming uni streams
94        let transport_config = Arc::get_mut(&mut server_config.transport)
95            .ok_or(CreateEndpointError::TransportConfig)?;
96        transport_config.max_concurrent_bidi_streams(1u8.into());
97        transport_config.max_concurrent_uni_streams(0u8.into());
98
99        // set window size
100        let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
101        transport_config.stream_receive_window(stream_rwnd.into());
102        transport_config.send_window(8 * stream_rwnd as u64);
103        transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
104
105        // set idle timeout
106        transport_config
107            .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
108
109        Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
110            error,
111            endpoint: self.endpoint,
112        })
113    }
114}
115
116#[derive(Debug, Error)]
117pub enum CreateEndpointError {
118    #[error("failed to crate QuicServerConfig")]
119    ServerConfig(#[from] NoInitialCipherSuite),
120    #[error("failed to modify TransportConfig")]
121    TransportConfig,
122    #[error("failed to bind {endpoint}: {error}")]
123    Bind {
124        error: io::Error,
125        endpoint: SocketAddr,
126    },
127}
128
129#[derive(Debug, Error)]
130enum ConnectionError {
131    #[error(transparent)]
132    QuinnConnection(#[from] quinn::ConnectionError),
133    #[error(transparent)]
134    QuinnReadExact(#[from] quinn::ReadExactError),
135    #[error(transparent)]
136    QuinnWrite(#[from] quinn::WriteError),
137    #[error(transparent)]
138    QuinnClosedStream(#[from] quinn::ClosedStream),
139    #[error(transparent)]
140    Io(#[from] io::Error),
141    #[error(transparent)]
142    Prost(#[from] prost::DecodeError),
143    #[error(transparent)]
144    Join(#[from] JoinError),
145    #[error("stream is not available")]
146    StreamNotAvailable,
147}
148
149#[derive(Debug)]
150pub struct QuicServer;
151
152impl QuicServer {
153    pub async fn spawn(
154        config: ConfigQuicServer,
155        messages: impl Subscribe + Clone + Send + 'static,
156        on_conn_new_cb: impl Fn() + Clone + Send + 'static,
157        on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
158        shutdown: Shutdown,
159    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
160        let endpoint = config.create_endpoint()?;
161        info!("start server at {}", config.endpoint);
162
163        Ok(tokio::spawn(async move {
164            let max_recv_streams = config.max_recv_streams;
165            let max_request_size = config.max_request_size as u64;
166            let x_tokens = Arc::new(config.x_tokens);
167
168            let mut id = 0;
169            tokio::pin!(shutdown);
170            loop {
171                tokio::select! {
172                    incoming = endpoint.accept() => {
173                        let Some(incoming) = incoming else {
174                            error!("quic connection closed");
175                            break;
176                        };
177
178                        let messages = messages.clone();
179                        let on_conn_new_cb = on_conn_new_cb.clone();
180                        let on_conn_drop_cb = on_conn_drop_cb.clone();
181                        let x_tokens = Arc::clone(&x_tokens);
182                        tokio::spawn(async move {
183                            on_conn_new_cb();
184                            if let Err(error) = Self::handle_incoming(
185                                id, incoming, messages, max_recv_streams, max_request_size, x_tokens
186                            ).await {
187                                error!("#{id}: connection failed: {error}");
188                            } else {
189                                info!("#{id}: connection closed");
190                            }
191                            on_conn_drop_cb();
192                        });
193                        id += 1;
194                    }
195                    () = &mut shutdown => {
196                        endpoint.close(0u32.into(), b"shutdown");
197                        info!("shutdown");
198                        break
199                    },
200                };
201            }
202        }))
203    }
204
205    async fn handle_incoming(
206        id: u64,
207        incoming: Incoming,
208        messages: impl Subscribe,
209        max_recv_streams: u32,
210        max_request_size: u64,
211        x_tokens: Arc<HashSet<Vec<u8>>>,
212    ) -> Result<(), ConnectionError> {
213        let conn = incoming.await?;
214        info!("#{id}: new connection from {:?}", conn.remote_address());
215
216        // Read request and subscribe
217        let (mut send, response, maybe_rx) = Self::handle_request(
218            id,
219            &conn,
220            messages,
221            max_recv_streams,
222            max_request_size,
223            x_tokens,
224        )
225        .await?;
226
227        // Send response
228        let buf = response.encode_to_vec();
229        send.write_u64(buf.len() as u64).await?;
230        send.write_all(&buf).await?;
231        send.flush().await?;
232
233        let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
234            return Ok(());
235        };
236
237        // Open connections
238        let mut streams = VecDeque::with_capacity(recv_streams as usize);
239        while streams.len() < recv_streams as usize {
240            streams.push_back(conn.open_uni().await?);
241        }
242
243        // Send loop
244        let mut msg_id = 0;
245        let mut msg_ids = BTreeSet::new();
246        let mut next_message: Option<RecvItem> = None;
247        let mut set = JoinSet::new();
248        loop {
249            if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
250                if let Some(message) = next_message.take() {
251                    if let Some(mut stream) = streams.pop_front() {
252                        msg_ids.insert(msg_id);
253                        set.spawn(async move {
254                            WriteVectored::new(
255                                &mut stream,
256                                &mut [
257                                    IoSlice::new(&msg_id.to_be_bytes()),
258                                    IoSlice::new(&(message.len() as u64).to_be_bytes()),
259                                    IoSlice::new(&message),
260                                ],
261                            )
262                            .await?;
263                            Ok::<_, ConnectionError>((msg_id, stream))
264                        });
265                        msg_id += 1;
266                    } else {
267                        next_message = Some(message);
268                    }
269                }
270            }
271
272            let rx_recv = if next_message.is_none() {
273                rx.next().boxed()
274            } else {
275                pending().boxed()
276            };
277            let set_join_next = if !set.is_empty() {
278                set.join_next().boxed()
279            } else {
280                pending().boxed()
281            };
282
283            tokio::select! {
284                message = rx_recv => {
285                    match message {
286                        Some(Ok(message)) => next_message = Some(message),
287                        Some(Err(error)) => {
288                            error!("#{id}: failed to get message: {error}");
289                            if streams.is_empty() {
290                                let (msg_id, stream) = set.join_next().await.expect("already verified")??;
291                                msg_ids.remove(&msg_id);
292                                streams.push_back(stream);
293                            }
294                            let Some(mut stream) = streams.pop_front() else {
295                                return Err(ConnectionError::StreamNotAvailable);
296                            };
297
298                            let msg = QuicSubscribeClose {
299                                error: match error {
300                                    RecvError::Lagged => QuicSubscribeCloseError::Lagged,
301                                    RecvError::Closed => QuicSubscribeCloseError::Closed,
302                                } as i32
303                            };
304                            let message = msg.encode_to_vec();
305
306                            set.spawn(async move {
307                                stream.write_u64(u64::MAX).await?;
308                                stream.write_u64(message.len() as u64).await?;
309                                stream.write_all(&message).await?;
310                                Ok::<_, ConnectionError>((msg_id, stream))
311                            });
312                        },
313                        None => break,
314                    }
315                },
316                result = set_join_next => {
317                    let (msg_id, stream) = result.expect("already verified")??;
318                    msg_ids.remove(&msg_id);
319                    streams.push_back(stream);
320                }
321            }
322        }
323
324        for (_, mut stream) in set.join_all().await.into_iter().flatten() {
325            stream.finish()?;
326        }
327        for mut stream in streams {
328            stream.finish()?;
329        }
330        drop(conn);
331
332        Ok(())
333    }
334
335    async fn handle_request(
336        id: u64,
337        conn: &Connection,
338        messages: impl Subscribe,
339        max_recv_streams: u32,
340        max_request_size: u64,
341        x_tokens: Arc<HashSet<Vec<u8>>>,
342    ) -> Result<
343        (
344            SendStream,
345            QuicSubscribeResponse,
346            Option<(u32, u64, RecvStream)>,
347        ),
348        ConnectionError,
349    > {
350        let (send, mut recv) = conn.accept_bi().await?;
351
352        // Read request
353        let size = recv.read_u64().await?;
354        if size > max_request_size {
355            let msg = QuicSubscribeResponse {
356                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
357                ..Default::default()
358            };
359            return Ok((send, msg, None));
360        }
361        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
362        recv.read_exact(buf.as_mut_slice()).await?;
363
364        // Decode request
365        let QuicSubscribeRequest {
366            x_token,
367            recv_streams,
368            max_backlog,
369            replay_from_slot,
370            filter,
371        } = Message::decode(buf.as_slice())?;
372
373        // verify access token
374        if !x_tokens.is_empty() {
375            if let Some(error) = match x_token {
376                Some(x_token) if !x_tokens.contains(&x_token) => {
377                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
378                }
379                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
380                _ => None,
381            } {
382                let msg = QuicSubscribeResponse {
383                    error: Some(error),
384                    ..Default::default()
385                };
386                return Ok((send, msg, None));
387            }
388        }
389
390        // validate number of streams
391        if recv_streams == 0 || recv_streams > max_recv_streams {
392            let code = if recv_streams == 0 {
393                QuicSubscribeResponseError::ZeroRecvStreams
394            } else {
395                QuicSubscribeResponseError::ExceedRecvStreams
396            };
397            let msg = QuicSubscribeResponse {
398                error: Some(code as i32),
399                max_recv_streams: Some(max_recv_streams),
400                ..Default::default()
401            };
402            return Ok((send, msg, None));
403        }
404
405        Ok(match messages.subscribe(replay_from_slot, filter) {
406            Ok(rx) => {
407                let pos = replay_from_slot
408                    .map(|slot| format!("slot {slot}").into())
409                    .unwrap_or(Cow::Borrowed("latest"));
410                info!("#{id}: subscribed from {pos}");
411                (
412                    send,
413                    QuicSubscribeResponse::default(),
414                    Some((
415                        recv_streams,
416                        max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
417                        rx,
418                    )),
419                )
420            }
421            Err(SubscribeError::NotInitialized) => {
422                let msg = QuicSubscribeResponse {
423                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
424                    ..Default::default()
425                };
426                (send, msg, None)
427            }
428            Err(SubscribeError::SlotNotAvailable { first_available }) => {
429                let msg = QuicSubscribeResponse {
430                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
431                    first_available_slot: Some(first_available),
432                    ..Default::default()
433                };
434                (send, msg, None)
435            }
436        })
437    }
438}