ntex-grpc 0.2.0-b.2

GRPC Client/Server framework
Documentation
use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};

use ntex_bytes::{Buf, BufMut};
use ntex_h2::{self as h2, frame::StreamId};
use ntex_http::{HeaderMap, HeaderValue, StatusCode};
use ntex_io::{Filter, Io, IoBoxed};
use ntex_util::{future::Either, future::Ready, HashMap};

pub use ntex_bytes::{ByteString, Bytes, BytesMut};
pub use ntex_service::{Service, ServiceFactory};

pub use crate::error::ServerError;
use crate::{consts, status::GrpcStatus, utils::Data};

#[derive(Debug)]
pub struct Request {
    pub name: ByteString,
    pub payload: Bytes,
    pub headers: HeaderMap,
}

#[derive(Debug)]
pub struct Response {
    pub payload: Bytes,
}

impl Response {
    #[inline]
    pub fn new(payload: Bytes) -> Response {
        Response { payload }
    }
}

/// Grpc server
pub struct GrpcServer<T> {
    factory: Rc<T>,
}

impl<T> GrpcServer<T> {
    /// Create grpc server
    pub fn new(factory: T) -> Self {
        Self {
            factory: Rc::new(factory),
        }
    }
}

impl<T> GrpcServer<T>
where
    T: ServiceFactory<Request, Response = Response, Error = ServerError>,
    T::Service: Clone,
{
    /// Create default server
    pub fn make_server(&self) -> GrpcService<T> {
        GrpcService {
            factory: self.factory.clone(),
        }
    }
}

impl<F, T> ServiceFactory<Io<F>> for GrpcServer<T>
where
    F: Filter,
    T: ServiceFactory<Request, Response = Response, Error = ServerError> + 'static,
    T::Service: Clone,
{
    type Response = ();
    type Error = T::InitError;
    type Service = GrpcService<T>;
    type InitError = ();
    type Future = Ready<Self::Service, Self::InitError>;

    fn new_service(&self, _: ()) -> Self::Future {
        Ready::Ok(self.make_server())
    }
}

pub struct GrpcService<T> {
    factory: Rc<T>,
}

impl<T, F> Service<Io<F>> for GrpcService<T>
where
    F: Filter,
    T: ServiceFactory<Request, Response = Response, Error = ServerError> + 'static,
{
    type Response = ();
    type Error = T::InitError;
    type Future = Pin<Box<dyn Future<Output = Result<(), Self::Error>>>>;

    #[inline]
    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&self, io: Io<F>) -> Self::Future {
        let fut = self.factory.new_service(());

        Box::pin(async move {
            // init server
            let service = fut.await?;

            let _ = h2::server::handle_one(
                io.into(),
                h2::Config::server(),
                ControlService,
                PublishService::new(service),
            )
            .await;

            Ok(())
        })
    }
}

impl<T> Service<IoBoxed> for GrpcService<T>
where
    T: ServiceFactory<Request, Response = Response, Error = ServerError> + 'static,
{
    type Response = ();
    type Error = T::InitError;
    type Future = Pin<Box<dyn Future<Output = Result<(), Self::Error>>>>;

    #[inline]
    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&self, io: IoBoxed) -> Self::Future {
        let fut = self.factory.new_service(());

        Box::pin(async move {
            // init server
            let service = fut.await?;

            let _ = h2::server::handle_one(
                io,
                h2::Config::server(),
                ControlService,
                PublishService::new(service),
            )
            .await;

            Ok(())
        })
    }
}

struct ControlService;

impl Service<h2::ControlMessage<h2::StreamError>> for ControlService {
    type Response = h2::ControlResult;
    type Error = ();
    type Future = Ready<Self::Response, Self::Error>;

    #[inline]
    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    #[inline]
    fn poll_shutdown(&self, _: &mut Context<'_>, _: bool) -> Poll<()> {
        Poll::Ready(())
    }

    fn call(&self, msg: h2::ControlMessage<h2::StreamError>) -> Self::Future {
        log::trace!("Control message: {:?}", msg);
        Ready::Ok::<_, ()>(msg.ack())
    }
}

struct PublishService<S: Service<Request>> {
    service: S,
    streams: RefCell<HashMap<StreamId, Inflight>>,
}

struct Inflight {
    name: ByteString,
    service: ByteString,
    data: Data,
    headers: HeaderMap,
}

impl<S> PublishService<S>
where
    S: Service<Request, Response = Response, Error = ServerError>,
{
    fn new(service: S) -> Self {
        Self {
            service,
            streams: RefCell::new(HashMap::default()),
        }
    }
}

