1use {
2 crate::{
3 error::{ReceiveError, SubscribeError},
4 stream::SubscribeStream,
5 },
6 foldhash::quality::RandomState,
7 futures::{
8 future::{BoxFuture, FutureExt},
9 stream::{Stream, StreamExt},
10 },
11 pin_project_lite::pin_project,
12 prost::Message,
13 quinn::{
14 ClientConfig, ConnectError, Connection, ConnectionError, Endpoint, RecvStream,
15 TransportConfig, VarInt,
16 crypto::rustls::{NoInitialCipherSuite, QuicClientConfig},
17 },
18 richat_proto::richat::{QuicSubscribeClose, QuicSubscribeRequest, RichatFilter},
19 richat_shared::{
20 config::{deserialize_maybe_num_str, deserialize_maybe_x_token, deserialize_num_str},
21 transports::quic::ConfigQuicServer,
22 },
23 rustls::{
24 RootCertStore,
25 pki_types::{CertificateDer, ServerName, UnixTime},
26 },
27 serde::Deserialize,
28 solana_clock::Slot,
29 std::{
30 collections::HashMap,
31 fmt,
32 future::Future,
33 io,
34 net::{IpAddr, Ipv6Addr, SocketAddr},
35 path::PathBuf,
36 pin::Pin,
37 sync::Arc,
38 task::{Context, Poll, ready},
39 time::Duration,
40 },
41 thiserror::Error,
42 tokio::{
43 fs,
44 io::{AsyncReadExt, AsyncWriteExt},
45 net::{ToSocketAddrs, lookup_host},
46 },
47};
48
49#[derive(Debug)]
52struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
53
54impl SkipServerVerification {
55 fn new() -> Arc<Self> {
56 Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
57 }
58}
59
60impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
61 fn verify_server_cert(
62 &self,
63 _end_entity: &CertificateDer<'_>,
64 _intermediates: &[CertificateDer<'_>],
65 _server_name: &ServerName<'_>,
66 _ocsp: &[u8],
67 _now: UnixTime,
68 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
69 Ok(rustls::client::danger::ServerCertVerified::assertion())
70 }
71
72 fn verify_tls12_signature(
73 &self,
74 message: &[u8],
75 cert: &CertificateDer<'_>,
76 dss: &rustls::DigitallySignedStruct,
77 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
78 rustls::crypto::verify_tls12_signature(
79 message,
80 cert,
81 dss,
82 &self.0.signature_verification_algorithms,
83 )
84 }
85
86 fn verify_tls13_signature(
87 &self,
88 message: &[u8],
89 cert: &CertificateDer<'_>,
90 dss: &rustls::DigitallySignedStruct,
91 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
92 rustls::crypto::verify_tls13_signature(
93 message,
94 cert,
95 dss,
96 &self.0.signature_verification_algorithms,
97 )
98 }
99
100 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
101 self.0.signature_verification_algorithms.supported_schemes()
102 }
103}
104
105#[derive(Debug, Error)]
106pub enum QuicConnectError {
107 #[error("failed to create Quic ClientConfig from Rustls: {0}")]
108 QuicClientConfig(#[from] NoInitialCipherSuite),
109 #[error("failed to resolve endpoint: {0}")]
110 LookupError(io::Error),
111 #[error("failed to bind local port: {0}")]
112 EndpointClient(io::Error),
113 #[error("failed to connect: {0}")]
114 Connect(#[from] ConnectError),
115 #[error("invalid max idle timeout: {0:?}")]
116 InvalidMaxIdleTimeout(Duration),
117 #[error("connection failed: {0}")]
118 Connection(#[from] ConnectionError),
119 #[error("server name should be defined")]
120 ServerName,
121 #[error("errors occured when loading native certs: {0:?}")]
122 LoadNativeCerts(Vec<rustls_native_certs::Error>),
123 #[error("failed to read certificate chain: {0}")]
124 LoadCert(io::Error),
125 #[error("failed to add cert to roots: {0}")]
126 AddCert(rustls::Error),
127 #[error("invalid PEM-encoded certificate: {0}")]
128 PemCert(io::Error),
129}
130
131#[derive(Debug, Clone, PartialEq, Deserialize)]
132#[serde(default)]
133pub struct ConfigQuicClient {
134 pub endpoint: String,
135 pub local_addr: SocketAddr,
136 #[serde(deserialize_with = "deserialize_num_str")]
137 pub expected_rtt: u32,
138 #[serde(deserialize_with = "deserialize_num_str")]
139 pub max_stream_bandwidth: u32,
140 #[serde(with = "humantime_serde")]
141 pub max_idle_timeout: Option<Duration>,
142 pub server_name: Option<String>,
143 #[serde(deserialize_with = "deserialize_num_str")]
144 pub recv_streams: u32,
145 #[serde(deserialize_with = "deserialize_maybe_num_str")]
146 pub max_backlog: Option<u32>,
147 pub insecure: bool,
148 pub cert: Option<PathBuf>,
149 #[serde(deserialize_with = "deserialize_maybe_x_token")]
150 pub x_token: Option<Vec<u8>>,
151}
152
153impl Default for ConfigQuicClient {
154 fn default() -> Self {
155 Self {
156 endpoint: ConfigQuicServer::default_endpoint().to_string(),
157 local_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
158 expected_rtt: 100,
159 max_stream_bandwidth: 12_500 * 1_000,
160 max_idle_timeout: Some(Duration::from_secs(30)),
161 server_name: None,
162 recv_streams: 1,
163 max_backlog: None,
164 insecure: false,
165 cert: None,
166 x_token: None,
167 }
168 }
169}
170
171impl ConfigQuicClient {
172 pub async fn connect(self) -> Result<QuicClient, QuicConnectError> {
173 let builder = QuicClient::builder()
174 .set_local_addr(Some(self.local_addr))
175 .set_expected_rtt(self.expected_rtt)
176 .set_max_stream_bandwidth(self.max_stream_bandwidth)
177 .set_max_idle_timeout(self.max_idle_timeout)
178 .set_server_name(self.server_name.clone())
179 .set_recv_streams(self.recv_streams)
180 .set_max_backlog(self.max_backlog)
181 .set_x_token(self.x_token);
182
183 if self.insecure {
184 builder.insecure().connect(self.endpoint.clone()).await
185 } else {
186 builder
187 .secure(self.cert)
188 .connect(self.endpoint.clone())
189 .await
190 }
191 }
192}
193
194#[derive(Debug)]
195pub struct QuicClientBuilder {
196 pub local_addr: SocketAddr,
197 pub expected_rtt: u32,
198 pub max_stream_bandwidth: u32,
199 pub max_idle_timeout: Option<Duration>,
200 pub server_name: Option<String>,
201 pub recv_streams: u32,
202 pub max_backlog: Option<u32>,
203 pub x_token: Option<Vec<u8>>,
204}
205
206impl Default for QuicClientBuilder {
207 fn default() -> Self {
208 let config = ConfigQuicClient::default();
209 Self {
210 local_addr: config.local_addr,
211 expected_rtt: config.expected_rtt,
212 max_stream_bandwidth: config.max_stream_bandwidth,
213 max_idle_timeout: config.max_idle_timeout,
214 server_name: config.server_name,
215 recv_streams: config.recv_streams,
216 max_backlog: config.max_backlog,
217 x_token: config.x_token,
218 }
219 }
220}
221
222impl QuicClientBuilder {
223 pub fn new() -> Self {
224 Self::default()
225 }
226
227 pub fn set_local_addr(self, local_addr: Option<SocketAddr>) -> Self {
228 Self {
229 local_addr: local_addr.unwrap_or(Self::default().local_addr),
230 ..self
231 }
232 }
233
234 pub fn set_expected_rtt(self, expected_rtt: u32) -> Self {
235 Self {
236 expected_rtt,
237 ..self
238 }
239 }
240
241 pub fn set_max_stream_bandwidth(self, max_stream_bandwidth: u32) -> Self {
242 Self {
243 max_stream_bandwidth,
244 ..self
245 }
246 }
247
248 pub fn set_max_idle_timeout(self, max_idle_timeout: Option<Duration>) -> Self {
249 Self {
250 max_idle_timeout,
251 ..self
252 }
253 }
254
255 pub fn set_server_name(self, server_name: Option<String>) -> Self {
256 Self {
257 server_name,
258 ..self
259 }
260 }
261
262 pub fn set_recv_streams(self, recv_streams: u32) -> Self {
263 Self {
264 recv_streams,
265 ..self
266 }
267 }
268
269 pub fn set_max_backlog(self, max_backlog: Option<u32>) -> Self {
270 Self {
271 max_backlog,
272 ..self
273 }
274 }
275
276 pub fn set_x_token(self, x_token: Option<Vec<u8>>) -> Self {
277 Self { x_token, ..self }
278 }
279
280 pub const fn insecure(self) -> QuicClientBuilderInsecure {
281 QuicClientBuilderInsecure { builder: self }
282 }
283
284 pub const fn secure(self, cert: Option<PathBuf>) -> QuicClientBuilderSecure {
285 QuicClientBuilderSecure {
286 builder: self,
287 cert,
288 }
289 }
290
291 async fn connect<T: ToSocketAddrs>(
292 self,
293 endpoint: T,
294 client_config: rustls::ClientConfig,
295 ) -> Result<QuicClient, QuicConnectError> {
296 let addr = lookup_host(endpoint)
297 .await
298 .map_err(QuicConnectError::LookupError)?
299 .next()
300 .ok_or(io::Error::new(
301 io::ErrorKind::AddrNotAvailable,
302 "failed to resolve",
303 ))
304 .map_err(QuicConnectError::LookupError)?;
305 let server_name = self.server_name.ok_or(QuicConnectError::ServerName)?;
306
307 let mut transport_config = TransportConfig::default();
308 transport_config.max_concurrent_bidi_streams(0u8.into());
309 transport_config.max_concurrent_uni_streams(self.recv_streams.into());
310 let stream_rwnd = self.max_stream_bandwidth / 1_000 * self.expected_rtt;
311 transport_config.stream_receive_window(stream_rwnd.into());
312 transport_config.send_window(8 * stream_rwnd as u64);
313 transport_config.datagram_receive_buffer_size(Some(stream_rwnd as usize));
314 transport_config.max_idle_timeout(
315 self.max_idle_timeout
316 .map(|d| d.as_millis().try_into())
317 .transpose()
318 .map_err(|_| {
319 QuicConnectError::InvalidMaxIdleTimeout(self.max_idle_timeout.unwrap())
320 })?
321 .map(|ms| VarInt::from_u32(ms).into()),
322 );
323
324 let crypto_config = Arc::new(QuicClientConfig::try_from(client_config)?);
325 let mut client_config = ClientConfig::new(crypto_config);
326 client_config.transport_config(Arc::new(transport_config));
327
328 let mut endpoint =
329 Endpoint::client(self.local_addr).map_err(QuicConnectError::EndpointClient)?;
330 endpoint.set_default_client_config(client_config);
331
332 let conn = endpoint.connect(addr, &server_name)?.await?;
333
334 Ok(QuicClient {
335 conn,
336 recv_streams: self.recv_streams,
337 max_backlog: self.max_backlog,
338 x_token: self.x_token,
339 })
340 }
341}
342
343#[derive(Debug)]
344pub struct QuicClientBuilderInsecure {
345 pub builder: QuicClientBuilder,
346}
347
348impl QuicClientBuilderInsecure {
349 pub async fn connect<T: ToSocketAddrs>(
350 self,
351 endpoint: T,
352 ) -> Result<QuicClient, QuicConnectError> {
353 self.builder
354 .connect(
355 endpoint,
356 rustls::ClientConfig::builder()
357 .dangerous()
358 .with_custom_certificate_verifier(SkipServerVerification::new())
359 .with_no_client_auth(),
360 )
361 .await
362 }
363}
364
365#[derive(Debug)]
366pub struct QuicClientBuilderSecure {
367 pub builder: QuicClientBuilder,
368 pub cert: Option<PathBuf>,
369}
370
371impl QuicClientBuilderSecure {
372 pub async fn connect<T: ToSocketAddrs>(
373 self,
374 endpoint: T,
375 ) -> Result<QuicClient, QuicConnectError> {
376 let mut roots = RootCertStore::empty();
377 let rustls_native_certs::CertificateResult { certs, errors, .. } =
379 rustls_native_certs::load_native_certs();
380 if !errors.is_empty() {
381 return Err(QuicConnectError::LoadNativeCerts(errors));
382 }
383 roots.add_parsable_certificates(certs);
384 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
386 if let Some(cert_path) = self.cert {
388 let cert_chain = fs::read(&cert_path)
389 .await
390 .map_err(QuicConnectError::LoadCert)?;
391 if cert_path.extension().is_some_and(|x| x == "der") {
392 roots
393 .add(CertificateDer::from(cert_chain))
394 .map_err(QuicConnectError::AddCert)?;
395 } else {
396 for cert in rustls_pemfile::certs(&mut &*cert_chain) {
397 roots
398 .add(cert.map_err(QuicConnectError::PemCert)?)
399 .map_err(QuicConnectError::AddCert)?;
400 }
401 }
402 }
403
404 self.builder
405 .connect(
406 endpoint,
407 rustls::ClientConfig::builder()
408 .with_root_certificates(roots)
409 .with_no_client_auth(),
410 )
411 .await
412 }
413}
414
415#[derive(Debug)]
416pub struct QuicClient {
417 conn: Connection,
418 recv_streams: u32,
419 max_backlog: Option<u32>,
420 x_token: Option<Vec<u8>>,
421}
422
423impl QuicClient {
424 pub fn builder() -> QuicClientBuilder {
425 QuicClientBuilder::new()
426 }
427
428 pub async fn subscribe(
429 self,
430 replay_from_slot: Option<Slot>,
431 filter: Option<RichatFilter>,
432 ) -> Result<QuicClientStream, SubscribeError> {
433 let message = QuicSubscribeRequest {
434 x_token: self.x_token,
435 recv_streams: self.recv_streams,
436 max_backlog: self.max_backlog,
437 replay_from_slot,
438 filter,
439 }
440 .encode_to_vec();
441
442 let (mut send, mut recv) = self.conn.open_bi().await?;
443 send.write_u64(message.len() as u64).await?;
444 send.write_all(&message).await?;
445 send.flush().await?;
446
447 let version = SubscribeError::parse_quic_response(&mut recv).await?;
448
449 let mut readers = Vec::with_capacity(self.recv_streams as usize);
450 for _ in 0..self.recv_streams {
451 let stream = self.conn.accept_uni().await?;
452 readers.push(QuicClientStreamReader::Init {
453 stream: Some(stream),
454 });
455 }
456
457 Ok(QuicClientStream {
458 conn: self.conn,
459 version,
460 messages: HashMap::default(),
461 msg_id: 0,
462 readers,
463 index: 0,
464 })
465 }
466
467 async fn recv(mut stream: RecvStream) -> Result<(RecvStream, u64, Vec<u8>), ReceiveError> {
468 let msg_id = stream.read_u64().await?;
469 let error = msg_id == u64::MAX;
470
471 let size = stream.read_u64().await? as usize;
472 let mut buffer = Vec::<u8>::with_capacity(size);
473 let read = unsafe { std::slice::from_raw_parts_mut(buffer.as_mut_ptr(), size) };
475 stream.read_exact(read).await?;
476 unsafe {
478 buffer.set_len(size);
479 }
480
481 if error {
482 let close = QuicSubscribeClose::decode(&buffer.as_slice()[0..size])?;
483 Err(close.into())
484 } else {
485 Ok((stream, msg_id, buffer))
486 }
487 }
488}
489
490pin_project! {
491 pub struct QuicClientStream {
492 conn: Connection,
493 version: String,
494 messages: HashMap<u64, Vec<u8>, RandomState>,
495 msg_id: u64,
496 #[pin]
497 readers: Vec<QuicClientStreamReader>,
498 index: usize,
499 }
500}
501
502impl fmt::Debug for QuicClientStream {
503 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504 f.debug_struct("QuicClientStream").finish()
505 }
506}
507
508impl QuicClientStream {
509 pub fn into_parsed(self) -> SubscribeStream {
510 SubscribeStream::new(self.boxed())
511 }
512
513 #[allow(clippy::missing_const_for_fn)]
515 pub fn get_version(&self) -> &str {
516 &self.version
517 }
518}
519
520impl Stream for QuicClientStream {
521 type Item = Result<Vec<u8>, ReceiveError>;
522
523 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
524 let mut me = self.project();
525
526 if let Some(msg) = me.messages.remove(me.msg_id) {
527 *me.msg_id += 1;
528 return Poll::Ready(Some(Ok(msg)));
529 }
530
531 let mut polled = 0;
532 loop {
533 let value = Pin::new(&mut me.readers[*me.index]).poll_next(cx);
535 *me.index = (*me.index + 1) % me.readers.len();
536 match value {
537 Poll::Ready(Some(Ok((msg_id, msg)))) => {
538 if *me.msg_id == msg_id {
539 *me.msg_id += 1;
540 return Poll::Ready(Some(Ok(msg)));
541 } else {
542 me.messages.insert(msg_id, msg);
543 }
544 }
545 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
546 Poll::Ready(None) => return Poll::Ready(None),
547 Poll::Pending => {}
548 }
549
550 polled += 1;
552 if polled == me.readers.len() {
553 return Poll::Pending;
554 }
555 }
556 }
557}
558
559pin_project! {
560 #[project = QuicClientStreamReaderProj]
561 pub enum QuicClientStreamReader {
562 Init {
563 stream: Option<RecvStream>,
564 },
565 Read {
566 #[pin] future: BoxFuture<'static, Result<(RecvStream, u64, Vec<u8>), ReceiveError>>,
567 },
568 }
569}
570
571impl fmt::Debug for QuicClientStreamReader {
572 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573 f.debug_struct("QuicClientStreamReader").finish()
574 }
575}
576
577impl Stream for QuicClientStreamReader {
578 type Item = Result<(u64, Vec<u8>), ReceiveError>;
579
580 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
581 loop {
582 match self.as_mut().project() {
583 QuicClientStreamReaderProj::Init { stream } => {
584 let stream = stream.take().unwrap();
585 let future = QuicClient::recv(stream).boxed();
586 self.set(Self::Read { future })
587 }
588 QuicClientStreamReaderProj::Read { mut future } => {
589 return Poll::Ready(match ready!(future.as_mut().poll(cx)) {
590 Ok((stream, msg_id, buffer)) => {
591 self.set(Self::Init {
592 stream: Some(stream),
593 });
594 Some(Ok((msg_id, buffer)))
595 }
596 Err(error) => {
597 if error.is_eof() {
598 None
599 } else {
600 Some(Err(error))
601 }
602 }
603 });
604 }
605 }
606 }
607 }
608}