ydb 0.10.2

Crate contains generated low-level grpc code from YDB API protobuf, used as base for ydb crate
Documentation
use itertools::enumerate;
use std::any::Any;
use std::fmt::{Debug, Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tonic::transport::Channel;

pub(crate) type InterceptorResult<T> = std::result::Result<T, InterceptorError>;
pub(crate) type InterceptorRequest = http::request::Request<tonic::body::Body>;

#[derive(Clone)]
pub(crate) struct InterceptedChannel {
    inner: Channel,
    interceptor: MultiInterceptor,
}

impl InterceptedChannel {
    pub fn new(channel: Channel, interceptor: MultiInterceptor) -> Self {
        Self {
            inner: channel,
            interceptor,
        }
    }
}

impl tower::Service<InterceptorRequest> for InterceptedChannel {
    type Response = ChannelResponse;
    type Error = InterceptorError;
    type Future = ChannelFuture;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner
            .poll_ready(cx)
            .map_err(InterceptorError::Transport)
    }

    fn call(&mut self, mut req: InterceptorRequest) -> Self::Future {
        let mut metadata: RequestMetadata = None;
        req = match self.interceptor.on_call(&mut metadata, req) {
            Ok(res) => res,
            Err(err) => return ChannelFuture::Error(Some(err)),
        };

        ChannelFuture::Future(ChannelFutureState {
            channel_future: self.inner.call(req),
            interceptor: self.interceptor.clone(),
            metadata,
        })
    }
}

impl Debug for InterceptedChannel {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "channel: {:?}, incerceptors count: {}",
            self.inner,
            self.interceptor.interceptors.len()
        )
    }
}

pub(crate) type ChannelResponse = <Channel as tower::Service<InterceptorRequest>>::Response;

pub(crate) enum ChannelFuture {
    Error(Option<InterceptorError>),
    Future(ChannelFutureState),
}

pub(crate) struct ChannelFutureState {
    channel_future: <Channel as tower::Service<InterceptorRequest>>::Future,
    interceptor: MultiInterceptor,
    metadata: RequestMetadata,
}

impl Future for ChannelFuture {
    type Output = std::result::Result<ChannelResponse, InterceptorError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let res: Poll<Self::Output> = match self.get_mut() {
            ChannelFuture::Error(None) => Poll::Ready(Err(InterceptorError::internal(
                "interceptor error is empty",
            ))),
            ChannelFuture::Error(err) => {
                let err_content = err.take().unwrap();
                *err = Some(InterceptorError::Internal(format!(
                    "interceptor err consumed already, prev err: '{err_content}'"
                )));
                Poll::Ready(Err(err_content))
            }
            ChannelFuture::Future(state) => {
                let poll_res = Future::poll(Pin::new(&mut state.channel_future), cx);

                match poll_res {
                    Poll::Ready(res) => {
                        let mut res = res.map_err(InterceptorError::Transport);
                        res = state
                            .interceptor
                            .on_feature_poll_ready(&mut state.metadata, res);
                        res.into()
                    }
                    Poll::Pending => Poll::Pending,
                }
            }
        };
        res
    }
}

pub(crate) trait GrpcInterceptor: Send + Sync {
    fn on_call(
        &self,
        _metadata: &mut RequestMetadata,
        req: InterceptorRequest,
    ) -> InterceptorResult<InterceptorRequest> {
        Ok(req)
    }

    fn on_feature_poll_ready(
        &self,
        _metadata: &mut RequestMetadata,
        res: Result<ChannelResponse, InterceptorError>,
    ) -> Result<ChannelResponse, InterceptorError> {
        res
    }
}

#[derive(Clone)]
pub(crate) struct MultiInterceptor {
    interceptors: Vec<Arc<Box<dyn GrpcInterceptor>>>,
}

impl MultiInterceptor {
    pub fn new() -> Self {
        Self {
            interceptors: Vec::new(),
        }
    }

    pub fn with_interceptor<T: GrpcInterceptor + 'static>(mut self, interceptor: T) -> Self {
        let boxed_interceptor: Box<dyn GrpcInterceptor> = Box::new(interceptor);
        let arc_boxed_interceptor = Arc::new(boxed_interceptor);
        self.interceptors.push(arc_boxed_interceptor);
        self
    }
}

impl GrpcInterceptor for MultiInterceptor {
    fn on_call(
        &self,
        metadata: &mut RequestMetadata,
        mut req: InterceptorRequest,
    ) -> InterceptorResult<InterceptorRequest> {
        let mut metadatas: Vec<RequestMetadata> = Vec::new();
        metadatas.resize_with(self.interceptors.len(), || None);
        for (index, interceptor) in enumerate(self.interceptors.iter()) {
            req = interceptor.on_call(&mut metadatas[index], req)?;
        }
        *metadata = Some(Box::new(metadatas));
        Ok(req)
    }

    fn on_feature_poll_ready(
        &self,
        metadata: &mut RequestMetadata,
        mut res: Result<ChannelResponse, InterceptorError>,
    ) -> Result<ChannelResponse, InterceptorError> {
        let metadata = metadata
            .as_mut()
            .unwrap()
            .downcast_mut::<Vec<RequestMetadata>>()
            .unwrap();

        for (index, interceptor) in enumerate(self.interceptors.iter()) {
            let item_meta = &mut metadata[index];
            res = interceptor.on_feature_poll_ready(item_meta, res)
        }
        res
    }
}

pub(crate) type RequestMetadata = Option<Box<dyn Any + Send>>;

pub(crate) enum InterceptorError {
    Custom(String),
    Internal(String),
    Transport(tonic::transport::Error),
}

impl InterceptorError {
    pub fn custom<S: Into<String>>(text: S) -> Self {
        InterceptorError::Custom(text.into())
    }

    pub fn internal<S: Into<String>>(text: S) -> Self {
        InterceptorError::Internal(text.into())
    }
}

impl Debug for InterceptorError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        use InterceptorError::*;

        match self {
            Custom(text) => write!(f, "interceptor custom error: '{text}'"),
            Internal(text) => write!(f, "interceptor internal error: '{text}'"),
            Transport(err) => write!(f, "interceptor transport error: {err:?}"),
        }
    }
}

impl Display for InterceptorError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        <Self as Debug>::fmt(self, f)
    }
}

impl std::error::Error for InterceptorError {}