richat-shared 8.0.1

Richat Shared code
Documentation
use {
    crate::jsonrpc::{
        helpers::{
            RpcResponse, get_x_bigtable_disabled, get_x_subscription_id, response_200,
            response_400, response_500, to_vec,
        },
        metrics::{
            RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
        },
    },
    futures::{
        future::BoxFuture,
        stream::{FuturesOrdered, StreamExt},
    },
    http_body_util::{BodyExt, Limited},
    hyper::{
        HeaderMap,
        body::{Bytes, Incoming as BodyIncoming},
        http::Result as HttpResult,
    },
    jsonrpsee_types::{
        Extensions, Request, Response, ResponsePayload, TwoPointZero, error::ErrorCode,
    },
    metrics::{counter, histogram},
    quanta::Instant,
    richat_metrics::duration_to_seconds,
    std::{collections::HashMap, fmt, sync::Arc},
};

pub type RpcRequestResult = anyhow::Result<Vec<u8>>;

pub type RpcRequestHandler<S> =
    Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;

#[derive(Debug)]
enum RpcRequests<'a> {
    Single(Request<'a>),
    Batch(Vec<Request<'a>>),
}

impl<'a> RpcRequests<'a> {
    fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
        for i in 0..bytes.len() {
            if bytes[i] == b'[' {
                return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
            } else if bytes[i] == b'{' {
                break;
            }
        }
        serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
    }
}

pub struct RpcRequestsProcessor<S> {
    body_limit: usize,
    state: S,
    extra_headers: HeaderMap,
    methods: HashMap<&'static str, RpcRequestHandler<S>>,
}

impl<S> fmt::Debug for RpcRequestsProcessor<S> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RpcRequestsProcessor").finish()
    }
}

impl<S: Clone> RpcRequestsProcessor<S> {
    pub fn new(body_limit: usize, state: S, extra_headers: HeaderMap) -> Self {
        Self {
            body_limit,
            state,
            extra_headers,
            methods: HashMap::new(),
        }
    }

    pub fn add_handler(
        &mut self,
        method: &'static str,
        handler: RpcRequestHandler<S>,
    ) -> &mut Self {
        self.methods.insert(method, handler);
        self
    }

    pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
        let (parts, body) = req.into_parts();

        let x_subscription_id = get_x_subscription_id(&parts.headers);
        let upstream_disabled = get_x_bigtable_disabled(&parts.headers);

        let bytes = match Limited::new(body, self.body_limit).collect().await {
            Ok(body) => body.to_bytes(),
            Err(error) => return response_400(error),
        };
        let requests = match RpcRequests::parse(&bytes) {
            Ok(requests) => requests,
            Err(error) => return response_400(error),
        };

        let mut buffer = match requests {
            RpcRequests::Single(request) => {
                match self
                    .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
                    .await
                {
                    Ok(response) => response,
                    Err(error) => return response_500(error),
                }
            }
            RpcRequests::Batch(requests) => {
                let mut futures = FuturesOrdered::new();
                for request in requests {
                    let x_subscription_id = Arc::clone(&x_subscription_id);
                    futures.push_back(self.process(
                        Arc::clone(&x_subscription_id),
                        upstream_disabled,
                        request,
                    ));
                }

                let mut buffer = Vec::new();
                buffer.push(b'[');
                while let Some(result) = futures.next().await {
                    match result {
                        Ok(mut response) => {
                            buffer.append(&mut response);
                        }
                        Err(error) => return response_500(error),
                    }
                    if !futures.is_empty() {
                        buffer.push(b',');
                    }
                }
                buffer.push(b']');
                buffer
            }
        };
        buffer.push(b'\n');
        counter!(
            RPC_REQUESTS_GENERATED_BYTES_TOTAL,
            "x_subscription_id" => x_subscription_id,
        )
        .increment(buffer.len() as u64);
        response_200(buffer, &self.extra_headers)
    }

    async fn process<'a>(
        &'a self,
        x_subscription_id: Arc<str>,
        upstream_disabled: bool,
        request: Request<'a>,
    ) -> anyhow::Result<Vec<u8>> {
        let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
            return Ok(to_vec(&Response {
                jsonrpc: Some(TwoPointZero),
                payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
                id: request.id.into_owned(),
                extensions: Extensions::default(), // doesn't matter, as it is not used in serialize
            }));
        };

        let ts = Instant::now();
        let result = handle(
            self.state.clone(),
            Arc::clone(&x_subscription_id),
            upstream_disabled,
            request,
        )
        .await;
        counter!(
            RPC_REQUESTS_TOTAL,
            "x_subscription_id" => Arc::clone(&x_subscription_id),
            "method" => *method,
        )
        .increment(1);
        histogram!(
            RPC_REQUESTS_DURATION_SECONDS,
            "x_subscription_id" => x_subscription_id,
            "method" => *method,
        )
        .record(duration_to_seconds(ts.elapsed()));
        result
    }
}