ceresdb-client 1.0.2

Rust implementation of CeresDB client.
Documentation
// Copyright 2022 CeresDB Project Authors. Licensed under Apache-2.0.

//! Rpc client impl

use std::{sync::Arc, time::Duration};

use async_trait::async_trait;
use ceresdbproto::{
    common::ResponseHeader,
    storage::{
        storage_service_client::StorageServiceClient, RouteRequest as RouteRequestPb,
        RouteResponse as RouteResponsePb, SqlQueryRequest, SqlQueryResponse,
        WriteRequest as WriteRequestPb, WriteResponse as WriteResponsePb,
    },
};
use tonic::{
    transport::{Channel, Endpoint},
    Request,
};

use crate::{
    config::RpcConfig,
    errors::{Error, Result, ServerError},
    rpc_client::{RpcClient, RpcClientFactory, RpcContext},
    util::is_ok,
};

struct RpcClientImpl {
    channel: Channel,
    default_read_timeout: Duration,
    default_write_timeout: Duration,
}

impl RpcClientImpl {
    fn new(
        channel: Channel,
        default_read_timeout: Duration,
        default_write_timeout: Duration,
    ) -> Self {
        Self {
            channel,
            default_read_timeout,
            default_write_timeout,
        }
    }

    fn check_status(header: ResponseHeader) -> Result<()> {
        if !is_ok(header.code) {
            return Err(Error::Server(ServerError {
                code: header.code,
                msg: header.error,
            }));
        }

        Ok(())
    }

    fn make_request<T>(ctx: &RpcContext, req: T, default_timeout: Duration) -> Request<T> {
        let timeout = ctx.timeout.unwrap_or(default_timeout);
        let mut req = Request::new(req);
        req.set_timeout(timeout);
        req
    }

    fn make_query_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
        Self::make_request(ctx, req, self.default_read_timeout)
    }

    fn make_write_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
        Self::make_request(ctx, req, self.default_write_timeout)
    }
}

#[async_trait]
impl RpcClient for RpcClientImpl {
    async fn sql_query(&self, ctx: &RpcContext, req: SqlQueryRequest) -> Result<SqlQueryResponse> {
        let mut client = StorageServiceClient::<Channel>::new(self.channel.clone());

        let resp = client
            .sql_query(self.make_query_request(ctx, req))
            .await
            .map_err(Error::Rpc)?;
        let mut resp = resp.into_inner();

        if let Some(header) = resp.header.take() {
            Self::check_status(header)?;
        }

        Ok(resp)
    }

    async fn write(&self, ctx: &RpcContext, req: WriteRequestPb) -> Result<WriteResponsePb> {
        let mut client = StorageServiceClient::<Channel>::new(self.channel.clone());

        let resp = client
            .write(self.make_write_request(ctx, req))
            .await
            .map_err(Error::Rpc)?;
        let mut resp = resp.into_inner();

        if let Some(header) = resp.header.take() {
            Self::check_status(header)?;
        }

        Ok(resp)
    }

    async fn route(&self, ctx: &RpcContext, req: RouteRequestPb) -> Result<RouteResponsePb> {
        let mut client = StorageServiceClient::<Channel>::new(self.channel.clone());

        // use the write timeout for the route request.
        let route_req = Self::make_request(ctx, req, self.default_write_timeout);
        let resp = client.route(route_req).await.map_err(Error::Rpc)?;
        let mut resp = resp.into_inner();

        if let Some(header) = resp.header.take() {
            Self::check_status(header)?;
        }

        Ok(resp)
    }
}

pub struct RpcClientImplFactory {
    rpc_config: RpcConfig,
}

impl RpcClientImplFactory {
    pub fn new(rpc_config: RpcConfig) -> Self {
        Self { rpc_config }
    }

    #[inline]
    fn make_endpoint_with_scheme(endpoint: &str) -> String {
        format!("http://{endpoint}")
    }
}

#[async_trait]
impl RpcClientFactory for RpcClientImplFactory {
    /// The endpoint should be in the form: `{ip_addr}:{port}`.
    async fn build(&self, endpoint: String) -> Result<Arc<dyn RpcClient>> {
        let endpoint_with_scheme = Self::make_endpoint_with_scheme(&endpoint);
        let configured_endpoint =
            Endpoint::from_shared(endpoint_with_scheme).map_err(|e| Error::Connect {
                addr: endpoint.clone(),
                source: Box::new(e),
            })?;

        let configured_endpoint = match self.rpc_config.keep_alive_while_idle {
            true => configured_endpoint
                .connect_timeout(self.rpc_config.connect_timeout)
                .keep_alive_timeout(self.rpc_config.keep_alive_timeout)
                .keep_alive_while_idle(true)
                .http2_keep_alive_interval(self.rpc_config.keep_alive_interval),
            false => configured_endpoint
                .connect_timeout(self.rpc_config.connect_timeout)
                .keep_alive_while_idle(false),
        };
        let channel = configured_endpoint
            .connect()
            .await
            .map_err(|e| Error::Connect {
                addr: endpoint,
                source: Box::new(e),
            })?;
        Ok(Arc::new(RpcClientImpl::new(
            channel,
            self.rpc_config.default_sql_query_timeout,
            self.rpc_config.default_write_timeout,
        )))
    }
}