Skip to main content

richat_client/
quic.rs

1use {
2    crate::{
3        error::{ReceiveError, SubscribeError},
4        stream::SubscribeStream,
5    },
6    foldhash::quality::RandomState,
7    futures::{
8        future::{BoxFuture, FutureExt},
9        stream::{Stream, StreamExt},
10    },
11    pin_project_lite::pin_project,
12    prost::Message,
13    quinn::{
14        ClientConfig, ConnectError, Connection, ConnectionError, Endpoint, RecvStream,
15        TransportConfig, VarInt,
16        crypto::rustls::{NoInitialCipherSuite, QuicClientConfig},
17    },
18    richat_proto::richat::{QuicSubscribeClose, QuicSubscribeRequest, RichatFilter},
19    richat_shared::{
20        config::{deserialize_maybe_num_str, deserialize_maybe_x_token, deserialize_num_str},
21        transports::quic::ConfigQuicServer,
22    },
23    rustls::{
24        RootCertStore,
25        pki_types::{CertificateDer, ServerName, UnixTime},
26    },
27    serde::Deserialize,
28    solana_clock::Slot,
29    std::{
30        collections::HashMap,
31        fmt,
32        future::Future,
33        io,
34        net::{IpAddr, Ipv6Addr, SocketAddr},
35        path::PathBuf,
36        pin::Pin,
37        sync::Arc,
38        task::{Context, Poll, ready},
39        time::Duration,
40    },
41    thiserror::Error,
42    tokio::{
43        fs,
44        io::{AsyncReadExt, AsyncWriteExt},
45        net::{ToSocketAddrs, lookup_host},
46    },
47};
48
49/// Dummy certificate verifier that treats any certificate as valid.
50/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing.
51#[derive(Debug)]
52struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
53
54impl SkipServerVerification {
55    fn new() -> Arc<Self> {
56        Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
57    }
58}
59
60impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
61    fn verify_server_cert(
62        &self,
63        _end_entity: &CertificateDer<'_>,
64        _intermediates: &[CertificateDer<'_>],
65        _server_name: &ServerName<'_>,
66        _ocsp: &[u8],
67        _now: UnixTime,
68    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
69        Ok(rustls::client::danger::ServerCertVerified::assertion())
70    }
71
72    fn verify_tls12_signature(
73        &self,
74        message: &[u8],
75        cert: &CertificateDer<'_>,
76        dss: &rustls::DigitallySignedStruct,
77    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
78        rustls::crypto::verify_tls12_signature(
79            message,
80            cert,
81            dss,
82            &self.0.signature_verification_algorithms,
83        )
84    }
85
86    fn verify_tls13_signature(
87        &self,
88        message: &[u8],
89        cert: &CertificateDer<'_>,
90        dss: &rustls::DigitallySignedStruct,
91    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
92        rustls::crypto::verify_tls13_signature(
93            message,
94            cert,
95            dss,
96            &self.0.signature_verification_algorithms,
97        )
98    }
99
100    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
101        self.0.signature_verification_algorithms.supported_schemes()
102    }
103}
104
105#[derive(Debug, Error)]
106pub enum QuicConnectError {
107    #[error("failed to create Quic ClientConfig from Rustls: {0}")]
108    QuicClientConfig(#[from] NoInitialCipherSuite),
109    #[error("failed to resolve endpoint: {0}")]
110    LookupError(io::Error),
111    #[error("failed to bind local port: {0}")]
112    EndpointClient(io::Error),
113    #[error("failed to connect: {0}")]
114    Connect(#[from] ConnectError),
115    #[error("invalid max idle timeout: {0:?}")]
116    InvalidMaxIdleTimeout(Duration),
117    #[error("connection failed: {0}")]
118    Connection(#[from] ConnectionError),
119    #[error("server name should be defined")]
120    ServerName,
121    #[error("errors occured when loading native certs: {0:?}")]
122    LoadNativeCerts(Vec<rustls_native_certs::Error>),
123    #[error("failed to read certificate chain: {0}")]
124    LoadCert(io::Error),
125    #[error("failed to add cert to roots: {0}")]
126    AddCert(rustls::Error),
127    #[error("invalid PEM-encoded certificate: {0}")]
128    PemCert(io::Error),
129}
130
131#[derive(Debug, Clone, PartialEq, Deserialize)]
132#[serde(default)]
133pub struct ConfigQuicClient {
134    pub endpoint: String,
135    pub local_addr: SocketAddr,
136    #[serde(deserialize_with = "deserialize_num_str")]
137    pub expected_rtt: u32,
138    #[serde(deserialize_with = "deserialize_num_str")]
139    pub max_stream_bandwidth: u32,
140    #[serde(with = "humantime_serde")]
141    pub max_idle_timeout: Option<Duration>,
142    pub server_name: Option<String>,
143    #[serde(deserialize_with = "deserialize_num_str")]
144    pub recv_streams: u32,
145    #[serde(deserialize_with = "deserialize_maybe_num_str")]
146    pub max_backlog: Option<u32>,
147    pub insecure: bool,
148    pub cert: Option<PathBuf>,
149    #[serde(deserialize_with = "deserialize_maybe_x_token")]
150    pub x_token: Option<Vec<u8>>,
151}
152
153impl Default for ConfigQuicClient {
154    fn default() -> Self {
155        Self {
156            endpoint: ConfigQuicServer::default_endpoint().to_string(),
157            local_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
158            expected_rtt: 100,
159            max_stream_bandwidth: 12_500 * 1_000,
160            max_idle_timeout: Some(Duration::from_secs(30)),
161            server_name: None,
162            recv_streams: 1,
163            max_backlog: None,
164            insecure: false,
165            cert: None,
166            x_token: None,
167        }
168    }
169}
170
171impl ConfigQuicClient {
172    pub async fn connect(self) -> Result<QuicClient, QuicConnectError> {
173        let builder = QuicClient::builder()
174            .set_local_addr(Some(self.local_addr))
175            .set_expected_rtt(self.expected_rtt)
176            .set_max_stream_bandwidth(self.max_stream_bandwidth)
177            .set_max_idle_timeout(self.max_idle_timeout)
178            .set_server_name(self.server_name.clone())
179            .set_recv_streams(self.recv_streams)
180            .set_max_backlog(self.max_backlog)
181            .set_x_token(self.x_token);
182
183        if self.insecure {
184            builder.insecure().connect(self.endpoint.clone()).await
185        } else {
186            builder
187                .secure(self.cert)
188                .connect(self.endpoint.clone())
189                .await
190        }
191    }
192}
193
194#[derive(Debug)]
195pub struct QuicClientBuilder {
196    pub local_addr: SocketAddr,
197    pub expected_rtt: u32,
198    pub max_stream_bandwidth: u32,
199    pub max_idle_timeout: Option<Duration>,
200    pub server_name: Option<String>,
201    pub recv_streams: u32,
202    pub max_backlog: Option<u32>,
203    pub x_token: Option<Vec<u8>>,
204}
205
206impl Default for QuicClientBuilder {
207    fn default() -> Self {
208        let config = ConfigQuicClient::default();
209        Self {
210            local_addr: config.local_addr,
211            expected_rtt: config.expected_rtt,
212            max_stream_bandwidth: config.max_stream_bandwidth,
213            max_idle_timeout: config.max_idle_timeout,
214            server_name: config.server_name,
215            recv_streams: config.recv_streams,
216            max_backlog: config.max_backlog,
217            x_token: config.x_token,
218        }
219    }
220}
221
222impl QuicClientBuilder {
223    pub fn new() -> Self {
224        Self::default()
225    }
226
227    pub fn set_local_addr(self, local_addr: Option<SocketAddr>) -> Self {
228        Self {
229            local_addr: local_addr.unwrap_or(Self::default().local_addr),
230            ..self
231        }
232    }
233
234    pub fn set_expected_rtt(self, expected_rtt: u32) -> Self {
235        Self {
236            expected_rtt,
237            ..self
238        }
239    }
240
241    pub fn set_max_stream_bandwidth(self, max_stream_bandwidth: u32) -> Self {
242        Self {
243            max_stream_bandwidth,
244            ..self
245        }
246    }
247
248    pub fn set_max_idle_timeout(self, max_idle_timeout: Option<Duration>) -> Self {
249        Self {
250            max_idle_timeout,
251            ..self
252        }
253    }
254
255    pub fn set_server_name(self, server_name: Option<String>) -> Self {
256        Self {
257            server_name,
258            ..self
259        }
260    }
261
262    pub fn set_recv_streams(self, recv_streams: u32) -> Self {
263        Self {
264            recv_streams,
265            ..self
266        }
267    }
268
269    pub fn set_max_backlog(self, max_backlog: Option<u32>) -> Self {
270        Self {
271            max_backlog,
272            ..self
273        }
274    }
275
276    pub fn set_x_token(self, x_token: Option<Vec<u8>>) -> Self {
277        Self { x_token, ..self }
278    }
279
280    pub const fn insecure(self) -> QuicClientBuilderInsecure {
281        QuicClientBuilderInsecure { builder: self }
282    }
283
284    pub const fn secure(self, cert: Option<PathBuf>) -> QuicClientBuilderSecure {
285        QuicClientBuilderSecure {
286            builder: self,
287            cert,
288        }
289    }
290
291    async fn connect<T: ToSocketAddrs>(
292        self,
293        endpoint: T,
294        client_config: rustls::ClientConfig,
295    ) -> Result<QuicClient, QuicConnectError> {
296        let addr = lookup_host(endpoint)
297            .await
298            .map_err(QuicConnectError::LookupError)?
299            .next()
300            .ok_or(io::Error::new(
301                io::ErrorKind::AddrNotAvailable,
302                "failed to resolve",
303            ))
304            .map_err(QuicConnectError::LookupError)?;
305        let server_name = self.server_name.ok_or(QuicConnectError::ServerName)?;
306
307        let mut transport_config = TransportConfig::default();
308        transport_config.max_concurrent_bidi_streams(0u8.into());
309        transport_config.max_concurrent_uni_streams(self.recv_streams.into());
310        let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
311        transport_config.stream_receive_window(stream_rwnd.into());
312        transport_config.send_window(8 * stream_rwnd as u64);
313        transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
314        transport_config.max_idle_timeout(
315            self.max_idle_timeout
316                .map(|d| d.as_millis().try_into())
317                .transpose()
318                .map_err(|_| {
319                    QuicConnectError::InvalidMaxIdleTimeout(self.max_idle_timeout.unwrap())
320                })?
321                .map(|ms| VarInt::from_u32(ms).into()),
322        );
323
324        let crypto_config = Arc::new(QuicClientConfig::try_from(client_config)?);
325        let mut client_config = ClientConfig::new(crypto_config);
326        client_config.transport_config(Arc::new(transport_config));
327
328        let mut endpoint =
329            Endpoint::client(self.local_addr).map_err(QuicConnectError::EndpointClient)?;
330        endpoint.set_default_client_config(client_config);
331
332        let conn = endpoint.connect(addr, &server_name)?.await?;
333
334        Ok(QuicClient {
335            conn,
336            recv_streams: self.recv_streams,
337            max_backlog: self.max_backlog,
338            x_token: self.x_token,
339        })
340    }
341}
342
343#[derive(Debug)]
344pub struct QuicClientBuilderInsecure {
345    pub builder: QuicClientBuilder,
346}
347
348impl QuicClientBuilderInsecure {
349    pub async fn connect<T: ToSocketAddrs>(
350        self,
351        endpoint: T,
352    ) -> Result<QuicClient, QuicConnectError> {
353        self.builder
354            .connect(
355                endpoint,
356                rustls::ClientConfig::builder()
357                    .dangerous()
358                    .with_custom_certificate_verifier(SkipServerVerification::new())
359                    .with_no_client_auth(),
360            )
361            .await
362    }
363}
364
365#[derive(Debug)]
366pub struct QuicClientBuilderSecure {
367    pub builder: QuicClientBuilder,
368    pub cert: Option<PathBuf>,
369}
370
371impl QuicClientBuilderSecure {
372    pub async fn connect<T: ToSocketAddrs>(
373        self,
374        endpoint: T,
375    ) -> Result<QuicClient, QuicConnectError> {
376        let mut roots = RootCertStore::empty();
377        // native
378        let rustls_native_certs::CertificateResult { certs, errors, .. } =
379            rustls_native_certs::load_native_certs();
380        if !errors.is_empty() {
381            return Err(QuicConnectError::LoadNativeCerts(errors));
382        }
383        roots.add_parsable_certificates(certs);
384        // webpki
385        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
386        // custom
387        if let Some(cert_path) = self.cert {
388            let cert_chain = fs::read(&cert_path)
389                .await
390                .map_err(QuicConnectError::LoadCert)?;
391            if cert_path.extension().is_some_and(|x| x == "der") {
392                roots
393                    .add(CertificateDer::from(cert_chain))
394                    .map_err(QuicConnectError::AddCert)?;
395            } else {
396                for cert in rustls_pemfile::certs(&mut &*cert_chain) {
397                    roots
398                        .add(cert.map_err(QuicConnectError::PemCert)?)
399                        .map_err(QuicConnectError::AddCert)?;
400                }
401            }
402        }
403
404        self.builder
405            .connect(
406                endpoint,
407                rustls::ClientConfig::builder()
408                    .with_root_certificates(roots)
409                    .with_no_client_auth(),
410            )
411            .await
412    }
413}
414
415#[derive(Debug)]
416pub struct QuicClient {
417    conn: Connection,
418    recv_streams: u32,
419    max_backlog: Option<u32>,
420    x_token: Option<Vec<u8>>,
421}
422
423impl QuicClient {
424    pub fn builder() -> QuicClientBuilder {
425        QuicClientBuilder::new()
426    }
427
428    pub async fn subscribe(
429        self,
430        replay_from_slot: Option<Slot>,
431        filter: Option<RichatFilter>,
432    ) -> Result<QuicClientStream, SubscribeError> {
433        let message = QuicSubscribeRequest {
434            x_token: self.x_token,
435            recv_streams: self.recv_streams,
436            max_backlog: self.max_backlog,
437            replay_from_slot,
438            filter,
439        }
440        .encode_to_vec();
441
442        let (mut send, mut recv) = self.conn.open_bi().await?;
443        send.write_u64(message.len() as u64).await?;
444        send.write_all(&message).await?;
445        send.flush().await?;
446
447        let version = SubscribeError::parse_quic_response(&mut recv).await?;
448
449        let mut readers = Vec::with_capacity(self.recv_streams as usize);
450        for _ in 0..self.recv_streams {
451            let stream = self.conn.accept_uni().await?;
452            readers.push(QuicClientStreamReader::Init {
453                stream: Some(stream),
454            });
455        }
456
457        Ok(QuicClientStream {
458            conn: self.conn,
459            version,
460            messages: HashMap::default(),
461            msg_id: 0,
462            readers,
463            index: 0,
464        })
465    }
466
467    async fn recv(mut stream: RecvStream) -> Result<(RecvStream, u64, Vec<u8>), ReceiveError> {
468        let msg_id = stream.read_u64().await?;
469        let error = msg_id == u64::MAX;
470
471        let size = stream.read_u64().await? as usize;
472        let mut buffer = Vec::<u8>::with_capacity(size);
473        // SAFETY: buffer capacity is equal to `size`, `len` is equal to `size`
474        let read = unsafe { std::slice::from_raw_parts_mut(buffer.as_mut_ptr(), size) };
475        stream.read_exact(read).await?;
476        // SAFETY: `new_len` equal to `capacity`, the elements at `old_len`..`new_len` is initialized.
477        unsafe {
478            buffer.set_len(size);
479        }
480
481        if error {
482            let close = QuicSubscribeClose::decode(&buffer.as_slice()[0..size])?;
483            Err(close.into())
484        } else {
485            Ok((stream, msg_id, buffer))
486        }
487    }
488}
489
490pin_project! {
491    pub struct QuicClientStream {
492        conn: Connection,
493        version: String,
494        messages: HashMap<u64, Vec<u8>, RandomState>,
495        msg_id: u64,
496        #[pin]
497        readers: Vec<QuicClientStreamReader>,
498        index: usize,
499    }
500}
501
502impl fmt::Debug for QuicClientStream {
503    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504        f.debug_struct("QuicClientStream").finish()
505    }
506}
507
508impl QuicClientStream {
509    pub fn into_parsed(self) -> SubscribeStream {
510        SubscribeStream::new(self.boxed())
511    }
512
513    // clippy bug?
514    #[allow(clippy::missing_const_for_fn)]
515    pub fn get_version(&self) -> &str {
516        &self.version
517    }
518}
519
520impl Stream for QuicClientStream {
521    type Item = Result<Vec<u8>, ReceiveError>;
522
523    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
524        let mut me = self.project();
525
526        if let Some(msg) = me.messages.remove(me.msg_id) {
527            *me.msg_id += 1;
528            return Poll::Ready(Some(Ok(msg)));
529        }
530
531        let mut polled = 0;
532        loop {
533            // try to get value and increment index
534            let value = Pin::new(&mut me.readers[*me.index]).poll_next(cx);
535            *me.index = (*me.index + 1) % me.readers.len();
536            match value {
537                Poll::Ready(Some(Ok((msg_id, msg)))) => {
538                    if *me.msg_id == msg_id {
539                        *me.msg_id += 1;
540                        return Poll::Ready(Some(Ok(msg)));
541                    } else {
542                        me.messages.insert(msg_id, msg);
543                    }
544                }
545                Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
546                Poll::Ready(None) => return Poll::Ready(None),
547                Poll::Pending => {}
548            }
549
550            // return pending if already polled all streams
551            polled += 1;
552            if polled == me.readers.len() {
553                return Poll::Pending;
554            }
555        }
556    }
557}
558
559pin_project! {
560    #[project = QuicClientStreamReaderProj]
561    pub enum QuicClientStreamReader {
562        Init {
563            stream: Option<RecvStream>,
564        },
565        Read {
566            #[pin] future: BoxFuture<'static, Result<(RecvStream, u64, Vec<u8>), ReceiveError>>,
567        },
568    }
569}
570
571impl fmt::Debug for QuicClientStreamReader {
572    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573        f.debug_struct("QuicClientStreamReader").finish()
574    }
575}
576
577impl Stream for QuicClientStreamReader {
578    type Item = Result<(u64, Vec<u8>), ReceiveError>;
579
580    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
581        loop {
582            match self.as_mut().project() {
583                QuicClientStreamReaderProj::Init { stream } => {
584                    let stream = stream.take().unwrap();
585                    let future = QuicClient::recv(stream).boxed();
586                    self.set(Self::Read { future })
587                }
588                QuicClientStreamReaderProj::Read { mut future } => {
589                    return Poll::Ready(match ready!(future.as_mut().poll(cx)) {
590                        Ok((stream, msg_id, buffer)) => {
591                            self.set(Self::Init {
592                                stream: Some(stream),
593                            });
594                            Some(Ok((msg_id, buffer)))
595                        }
596                        Err(error) => {
597                            if error.is_eof() {
598                                None
599                            } else {
600                                Some(Err(error))
601                            }
602                        }
603                    });
604                }
605            }
606        }
607    }
608}