1pub mod geyser_gen {
2 include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
3}
4
5use {
6 crate::{error::ReceiveError, stream::SubscribeStream},
7 bytes::{Buf, Bytes},
8 futures::{
9 channel::mpsc,
10 sink::{Sink, SinkExt},
11 stream::{Stream, StreamExt},
12 },
13 geyser_gen::geyser_client::GeyserClient,
14 pin_project_lite::pin_project,
15 prost::Message,
16 richat_proto::{
17 geyser::{
18 CommitmentLevel, GetBlockHeightRequest, GetBlockHeightResponse,
19 GetLatestBlockhashRequest, GetLatestBlockhashResponse, GetSlotRequest, GetSlotResponse,
20 GetVersionRequest, GetVersionResponse, IsBlockhashValidRequest,
21 IsBlockhashValidResponse, PingRequest, PongResponse, SubscribeReplayInfoRequest,
22 SubscribeReplayInfoResponse, SubscribeRequest,
23 },
24 richat::{GrpcSubscribeRequest, SubscribeAccountsRequest},
25 },
26 richat_shared::{
27 config::{deserialize_humansize_usize, deserialize_maybe_x_token},
28 transports::grpc::{ConfigGrpcCompression, ConfigGrpcServer},
29 },
30 serde::Deserialize,
31 std::{
32 collections::HashMap,
33 fmt, io,
34 marker::PhantomData,
35 path::PathBuf,
36 pin::Pin,
37 task::{Context, Poll},
38 time::Duration,
39 },
40 thiserror::Error,
41 tokio::fs,
42 tonic::{
43 Request, Response, Status, Streaming,
44 codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
45 metadata::{AsciiMetadataKey, AsciiMetadataValue, errors::InvalidMetadataValueBytes},
46 service::{Interceptor, interceptor::InterceptedService},
47 transport::{
48 Certificate,
49 channel::{Channel, ClientTlsConfig, Endpoint},
50 },
51 },
52};
53
54#[derive(Debug, Clone, PartialEq, Deserialize)]
55#[serde(default)]
56pub struct ConfigGrpcClient {
57 pub endpoint: String,
58 pub ca_certificate: Option<PathBuf>,
59 #[serde(with = "humantime_serde")]
60 pub connect_timeout: Option<Duration>,
61 pub buffer_size: Option<usize>,
62 pub http2_adaptive_window: Option<bool>,
63 #[serde(with = "humantime_serde")]
64 pub http2_keep_alive_interval: Option<Duration>,
65 pub initial_connection_window_size: Option<u32>,
66 pub initial_stream_window_size: Option<u32>,
67 #[serde(with = "humantime_serde")]
68 pub keep_alive_timeout: Option<Duration>,
69 pub keep_alive_while_idle: bool,
70 #[serde(with = "humantime_serde")]
71 pub tcp_keepalive: Option<Duration>,
72 pub tcp_nodelay: bool,
73 #[serde(with = "humantime_serde")]
74 pub timeout: Option<Duration>,
75 #[serde(deserialize_with = "deserialize_humansize_usize")]
76 pub max_decoding_message_size: usize,
77 pub compression: ConfigGrpcCompression,
78 #[serde(deserialize_with = "deserialize_maybe_x_token")]
79 pub x_token: Option<Vec<u8>>,
80}
81
82impl Default for ConfigGrpcClient {
83 fn default() -> Self {
84 Self {
85 endpoint: format!("http://{}", ConfigGrpcServer::default().endpoint),
86 ca_certificate: None,
87 connect_timeout: None,
88 buffer_size: None,
89 http2_adaptive_window: None,
90 http2_keep_alive_interval: None,
91 initial_connection_window_size: None,
92 initial_stream_window_size: None,
93 keep_alive_timeout: None,
94 keep_alive_while_idle: false,
95 tcp_keepalive: Some(Duration::from_secs(15)),
96 tcp_nodelay: true,
97 timeout: None,
98 max_decoding_message_size: 4 * 1024 * 1024, compression: ConfigGrpcCompression::default(),
100 x_token: None,
101 }
102 }
103}
104
105impl ConfigGrpcClient {
106 pub async fn connect(self) -> Result<GrpcClient<impl Interceptor>, GrpcClientBuilderError> {
107 let mut builder = GrpcClientBuilder::from_shared(self.endpoint)?
108 .tls_config_native_roots(self.ca_certificate.as_ref())
109 .await?
110 .buffer_size(self.buffer_size)
111 .keep_alive_while_idle(self.keep_alive_while_idle)
112 .tcp_keepalive(self.tcp_keepalive)
113 .tcp_nodelay(self.tcp_nodelay)
114 .max_decoding_message_size(self.max_decoding_message_size)
115 .x_token(self.x_token)?;
116 if let Some(connect_timeout) = self.connect_timeout {
117 builder = builder.connect_timeout(connect_timeout)
118 }
119 if let Some(http2_adaptive_window) = self.http2_adaptive_window {
120 builder = builder.http2_adaptive_window(http2_adaptive_window);
121 }
122 if let Some(http2_keep_alive_interval) = self.http2_keep_alive_interval {
123 builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
124 }
125 if let Some(initial_connection_window_size) = self.initial_connection_window_size {
126 builder = builder.initial_connection_window_size(initial_connection_window_size);
127 }
128 if let Some(initial_stream_window_size) = self.initial_stream_window_size {
129 builder = builder.initial_stream_window_size(initial_stream_window_size);
130 }
131 if let Some(keep_alive_timeout) = self.keep_alive_timeout {
132 builder = builder.keep_alive_timeout(keep_alive_timeout);
133 }
134 if let Some(timeout) = self.timeout {
135 builder = builder.timeout(timeout);
136 }
137 for encoding in self.compression.accept {
138 builder = builder.accept_compressed(encoding);
139 }
140 for encoding in self.compression.send {
141 builder = builder.send_compressed(encoding);
142 }
143 builder.connect().await.map_err(Into::into)
144 }
145}
146
147#[derive(Debug, Error)]
148pub enum GrpcClientBuilderError {
149 #[error("failed to load cert: {0}")]
150 LoadCert(io::Error),
151 #[error("tonic transport error: {0}")]
152 Tonic(#[from] tonic::transport::Error),
153 #[error("tonic status error: {0}")]
154 Status(#[from] tonic::Status),
155 #[error("x-token error: {0}")]
156 XToken(#[from] InvalidMetadataValueBytes),
157}
158
159#[derive(Debug)]
160pub struct GrpcClientBuilder {
161 pub endpoint: Endpoint,
162 pub send_compressed: Option<CompressionEncoding>,
163 pub accept_compressed: Option<CompressionEncoding>,
164 pub max_decoding_message_size: Option<usize>,
165 pub max_encoding_message_size: Option<usize>,
166 pub interceptor: GrpcInterceptor,
167}
168
169impl GrpcClientBuilder {
170 fn new(endpoint: Endpoint) -> Self {
172 Self {
173 endpoint,
174 send_compressed: None,
175 accept_compressed: None,
176 max_decoding_message_size: None,
177 max_encoding_message_size: None,
178 interceptor: GrpcInterceptor::default(),
179 }
180 }
181
182 pub fn from_shared(endpoint: impl Into<Bytes>) -> Result<Self, tonic::transport::Error> {
183 Endpoint::from_shared(endpoint).map(Self::new)
184 }
185
186 pub fn from_static(endpoint: &'static str) -> Self {
187 Self::new(Endpoint::from_static(endpoint))
188 }
189
190 pub fn connect_timeout(self, dur: Duration) -> Self {
192 Self {
193 endpoint: self.endpoint.connect_timeout(dur),
194 ..self
195 }
196 }
197
198 pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
199 Self {
200 endpoint: self.endpoint.buffer_size(sz),
201 ..self
202 }
203 }
204
205 pub fn http2_adaptive_window(self, enabled: bool) -> Self {
206 Self {
207 endpoint: self.endpoint.http2_adaptive_window(enabled),
208 ..self
209 }
210 }
211
212 pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
213 Self {
214 endpoint: self.endpoint.http2_keep_alive_interval(interval),
215 ..self
216 }
217 }
218
219 pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
220 Self {
221 endpoint: self.endpoint.initial_connection_window_size(sz),
222 ..self
223 }
224 }
225
226 pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
227 Self {
228 endpoint: self.endpoint.initial_stream_window_size(sz),
229 ..self
230 }
231 }
232
233 pub fn keep_alive_timeout(self, duration: Duration) -> Self {
234 Self {
235 endpoint: self.endpoint.keep_alive_timeout(duration),
236 ..self
237 }
238 }
239
240 pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
241 Self {
242 endpoint: self.endpoint.keep_alive_while_idle(enabled),
243 ..self
244 }
245 }
246
247 pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
248 Self {
249 endpoint: self.endpoint.tcp_keepalive(tcp_keepalive),
250 ..self
251 }
252 }
253
254 pub fn tcp_nodelay(self, enabled: bool) -> Self {
255 Self {
256 endpoint: self.endpoint.tcp_nodelay(enabled),
257 ..self
258 }
259 }
260
261 pub fn timeout(self, dur: Duration) -> Self {
262 Self {
263 endpoint: self.endpoint.timeout(dur),
264 ..self
265 }
266 }
267
268 pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, GrpcClientBuilderError> {
269 Ok(Self {
270 endpoint: self.endpoint.tls_config(tls_config)?,
271 ..self
272 })
273 }
274
275 pub async fn tls_config_native_roots(
276 self,
277 ca_certificate: Option<&PathBuf>,
278 ) -> Result<Self, GrpcClientBuilderError> {
279 let mut tls_config = ClientTlsConfig::new().with_native_roots();
280 if let Some(path) = ca_certificate {
281 let bytes = fs::read(path)
282 .await
283 .map_err(GrpcClientBuilderError::LoadCert)?;
284 tls_config = tls_config.ca_certificate(Certificate::from_pem(bytes));
285 }
286 self.tls_config(tls_config)
287 }
288
289 pub fn send_compressed(self, encoding: CompressionEncoding) -> Self {
291 Self {
292 send_compressed: Some(encoding),
293 ..self
294 }
295 }
296
297 pub fn accept_compressed(self, encoding: CompressionEncoding) -> Self {
298 Self {
299 accept_compressed: Some(encoding),
300 ..self
301 }
302 }
303
304 pub fn max_decoding_message_size(self, limit: usize) -> Self {
305 Self {
306 max_decoding_message_size: Some(limit),
307 ..self
308 }
309 }
310
311 pub fn max_encoding_message_size(self, limit: usize) -> Self {
312 Self {
313 max_encoding_message_size: Some(limit),
314 ..self
315 }
316 }
317
318 pub fn x_token<T>(mut self, x_token: Option<T>) -> Result<Self, InvalidMetadataValueBytes>
320 where
321 T: TryInto<AsciiMetadataValue, Error = InvalidMetadataValueBytes>,
322 {
323 if let Some(x_token) = x_token {
324 self.interceptor.metadata.insert(
325 AsciiMetadataKey::from_static("x-token"),
326 x_token.try_into()?,
327 );
328 } else {
329 self.interceptor.metadata.remove("x-token");
330 }
331 Ok(self)
332 }
333
334 fn build(self, channel: Channel) -> GrpcClient<impl Interceptor> {
336 let mut geyser = GeyserClient::with_interceptor(channel, self.interceptor);
337 if let Some(encoding) = self.send_compressed {
338 geyser = geyser.send_compressed(encoding);
339 }
340 if let Some(encoding) = self.accept_compressed {
341 geyser = geyser.accept_compressed(encoding);
342 }
343 if let Some(limit) = self.max_decoding_message_size {
344 geyser = geyser.max_decoding_message_size(limit);
345 }
346 if let Some(limit) = self.max_encoding_message_size {
347 geyser = geyser.max_encoding_message_size(limit);
348 }
349 GrpcClient::new(geyser)
350 }
351
352 pub async fn connect(self) -> Result<GrpcClient<impl Interceptor>, tonic::transport::Error> {
353 let channel = self.endpoint.connect().await?;
354 Ok(self.build(channel))
355 }
356
357 pub fn connect_lazy(self) -> Result<GrpcClient<impl Interceptor>, tonic::transport::Error> {
358 let channel = self.endpoint.connect_lazy();
359 Ok(self.build(channel))
360 }
361}
362
363#[derive(Debug, Default)]
364pub struct GrpcInterceptor {
365 metadata: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
366}
367
368impl Interceptor for GrpcInterceptor {
369 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
370 for (key, value) in self.metadata.iter() {
371 request.metadata_mut().insert(key, value.clone());
372 }
373 Ok(request)
374 }
375}
376
377#[derive(Debug)]
378pub struct GrpcClient<F> {
379 pub geyser: GeyserClient<InterceptedService<Channel, F>>,
380}
381
382impl GrpcClient<()> {
383 pub fn build_from_shared(
384 endpoint: impl Into<Bytes>,
385 ) -> Result<GrpcClientBuilder, tonic::transport::Error> {
386 Ok(GrpcClientBuilder::new(Endpoint::from_shared(endpoint)?))
387 }
388
389 pub fn build_from_static(endpoint: &'static str) -> GrpcClientBuilder {
390 GrpcClientBuilder::new(Endpoint::from_static(endpoint))
391 }
392}
393
394impl<F: Interceptor> GrpcClient<F> {
395 pub const fn new(geyser: GeyserClient<InterceptedService<Channel, F>>) -> Self {
396 Self { geyser }
397 }
398
399 pub async fn subscribe_dragons_mouth(
401 &mut self,
402 ) -> Result<
403 (
404 impl Sink<SubscribeRequest, Error = mpsc::SendError>,
405 GrpcClientStream,
406 ),
407 Status,
408 > {
409 let (subscribe_tx, subscribe_rx) = mpsc::unbounded();
410 let response: Response<Streaming<Vec<u8>>> = self.geyser.subscribe(subscribe_rx).await?;
411 let stream = GrpcClientStream::new(response.into_inner());
412 Ok((subscribe_tx, stream))
413 }
414
415 pub async fn subscribe_dragons_mouth_once(
416 &mut self,
417 request: SubscribeRequest,
418 ) -> Result<GrpcClientStream, Status> {
419 let (mut tx, rx) = self.subscribe_dragons_mouth().await?;
420 tx.send(request)
421 .await
422 .expect("failed to send to unbounded channel");
423 Ok(rx)
424 }
425
426 pub async fn subscribe_accounts(
428 &mut self,
429 ) -> Result<
430 (
431 mpsc::UnboundedSender<SubscribeAccountsRequest>,
432 GrpcClientStream,
433 ),
434 Status,
435 > {
436 let (subscribe_tx, subscribe_rx) = mpsc::unbounded();
437 let response: Response<Streaming<Vec<u8>>> =
438 self.geyser.subscribe_accounts(subscribe_rx).await?;
439 let stream = GrpcClientStream::new(response.into_inner());
440 Ok((subscribe_tx, stream))
441 }
442
443 pub async fn subscribe_richat(
445 &mut self,
446 request: GrpcSubscribeRequest,
447 ) -> Result<GrpcClientStream, Status> {
448 let (mut tx, rx) = mpsc::unbounded();
449 tx.send(request)
450 .await
451 .expect("failed to send to unbounded channel");
452
453 let response: Response<Streaming<Vec<u8>>> = self.geyser.subscribe_richat(rx).await?;
454 Ok(GrpcClientStream::new(response.into_inner()))
455 }
456
457 pub async fn subscribe_replay_info(&mut self) -> Result<SubscribeReplayInfoResponse, Status> {
459 let message = SubscribeReplayInfoRequest {};
460 let request = tonic::Request::new(message);
461 let response = self.geyser.subscribe_replay_info(request).await?;
462 Ok(response.into_inner())
463 }
464
465 pub async fn ping(&mut self, count: i32) -> Result<PongResponse, Status> {
466 let message = PingRequest { count };
467 let request = Request::new(message);
468 let response = self.geyser.ping(request).await?;
469 Ok(response.into_inner())
470 }
471
472 pub async fn get_latest_blockhash(
473 &mut self,
474 commitment: Option<CommitmentLevel>,
475 ) -> Result<GetLatestBlockhashResponse, Status> {
476 let request = Request::new(GetLatestBlockhashRequest {
477 commitment: commitment.map(|value| value as i32),
478 });
479 let response = self.geyser.get_latest_blockhash(request).await?;
480 Ok(response.into_inner())
481 }
482
483 pub async fn get_block_height(
484 &mut self,
485 commitment: Option<CommitmentLevel>,
486 ) -> Result<GetBlockHeightResponse, Status> {
487 let request = Request::new(GetBlockHeightRequest {
488 commitment: commitment.map(|value| value as i32),
489 });
490 let response = self.geyser.get_block_height(request).await?;
491 Ok(response.into_inner())
492 }
493
494 pub async fn get_slot(
495 &mut self,
496 commitment: Option<CommitmentLevel>,
497 ) -> Result<GetSlotResponse, Status> {
498 let request = Request::new(GetSlotRequest {
499 commitment: commitment.map(|value| value as i32),
500 });
501 let response = self.geyser.get_slot(request).await?;
502 Ok(response.into_inner())
503 }
504
505 pub async fn is_blockhash_valid(
506 &mut self,
507 blockhash: String,
508 commitment: Option<CommitmentLevel>,
509 ) -> Result<IsBlockhashValidResponse, Status> {
510 let request = Request::new(IsBlockhashValidRequest {
511 blockhash,
512 commitment: commitment.map(|value| value as i32),
513 });
514 let response = self.geyser.is_blockhash_valid(request).await?;
515 Ok(response.into_inner())
516 }
517
518 pub async fn get_version(&mut self) -> Result<GetVersionResponse, Status> {
519 let request = Request::new(GetVersionRequest {});
520 let response = self.geyser.get_version(request).await?;
521 Ok(response.into_inner())
522 }
523}
524
525trait SubscribeMessage {
526 fn decode(buf: &mut DecodeBuf<'_>) -> Self;
527}
528
529impl SubscribeMessage for Vec<u8> {
530 fn decode(src: &mut DecodeBuf<'_>) -> Self {
531 let mut dst = Box::new_uninit_slice(src.remaining());
532 let mut start = 0;
533 while src.remaining() > 0 {
534 let chunk = src.chunk();
535 unsafe {
538 std::ptr::copy_nonoverlapping(
539 chunk.as_ptr(),
540 dst.as_mut_ptr().cast::<u8>().add(start),
541 chunk.len(),
542 );
543 }
544 start += chunk.len();
545 src.advance(chunk.len());
546 }
547 unsafe { dst.assume_init() }.into_vec()
549 }
550}
551
552pub struct SubscribeCodec<T, U> {
553 _pd: PhantomData<(T, U)>,
554}
555
556impl<T, U> Default for SubscribeCodec<T, U> {
557 fn default() -> Self {
558 Self { _pd: PhantomData }
559 }
560}
561
562impl<T, U> Codec for SubscribeCodec<T, U>
563where
564 T: Message + Send + 'static,
565 U: SubscribeMessage + Default + Send + 'static,
566{
567 type Encode = T;
568 type Decode = U;
569
570 type Encoder = ProstEncoder<T>;
571 type Decoder = SubscribeDecoder<U>;
572
573 fn encoder(&mut self) -> Self::Encoder {
574 ProstEncoder(PhantomData)
575 }
576
577 fn decoder(&mut self) -> Self::Decoder {
578 SubscribeDecoder(PhantomData)
579 }
580}
581
582#[derive(Debug, Clone, Default)]
584pub struct ProstEncoder<T>(PhantomData<T>);
585
586impl<T: Message> Encoder for ProstEncoder<T> {
587 type Item = T;
588 type Error = Status;
589
590 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
591 item.encode(buf)
592 .expect("Message only errors if not enough space");
593 Ok(())
594 }
595}
596
597#[derive(Debug, Clone, Default)]
599pub struct SubscribeDecoder<U>(PhantomData<U>);
600
601impl<U: SubscribeMessage + Default> Decoder for SubscribeDecoder<U> {
602 type Item = U;
603 type Error = Status;
604
605 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
606 Ok(Some(SubscribeMessage::decode(buf)))
607 }
608}
609
610pin_project! {
611 pub struct GrpcClientStream {
612 #[pin]
613 stream: Streaming<Vec<u8>>,
614 }
615}
616
617impl GrpcClientStream {
618 pub const fn new(stream: Streaming<Vec<u8>>) -> Self {
619 Self { stream }
620 }
621
622 pub fn into_parsed(self) -> SubscribeStream {
623 SubscribeStream::new(self.boxed())
624 }
625}
626
627impl fmt::Debug for GrpcClientStream {
628 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
629 f.debug_struct("GrpcClientStream").finish()
630 }
631}
632
633impl Stream for GrpcClientStream {
634 type Item = Result<Vec<u8>, ReceiveError>;
635
636 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
637 let me = self.project();
638 me.stream.poll_next(cx).map_err(Into::into)
639 }
640}