impl<S> Service<h2::Message> for PublishService<S>
where
    S: Service<Request, Response = Response, Error = ServerError> + 'static,
{
    type Response = ();
    type Error = h2::StreamError;
    type Future = Either<
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>,
        Ready<Self::Response, Self::Error>,
    >;

    #[inline]
    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    #[inline]
    fn poll_shutdown(&self, _: &mut Context<'_>, _: bool) -> Poll<()> {
        Poll::Ready(())
    }

    fn call(&self, mut msg: h2::Message) -> Self::Future {
        let id = msg.id();
        let mut streams = self.streams.borrow_mut();

        match msg.kind().take() {
            h2::MessageKind::Headers {
                headers,
                pseudo,
                eof,
            } => {
                let mut path = pseudo.path.unwrap().split_off(1);
                let srvname = if let Some(n) = path.find('/') {
                    path.split_to(n)
                } else {
                    // not found
                    let _ = msg.stream().send_response(
                        StatusCode::NOT_FOUND,
                        HeaderMap::default(),
                        true,
                    );
                    return Either::Right(Ready::Ok(()));
                };

                // stream eof, cannot do anything
                if eof {
                    if msg
                        .stream()
                        .send_response(StatusCode::OK, HeaderMap::default(), false)
                        .is_ok()
                    {
                        let mut trailers = HeaderMap::default();
                        trailers.insert(consts::GRPC_STATUS, GrpcStatus::InvalidArgument.into());
                        trailers.insert(
                            consts::GRPC_MESSAGE,
                            HeaderValue::from_static("Cannot decode request message"),
                        );
                        msg.stream().send_trailers(trailers);
                    }
                    return Either::Right(Ready::Ok(()));
                }

                let mut path = path.split_off(1);
                let methodname = if let Some(n) = path.find('/') {
                    path.split_to(n)
                } else {
                    path
                };

                let _ = streams.insert(
                    msg.id(),
                    Inflight {
                        headers,
                        data: Data::Empty,
                        name: methodname,
                        service: srvname,
                    },
                );
            }
            h2::MessageKind::Data(data, _cap) => {
                if let Some(inflight) = streams.get_mut(&msg.id()) {
                    inflight.data.push(data);
                }
            }
            h2::MessageKind::Eof(data) => {
                let mut inflight = streams.remove(&id).unwrap();

                match data {
                    h2::StreamEof::Data(chunk) => inflight.data.push(chunk),
                    h2::StreamEof::Trailers(hdrs) => {
                        for (name, val) in hdrs.iter() {
                            inflight.headers.insert(name.clone(), val.clone());
                        }
                    }
                    h2::StreamEof::Error(err) => return Either::Right(Ready::Err(err)),
                }

                let mut data = inflight.data.get();
                let _compressed = data.get_u8();
                let len = data.get_u32();
                let data = data.split_to(len as usize);

                log::debug!("Call service {} method {}", inflight.service, inflight.name);
                let req = Request {
                    payload: data,
                    name: inflight.name,
                    headers: inflight.headers,
                };
                if msg
                    .stream()
                    .send_response(StatusCode::OK, HeaderMap::default(), false)
                    .is_err()
                {
                    return Either::Right(Ready::Ok(()));
                }

                let fut = self.service.call(req);
                return Either::Left(Box::pin(async move {
                    match fut.await {
                        Ok(res) => {
                            log::debug!("Response is received {:?}", res);
                            let mut buf = BytesMut::with_capacity(res.payload.len() + 5);
                            buf.put_u8(0); // compression
                            buf.put_u32(res.payload.len() as u32); // length
                            buf.extend_from_slice(&res.payload);

                            let _ = msg.stream().send_payload(buf.freeze(), false).await;

                            let mut trailers = HeaderMap::default();
                            trailers.insert(consts::GRPC_STATUS, GrpcStatus::Ok.into());
                            msg.stream().send_trailers(trailers);
                        }
                        Err(err) => {
                            let error = format!("Failure during service call: {}", err);
                            log::debug!("{}", error);
                            let mut trailers = HeaderMap::default();
                            trailers.insert(consts::GRPC_STATUS, GrpcStatus::Aborted.into());
                            if let Ok(val) = HeaderValue::from_str(&error) {
                                trailers.insert(consts::GRPC_MESSAGE, val);
                            }
                            msg.stream().send_trailers(trailers);
                        }
                    };

                    Ok(())
                }));
            }
            h2::MessageKind::Disconnect(_) | h2::MessageKind::Empty => {
                streams.remove(&id);
            }
        }
        Either::Right(Ready::Ok(()))
    }
}