richat_shared/transports/
quic.rs

1use {
2    crate::{
3        config::{deserialize_num_str, deserialize_rustls_server_config, deserialize_x_tokens_set},
4        transports::{RecvError, RecvItem, RecvStream, Subscribe, SubscribeError, WriteVectored},
5        version::Version,
6    },
7    futures::{
8        future::{FutureExt, pending},
9        stream::StreamExt,
10    },
11    prost::Message,
12    quinn::{
13        Connection, Endpoint, Incoming, SendStream, VarInt,
14        crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
15    },
16    richat_proto::richat::{
17        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeRequest, QuicSubscribeResponse,
18        QuicSubscribeResponseError,
19    },
20    serde::Deserialize,
21    std::{
22        borrow::Cow,
23        collections::{BTreeSet, HashSet, VecDeque},
24        future::Future,
25        io::{self, IoSlice},
26        net::{IpAddr, Ipv4Addr, SocketAddr},
27        sync::Arc,
28    },
29    thiserror::Error,
30    tokio::{
31        io::{AsyncReadExt, AsyncWriteExt},
32        task::{JoinError, JoinSet},
33    },
34    tokio_util::sync::CancellationToken,
35    tracing::{error, info},
36};
37
38#[derive(Debug, Clone, Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ConfigQuicServer {
41    #[serde(default = "ConfigQuicServer::default_endpoint")]
42    pub endpoint: SocketAddr,
43    #[serde(deserialize_with = "deserialize_rustls_server_config")]
44    pub tls_config: rustls::ServerConfig,
45    /// Value in ms
46    #[serde(default = "ConfigQuicServer::default_expected_rtt")]
47    pub expected_rtt: u32,
48    /// Value in bytes/s, default with expected rtt 100 is 100Mbps
49    #[serde(
50        default = "ConfigQuicServer::default_max_stream_bandwidth",
51        deserialize_with = "deserialize_num_str"
52    )]
53    pub max_stream_bandwidth: u32,
54    /// Maximum duration of inactivity to accept before timing out the connection
55    #[serde(default = "ConfigQuicServer::default_max_idle_timeout")]
56    pub max_idle_timeout: Option<u32>,
57    /// Max number of outgoing streams
58    #[serde(default = "ConfigQuicServer::default_max_recv_streams")]
59    pub max_recv_streams: u32,
60    /// Max request size in bytes
61    #[serde(default = "ConfigQuicServer::default_max_request_size")]
62    pub max_request_size: usize,
63    #[serde(default, deserialize_with = "deserialize_x_tokens_set")]
64    pub x_tokens: HashSet<Vec<u8>>,
65}
66
67impl ConfigQuicServer {
68    pub const fn default_endpoint() -> SocketAddr {
69        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101)
70    }
71
72    const fn default_expected_rtt() -> u32 {
73        100
74    }
75
76    const fn default_max_stream_bandwidth() -> u32 {
77        12_500 * 1000
78    }
79
80    const fn default_max_idle_timeout() -> Option<u32> {
81        Some(30_000)
82    }
83
84    const fn default_max_recv_streams() -> u32 {
85        16
86    }
87
88    const fn default_max_request_size() -> usize {
89        1024
90    }
91
92    pub fn create_endpoint(&self) -> Result<Endpoint, CreateEndpointError> {
93        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
94            QuicServerConfig::try_from(self.tls_config.clone())?,
95        ));
96
97        // disallow incoming uni streams
98        let transport_config = Arc::get_mut(&mut server_config.transport)
99            .ok_or(CreateEndpointError::TransportConfig)?;
100        transport_config.max_concurrent_bidi_streams(1u8.into());
101        transport_config.max_concurrent_uni_streams(0u8.into());
102
103        // set window size
104        let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
105        transport_config.stream_receive_window(stream_rwnd.into());
106        transport_config.send_window(8 * stream_rwnd as u64);
107        transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
108
109        // set idle timeout
110        transport_config
111            .max_idle_timeout(self.max_idle_timeout.map(|ms| VarInt::from_u32(ms).into()));
112
113        Endpoint::server(server_config, self.endpoint).map_err(|error| CreateEndpointError::Bind {
114            error,
115            endpoint: self.endpoint,
116        })
117    }
118}
119
120#[derive(Debug, Error)]
121pub enum CreateEndpointError {
122    #[error("failed to crate QuicServerConfig")]
123    ServerConfig(#[from] NoInitialCipherSuite),
124    #[error("failed to modify TransportConfig")]
125    TransportConfig,
126    #[error("failed to bind {endpoint}: {error}")]
127    Bind {
128        error: io::Error,
129        endpoint: SocketAddr,
130    },
131}
132
133#[derive(Debug, Error)]
134enum ConnectionError {
135    #[error(transparent)]
136    QuinnConnection(#[from] quinn::ConnectionError),
137    #[error(transparent)]
138    QuinnReadExact(#[from] quinn::ReadExactError),
139    #[error(transparent)]
140    QuinnWrite(#[from] quinn::WriteError),
141    #[error(transparent)]
142    QuinnClosedStream(#[from] quinn::ClosedStream),
143    #[error(transparent)]
144    Io(#[from] io::Error),
145    #[error(transparent)]
146    Prost(#[from] prost::DecodeError),
147    #[error(transparent)]
148    Join(#[from] JoinError),
149    #[error("stream is not available")]
150    StreamNotAvailable,
151}
152
153#[derive(Debug)]
154pub struct QuicServer;
155
156impl QuicServer {
157    pub async fn spawn(
158        config: ConfigQuicServer,
159        messages: impl Subscribe + Clone + Send + 'static,
160        on_conn_new_cb: impl Fn() + Clone + Send + 'static,
161        on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
162        version: Version<'static>,
163        shutdown: CancellationToken,
164    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateEndpointError> {
165        let endpoint = config.create_endpoint()?;
166        info!("start server at {}", config.endpoint);
167
168        Ok(tokio::spawn(async move {
169            let max_recv_streams = config.max_recv_streams;
170            let max_request_size = config.max_request_size as u64;
171            let x_tokens = Arc::new(config.x_tokens);
172
173            let mut id = 0;
174            loop {
175                tokio::select! {
176                    incoming = endpoint.accept() => {
177                        let Some(incoming) = incoming else {
178                            error!("quic connection closed");
179                            break;
180                        };
181
182                        let messages = messages.clone();
183                        let on_conn_new_cb = on_conn_new_cb.clone();
184                        let on_conn_drop_cb = on_conn_drop_cb.clone();
185                        let x_tokens = Arc::clone(&x_tokens);
186                        tokio::spawn(async move {
187                            on_conn_new_cb();
188                            if let Err(error) = Self::handle_incoming(
189                                id,
190                                incoming,
191                                messages,
192                                max_recv_streams,
193                                max_request_size,
194                                x_tokens,
195                                version.create_grpc_version_info().json(),
196                            ).await {
197                                error!("#{id}: connection failed: {error}");
198                            } else {
199                                info!("#{id}: connection closed");
200                            }
201                            on_conn_drop_cb();
202                        });
203                        id += 1;
204                    }
205                    () = shutdown.cancelled() => {
206                        endpoint.close(0u32.into(), b"shutdown");
207                        info!("shutdown");
208                        break
209                    },
210                };
211            }
212        }))
213    }
214
215    async fn handle_incoming(
216        id: u64,
217        incoming: Incoming,
218        messages: impl Subscribe,
219        max_recv_streams: u32,
220        max_request_size: u64,
221        x_tokens: Arc<HashSet<Vec<u8>>>,
222        version: String,
223    ) -> Result<(), ConnectionError> {
224        let conn = incoming.await?;
225        info!("#{id}: new connection from {:?}", conn.remote_address());
226
227        // Read request and subscribe
228        let (mut send, response, maybe_rx) = Self::handle_request(
229            id,
230            &conn,
231            messages,
232            max_recv_streams,
233            max_request_size,
234            x_tokens,
235            version,
236        )
237        .await?;
238
239        // Send response
240        let buf = response.encode_to_vec();
241        send.write_u64(buf.len() as u64).await?;
242        send.write_all(&buf).await?;
243        send.flush().await?;
244
245        let Some((recv_streams, max_backlog, mut rx)) = maybe_rx else {
246            return Ok(());
247        };
248
249        // Open connections
250        let mut streams = VecDeque::with_capacity(recv_streams as usize);
251        while streams.len() < recv_streams as usize {
252            streams.push_back(conn.open_uni().await?);
253        }
254
255        // Send loop
256        let mut msg_id = 0;
257        let mut msg_ids = BTreeSet::new();
258        let mut next_message: Option<RecvItem> = None;
259        let mut set = JoinSet::new();
260        loop {
261            if msg_id - msg_ids.first().copied().unwrap_or(msg_id) < max_backlog {
262                if let Some(message) = next_message.take() {
263                    if let Some(mut stream) = streams.pop_front() {
264                        msg_ids.insert(msg_id);
265                        set.spawn(async move {
266                            WriteVectored::new(
267                                &mut stream,
268                                &mut [
269                                    IoSlice::new(&msg_id.to_be_bytes()),
270                                    IoSlice::new(&(message.len() as u64).to_be_bytes()),
271                                    IoSlice::new(&message),
272                                ],
273                            )
274                            .await?;
275                            Ok::<_, ConnectionError>((msg_id, stream))
276                        });
277                        msg_id += 1;
278                    } else {
279                        next_message = Some(message);
280                    }
281                }
282            }
283
284            let rx_recv = if next_message.is_none() {
285                rx.next().boxed()
286            } else {
287                pending().boxed()
288            };
289            let set_join_next = if !set.is_empty() {
290                set.join_next().boxed()
291            } else {
292                pending().boxed()
293            };
294
295            tokio::select! {
296                message = rx_recv => {
297                    match message {
298                        Some(Ok(message)) => next_message = Some(message),
299                        Some(Err(error)) => {
300                            error!("#{id}: failed to get message: {error}");
301                            if streams.is_empty() {
302                                let (msg_id, stream) = set.join_next().await.expect("already verified")??;
303                                msg_ids.remove(&msg_id);
304                                streams.push_back(stream);
305                            }
306                            let Some(mut stream) = streams.pop_front() else {
307                                return Err(ConnectionError::StreamNotAvailable);
308                            };
309
310                            let msg = QuicSubscribeClose {
311                                error: match error {
312                                    RecvError::Lagged => QuicSubscribeCloseError::Lagged,
313                                    RecvError::Closed => QuicSubscribeCloseError::Closed,
314                                } as i32
315                            };
316                            let message = msg.encode_to_vec();
317
318                            set.spawn(async move {
319                                stream.write_u64(u64::MAX).await?;
320                                stream.write_u64(message.len() as u64).await?;
321                                stream.write_all(&message).await?;
322                                Ok::<_, ConnectionError>((msg_id, stream))
323                            });
324                        },
325                        None => break,
326                    }
327                },
328                result = set_join_next => {
329                    let (msg_id, stream) = result.expect("already verified")??;
330                    msg_ids.remove(&msg_id);
331                    streams.push_back(stream);
332                }
333            }
334        }
335
336        for (_, mut stream) in set.join_all().await.into_iter().flatten() {
337            stream.finish()?;
338        }
339        for mut stream in streams {
340            stream.finish()?;
341        }
342        drop(conn);
343
344        Ok(())
345    }
346
347    async fn handle_request(
348        id: u64,
349        conn: &Connection,
350        messages: impl Subscribe,
351        max_recv_streams: u32,
352        max_request_size: u64,
353        x_tokens: Arc<HashSet<Vec<u8>>>,
354        version: String,
355    ) -> Result<
356        (
357            SendStream,
358            QuicSubscribeResponse,
359            Option<(u32, u64, RecvStream)>,
360        ),
361        ConnectionError,
362    > {
363        let (send, mut recv) = conn.accept_bi().await?;
364
365        // Read request
366        let size = recv.read_u64().await?;
367        if size > max_request_size {
368            let msg = QuicSubscribeResponse {
369                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
370                version,
371                ..Default::default()
372            };
373            return Ok((send, msg, None));
374        }
375        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
376        recv.read_exact(buf.as_mut_slice()).await?;
377
378        // Decode request
379        let QuicSubscribeRequest {
380            x_token,
381            recv_streams,
382            max_backlog,
383            replay_from_slot,
384            filter,
385        } = Message::decode(buf.as_slice())?;
386
387        // verify access token
388        if !x_tokens.is_empty() {
389            if let Some(error) = match x_token {
390                Some(x_token) if !x_tokens.contains(&x_token) => {
391                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
392                }
393                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
394                _ => None,
395            } {
396                let msg = QuicSubscribeResponse {
397                    error: Some(error),
398                    version,
399                    ..Default::default()
400                };
401                return Ok((send, msg, None));
402            }
403        }
404
405        // validate number of streams
406        if recv_streams == 0 || recv_streams > max_recv_streams {
407            let code = if recv_streams == 0 {
408                QuicSubscribeResponseError::ZeroRecvStreams
409            } else {
410                QuicSubscribeResponseError::ExceedRecvStreams
411            };
412            let msg = QuicSubscribeResponse {
413                error: Some(code as i32),
414                max_recv_streams: Some(max_recv_streams),
415                version,
416                ..Default::default()
417            };
418            return Ok((send, msg, None));
419        }
420
421        Ok(match messages.subscribe(replay_from_slot, filter) {
422            Ok(rx) => {
423                let pos = replay_from_slot
424                    .map(|slot| format!("slot {slot}").into())
425                    .unwrap_or(Cow::Borrowed("latest"));
426                info!("#{id}: subscribed from {pos}");
427                (
428                    send,
429                    QuicSubscribeResponse {
430                        version,
431                        ..Default::default()
432                    },
433                    Some((
434                        recv_streams,
435                        max_backlog.map(|x| x as u64).unwrap_or(u64::MAX),
436                        rx,
437                    )),
438                )
439            }
440            Err(SubscribeError::NotInitialized) => {
441                let msg = QuicSubscribeResponse {
442                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
443                    version,
444                    ..Default::default()
445                };
446                (send, msg, None)
447            }
448            Err(SubscribeError::SlotNotAvailable { first_available }) => {
449                let msg = QuicSubscribeResponse {
450                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
451                    first_available_slot: Some(first_available),
452                    version,
453                    ..Default::default()
454                };
455                (send, msg, None)
456            }
457        })
458    }
459}