richat_shared/transports/
grpc.rs

1use {
2    crate::{
3        config::{deserialize_num_str, deserialize_x_token_set},
4        shutdown::Shutdown,
5        transports::{RecvError, RecvStream, Subscribe, SubscribeError},
6        version::Version,
7    },
8    futures::stream::{Stream, StreamExt},
9    prost::{bytes::BufMut, Message},
10    richat_proto::{
11        geyser::{GetVersionRequest, GetVersionResponse},
12        richat::GrpcSubscribeRequest,
13    },
14    serde::{
15        de::{self, Deserializer},
16        Deserialize,
17    },
18    std::{
19        borrow::Cow,
20        collections::HashSet,
21        fmt, fs,
22        future::Future,
23        marker::PhantomData,
24        net::{IpAddr, Ipv4Addr, SocketAddr},
25        pin::Pin,
26        sync::{
27            atomic::{AtomicU64, Ordering},
28            Arc,
29        },
30        task::{ready, Context, Poll},
31        time::Duration,
32    },
33    thiserror::Error,
34    tokio::task::JoinError,
35    tonic::{
36        codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
37        service::interceptor::interceptor,
38        transport::{
39            server::{Server, TcpIncoming},
40            Identity, ServerTlsConfig,
41        },
42        Request, Response, Status, Streaming,
43    },
44    tracing::{error, info},
45};
46
47pub mod gen {
48    #![allow(clippy::clone_on_ref_ptr)]
49    #![allow(clippy::missing_const_for_fn)]
50
51    include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
52}
53
54#[derive(Debug, Clone, Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ConfigGrpcCompression {
57    #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
58    pub accept: Vec<CompressionEncoding>,
59    #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
60    pub send: Vec<CompressionEncoding>,
61}
62
63impl Default for ConfigGrpcCompression {
64    fn default() -> Self {
65        Self {
66            accept: Self::default_compression(),
67            send: Self::default_compression(),
68        }
69    }
70}
71
72impl ConfigGrpcCompression {
73    pub fn deserialize_compression<'de, D>(
74        deserializer: D,
75    ) -> Result<Vec<CompressionEncoding>, D::Error>
76    where
77        D: Deserializer<'de>,
78    {
79        Vec::<&str>::deserialize(deserializer)?
80            .into_iter()
81            .map(|value| match value {
82                "gzip" => Ok(CompressionEncoding::Gzip),
83                "zstd" => Ok(CompressionEncoding::Zstd),
84                value => Err(de::Error::custom(format!(
85                    "Unknown compression format: {value}"
86                ))),
87            })
88            .collect::<Result<_, _>>()
89    }
90
91    const fn default_compression() -> Vec<CompressionEncoding> {
92        vec![]
93    }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(deny_unknown_fields, default)]
98pub struct ConfigGrpcServer {
99    pub endpoint: SocketAddr,
100    #[serde(deserialize_with = "ConfigGrpcServer::deserialize_tls_config")]
101    pub tls_config: Option<ServerTlsConfig>,
102    pub compression: ConfigGrpcCompression,
103    /// Limits the maximum size of a decoded message, default is 4MiB
104    #[serde(deserialize_with = "deserialize_num_str")]
105    pub max_decoding_message_size: usize,
106    #[serde(with = "humantime_serde")]
107    pub server_tcp_keepalive: Option<Duration>,
108    pub server_tcp_nodelay: bool,
109    pub server_http2_adaptive_window: Option<bool>,
110    #[serde(with = "humantime_serde")]
111    pub server_http2_keepalive_interval: Option<Duration>,
112    #[serde(with = "humantime_serde")]
113    pub server_http2_keepalive_timeout: Option<Duration>,
114    pub server_initial_connection_window_size: Option<u32>,
115    pub server_initial_stream_window_size: Option<u32>,
116    #[serde(deserialize_with = "deserialize_x_token_set")]
117    pub x_tokens: HashSet<Vec<u8>>,
118}
119
120impl Default for ConfigGrpcServer {
121    fn default() -> Self {
122        Self {
123            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10102),
124            tls_config: None,
125            compression: ConfigGrpcCompression::default(),
126            max_decoding_message_size: 4 * 1024 * 1024, // 4MiB
127            server_tcp_keepalive: Some(Duration::from_secs(15)),
128            server_tcp_nodelay: true,
129            server_http2_adaptive_window: None,
130            server_http2_keepalive_interval: None,
131            server_http2_keepalive_timeout: None,
132            server_initial_connection_window_size: None,
133            server_initial_stream_window_size: None,
134            x_tokens: HashSet::new(),
135        }
136    }
137}
138
139impl ConfigGrpcServer {
140    pub fn deserialize_tls_config<'de, D>(
141        deserializer: D,
142    ) -> Result<Option<ServerTlsConfig>, D::Error>
143    where
144        D: Deserializer<'de>,
145    {
146        #[derive(Debug, Deserialize)]
147        #[serde(deny_unknown_fields)]
148        struct ConfigTls<'a> {
149            cert: &'a str,
150            key: &'a str,
151        }
152
153        Option::<ConfigTls>::deserialize(deserializer)?
154            .map(|config| {
155                let cert = fs::read(config.cert).map_err(|error| {
156                    de::Error::custom(format!("failed to read cert {}: {error:?}", config.cert))
157                })?;
158                let key = fs::read(config.key).map_err(|error| {
159                    de::Error::custom(format!("failed to read key {}: {error:?}", config.key))
160                })?;
161
162                Ok(ServerTlsConfig::new().identity(Identity::from_pem(cert, key)))
163            })
164            .transpose()
165    }
166
167    pub fn create_server_builder(&self) -> Result<(TcpIncoming, Server), CreateServerError> {
168        // Bind service address
169        let incoming = TcpIncoming::new(
170            self.endpoint,
171            self.server_tcp_nodelay,
172            self.server_tcp_keepalive,
173        )
174        .map_err(|error| CreateServerError::Bind {
175            error,
176            endpoint: self.endpoint,
177        })?;
178
179        // Create service
180        let mut server_builder = Server::builder();
181        if let Some(tls_config) = self.tls_config.clone() {
182            server_builder = server_builder.tls_config(tls_config)?;
183        }
184        if let Some(enabled) = self.server_http2_adaptive_window {
185            server_builder = server_builder.http2_adaptive_window(Some(enabled));
186        }
187        if let Some(http2_keepalive_interval) = self.server_http2_keepalive_interval {
188            server_builder =
189                server_builder.http2_keepalive_interval(Some(http2_keepalive_interval));
190        }
191        if let Some(http2_keepalive_timeout) = self.server_http2_keepalive_timeout {
192            server_builder = server_builder.http2_keepalive_timeout(Some(http2_keepalive_timeout));
193        }
194        if let Some(sz) = self.server_initial_connection_window_size {
195            server_builder = server_builder.initial_connection_window_size(sz);
196        }
197        if let Some(sz) = self.server_initial_stream_window_size {
198            server_builder = server_builder.initial_stream_window_size(sz);
199        }
200
201        Ok((incoming, server_builder))
202    }
203}
204
205#[derive(Debug, Error)]
206pub enum CreateServerError {
207    #[error("failed to bind {endpoint}: {error}")]
208    Bind {
209        error: Box<dyn std::error::Error + Send + Sync>,
210        endpoint: SocketAddr,
211    },
212    #[error("failed to apply tls_config: {0}")]
213    Tls(#[from] tonic::transport::Error),
214}
215
216pub struct GrpcServer<S, F1, F2> {
217    messages: S,
218    subscribe_id: AtomicU64,
219    on_conn_new_cb: F1,
220    on_conn_drop_cb: F2,
221    version: Version<'static>,
222}
223
224impl<S, F1, F2> fmt::Debug for GrpcServer<S, F1, F2> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("GrpcServer")
227            .field("subscribe_id", &self.subscribe_id)
228            .field("version", &self.version)
229            .finish()
230    }
231}
232
233impl<S, F1, F2> GrpcServer<S, F1, F2>
234where
235    S: Subscribe + Send + Sync + 'static,
236    F1: Fn() + Clone + Unpin + Send + Sync + 'static,
237    F2: Fn() + Clone + Unpin + Send + Sync + 'static,
238{
239    pub async fn spawn(
240        config: ConfigGrpcServer,
241        messages: S,
242        on_conn_new_cb: F1,
243        on_conn_drop_cb: F2,
244        version: Version<'static>,
245        shutdown: Shutdown,
246    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateServerError> {
247        let (incoming, server_builder) = config.create_server_builder()?;
248        info!("start server at {}", config.endpoint);
249
250        let mut service = gen::geyser_server::GeyserServer::new(Self {
251            messages,
252            subscribe_id: AtomicU64::new(0),
253            on_conn_new_cb,
254            on_conn_drop_cb,
255            version,
256        })
257        .max_decoding_message_size(config.max_decoding_message_size);
258        for encoding in config.compression.accept {
259            service = service.accept_compressed(encoding);
260        }
261        for encoding in config.compression.send {
262            service = service.send_compressed(encoding);
263        }
264
265        // Spawn server
266        Ok(tokio::spawn(async move {
267            if let Err(error) = server_builder
268                .layer(interceptor(move |request: Request<()>| {
269                    if config.x_tokens.is_empty() {
270                        Ok(request)
271                    } else {
272                        match request.metadata().get("x-token") {
273                            Some(token) if config.x_tokens.contains(token.as_bytes()) => {
274                                Ok(request)
275                            }
276                            _ => Err(Status::unauthenticated("No valid auth token")),
277                        }
278                    }
279                }))
280                .add_service(service)
281                .serve_with_incoming_shutdown(incoming, shutdown)
282                .await
283            {
284                error!("server error: {error:?}")
285            } else {
286                info!("shutdown")
287            }
288        }))
289    }
290}
291
292#[tonic::async_trait]
293impl<S, F1, F2> gen::geyser_server::Geyser for GrpcServer<S, F1, F2>
294where
295    S: Subscribe + Send + Sync + 'static,
296    F2: Fn() + Clone + Unpin + Send + Sync + 'static,
297    F1: Fn() + Clone + Unpin + Send + Sync + 'static,
298{
299    type SubscribeStream = ReceiverStream<F2>;
300
301    async fn subscribe(
302        &self,
303        mut request: Request<Streaming<GrpcSubscribeRequest>>,
304    ) -> Result<Response<Self::SubscribeStream>, Status> {
305        let id = self.subscribe_id.fetch_add(1, Ordering::Relaxed);
306        info!("#{id}: new connection from {:?}", request.remote_addr());
307
308        let (replay_from_slot, filter) = match request.get_mut().message().await {
309            Ok(Some(GrpcSubscribeRequest {
310                replay_from_slot,
311                filter,
312            })) => (replay_from_slot, filter),
313            Ok(None) => {
314                info!("#{id}: connection closed before receiving request");
315                return Err(Status::aborted("stream closed before request received"));
316            }
317            Err(error) => {
318                error!("#{id}: error receiving request {error}");
319                return Err(Status::aborted("recv error"));
320            }
321        };
322
323        match self.messages.subscribe(replay_from_slot, filter) {
324            Ok(rx) => {
325                let pos = replay_from_slot
326                    .map(|slot| format!("slot {slot}").into())
327                    .unwrap_or(Cow::Borrowed("latest"));
328                info!("#{id}: subscribed from {pos}");
329                Ok(Response::new(ReceiverStream::new(
330                    rx.boxed(),
331                    id,
332                    self.on_conn_new_cb.clone(),  // on new conn
333                    self.on_conn_drop_cb.clone(), // on drop conn
334                )))
335            }
336            Err(SubscribeError::NotInitialized) => Err(Status::internal("not initialized")),
337            Err(SubscribeError::SlotNotAvailable { first_available }) => Err(
338                Status::invalid_argument(format!("first available slot: {first_available}")),
339            ),
340        }
341    }
342
343    async fn get_version(
344        &self,
345        _request: Request<GetVersionRequest>,
346    ) -> Result<Response<GetVersionResponse>, Status> {
347        Ok(Response::new(GetVersionResponse {
348            version: self.version.create_grpc_version_info().json(),
349        }))
350    }
351}
352
353pub struct ReceiverStream<F2: Fn()> {
354    rx: RecvStream,
355    id: u64,
356    on_conn_drop_cb: F2,
357}
358
359impl<F2: Fn()> fmt::Debug for ReceiverStream<F2> {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        f.debug_struct("ReceiverStream").finish()
362    }
363}
364
365impl<F2: Fn()> ReceiverStream<F2> {
366    fn new<F1: Fn()>(rx: RecvStream, id: u64, on_conn_new_cb: F1, on_conn_drop_cb: F2) -> Self {
367        on_conn_new_cb();
368        Self {
369            rx,
370            id,
371            on_conn_drop_cb,
372        }
373    }
374}
375
376impl<F2: Fn()> Drop for ReceiverStream<F2> {
377    fn drop(&mut self) {
378        info!("#{}: send stream closed", self.id);
379        (self.on_conn_drop_cb)();
380    }
381}
382
383impl<F2: Fn() + Unpin> Stream for ReceiverStream<F2> {
384    type Item = Result<Arc<Vec<u8>>, Status>;
385
386    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
387        match ready!(self.rx.poll_next_unpin(cx)) {
388            Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
389            Some(Err(error)) => {
390                error!("#{}: failed to get message: {error}", self.id);
391                match error {
392                    RecvError::Lagged => Poll::Ready(Some(Err(Status::out_of_range("lagged")))),
393                    RecvError::Closed => Poll::Ready(Some(Err(Status::out_of_range("closed")))),
394                }
395            }
396            None => Poll::Ready(None),
397        }
398    }
399}
400
401trait SubscribeMessage {
402    fn encode(self, buf: &mut EncodeBuf<'_>);
403}
404
405impl SubscribeMessage for &[u8] {
406    fn encode(self, buf: &mut EncodeBuf<'_>) {
407        let required = self.len();
408        let remaining = buf.remaining_mut();
409        if required > remaining {
410            panic!("SubscribeMessage only errors if not enough space");
411        }
412        buf.put_slice(self.as_ref());
413    }
414}
415
416impl SubscribeMessage for Vec<u8> {
417    fn encode(self, buf: &mut EncodeBuf<'_>) {
418        self.as_slice().encode(buf);
419    }
420}
421
422impl SubscribeMessage for Arc<Vec<u8>> {
423    fn encode(self, buf: &mut EncodeBuf<'_>) {
424        self.as_slice().encode(buf);
425    }
426}
427
428pub struct SubscribeCodec<T, U> {
429    _pd: PhantomData<(T, U)>,
430}
431
432impl<T, U> Default for SubscribeCodec<T, U> {
433    fn default() -> Self {
434        Self { _pd: PhantomData }
435    }
436}
437
438impl<T, U> Codec for SubscribeCodec<T, U>
439where
440    T: SubscribeMessage + Send + 'static,
441    U: Message + Default + Send + 'static,
442{
443    type Encode = T;
444    type Decode = U;
445
446    type Encoder = SubscribeEncoder<T>;
447    type Decoder = ProstDecoder<U>;
448
449    fn encoder(&mut self) -> Self::Encoder {
450        SubscribeEncoder(PhantomData)
451    }
452
453    fn decoder(&mut self) -> Self::Decoder {
454        ProstDecoder(PhantomData)
455    }
456}
457
458/// A [`Encoder`] that knows how to encode `T`.
459#[derive(Debug, Clone, Default)]
460pub struct SubscribeEncoder<T>(PhantomData<T>);
461
462impl<T: SubscribeMessage> Encoder for SubscribeEncoder<T> {
463    type Item = T;
464    type Error = Status;
465
466    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
467        item.encode(buf);
468        Ok(())
469    }
470}
471
472/// A [`Decoder`] that knows how to decode `U`.
473#[derive(Debug, Clone, Default)]
474pub struct ProstDecoder<U>(PhantomData<U>);
475
476impl<U: Message + Default> Decoder for ProstDecoder<U> {
477    type Item = U;
478    type Error = Status;
479
480    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
481        let item = Message::decode(buf)
482            .map(Option::Some)
483            .map_err(from_decode_error)?;
484
485        Ok(item)
486    }
487}
488
489fn from_decode_error(error: prost::DecodeError) -> Status {
490    // Map Protobuf parse errors to an INTERNAL status code, as per
491    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
492    Status::new(tonic::Code::Internal, error.to_string())
493}