1use {
2 crate::{
3 config::{deserialize_num_str, deserialize_x_token_set},
4 shutdown::Shutdown,
5 transports::{RecvError, RecvStream, Subscribe, SubscribeError},
6 version::Version,
7 },
8 futures::stream::{Stream, StreamExt},
9 prost::{bytes::BufMut, Message},
10 richat_proto::{
11 geyser::{GetVersionRequest, GetVersionResponse},
12 richat::GrpcSubscribeRequest,
13 },
14 serde::{
15 de::{self, Deserializer},
16 Deserialize,
17 },
18 std::{
19 borrow::Cow,
20 collections::HashSet,
21 fmt, fs,
22 future::Future,
23 marker::PhantomData,
24 net::{IpAddr, Ipv4Addr, SocketAddr},
25 pin::Pin,
26 sync::{
27 atomic::{AtomicU64, Ordering},
28 Arc,
29 },
30 task::{ready, Context, Poll},
31 time::Duration,
32 },
33 thiserror::Error,
34 tokio::task::JoinError,
35 tonic::{
36 codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
37 service::interceptor::interceptor,
38 transport::{
39 server::{Server, TcpIncoming},
40 Identity, ServerTlsConfig,
41 },
42 Request, Response, Status, Streaming,
43 },
44 tracing::{error, info},
45};
46
47pub mod gen {
48 #![allow(clippy::clone_on_ref_ptr)]
49 #![allow(clippy::missing_const_for_fn)]
50
51 include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
52}
53
54#[derive(Debug, Clone, Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ConfigGrpcCompression {
57 #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
58 pub accept: Vec<CompressionEncoding>,
59 #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
60 pub send: Vec<CompressionEncoding>,
61}
62
63impl Default for ConfigGrpcCompression {
64 fn default() -> Self {
65 Self {
66 accept: Self::default_compression(),
67 send: Self::default_compression(),
68 }
69 }
70}
71
72impl ConfigGrpcCompression {
73 pub fn deserialize_compression<'de, D>(
74 deserializer: D,
75 ) -> Result<Vec<CompressionEncoding>, D::Error>
76 where
77 D: Deserializer<'de>,
78 {
79 Vec::<&str>::deserialize(deserializer)?
80 .into_iter()
81 .map(|value| match value {
82 "gzip" => Ok(CompressionEncoding::Gzip),
83 "zstd" => Ok(CompressionEncoding::Zstd),
84 value => Err(de::Error::custom(format!(
85 "Unknown compression format: {value}"
86 ))),
87 })
88 .collect::<Result<_, _>>()
89 }
90
91 const fn default_compression() -> Vec<CompressionEncoding> {
92 vec![]
93 }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(deny_unknown_fields, default)]
98pub struct ConfigGrpcServer {
99 pub endpoint: SocketAddr,
100 #[serde(deserialize_with = "ConfigGrpcServer::deserialize_tls_config")]
101 pub tls_config: Option<ServerTlsConfig>,
102 pub compression: ConfigGrpcCompression,
103 #[serde(deserialize_with = "deserialize_num_str")]
105 pub max_decoding_message_size: usize,
106 #[serde(with = "humantime_serde")]
107 pub server_tcp_keepalive: Option<Duration>,
108 pub server_tcp_nodelay: bool,
109 pub server_http2_adaptive_window: Option<bool>,
110 #[serde(with = "humantime_serde")]
111 pub server_http2_keepalive_interval: Option<Duration>,
112 #[serde(with = "humantime_serde")]
113 pub server_http2_keepalive_timeout: Option<Duration>,
114 pub server_initial_connection_window_size: Option<u32>,
115 pub server_initial_stream_window_size: Option<u32>,
116 #[serde(deserialize_with = "deserialize_x_token_set")]
117 pub x_tokens: HashSet<Vec<u8>>,
118}
119
120impl Default for ConfigGrpcServer {
121 fn default() -> Self {
122 Self {
123 endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10102),
124 tls_config: None,
125 compression: ConfigGrpcCompression::default(),
126 max_decoding_message_size: 4 * 1024 * 1024, server_tcp_keepalive: Some(Duration::from_secs(15)),
128 server_tcp_nodelay: true,
129 server_http2_adaptive_window: None,
130 server_http2_keepalive_interval: None,
131 server_http2_keepalive_timeout: None,
132 server_initial_connection_window_size: None,
133 server_initial_stream_window_size: None,
134 x_tokens: HashSet::new(),
135 }
136 }
137}
138
139impl ConfigGrpcServer {
140 pub fn deserialize_tls_config<'de, D>(
141 deserializer: D,
142 ) -> Result<Option<ServerTlsConfig>, D::Error>
143 where
144 D: Deserializer<'de>,
145 {
146 #[derive(Debug, Deserialize)]
147 #[serde(deny_unknown_fields)]
148 struct ConfigTls<'a> {
149 cert: &'a str,
150 key: &'a str,
151 }
152
153 Option::<ConfigTls>::deserialize(deserializer)?
154 .map(|config| {
155 let cert = fs::read(config.cert).map_err(|error| {
156 de::Error::custom(format!("failed to read cert {}: {error:?}", config.cert))
157 })?;
158 let key = fs::read(config.key).map_err(|error| {
159 de::Error::custom(format!("failed to read key {}: {error:?}", config.key))
160 })?;
161
162 Ok(ServerTlsConfig::new().identity(Identity::from_pem(cert, key)))
163 })
164 .transpose()
165 }
166
167 pub fn create_server_builder(&self) -> Result<(TcpIncoming, Server), CreateServerError> {
168 let incoming = TcpIncoming::new(
170 self.endpoint,
171 self.server_tcp_nodelay,
172 self.server_tcp_keepalive,
173 )
174 .map_err(|error| CreateServerError::Bind {
175 error,
176 endpoint: self.endpoint,
177 })?;
178
179 let mut server_builder = Server::builder();
181 if let Some(tls_config) = self.tls_config.clone() {
182 server_builder = server_builder.tls_config(tls_config)?;
183 }
184 if let Some(enabled) = self.server_http2_adaptive_window {
185 server_builder = server_builder.http2_adaptive_window(Some(enabled));
186 }
187 if let Some(http2_keepalive_interval) = self.server_http2_keepalive_interval {
188 server_builder =
189 server_builder.http2_keepalive_interval(Some(http2_keepalive_interval));
190 }
191 if let Some(http2_keepalive_timeout) = self.server_http2_keepalive_timeout {
192 server_builder = server_builder.http2_keepalive_timeout(Some(http2_keepalive_timeout));
193 }
194 if let Some(sz) = self.server_initial_connection_window_size {
195 server_builder = server_builder.initial_connection_window_size(sz);
196 }
197 if let Some(sz) = self.server_initial_stream_window_size {
198 server_builder = server_builder.initial_stream_window_size(sz);
199 }
200
201 Ok((incoming, server_builder))
202 }
203}
204
205#[derive(Debug, Error)]
206pub enum CreateServerError {
207 #[error("failed to bind {endpoint}: {error}")]
208 Bind {
209 error: Box<dyn std::error::Error + Send + Sync>,
210 endpoint: SocketAddr,
211 },
212 #[error("failed to apply tls_config: {0}")]
213 Tls(#[from] tonic::transport::Error),
214}
215
216pub struct GrpcServer<S, F1, F2> {
217 messages: S,
218 subscribe_id: AtomicU64,
219 on_conn_new_cb: F1,
220 on_conn_drop_cb: F2,
221 version: Version<'static>,
222}
223
224impl<S, F1, F2> fmt::Debug for GrpcServer<S, F1, F2> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("GrpcServer")
227 .field("subscribe_id", &self.subscribe_id)
228 .field("version", &self.version)
229 .finish()
230 }
231}
232
233impl<S, F1, F2> GrpcServer<S, F1, F2>
234where
235 S: Subscribe + Send + Sync + 'static,
236 F1: Fn() + Clone + Unpin + Send + Sync + 'static,
237 F2: Fn() + Clone + Unpin + Send + Sync + 'static,
238{
239 pub async fn spawn(
240 config: ConfigGrpcServer,
241 messages: S,
242 on_conn_new_cb: F1,
243 on_conn_drop_cb: F2,
244 version: Version<'static>,
245 shutdown: Shutdown,
246 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateServerError> {
247 let (incoming, server_builder) = config.create_server_builder()?;
248 info!("start server at {}", config.endpoint);
249
250 let mut service = gen::geyser_server::GeyserServer::new(Self {
251 messages,
252 subscribe_id: AtomicU64::new(0),
253 on_conn_new_cb,
254 on_conn_drop_cb,
255 version,
256 })
257 .max_decoding_message_size(config.max_decoding_message_size);
258 for encoding in config.compression.accept {
259 service = service.accept_compressed(encoding);
260 }
261 for encoding in config.compression.send {
262 service = service.send_compressed(encoding);
263 }
264
265 Ok(tokio::spawn(async move {
267 if let Err(error) = server_builder
268 .layer(interceptor(move |request: Request<()>| {
269 if config.x_tokens.is_empty() {
270 Ok(request)
271 } else {
272 match request.metadata().get("x-token") {
273 Some(token) if config.x_tokens.contains(token.as_bytes()) => {
274 Ok(request)
275 }
276 _ => Err(Status::unauthenticated("No valid auth token")),
277 }
278 }
279 }))
280 .add_service(service)
281 .serve_with_incoming_shutdown(incoming, shutdown)
282 .await
283 {
284 error!("server error: {error:?}")
285 } else {
286 info!("shutdown")
287 }
288 }))
289 }
290}
291
292#[tonic::async_trait]
293impl<S, F1, F2> gen::geyser_server::Geyser for GrpcServer<S, F1, F2>
294where
295 S: Subscribe + Send + Sync + 'static,
296 F2: Fn() + Clone + Unpin + Send + Sync + 'static,
297 F1: Fn() + Clone + Unpin + Send + Sync + 'static,
298{
299 type SubscribeStream = ReceiverStream<F2>;
300
301 async fn subscribe(
302 &self,
303 mut request: Request<Streaming<GrpcSubscribeRequest>>,
304 ) -> Result<Response<Self::SubscribeStream>, Status> {
305 let id = self.subscribe_id.fetch_add(1, Ordering::Relaxed);
306 info!("#{id}: new connection from {:?}", request.remote_addr());
307
308 let (replay_from_slot, filter) = match request.get_mut().message().await {
309 Ok(Some(GrpcSubscribeRequest {
310 replay_from_slot,
311 filter,
312 })) => (replay_from_slot, filter),
313 Ok(None) => {
314 info!("#{id}: connection closed before receiving request");
315 return Err(Status::aborted("stream closed before request received"));
316 }
317 Err(error) => {
318 error!("#{id}: error receiving request {error}");
319 return Err(Status::aborted("recv error"));
320 }
321 };
322
323 match self.messages.subscribe(replay_from_slot, filter) {
324 Ok(rx) => {
325 let pos = replay_from_slot
326 .map(|slot| format!("slot {slot}").into())
327 .unwrap_or(Cow::Borrowed("latest"));
328 info!("#{id}: subscribed from {pos}");
329 Ok(Response::new(ReceiverStream::new(
330 rx.boxed(),
331 id,
332 self.on_conn_new_cb.clone(), self.on_conn_drop_cb.clone(), )))
335 }
336 Err(SubscribeError::NotInitialized) => Err(Status::internal("not initialized")),
337 Err(SubscribeError::SlotNotAvailable { first_available }) => Err(
338 Status::invalid_argument(format!("first available slot: {first_available}")),
339 ),
340 }
341 }
342
343 async fn get_version(
344 &self,
345 _request: Request<GetVersionRequest>,
346 ) -> Result<Response<GetVersionResponse>, Status> {
347 Ok(Response::new(GetVersionResponse {
348 version: self.version.create_grpc_version_info().json(),
349 }))
350 }
351}
352
353pub struct ReceiverStream<F2: Fn()> {
354 rx: RecvStream,
355 id: u64,
356 on_conn_drop_cb: F2,
357}
358
359impl<F2: Fn()> fmt::Debug for ReceiverStream<F2> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 f.debug_struct("ReceiverStream").finish()
362 }
363}
364
365impl<F2: Fn()> ReceiverStream<F2> {
366 fn new<F1: Fn()>(rx: RecvStream, id: u64, on_conn_new_cb: F1, on_conn_drop_cb: F2) -> Self {
367 on_conn_new_cb();
368 Self {
369 rx,
370 id,
371 on_conn_drop_cb,
372 }
373 }
374}
375
376impl<F2: Fn()> Drop for ReceiverStream<F2> {
377 fn drop(&mut self) {
378 info!("#{}: send stream closed", self.id);
379 (self.on_conn_drop_cb)();
380 }
381}
382
383impl<F2: Fn() + Unpin> Stream for ReceiverStream<F2> {
384 type Item = Result<Arc<Vec<u8>>, Status>;
385
386 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
387 match ready!(self.rx.poll_next_unpin(cx)) {
388 Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
389 Some(Err(error)) => {
390 error!("#{}: failed to get message: {error}", self.id);
391 match error {
392 RecvError::Lagged => Poll::Ready(Some(Err(Status::out_of_range("lagged")))),
393 RecvError::Closed => Poll::Ready(Some(Err(Status::out_of_range("closed")))),
394 }
395 }
396 None => Poll::Ready(None),
397 }
398 }
399}
400
401trait SubscribeMessage {
402 fn encode(self, buf: &mut EncodeBuf<'_>);
403}
404
405impl SubscribeMessage for &[u8] {
406 fn encode(self, buf: &mut EncodeBuf<'_>) {
407 let required = self.len();
408 let remaining = buf.remaining_mut();
409 if required > remaining {
410 panic!("SubscribeMessage only errors if not enough space");
411 }
412 buf.put_slice(self.as_ref());
413 }
414}
415
416impl SubscribeMessage for Vec<u8> {
417 fn encode(self, buf: &mut EncodeBuf<'_>) {
418 self.as_slice().encode(buf);
419 }
420}
421
422impl SubscribeMessage for Arc<Vec<u8>> {
423 fn encode(self, buf: &mut EncodeBuf<'_>) {
424 self.as_slice().encode(buf);
425 }
426}
427
428pub struct SubscribeCodec<T, U> {
429 _pd: PhantomData<(T, U)>,
430}
431
432impl<T, U> Default for SubscribeCodec<T, U> {
433 fn default() -> Self {
434 Self { _pd: PhantomData }
435 }
436}
437
438impl<T, U> Codec for SubscribeCodec<T, U>
439where
440 T: SubscribeMessage + Send + 'static,
441 U: Message + Default + Send + 'static,
442{
443 type Encode = T;
444 type Decode = U;
445
446 type Encoder = SubscribeEncoder<T>;
447 type Decoder = ProstDecoder<U>;
448
449 fn encoder(&mut self) -> Self::Encoder {
450 SubscribeEncoder(PhantomData)
451 }
452
453 fn decoder(&mut self) -> Self::Decoder {
454 ProstDecoder(PhantomData)
455 }
456}
457
458#[derive(Debug, Clone, Default)]
460pub struct SubscribeEncoder<T>(PhantomData<T>);
461
462impl<T: SubscribeMessage> Encoder for SubscribeEncoder<T> {
463 type Item = T;
464 type Error = Status;
465
466 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
467 item.encode(buf);
468 Ok(())
469 }
470}
471
472#[derive(Debug, Clone, Default)]
474pub struct ProstDecoder<U>(PhantomData<U>);
475
476impl<U: Message + Default> Decoder for ProstDecoder<U> {
477 type Item = U;
478 type Error = Status;
479
480 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
481 let item = Message::decode(buf)
482 .map(Option::Some)
483 .map_err(from_decode_error)?;
484
485 Ok(item)
486 }
487}
488
489fn from_decode_error(error: prost::DecodeError) -> Status {
490 Status::new(tonic::Code::Internal, error.to_string())
493}