tycho-rpc 0.3.9

Public RPC service.
Documentation
use std::borrow::Cow;

use anyhow::Result;
use axum::body::Bytes;
use axum::extract::{FromRequest, Request};
use axum::http::header::CONTENT_TYPE;
use axum::http::{HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use prost::Message;

use crate::endpoint::{APPLICATION_PROTOBUF, proto};
use crate::util::error_codes::PARSE_ERROR_CODE;

// Counters
const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_rpc_in_req_fail_total";

pub struct Protobuf<T>(pub T);

impl<S, T> FromRequest<S> for Protobuf<T>
where
    T: Message + Default,
    S: Send + Sync,
{
    type Rejection = ProtoErrorResponse;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let bytes = match Bytes::from_request(req, state).await {
            Ok(bytes) => bytes,
            Err(e) => {
                return Err(ProtoErrorResponse {
                    code: PARSE_ERROR_CODE,
                    message: e.to_string().into(),
                });
            }
        };
        let message = match T::decode(bytes) {
            Ok(message) => message,
            Err(e) => {
                return Err(ProtoErrorResponse {
                    code: PARSE_ERROR_CODE,
                    message: e.to_string().into(),
                });
            }
        };
        Ok(Protobuf(message))
    }
}

impl<T> IntoResponse for Protobuf<T>
where
    T: Message,
{
    fn into_response(self) -> Response {
        let body = self.0.encode_to_vec();
        let mut res = Response::new(body.into());
        res.headers_mut()
            .insert(CONTENT_TYPE, HeaderValue::from_static(APPLICATION_PROTOBUF));
        res
    }
}

pub struct ProtobufRef<'a, T>(pub &'a T);

impl<T> IntoResponse for ProtobufRef<'_, T>
where
    T: Message,
{
    fn into_response(self) -> Response {
        let body = self.0.encode_to_vec();
        let mut res = Response::new(body.into());
        res.headers_mut()
            .insert(CONTENT_TYPE, HeaderValue::from_static(APPLICATION_PROTOBUF));
        res
    }
}

pub struct ProtoOkResponse(proto::rpc::Response);

impl ProtoOkResponse {
    pub fn new(result: proto::rpc::response::Result) -> Self {
        ProtoOkResponse(proto::rpc::Response {
            result: Some(result),
        })
    }

    pub fn into_raw(self) -> RawProtoOkResponse {
        RawProtoOkResponse::from(self)
    }

    pub fn as_raw(&self) -> RawProtoOkResponse {
        RawProtoOkResponse::from(self)
    }
}

impl IntoResponse for ProtoOkResponse {
    fn into_response(self) -> Response {
        (StatusCode::OK, Protobuf(self.0)).into_response()
    }
}

#[derive(Clone)]
pub struct RawProtoOkResponse(pub bytes::Bytes);

impl From<ProtoOkResponse> for RawProtoOkResponse {
    fn from(value: ProtoOkResponse) -> Self {
        RawProtoOkResponse(value.0.encode_to_vec().into())
    }
}

impl From<&ProtoOkResponse> for RawProtoOkResponse {
    fn from(value: &ProtoOkResponse) -> Self {
        RawProtoOkResponse(value.0.encode_to_vec().into())
    }
}

impl IntoResponse for RawProtoOkResponse {
    fn into_response(self) -> Response {
        let mut res = Response::new(self.0.into());
        res.headers_mut()
            .insert(CONTENT_TYPE, HeaderValue::from_static(APPLICATION_PROTOBUF));
        res
    }
}

pub struct ProtoErrorResponse {
    pub code: i32,
    pub message: Cow<'static, str>,
}

impl IntoResponse for ProtoErrorResponse {
    fn into_response(self) -> Response {
        metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);

        Protobuf(proto::rpc::Error {
            code: self.code,
            message: self.message.into(),
        })
        .into_response()
    }
}