1use {
2 crate::{
3 config::{deserialize_humansize_usize, deserialize_x_tokens_set},
4 transports::{RecvError, RecvStream, Subscribe, SubscribeError},
5 version::Version,
6 },
7 futures::stream::{Stream, StreamExt},
8 prost::{Message, bytes::BufMut},
9 richat_proto::{
10 geyser::{GetVersionRequest, GetVersionResponse},
11 richat::GrpcSubscribeRequest,
12 },
13 serde::{
14 Deserialize,
15 de::{self, Deserializer},
16 },
17 std::{
18 borrow::Cow,
19 collections::HashSet,
20 fmt, fs,
21 future::Future,
22 marker::PhantomData,
23 net::{IpAddr, Ipv4Addr, SocketAddr},
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicU64, Ordering},
28 },
29 task::{Context, Poll, ready},
30 time::Duration,
31 },
32 thiserror::Error,
33 tokio::task::JoinError,
34 tokio_util::sync::CancellationToken,
35 tonic::{
36 Request, Response, Status, Streaming,
37 codec::{Codec, CompressionEncoding, DecodeBuf, Decoder, EncodeBuf, Encoder},
38 service::interceptor::InterceptorLayer,
39 transport::{
40 Identity, ServerTlsConfig,
41 server::{Server, TcpIncoming},
42 },
43 },
44 tracing::{error, info},
45};
46
47pub mod geyser_gen {
48 #![allow(clippy::clone_on_ref_ptr)]
49 #![allow(clippy::missing_const_for_fn)]
50
51 include!(concat!(env!("OUT_DIR"), "/geyser.Geyser.rs"));
52}
53
54#[derive(Debug, Default, Clone, Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ConfigGrpcCompression {
57 #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
58 pub accept: Vec<CompressionEncoding>,
59 #[serde(deserialize_with = "ConfigGrpcCompression::deserialize_compression")]
60 pub send: Vec<CompressionEncoding>,
61}
62
63impl ConfigGrpcCompression {
64 pub fn deserialize_compression<'de, D>(
65 deserializer: D,
66 ) -> Result<Vec<CompressionEncoding>, D::Error>
67 where
68 D: Deserializer<'de>,
69 {
70 Vec::<&str>::deserialize(deserializer)?
71 .into_iter()
72 .map(|value| match value {
73 "gzip" => Ok(CompressionEncoding::Gzip),
74 "zstd" => Ok(CompressionEncoding::Zstd),
75 value => Err(de::Error::custom(format!(
76 "Unknown compression format: {value}"
77 ))),
78 })
79 .collect::<Result<_, _>>()
80 }
81}
82
83#[derive(Debug, Clone, Deserialize)]
84#[serde(deny_unknown_fields, default)]
85pub struct ConfigGrpcServer {
86 pub endpoint: SocketAddr,
87 #[serde(deserialize_with = "ConfigGrpcServer::deserialize_tls_config")]
88 pub tls_config: Option<ServerTlsConfig>,
89 pub compression: ConfigGrpcCompression,
90 #[serde(deserialize_with = "deserialize_humansize_usize")]
92 pub max_decoding_message_size: usize,
93 #[serde(with = "humantime_serde")]
94 pub server_tcp_keepalive: Option<Duration>,
95 pub server_tcp_nodelay: bool,
96 pub server_http2_adaptive_window: Option<bool>,
97 #[serde(with = "humantime_serde")]
98 pub server_http2_keepalive_interval: Option<Duration>,
99 #[serde(with = "humantime_serde")]
100 pub server_http2_keepalive_timeout: Option<Duration>,
101 pub server_initial_connection_window_size: Option<u32>,
102 pub server_initial_stream_window_size: Option<u32>,
103 #[serde(deserialize_with = "deserialize_x_tokens_set")]
104 pub x_tokens: HashSet<Vec<u8>>,
105}
106
107impl Default for ConfigGrpcServer {
108 fn default() -> Self {
109 Self {
110 endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10100),
111 tls_config: None,
112 compression: ConfigGrpcCompression::default(),
113 max_decoding_message_size: 4 * 1024 * 1024, server_tcp_keepalive: Some(Duration::from_secs(15)),
115 server_tcp_nodelay: true,
116 server_http2_adaptive_window: None,
117 server_http2_keepalive_interval: None,
118 server_http2_keepalive_timeout: None,
119 server_initial_connection_window_size: None,
120 server_initial_stream_window_size: None,
121 x_tokens: HashSet::new(),
122 }
123 }
124}
125
126impl ConfigGrpcServer {
127 pub fn deserialize_tls_config<'de, D>(
128 deserializer: D,
129 ) -> Result<Option<ServerTlsConfig>, D::Error>
130 where
131 D: Deserializer<'de>,
132 {
133 #[derive(Debug, Deserialize)]
134 #[serde(deny_unknown_fields)]
135 struct ConfigTls<'a> {
136 cert: &'a str,
137 key: &'a str,
138 }
139
140 Option::<ConfigTls>::deserialize(deserializer)?
141 .map(|config| {
142 let cert = fs::read(config.cert).map_err(|error| {
143 de::Error::custom(format!("failed to read cert {}: {error:?}", config.cert))
144 })?;
145 let key = fs::read(config.key).map_err(|error| {
146 de::Error::custom(format!("failed to read key {}: {error:?}", config.key))
147 })?;
148
149 Ok(ServerTlsConfig::new().identity(Identity::from_pem(cert, key)))
150 })
151 .transpose()
152 }
153
154 pub fn create_server_builder(&self) -> Result<(TcpIncoming, Server), CreateServerError> {
155 let incoming = TcpIncoming::bind(self.endpoint)
157 .map_err(|error| CreateServerError::Bind {
158 error,
159 endpoint: self.endpoint,
160 })?
161 .with_nodelay(Some(self.server_tcp_nodelay))
162 .with_keepalive(self.server_tcp_keepalive);
163
164 let mut server_builder = Server::builder();
166 if let Some(tls_config) = self.tls_config.clone() {
167 server_builder = server_builder.tls_config(tls_config)?;
168 }
169 if let Some(enabled) = self.server_http2_adaptive_window {
170 server_builder = server_builder.http2_adaptive_window(Some(enabled));
171 }
172 if let Some(http2_keepalive_interval) = self.server_http2_keepalive_interval {
173 server_builder =
174 server_builder.http2_keepalive_interval(Some(http2_keepalive_interval));
175 }
176 if let Some(http2_keepalive_timeout) = self.server_http2_keepalive_timeout {
177 server_builder = server_builder.http2_keepalive_timeout(Some(http2_keepalive_timeout));
178 }
179 if let Some(sz) = self.server_initial_connection_window_size {
180 server_builder = server_builder.initial_connection_window_size(sz);
181 }
182 if let Some(sz) = self.server_initial_stream_window_size {
183 server_builder = server_builder.initial_stream_window_size(sz);
184 }
185
186 Ok((incoming, server_builder))
187 }
188}
189
190#[derive(Debug, Error)]
191pub enum CreateServerError {
192 #[error("failed to bind {endpoint}: {error}")]
193 Bind {
194 error: std::io::Error,
195 endpoint: SocketAddr,
196 },
197 #[error("failed to apply tls_config: {0}")]
198 Tls(#[from] tonic::transport::Error),
199}
200
201pub struct GrpcServer<S, F1, F2> {
202 messages: S,
203 subscribe_id: AtomicU64,
204 on_conn_new_cb: F1,
205 on_conn_drop_cb: F2,
206 version: Version<'static>,
207}
208
209impl<S, F1, F2> fmt::Debug for GrpcServer<S, F1, F2> {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_struct("GrpcServer")
212 .field("subscribe_id", &self.subscribe_id)
213 .field("version", &self.version)
214 .finish()
215 }
216}
217
218impl<S, F1, F2> GrpcServer<S, F1, F2>
219where
220 S: Subscribe + Send + Sync + 'static,
221 F1: Fn() + Clone + Unpin + Send + Sync + 'static,
222 F2: Fn() + Clone + Unpin + Send + Sync + 'static,
223{
224 pub async fn spawn(
225 config: ConfigGrpcServer,
226 messages: S,
227 on_conn_new_cb: F1,
228 on_conn_drop_cb: F2,
229 version: Version<'static>,
230 shutdown: CancellationToken,
231 ) -> Result<impl Future<Output = Result<(), JoinError>>, CreateServerError> {
232 let (incoming, server_builder) = config.create_server_builder()?;
233 info!("start server at {}", config.endpoint);
234
235 let mut service = geyser_gen::geyser_server::GeyserServer::new(Self {
236 messages,
237 subscribe_id: AtomicU64::new(0),
238 on_conn_new_cb,
239 on_conn_drop_cb,
240 version,
241 })
242 .max_decoding_message_size(config.max_decoding_message_size);
243 for encoding in config.compression.accept {
244 service = service.accept_compressed(encoding);
245 }
246 for encoding in config.compression.send {
247 service = service.send_compressed(encoding);
248 }
249
250 Ok(tokio::spawn(async move {
252 if let Err(error) = server_builder
253 .layer(InterceptorLayer::new(move |request: Request<()>| {
254 if config.x_tokens.is_empty() {
255 Ok(request)
256 } else {
257 match request.metadata().get("x-token") {
258 Some(token) if config.x_tokens.contains(token.as_bytes()) => {
259 Ok(request)
260 }
261 _ => Err(Status::unauthenticated("No valid auth token")),
262 }
263 }
264 }))
265 .add_service(service)
266 .serve_with_incoming_shutdown(incoming, shutdown.cancelled())
267 .await
268 {
269 error!("server error: {error:?}")
270 } else {
271 info!("shutdown")
272 }
273 }))
274 }
275}
276
277#[tonic::async_trait]
278impl<S, F1, F2> geyser_gen::geyser_server::Geyser for GrpcServer<S, F1, F2>
279where
280 S: Subscribe + Send + Sync + 'static,
281 F2: Fn() + Clone + Unpin + Send + Sync + 'static,
282 F1: Fn() + Clone + Unpin + Send + Sync + 'static,
283{
284 type SubscribeStream = ReceiverStream<F2>;
285
286 async fn subscribe(
287 &self,
288 mut request: Request<Streaming<GrpcSubscribeRequest>>,
289 ) -> Result<Response<Self::SubscribeStream>, Status> {
290 let id = self.subscribe_id.fetch_add(1, Ordering::Relaxed);
291 info!("#{id}: new connection from {:?}", request.remote_addr());
292
293 let (replay_from_slot, filter) = match request.get_mut().message().await {
294 Ok(Some(GrpcSubscribeRequest {
295 replay_from_slot,
296 filter,
297 })) => (replay_from_slot, filter),
298 Ok(None) => {
299 info!("#{id}: connection closed before receiving request");
300 return Err(Status::aborted("stream closed before request received"));
301 }
302 Err(error) => {
303 error!("#{id}: error receiving request {error}");
304 return Err(Status::aborted("recv error"));
305 }
306 };
307
308 match self.messages.subscribe(replay_from_slot, filter) {
309 Ok(rx) => {
310 let pos = replay_from_slot
311 .map(|slot| format!("slot {slot}").into())
312 .unwrap_or(Cow::Borrowed("latest"));
313 info!("#{id}: subscribed from {pos}");
314 Ok(Response::new(ReceiverStream::new(
315 rx.boxed(),
316 id,
317 self.on_conn_new_cb.clone(), self.on_conn_drop_cb.clone(), )))
320 }
321 Err(SubscribeError::NotInitialized) => Err(Status::internal("not initialized")),
322 Err(SubscribeError::SlotNotAvailable { first_available }) => Err(
323 Status::invalid_argument(format!("first available slot: {first_available}")),
324 ),
325 }
326 }
327
328 async fn get_version(
329 &self,
330 _request: Request<GetVersionRequest>,
331 ) -> Result<Response<GetVersionResponse>, Status> {
332 Ok(Response::new(GetVersionResponse {
333 version: self.version.create_grpc_version_info().json(),
334 }))
335 }
336}
337
338pub struct ReceiverStream<F2: Fn()> {
339 rx: RecvStream,
340 id: u64,
341 on_conn_drop_cb: F2,
342}
343
344impl<F2: Fn()> fmt::Debug for ReceiverStream<F2> {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 f.debug_struct("ReceiverStream").finish()
347 }
348}
349
350impl<F2: Fn()> ReceiverStream<F2> {
351 fn new<F1: Fn()>(rx: RecvStream, id: u64, on_conn_new_cb: F1, on_conn_drop_cb: F2) -> Self {
352 on_conn_new_cb();
353 Self {
354 rx,
355 id,
356 on_conn_drop_cb,
357 }
358 }
359}
360
361impl<F2: Fn()> Drop for ReceiverStream<F2> {
362 fn drop(&mut self) {
363 info!("#{}: send stream closed", self.id);
364 (self.on_conn_drop_cb)();
365 }
366}
367
368impl<F2: Fn() + Unpin> Stream for ReceiverStream<F2> {
369 type Item = Result<Arc<Vec<u8>>, Status>;
370
371 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372 match ready!(self.rx.poll_next_unpin(cx)) {
373 Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
374 Some(Err(error)) => {
375 error!("#{}: failed to get message: {error}", self.id);
376 match error {
377 RecvError::Lagged => Poll::Ready(Some(Err(Status::out_of_range("lagged")))),
378 RecvError::Closed => Poll::Ready(Some(Err(Status::out_of_range("closed")))),
379 }
380 }
381 None => Poll::Ready(None),
382 }
383 }
384}
385
386trait SubscribeMessage {
387 fn encode(self, buf: &mut EncodeBuf<'_>);
388}
389
390impl SubscribeMessage for &[u8] {
391 fn encode(self, buf: &mut EncodeBuf<'_>) {
392 let required = self.len();
393 let remaining = buf.remaining_mut();
394 if required > remaining {
395 panic!("SubscribeMessage only errors if not enough space");
396 }
397 buf.put_slice(self.as_ref());
398 }
399}
400
401impl SubscribeMessage for Vec<u8> {
402 fn encode(self, buf: &mut EncodeBuf<'_>) {
403 self.as_slice().encode(buf);
404 }
405}
406
407impl SubscribeMessage for Arc<Vec<u8>> {
408 fn encode(self, buf: &mut EncodeBuf<'_>) {
409 self.as_slice().encode(buf);
410 }
411}
412
413pub struct SubscribeCodec<T, U> {
414 _pd: PhantomData<(T, U)>,
415}
416
417impl<T, U> Default for SubscribeCodec<T, U> {
418 fn default() -> Self {
419 Self { _pd: PhantomData }
420 }
421}
422
423impl<T, U> Codec for SubscribeCodec<T, U>
424where
425 T: SubscribeMessage + Send + 'static,
426 U: Message + Default + Send + 'static,
427{
428 type Encode = T;
429 type Decode = U;
430
431 type Encoder = SubscribeEncoder<T>;
432 type Decoder = ProstDecoder<U>;
433
434 fn encoder(&mut self) -> Self::Encoder {
435 SubscribeEncoder(PhantomData)
436 }
437
438 fn decoder(&mut self) -> Self::Decoder {
439 ProstDecoder(PhantomData)
440 }
441}
442
443#[derive(Debug, Clone, Default)]
445pub struct SubscribeEncoder<T>(PhantomData<T>);
446
447impl<T: SubscribeMessage> Encoder for SubscribeEncoder<T> {
448 type Item = T;
449 type Error = Status;
450
451 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
452 item.encode(buf);
453 Ok(())
454 }
455}
456
457#[derive(Debug, Clone, Default)]
459pub struct ProstDecoder<U>(PhantomData<U>);
460
461impl<U: Message + Default> Decoder for ProstDecoder<U> {
462 type Item = U;
463 type Error = Status;
464
465 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
466 let item = Message::decode(buf)
467 .map(Option::Some)
468 .map_err(from_decode_error)?;
469
470 Ok(item)
471 }
472}
473
474fn from_decode_error(error: prost::DecodeError) -> Status {
475 Status::new(tonic::Code::Internal, error.to_string())
478}