Skip to main content

richat_client/
grpc.rs

1pub mod geyser_gen {
2    include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
3}
4
5use {
6    crate::{error::ReceiveError, stream::SubscribeStream},
7    bytes::{Buf, Bytes},
8    futures::{
9        channel::mpsc,
10        sink::{Sink, SinkExt},
11        stream::{Stream, StreamExt},
12    },
13    geyser_gen::geyser_client::GeyserClient,
14    pin_project_lite::pin_project,
15    prost::Message,
16    richat_proto::{
17        geyser::{
18            CommitmentLevel, GetBlockHeightRequest, GetBlockHeightResponse,
19            GetLatestBlockhashRequest, GetLatestBlockhashResponse, GetSlotRequest, GetSlotResponse,
20            GetVersionRequest, GetVersionResponse, IsBlockhashValidRequest,
21            IsBlockhashValidResponse, PingRequest, PongResponse, SubscribeReplayInfoRequest,
22            SubscribeReplayInfoResponse, SubscribeRequest,
23        },
24        richat::{GrpcSubscribeRequest, SubscribeAccountsRequest},
25    },
26    richat_shared::{
27        config::{deserialize_humansize_usize, deserialize_maybe_x_token},
28        transports::grpc::{ConfigGrpcCompression, ConfigGrpcServer},
29    },
30    serde::Deserialize,
31    std::{
32        collections::HashMap,
33        fmt, io,
34        marker::PhantomData,
35        path::PathBuf,
36        pin::Pin,
37        task::{Context, Poll},
38        time::Duration,
39    },
40    thiserror::Error,
41    tokio::fs,
42    tonic::{
43        Request, Response, Status, Streaming,
44        codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
45        metadata::{AsciiMetadataKey, AsciiMetadataValue, errors::InvalidMetadataValueBytes},
46        service::{Interceptor, interceptor::InterceptedService},
47        transport::{
48            Certificate,
49            channel::{Channel, ClientTlsConfig, Endpoint},
50        },
51    },
52};
53
54#[derive(Debug, Clone, PartialEq, Deserialize)]
55#[serde(default)]
56pub struct ConfigGrpcClient {
57    pub endpoint: String,
58    pub ca_certificate: Option<PathBuf>,
59    #[serde(with = "humantime_serde")]
60    pub connect_timeout: Option<Duration>,
61    pub buffer_size: Option<usize>,
62    pub http2_adaptive_window: Option<bool>,
63    #[serde(with = "humantime_serde")]
64    pub http2_keep_alive_interval: Option<Duration>,
65    pub initial_connection_window_size: Option<u32>,
66    pub initial_stream_window_size: Option<u32>,
67    #[serde(with = "humantime_serde")]
68    pub keep_alive_timeout: Option<Duration>,
69    pub keep_alive_while_idle: bool,
70    #[serde(with = "humantime_serde")]
71    pub tcp_keepalive: Option<Duration>,
72    pub tcp_nodelay: bool,
73    #[serde(with = "humantime_serde")]
74    pub timeout: Option<Duration>,
75    #[serde(deserialize_with = "deserialize_humansize_usize")]
76    pub max_decoding_message_size: usize,
77    pub compression: ConfigGrpcCompression,
78    #[serde(deserialize_with = "deserialize_maybe_x_token")]
79    pub x_token: Option<Vec<u8>>,
80}
81
82impl Default for ConfigGrpcClient {
83    fn default() -> Self {
84        Self {
85            endpoint: format!("http://{}", ConfigGrpcServer::default().endpoint),
86            ca_certificate: None,
87            connect_timeout: None,
88            buffer_size: None,
89            http2_adaptive_window: None,
90            http2_keep_alive_interval: None,
91            initial_connection_window_size: None,
92            initial_stream_window_size: None,
93            keep_alive_timeout: None,
94            keep_alive_while_idle: false,
95            tcp_keepalive: Some(Duration::from_secs(15)),
96            tcp_nodelay: true,
97            timeout: None,
98            max_decoding_message_size: 4 * 1024 * 1024, // 4MiB
99            compression: ConfigGrpcCompression::default(),
100            x_token: None,
101        }
102    }
103}
104
105impl ConfigGrpcClient {
106    pub async fn connect(self) -> Result<GrpcClient<impl Interceptor>, GrpcClientBuilderError> {
107        let mut builder = GrpcClientBuilder::from_shared(self.endpoint)?
108            .tls_config_native_roots(self.ca_certificate.as_ref())
109            .await?
110            .buffer_size(self.buffer_size)
111            .keep_alive_while_idle(self.keep_alive_while_idle)
112            .tcp_keepalive(self.tcp_keepalive)
113            .tcp_nodelay(self.tcp_nodelay)
114            .max_decoding_message_size(self.max_decoding_message_size)
115            .x_token(self.x_token)?;
116        if let Some(connect_timeout) = self.connect_timeout {
117            builder = builder.connect_timeout(connect_timeout)
118        }
119        if let Some(http2_adaptive_window) = self.http2_adaptive_window {
120            builder = builder.http2_adaptive_window(http2_adaptive_window);
121        }
122        if let Some(http2_keep_alive_interval) = self.http2_keep_alive_interval {
123            builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
124        }
125        if let Some(initial_connection_window_size) = self.initial_connection_window_size {
126            builder = builder.initial_connection_window_size(initial_connection_window_size);
127        }
128        if let Some(initial_stream_window_size) = self.initial_stream_window_size {
129            builder = builder.initial_stream_window_size(initial_stream_window_size);
130        }
131        if let Some(keep_alive_timeout) = self.keep_alive_timeout {
132            builder = builder.keep_alive_timeout(keep_alive_timeout);
133        }
134        if let Some(timeout) = self.timeout {
135            builder = builder.timeout(timeout);
136        }
137        for encoding in self.compression.accept {
138            builder = builder.accept_compressed(encoding);
139        }
140        for encoding in self.compression.send {
141            builder = builder.send_compressed(encoding);
142        }
143        builder.connect().await.map_err(Into::into)
144    }
145}
146
147#[derive(Debug, Error)]
148pub enum GrpcClientBuilderError {
149    #[error("failed to load cert: {0}")]
150    LoadCert(io::Error),
151    #[error("tonic transport error: {0}")]
152    Tonic(#[from] tonic::transport::Error),
153    #[error("tonic status error: {0}")]
154    Status(#[from] tonic::Status),
155    #[error("x-token error: {0}")]
156    XToken(#[from] InvalidMetadataValueBytes),
157}
158
159#[derive(Debug)]
160pub struct GrpcClientBuilder {
161    pub endpoint: Endpoint,
162    pub send_compressed: Option<CompressionEncoding>,
163    pub accept_compressed: Option<CompressionEncoding>,
164    pub max_decoding_message_size: Option<usize>,
165    pub max_encoding_message_size: Option<usize>,
166    pub interceptor: GrpcInterceptor,
167}
168
169impl GrpcClientBuilder {
170    // Create new builder
171    fn new(endpoint: Endpoint) -> Self {
172        Self {
173            endpoint,
174            send_compressed: None,
175            accept_compressed: None,
176            max_decoding_message_size: None,
177            max_encoding_message_size: None,
178            interceptor: GrpcInterceptor::default(),
179        }
180    }
181
182    pub fn from_shared(endpoint: impl Into<Bytes>) -> Result<Self, tonic::transport::Error> {
183        Endpoint::from_shared(endpoint).map(Self::new)
184    }
185
186    pub fn from_static(endpoint: &'static str) -> Self {
187        Self::new(Endpoint::from_static(endpoint))
188    }
189
190    // Endpoint options
191    pub fn connect_timeout(self, dur: Duration) -> Self {
192        Self {
193            endpoint: self.endpoint.connect_timeout(dur),
194            ..self
195        }
196    }
197
198    pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
199        Self {
200            endpoint: self.endpoint.buffer_size(sz),
201            ..self
202        }
203    }
204
205    pub fn http2_adaptive_window(self, enabled: bool) -> Self {
206        Self {
207            endpoint: self.endpoint.http2_adaptive_window(enabled),
208            ..self
209        }
210    }
211
212    pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
213        Self {
214            endpoint: self.endpoint.http2_keep_alive_interval(interval),
215            ..self
216        }
217    }
218
219    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
220        Self {
221            endpoint: self.endpoint.initial_connection_window_size(sz),
222            ..self
223        }
224    }
225
226    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
227        Self {
228            endpoint: self.endpoint.initial_stream_window_size(sz),
229            ..self
230        }
231    }
232
233    pub fn keep_alive_timeout(self, duration: Duration) -> Self {
234        Self {
235            endpoint: self.endpoint.keep_alive_timeout(duration),
236            ..self
237        }
238    }
239
240    pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
241        Self {
242            endpoint: self.endpoint.keep_alive_while_idle(enabled),
243            ..self
244        }
245    }
246
247    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
248        Self {
249            endpoint: self.endpoint.tcp_keepalive(tcp_keepalive),
250            ..self
251        }
252    }
253
254    pub fn tcp_nodelay(self, enabled: bool) -> Self {
255        Self {
256            endpoint: self.endpoint.tcp_nodelay(enabled),
257            ..self
258        }
259    }
260
261    pub fn timeout(self, dur: Duration) -> Self {
262        Self {
263            endpoint: self.endpoint.timeout(dur),
264            ..self
265        }
266    }
267
268    pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, GrpcClientBuilderError> {
269        Ok(Self {
270            endpoint: self.endpoint.tls_config(tls_config)?,
271            ..self
272        })
273    }
274
275    pub async fn tls_config_native_roots(
276        self,
277        ca_certificate: Option<&PathBuf>,
278    ) -> Result<Self, GrpcClientBuilderError> {
279        let mut tls_config = ClientTlsConfig::new().with_native_roots();
280        if let Some(path) = ca_certificate {
281            let bytes = fs::read(path)
282                .await
283                .map_err(GrpcClientBuilderError::LoadCert)?;
284            tls_config = tls_config.ca_certificate(Certificate::from_pem(bytes));
285        }
286        self.tls_config(tls_config)
287    }
288
289    // gRPC options
290    pub fn send_compressed(self, encoding: CompressionEncoding) -> Self {
291        Self {
292            send_compressed: Some(encoding),
293            ..self
294        }
295    }
296
297    pub fn accept_compressed(self, encoding: CompressionEncoding) -> Self {
298        Self {
299            accept_compressed: Some(encoding),
300            ..self
301        }
302    }
303
304    pub fn max_decoding_message_size(self, limit: usize) -> Self {
305        Self {
306            max_decoding_message_size: Some(limit),
307            ..self
308        }
309    }
310
311    pub fn max_encoding_message_size(self, limit: usize) -> Self {
312        Self {
313            max_encoding_message_size: Some(limit),
314            ..self
315        }
316    }
317
318    // Metadata
319    pub fn x_token<T>(mut self, x_token: Option<T>) -> Result<Self, InvalidMetadataValueBytes>
320    where
321        T: TryInto<AsciiMetadataValue, Error = InvalidMetadataValueBytes>,
322    {
323        if let Some(x_token) = x_token {
324            self.interceptor.metadata.insert(
325                AsciiMetadataKey::from_static("x-token"),
326                x_token.try_into()?,
327            );
328        } else {
329            self.interceptor.metadata.remove("x-token");
330        }
331        Ok(self)
332    }
333
334    // Create client
335    fn build(self, channel: Channel) -> GrpcClient<impl Interceptor> {
336        let mut geyser = GeyserClient::with_interceptor(channel, self.interceptor);
337        if let Some(encoding) = self.send_compressed {
338            geyser = geyser.send_compressed(encoding);
339        }
340        if let Some(encoding) = self.accept_compressed {
341            geyser = geyser.accept_compressed(encoding);
342        }
343        if let Some(limit) = self.max_decoding_message_size {
344            geyser = geyser.max_decoding_message_size(limit);
345        }
346        if let Some(limit) = self.max_encoding_message_size {
347            geyser = geyser.max_encoding_message_size(limit);
348        }
349        GrpcClient::new(geyser)
350    }
351
352    pub async fn connect(self) -> Result<GrpcClient<impl Interceptor>, tonic::transport::Error> {
353        let channel = self.endpoint.connect().await?;
354        Ok(self.build(channel))
355    }
356
357    pub fn connect_lazy(self) -> Result<GrpcClient<impl Interceptor>, tonic::transport::Error> {
358        let channel = self.endpoint.connect_lazy();
359        Ok(self.build(channel))
360    }
361}
362
363#[derive(Debug, Default)]
364pub struct GrpcInterceptor {
365    metadata: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
366}
367
368impl Interceptor for GrpcInterceptor {
369    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
370        for (key, value) in self.metadata.iter() {
371            request.metadata_mut().insert(key, value.clone());
372        }
373        Ok(request)
374    }
375}
376
377#[derive(Debug)]
378pub struct GrpcClient<F> {
379    pub geyser: GeyserClient<InterceptedService<Channel, F>>,
380}
381
382impl GrpcClient<()> {
383    pub fn build_from_shared(
384        endpoint: impl Into<Bytes>,
385    ) -> Result<GrpcClientBuilder, tonic::transport::Error> {
386        Ok(GrpcClientBuilder::new(Endpoint::from_shared(endpoint)?))
387    }
388
389    pub fn build_from_static(endpoint: &'static str) -> GrpcClientBuilder {
390        GrpcClientBuilder::new(Endpoint::from_static(endpoint))
391    }
392}
393
394impl<F: Interceptor> GrpcClient<F> {
395    pub const fn new(geyser: GeyserClient<InterceptedService<Channel, F>>) -> Self {
396        Self { geyser }
397    }
398
399    // Subscribe Yellowstone gRPC Dragon's Mouth
400    pub async fn subscribe_dragons_mouth(
401        &mut self,
402    ) -> Result<
403        (
404            impl Sink<SubscribeRequest, Error = mpsc::SendError>,
405            GrpcClientStream,
406        ),
407        Status,
408    > {
409        let (subscribe_tx, subscribe_rx) = mpsc::unbounded();
410        let response: Response<Streaming<Vec<u8>>> = self.geyser.subscribe(subscribe_rx).await?;
411        let stream = GrpcClientStream::new(response.into_inner());
412        Ok((subscribe_tx, stream))
413    }
414
415    pub async fn subscribe_dragons_mouth_once(
416        &mut self,
417        request: SubscribeRequest,
418    ) -> Result<GrpcClientStream, Status> {
419        let (mut tx, rx) = self.subscribe_dragons_mouth().await?;
420        tx.send(request)
421            .await
422            .expect("failed to send to unbounded channel");
423        Ok(rx)
424    }
425
426    // Subscribe Accounts
427    pub async fn subscribe_accounts(
428        &mut self,
429    ) -> Result<
430        (
431            mpsc::UnboundedSender<SubscribeAccountsRequest>,
432            GrpcClientStream,
433        ),
434        Status,
435    > {
436        let (subscribe_tx, subscribe_rx) = mpsc::unbounded();
437        let response: Response<Streaming<Vec<u8>>> =
438            self.geyser.subscribe_accounts(subscribe_rx).await?;
439        let stream = GrpcClientStream::new(response.into_inner());
440        Ok((subscribe_tx, stream))
441    }
442
443    // Subscribe Richat
444    pub async fn subscribe_richat(
445        &mut self,
446        request: GrpcSubscribeRequest,
447    ) -> Result<GrpcClientStream, Status> {
448        let (mut tx, rx) = mpsc::unbounded();
449        tx.send(request)
450            .await
451            .expect("failed to send to unbounded channel");
452
453        let response: Response<Streaming<Vec<u8>>> = self.geyser.subscribe_richat(rx).await?;
454        Ok(GrpcClientStream::new(response.into_inner()))
455    }
456
457    // RPC calls
458    pub async fn subscribe_replay_info(&mut self) -> Result<SubscribeReplayInfoResponse, Status> {
459        let message = SubscribeReplayInfoRequest {};
460        let request = tonic::Request::new(message);
461        let response = self.geyser.subscribe_replay_info(request).await?;
462        Ok(response.into_inner())
463    }
464
465    pub async fn ping(&mut self, count: i32) -> Result<PongResponse, Status> {
466        let message = PingRequest { count };
467        let request = Request::new(message);
468        let response = self.geyser.ping(request).await?;
469        Ok(response.into_inner())
470    }
471
472    pub async fn get_latest_blockhash(
473        &mut self,
474        commitment: Option<CommitmentLevel>,
475    ) -> Result<GetLatestBlockhashResponse, Status> {
476        let request = Request::new(GetLatestBlockhashRequest {
477            commitment: commitment.map(|value| value as i32),
478        });
479        let response = self.geyser.get_latest_blockhash(request).await?;
480        Ok(response.into_inner())
481    }
482
483    pub async fn get_block_height(
484        &mut self,
485        commitment: Option<CommitmentLevel>,
486    ) -> Result<GetBlockHeightResponse, Status> {
487        let request = Request::new(GetBlockHeightRequest {
488            commitment: commitment.map(|value| value as i32),
489        });
490        let response = self.geyser.get_block_height(request).await?;
491        Ok(response.into_inner())
492    }
493
494    pub async fn get_slot(
495        &mut self,
496        commitment: Option<CommitmentLevel>,
497    ) -> Result<GetSlotResponse, Status> {
498        let request = Request::new(GetSlotRequest {
499            commitment: commitment.map(|value| value as i32),
500        });
501        let response = self.geyser.get_slot(request).await?;
502        Ok(response.into_inner())
503    }
504
505    pub async fn is_blockhash_valid(
506        &mut self,
507        blockhash: String,
508        commitment: Option<CommitmentLevel>,
509    ) -> Result<IsBlockhashValidResponse, Status> {
510        let request = Request::new(IsBlockhashValidRequest {
511            blockhash,
512            commitment: commitment.map(|value| value as i32),
513        });
514        let response = self.geyser.is_blockhash_valid(request).await?;
515        Ok(response.into_inner())
516    }
517
518    pub async fn get_version(&mut self) -> Result<GetVersionResponse, Status> {
519        let request = Request::new(GetVersionRequest {});
520        let response = self.geyser.get_version(request).await?;
521        Ok(response.into_inner())
522    }
523}
524
525trait SubscribeMessage {
526    fn decode(buf: &mut DecodeBuf<'_>) -> Self;
527}
528
529impl SubscribeMessage for Vec<u8> {
530    fn decode(src: &mut DecodeBuf<'_>) -> Self {
531        let mut dst = Box::new_uninit_slice(src.remaining());
532        let mut start = 0;
533        while src.remaining() > 0 {
534            let chunk = src.chunk();
535            // SAFETY: writing within bounds of allocated uninit slice,
536            // MaybeUninit<u8> has the same layout as u8
537            unsafe {
538                std::ptr::copy_nonoverlapping(
539                    chunk.as_ptr(),
540                    dst.as_mut_ptr().cast::<u8>().add(start),
541                    chunk.len(),
542                );
543            }
544            start += chunk.len();
545            src.advance(chunk.len());
546        }
547        // SAFETY: all bytes initialized by copying from src
548        unsafe { dst.assume_init() }.into_vec()
549    }
550}
551
552pub struct SubscribeCodec<T, U> {
553    _pd: PhantomData<(T, U)>,
554}
555
556impl<T, U> Default for SubscribeCodec<T, U> {
557    fn default() -> Self {
558        Self { _pd: PhantomData }
559    }
560}
561
562impl<T, U> Codec for SubscribeCodec<T, U>
563where
564    T: Message + Send + 'static,
565    U: SubscribeMessage + Default + Send + 'static,
566{
567    type Encode = T;
568    type Decode = U;
569
570    type Encoder = ProstEncoder<T>;
571    type Decoder = SubscribeDecoder<U>;
572
573    fn encoder(&mut self) -> Self::Encoder {
574        ProstEncoder(PhantomData)
575    }
576
577    fn decoder(&mut self) -> Self::Decoder {
578        SubscribeDecoder(PhantomData)
579    }
580}
581
582/// A [`Encoder`] that knows how to encode `T`.
583#[derive(Debug, Clone, Default)]
584pub struct ProstEncoder<T>(PhantomData<T>);
585
586impl<T: Message> Encoder for ProstEncoder<T> {
587    type Item = T;
588    type Error = Status;
589
590    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
591        item.encode(buf)
592            .expect("Message only errors if not enough space");
593        Ok(())
594    }
595}
596
597/// A [`Decoder`] that knows how to decode `U`.
598#[derive(Debug, Clone, Default)]
599pub struct SubscribeDecoder<U>(PhantomData<U>);
600
601impl<U: SubscribeMessage + Default> Decoder for SubscribeDecoder<U> {
602    type Item = U;
603    type Error = Status;
604
605    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
606        Ok(Some(SubscribeMessage::decode(buf)))
607    }
608}
609
610pin_project! {
611    pub struct GrpcClientStream {
612        #[pin]
613        stream: Streaming<Vec<u8>>,
614    }
615}
616
617impl GrpcClientStream {
618    pub const fn new(stream: Streaming<Vec<u8>>) -> Self {
619        Self { stream }
620    }
621
622    pub fn into_parsed(self) -> SubscribeStream {
623        SubscribeStream::new(self.boxed())
624    }
625}
626
627impl fmt::Debug for GrpcClientStream {
628    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
629        f.debug_struct("GrpcClientStream").finish()
630    }
631}
632
633impl Stream for GrpcClientStream {
634    type Item = Result<Vec<u8>, ReceiveError>;
635
636    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
637        let me = self.project();
638        me.stream.poll_next(cx).map_err(Into::into)
639    }
640}