richat_shared/transports/
quic.rs

1use {
2    crate::{
3        config::{deserialize_rustls_server_config, deserialize_x_token_set},
4        shutdown::Shutdown,
5        transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
6    },
7    futures::{
8        future::{pending, FutureExt},
9        stream::StreamExt,
10    },
11    prost::Message,
12    quinn::{
13        crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
14        Connection, Endpoint, Incoming, SendStream, VarInt,
15    },
16    richat_proto::richat::{
17        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18        QuicSubscribeResponseError,
19    },
20    serde::Deserialize,
21    std::{
22        borrow::Cow,
23        collections::{BTreeSet, HashSet, VecDeque},
24        future::Future,
25        io::{self, IoSlice},
26        net::{IpAddr, Ipv4Addr, SocketAddr},
27        sync::Arc,
28    },
29    thiserror::Error,
30    tokio::{
31        io::{AsyncReadExt, AsyncWriteExt},
32        task::{JoinError, JoinSet},
33    },
34    tracing::{error, info},
35};
36
37#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ConfigQuicServer {
40    #[serde(default = "ConfigQuicServer::default_endpoint")]
41    pub endpoint: SocketAddr,
42    #[serde(deserialize_with = "deserialize_rustls_server_config")]
43    pub tls_config: rustls::ServerConfig,
44    /// Value in ms
45    #[serde(default = "ConfigQuicServer::default_expected_rtt")]
46    pub expected_rtt: u32,
47    /// Value in bytes/s, default with expected rtt 100 is 100Mbps
48    #[serde(default = "ConfigQuicServer::default_max_stream_bandwidth")]
49    pub max_stream_bandwidth: u32,
50    /// Maximum duration of inactivity to accept before timing out the connection
51    #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
52    pub max_idle_timeout: Option<u32>,
53    /// Max number of outgoing streams
54    #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
55    pub max_recv_streams: u32,
56    /// Max request size in bytes
57    #[serde(default = "ConfigQuicServer::default_max_request_size")]
58    pub max_request_size: usize,
59    #[serde(default, deserialize_with = "deserialize_x_token_set")]
60    pub x_tokens: HashSet<Vec<u8>>,
61}
62
63impl ConfigQuicServer {
64    pub const fn default_endpoint() -> SocketAddr {
65        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100)
66    }
67
68    const fn default_expected_rtt() -> u32 {
69        100
70    }
71
72    const fn default_max_stream_bandwidth() -> u32 {
73        12_500 * 1000
74    }
75
76    const fn default_max_idle_timeout() -> Option<u32> {
77        Some(30_000)
78    }
79
80    const fn default_max_recv_streams() -> u32 {
81        16
82    }
83
84    const fn default_max_request_size() -> usize {
85        1024
86    }
87
88    pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
89        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
90            QuicServerConfig::try_from(self.tls_config.clone())?,
91        ));
92
93        // disallow incoming uni streams
94        let transport_config = Arc::get_mut(&mut server_config.transport)
95            .ok_or(CreateEndpointError::TransportConfig)?;
96        transport_config.max_concurrent_bidi_streams(1u8.into());
97        transport_config.max_concurrent_uni_streams(0u8.into());
98
99        // set window size
100        let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
101        transport_config.stream_receive_window(stream_rwnd.into());
102        transport_config.send_window(8 * stream_rwnd as u64);
103        transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
104
105        // set idle timeout
106        transport_config
107            .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
108
109        Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
110            error,
111            endpoint: self.endpoint,
112        })
113    }
114}
115
116#[derive(Debug, Error)]
117pub enum CreateEndpointError {
118    #[error("failed to crate QuicServerConfig")]
119    ServerConfig(#[from] NoInitialCipherSuite),
120    #[error("failed to modify TransportConfig")]
121    TransportConfig,
122    #[error("failed to bind {endpoint}: {error}")]
123    Bind {
124        error: io::Error,
125        endpoint: SocketAddr,
126    },
127}
128
129#[derive(Debug, Error)]
130enum ConnectionError {
131    #[error(transparent)]
132    QuinnConnection(#[from] quinn::ConnectionError),
133    #[error(transparent)]
134    QuinnReadExact(#[from] quinn::ReadExactError),
135    #[error(transparent)]
136    QuinnWrite(#[from] quinn::WriteError),
137    #[error(transparent)]
138    QuinnClosedStream(#[from] quinn::ClosedStream),
139    #[error(transparent)]
140    Io(#[from] io::Error),
141    #[error(transparent)]
142    Prost(#[from] prost::DecodeError),
143    #[error(transparent)]
144    Join(#[from] JoinError),
145    #[error("stream is not available")]
146    StreamNotAvailable,
147}
148
149#[derive(Debug)]
150pub struct QuicServer;
151
152impl QuicServer {
153    pub async fn spawn(
154        config: ConfigQuicServer,
155        messages: impl Subscribe + Clone + Send + 'static,
156        on_conn_new_cb: impl Fn() + Copy + Send + 'static,
157        on_conn_drop_cb: impl Fn() + Copy + Send + 'static,
158        shutdown: Shutdown,
159    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
160        let endpoint = config.create_endpoint()?;
161        info!("start server at {}", config.endpoint);
162
163        Ok(tokio::spawn(async move {
164            let max_recv_streams = config.max_recv_streams;
165            let max_request_size = config.max_request_size as u64;
166            let x_tokens = Arc::new(config.x_tokens);
167
168            let mut id = 0;
169            tokio::pin!(shutdown);
170            loop {
171                tokio::select! {
172                    incoming = endpoint.accept() => {
173                        let Some(incoming) = incoming else {
174                            error!("quic connection closed");
175                            break;
176                        };
177
178                        let messages = messages.clone();
179                        let x_tokens = Arc::clone(&x_tokens);
180                        tokio::spawn(async move {
181                            on_conn_new_cb();
182                            if let Err(error) = Self::handle_incoming(
183                                id, incoming, messages, max_recv_streams, max_request_size, x_tokens
184                            ).await {
185                                error!("#{id}: connection failed: {error}");
186                            } else {
187                                info!("#{id}: connection closed");
188                            }
189                            on_conn_drop_cb();
190                        });
191                        id += 1;
192                    }
193                    () = &mut shutdown => {
194                        endpoint.close(0u32.into(), b"shutdown");
195                        info!("shutdown");
196                        break
197                    },
198                };
199            }
200        }))
201    }
202
203    async fn handle_incoming(
204        id: u64,
205        incoming: Incoming,
206        messages: impl Subscribe,
207        max_recv_streams: u32,
208        max_request_size: u64,
209        x_tokens: Arc<HashSet<Vec<u8>>>,
210    ) -> Result<(), ConnectionError> {
211        let conn = incoming.await?;
212        info!("#{id}: new connection from {:?}", conn.remote_address());
213
214        // Read request and subscribe
215        let (mut send, response, maybe_rx) = Self::handle_request(
216            id,
217            &conn,
218            messages,
219            max_recv_streams,
220            max_request_size,
221            x_tokens,
222        )
223        .await?;
224
225        // Send response
226        let buf = response.encode_to_vec();
227        send.write_u64(buf.len() as u64).await?;
228        send.write_all(&buf).await?;
229        send.flush().await?;
230
231        let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
232            return Ok(());
233        };
234
235        // Open connections
236        let mut streams = VecDeque::with_capacity(recv_streams as usize);
237        while streams.len() < recv_streams as usize {
238            streams.push_back(conn.open_uni().await?);
239        }
240
241        // Send loop
242        let mut msg_id = 0;
243        let mut msg_ids = BTreeSet::new();
244        let mut next_message: Option<RecvItem> = None;
245        let mut set = JoinSet::new();
246        loop {
247            if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
248                if let Some(message) = next_message.take() {
249                    if let Some(mut stream) = streams.pop_front() {
250                        msg_ids.insert(msg_id);
251                        set.spawn(async move {
252                            WriteVectored::new(
253                                &mut stream,
254                                &mut [
255                                    IoSlice::new(&msg_id.to_be_bytes()),
256                                    IoSlice::new(&(message.len() as u64).to_be_bytes()),
257                                    IoSlice::new(&message),
258                                ],
259                            )
260                            .await?;
261                            Ok::<_, ConnectionError>((msg_id, stream))
262                        });
263                        msg_id += 1;
264                    } else {
265                        next_message = Some(message);
266                    }
267                }
268            }
269
270            let rx_recv = if next_message.is_none() {
271                rx.next().boxed()
272            } else {
273                pending().boxed()
274            };
275            let set_join_next = if !set.is_empty() {
276                set.join_next().boxed()
277            } else {
278                pending().boxed()
279            };
280
281            tokio::select! {
282                message = rx_recv => {
283                    match message {
284                        Some(Ok(message)) => next_message = Some(message),
285                        Some(Err(error)) => {
286                            error!("#{id}: failed to get message: {error}");
287                            if streams.is_empty() {
288                                let (msg_id, stream) = set.join_next().await.expect("already verified")??;
289                                msg_ids.remove(&msg_id);
290                                streams.push_back(stream);
291                            }
292                            let Some(mut stream) = streams.pop_front() else {
293                                return Err(ConnectionError::StreamNotAvailable);
294                            };
295
296                            let msg = QuicSubscribeClose {
297                                error: match error {
298                                    RecvError::Lagged => QuicSubscribeCloseError::Lagged,
299                                    RecvError::Closed => QuicSubscribeCloseError::Closed,
300                                } as i32
301                            };
302                            let message = msg.encode_to_vec();
303
304                            set.spawn(async move {
305                                stream.write_u64(u64::MAX).await?;
306                                stream.write_u64(message.len() as u64).await?;
307                                stream.write_all(&message).await?;
308                                Ok::<_, ConnectionError>((msg_id, stream))
309                            });
310                        },
311                        None => break,
312                    }
313                },
314                result = set_join_next => {
315                    let (msg_id, stream) = result.expect("already verified")??;
316                    msg_ids.remove(&msg_id);
317                    streams.push_back(stream);
318                }
319            }
320        }
321
322        for (_, mut stream) in set.join_all().await.into_iter().flatten() {
323            stream.finish()?;
324        }
325        for mut stream in streams {
326            stream.finish()?;
327        }
328        drop(conn);
329
330        Ok(())
331    }
332
333    async fn handle_request(
334        id: u64,
335        conn: &Connection,
336        messages: impl Subscribe,
337        max_recv_streams: u32,
338        max_request_size: u64,
339        x_tokens: Arc<HashSet<Vec<u8>>>,
340    ) -> Result<
341        (
342            SendStream,
343            QuicSubscribeResponse,
344            Option<(u32, u64, RecvStream)>,
345        ),
346        ConnectionError,
347    > {
348        let (send, mut recv) = conn.accept_bi().await?;
349
350        // Read request
351        let size = recv.read_u64().await?;
352        if size > max_request_size {
353            let msg = QuicSubscribeResponse {
354                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
355                ..Default::default()
356            };
357            return Ok((send, msg, None));
358        }
359        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
360        recv.read_exact(buf.as_mut_slice()).await?;
361
362        // Decode request
363        let QuicSubscribeRequest {
364            x_token,
365            recv_streams,
366            max_backlog,
367            replay_from_slot,
368            filter,
369        } = Message::decode(buf.as_slice())?;
370
371        // verify access token
372        if !x_tokens.is_empty() {
373            if let Some(error) = match x_token {
374                Some(x_token) if !x_tokens.contains(&x_token) => {
375                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
376                }
377                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
378                _ => None,
379            } {
380                let msg = QuicSubscribeResponse {
381                    error: Some(error),
382                    ..Default::default()
383                };
384                return Ok((send, msg, None));
385            }
386        }
387
388        // validate number of streams
389        if recv_streams == 0 || recv_streams > max_recv_streams {
390            let code = if recv_streams == 0 {
391                QuicSubscribeResponseError::ZeroRecvStreams
392            } else {
393                QuicSubscribeResponseError::ExceedRecvStreams
394            };
395            let msg = QuicSubscribeResponse {
396                error: Some(code as i32),
397                max_recv_streams: Some(max_recv_streams),
398                ..Default::default()
399            };
400            return Ok((send, msg, None));
401        }
402
403        Ok(match messages.subscribe(replay_from_slot, filter) {
404            Ok(rx) => {
405                let pos = replay_from_slot
406                    .map(|slot| format!("slot {slot}").into())
407                    .unwrap_or(Cow::Borrowed("latest"));
408                info!("#{id}: subscribed from {pos}");
409                (
410                    send,
411                    QuicSubscribeResponse::default(),
412                    Some((
413                        recv_streams,
414                        max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
415                        rx,
416                    )),
417                )
418            }
419            Err(SubscribeError::NotInitialized) => {
420                let msg = QuicSubscribeResponse {
421                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
422                    ..Default::default()
423                };
424                (send, msg, None)
425            }
426            Err(SubscribeError::SlotNotAvailable { first_available }) => {
427                let msg = QuicSubscribeResponse {
428                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
429                    first_available_slot: Some(first_available),
430                    ..Default::default()
431                };
432                (send, msg, None)
433            }
434        })
435    }
436}