richat_shared/transports/
grpc.rs

1use {
2    crate::{
3        config::{deserialize_humansize_usize, deserialize_x_tokens_set},
4        transports::{RecvError, RecvStream, Subscribe, SubscribeError},
5        version::Version,
6    },
7    futures::stream::{Stream, StreamExt},
8    prost::{Message, bytes::BufMut},
9    richat_proto::{
10        geyser::{GetVersionRequest, GetVersionResponse},
11        richat::GrpcSubscribeRequest,
12    },
13    serde::{
14        Deserialize,
15        de::{self, Deserializer},
16    },
17    std::{
18        borrow::Cow,
19        collections::HashSet,
20        fmt, fs,
21        future::Future,
22        marker::PhantomData,
23        net::{IpAddr, Ipv4Addr, SocketAddr},
24        pin::Pin,
25        sync::{
26            Arc,
27            atomic::{AtomicU64, Ordering},
28        },
29        task::{Context, Poll, ready},
30        time::Duration,
31    },
32    thiserror::Error,
33    tokio::task::JoinError,
34    tokio_util::sync::CancellationToken,
35    tonic::{
36        Request, Response, Status, Streaming,
37        codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
38        service::interceptor::InterceptorLayer,
39        transport::{
40            Identity, ServerTlsConfig,
41            server::{Server, TcpIncoming},
42        },
43    },
44    tracing::{error, info},
45};
46
47pub mod geyser_gen {
48    #![allow(clippy::clone_on_ref_ptr)]
49    #![allow(clippy::missing_const_for_fn)]
50
51    include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
52}
53
54#[derive(Debug, Default, Clone, Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ConfigGrpcCompression {
57    #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
58    pub accept: Vec<CompressionEncoding>,
59    #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
60    pub send: Vec<CompressionEncoding>,
61}
62
63impl ConfigGrpcCompression {
64    pub fn deserialize_compression<'de, D>(
65        deserializer: D,
66    ) -> Result<Vec<CompressionEncoding>, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        Vec::<&str>::deserialize(deserializer)?
71            .into_iter()
72            .map(|value| match value {
73                "gzip" => Ok(CompressionEncoding::Gzip),
74                "zstd" => Ok(CompressionEncoding::Zstd),
75                value => Err(de::Error::custom(format!(
76                    "Unknown compression format: {value}"
77                ))),
78            })
79            .collect::<Result<_, _>>()
80    }
81}
82
83#[derive(Debug, Clone, Deserialize)]
84#[serde(deny_unknown_fields, default)]
85pub struct ConfigGrpcServer {
86    pub endpoint: SocketAddr,
87    #[serde(deserialize_with = "ConfigGrpcServer::deserialize_tls_config")]
88    pub tls_config: Option<ServerTlsConfig>,
89    pub compression: ConfigGrpcCompression,
90    /// Limits the maximum size of a decoded message, default is 4MiB
91    #[serde(deserialize_with = "deserialize_humansize_usize")]
92    pub max_decoding_message_size: usize,
93    #[serde(with = "humantime_serde")]
94    pub server_tcp_keepalive: Option<Duration>,
95    pub server_tcp_nodelay: bool,
96    pub server_http2_adaptive_window: Option<bool>,
97    #[serde(with = "humantime_serde")]
98    pub server_http2_keepalive_interval: Option<Duration>,
99    #[serde(with = "humantime_serde")]
100    pub server_http2_keepalive_timeout: Option<Duration>,
101    pub server_initial_connection_window_size: Option<u32>,
102    pub server_initial_stream_window_size: Option<u32>,
103    #[serde(deserialize_with = "deserialize_x_tokens_set")]
104    pub x_tokens: HashSet<Vec<u8>>,
105}
106
107impl Default for ConfigGrpcServer {
108    fn default() -> Self {
109        Self {
110            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100),
111            tls_config: None,
112            compression: ConfigGrpcCompression::default(),
113            max_decoding_message_size: 4 * 1024 * 1024, // 4MiB
114            server_tcp_keepalive: Some(Duration::from_secs(15)),
115            server_tcp_nodelay: true,
116            server_http2_adaptive_window: None,
117            server_http2_keepalive_interval: None,
118            server_http2_keepalive_timeout: None,
119            server_initial_connection_window_size: None,
120            server_initial_stream_window_size: None,
121            x_tokens: HashSet::new(),
122        }
123    }
124}
125
126impl ConfigGrpcServer {
127    pub fn deserialize_tls_config<'de, D>(
128        deserializer: D,
129    ) -> Result<Option<ServerTlsConfig>, D::Error>
130    where
131        D: Deserializer<'de>,
132    {
133        #[derive(Debug, Deserialize)]
134        #[serde(deny_unknown_fields)]
135        struct ConfigTls<'a> {
136            cert: &'a str,
137            key: &'a str,
138        }
139
140        Option::<ConfigTls>::deserialize(deserializer)?
141            .map(|config| {
142                let cert = fs::read(config.cert).map_err(|error| {
143                    de::Error::custom(format!("failed to read cert {}: {error:?}", config.cert))
144                })?;
145                let key = fs::read(config.key).map_err(|error| {
146                    de::Error::custom(format!("failed to read key {}: {error:?}", config.key))
147                })?;
148
149                Ok(ServerTlsConfig::new().identity(Identity::from_pem(cert, key)))
150            })
151            .transpose()
152    }
153
154    pub fn create_server_builder(&self) -> Result<(TcpIncoming, Server), CreateServerError> {
155        // Bind service address
156        let incoming = TcpIncoming::bind(self.endpoint)
157            .map_err(|error| CreateServerError::Bind {
158                error,
159                endpoint: self.endpoint,
160            })?
161            .with_nodelay(Some(self.server_tcp_nodelay))
162            .with_keepalive(self.server_tcp_keepalive);
163
164        // Create service
165        let mut server_builder = Server::builder();
166        if let Some(tls_config) = self.tls_config.clone() {
167            server_builder = server_builder.tls_config(tls_config)?;
168        }
169        if let Some(enabled) = self.server_http2_adaptive_window {
170            server_builder = server_builder.http2_adaptive_window(Some(enabled));
171        }
172        if let Some(http2_keepalive_interval) = self.server_http2_keepalive_interval {
173            server_builder =
174                server_builder.http2_keepalive_interval(Some(http2_keepalive_interval));
175        }
176        if let Some(http2_keepalive_timeout) = self.server_http2_keepalive_timeout {
177            server_builder = server_builder.http2_keepalive_timeout(Some(http2_keepalive_timeout));
178        }
179        if let Some(sz) = self.server_initial_connection_window_size {
180            server_builder = server_builder.initial_connection_window_size(sz);
181        }
182        if let Some(sz) = self.server_initial_stream_window_size {
183            server_builder = server_builder.initial_stream_window_size(sz);
184        }
185
186        Ok((incoming, server_builder))
187    }
188}
189
190#[derive(Debug, Error)]
191pub enum CreateServerError {
192    #[error("failed to bind {endpoint}: {error}")]
193    Bind {
194        error: std::io::Error,
195        endpoint: SocketAddr,
196    },
197    #[error("failed to apply tls_config: {0}")]
198    Tls(#[from] tonic::transport::Error),
199}
200
201pub struct GrpcServer<S, F1, F2> {
202    messages: S,
203    subscribe_id: AtomicU64,
204    on_conn_new_cb: F1,
205    on_conn_drop_cb: F2,
206    version: Version<'static>,
207}
208
209impl<S, F1, F2> fmt::Debug for GrpcServer<S, F1, F2> {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        f.debug_struct("GrpcServer")
212            .field("subscribe_id", &self.subscribe_id)
213            .field("version", &self.version)
214            .finish()
215    }
216}
217
218impl<S, F1, F2> GrpcServer<S, F1, F2>
219where
220    S: Subscribe + Send + Sync + 'static,
221    F1: Fn() + Clone + Unpin + Send + Sync + 'static,
222    F2: Fn() + Clone + Unpin + Send + Sync + 'static,
223{
224    pub async fn spawn(
225        config: ConfigGrpcServer,
226        messages: S,
227        on_conn_new_cb: F1,
228        on_conn_drop_cb: F2,
229        version: Version<'static>,
230        shutdown: CancellationToken,
231    ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateServerError> {
232        let (incoming, server_builder) = config.create_server_builder()?;
233        info!("start server at {}", config.endpoint);
234
235        let mut service = geyser_gen::geyser_server::GeyserServer::new(Self {
236            messages,
237            subscribe_id: AtomicU64::new(0),
238            on_conn_new_cb,
239            on_conn_drop_cb,
240            version,
241        })
242        .max_decoding_message_size(config.max_decoding_message_size);
243        for encoding in config.compression.accept {
244            service = service.accept_compressed(encoding);
245        }
246        for encoding in config.compression.send {
247            service = service.send_compressed(encoding);
248        }
249
250        // Spawn server
251        Ok(tokio::spawn(async move {
252            if let Err(error) = server_builder
253                .layer(InterceptorLayer::new(move |request: Request<()>| {
254                    if config.x_tokens.is_empty() {
255                        Ok(request)
256                    } else {
257                        match request.metadata().get("x-token") {
258                            Some(token) if config.x_tokens.contains(token.as_bytes()) => {
259                                Ok(request)
260                            }
261                            _ => Err(Status::unauthenticated("No valid auth token")),
262                        }
263                    }
264                }))
265                .add_service(service)
266                .serve_with_incoming_shutdown(incoming, shutdown.cancelled())
267                .await
268            {
269                error!("server error: {error:?}")
270            } else {
271                info!("shutdown")
272            }
273        }))
274    }
275}
276
277#[tonic::async_trait]
278impl<S, F1, F2> geyser_gen::geyser_server::Geyser for GrpcServer<S, F1, F2>
279where
280    S: Subscribe + Send + Sync + 'static,
281    F2: Fn() + Clone + Unpin + Send + Sync + 'static,
282    F1: Fn() + Clone + Unpin + Send + Sync + 'static,
283{
284    type SubscribeStream = ReceiverStream<F2>;
285
286    async fn subscribe(
287        &self,
288        mut request: Request<Streaming<GrpcSubscribeRequest>>,
289    ) -> Result<Response<Self::SubscribeStream>, Status> {
290        let id = self.subscribe_id.fetch_add(1, Ordering::Relaxed);
291        info!("#{id}: new connection from {:?}", request.remote_addr());
292
293        let (replay_from_slot, filter) = match request.get_mut().message().await {
294            Ok(Some(GrpcSubscribeRequest {
295                replay_from_slot,
296                filter,
297            })) => (replay_from_slot, filter),
298            Ok(None) => {
299                info!("#{id}: connection closed before receiving request");
300                return Err(Status::aborted("stream closed before request received"));
301            }
302            Err(error) => {
303                error!("#{id}: error receiving request {error}");
304                return Err(Status::aborted("recv error"));
305            }
306        };
307
308        match self.messages.subscribe(replay_from_slot, filter) {
309            Ok(rx) => {
310                let pos = replay_from_slot
311                    .map(|slot| format!("slot {slot}").into())
312                    .unwrap_or(Cow::Borrowed("latest"));
313                info!("#{id}: subscribed from {pos}");
314                Ok(Response::new(ReceiverStream::new(
315                    rx.boxed(),
316                    id,
317                    self.on_conn_new_cb.clone(),  // on new conn
318                    self.on_conn_drop_cb.clone(), // on drop conn
319                )))
320            }
321            Err(SubscribeError::NotInitialized) => Err(Status::internal("not initialized")),
322            Err(SubscribeError::SlotNotAvailable { first_available }) => Err(
323                Status::invalid_argument(format!("first available slot: {first_available}")),
324            ),
325        }
326    }
327
328    async fn get_version(
329        &self,
330        _request: Request<GetVersionRequest>,
331    ) -> Result<Response<GetVersionResponse>, Status> {
332        Ok(Response::new(GetVersionResponse {
333            version: self.version.create_grpc_version_info().json(),
334        }))
335    }
336}
337
338pub struct ReceiverStream<F2: Fn()> {
339    rx: RecvStream,
340    id: u64,
341    on_conn_drop_cb: F2,
342}
343
344impl<F2: Fn()> fmt::Debug for ReceiverStream<F2> {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        f.debug_struct("ReceiverStream").finish()
347    }
348}
349
350impl<F2: Fn()> ReceiverStream<F2> {
351    fn new<F1: Fn()>(rx: RecvStream, id: u64, on_conn_new_cb: F1, on_conn_drop_cb: F2) -> Self {
352        on_conn_new_cb();
353        Self {
354            rx,
355            id,
356            on_conn_drop_cb,
357        }
358    }
359}
360
361impl<F2: Fn()> Drop for ReceiverStream<F2> {
362    fn drop(&mut self) {
363        info!("#{}: send stream closed", self.id);
364        (self.on_conn_drop_cb)();
365    }
366}
367
368impl<F2: Fn() + Unpin> Stream for ReceiverStream<F2> {
369    type Item = Result<Arc<Vec<u8>>, Status>;
370
371    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372        match ready!(self.rx.poll_next_unpin(cx)) {
373            Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
374            Some(Err(error)) => {
375                error!("#{}: failed to get message: {error}", self.id);
376                match error {
377                    RecvError::Lagged => Poll::Ready(Some(Err(Status::out_of_range("lagged")))),
378                    RecvError::Closed => Poll::Ready(Some(Err(Status::out_of_range("closed")))),
379                }
380            }
381            None => Poll::Ready(None),
382        }
383    }
384}
385
386trait SubscribeMessage {
387    fn encode(self, buf: &mut EncodeBuf<'_>);
388}
389
390impl SubscribeMessage for &[u8] {
391    fn encode(self, buf: &mut EncodeBuf<'_>) {
392        let required = self.len();
393        let remaining = buf.remaining_mut();
394        if required > remaining {
395            panic!("SubscribeMessage only errors if not enough space");
396        }
397        buf.put_slice(self.as_ref());
398    }
399}
400
401impl SubscribeMessage for Vec<u8> {
402    fn encode(self, buf: &mut EncodeBuf<'_>) {
403        self.as_slice().encode(buf);
404    }
405}
406
407impl SubscribeMessage for Arc<Vec<u8>> {
408    fn encode(self, buf: &mut EncodeBuf<'_>) {
409        self.as_slice().encode(buf);
410    }
411}
412
413pub struct SubscribeCodec<T, U> {
414    _pd: PhantomData<(T, U)>,
415}
416
417impl<T, U> Default for SubscribeCodec<T, U> {
418    fn default() -> Self {
419        Self { _pd: PhantomData }
420    }
421}
422
423impl<T, U> Codec for SubscribeCodec<T, U>
424where
425    T: SubscribeMessage + Send + 'static,
426    U: Message + Default + Send + 'static,
427{
428    type Encode = T;
429    type Decode = U;
430
431    type Encoder = SubscribeEncoder<T>;
432    type Decoder = ProstDecoder<U>;
433
434    fn encoder(&mut self) -> Self::Encoder {
435        SubscribeEncoder(PhantomData)
436    }
437
438    fn decoder(&mut self) -> Self::Decoder {
439        ProstDecoder(PhantomData)
440    }
441}
442
443/// A [`Encoder`] that knows how to encode `T`.
444#[derive(Debug, Clone, Default)]
445pub struct SubscribeEncoder<T>(PhantomData<T>);
446
447impl<T: SubscribeMessage> Encoder for SubscribeEncoder<T> {
448    type Item = T;
449    type Error = Status;
450
451    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
452        item.encode(buf);
453        Ok(())
454    }
455}
456
457/// A [`Decoder`] that knows how to decode `U`.
458#[derive(Debug, Clone, Default)]
459pub struct ProstDecoder<U>(PhantomData<U>);
460
461impl<U: Message + Default> Decoder for ProstDecoder<U> {
462    type Item = U;
463    type Error = Status;
464
465    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
466        let item = Message::decode(buf)
467            .map(Option::Some)
468            .map_err(from_decode_error)?;
469
470        Ok(item)
471    }
472}
473
474fn from_decode_error(error: prost::DecodeError) -> Status {
475    // Map Protobuf parse errors to an INTERNAL status code, as per
476    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
477    Status::new(tonic::Code::Internal, error.to_string())
478}