batata-client 0.0.2

Rust client for Batata/Nacos service discovery and configuration management
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use parking_lot::RwLock;
use tokio::time::timeout;
use tonic::transport::Channel;
use tracing::{debug, warn};

use crate::api::remote::{RequestTrait, ResponseTrait};
use crate::api::{request_client::RequestClient as GrpcRequestClient, Payload};
use crate::common::DEFAULT_TIMEOUT_MS;
use crate::error::{BatataError, Result};
use crate::remote::{GrpcConnection, ServerListManager};

/// Client for making requests to Batata/Nacos server
pub struct RpcClient {
    server_list: Arc<ServerListManager>,
    connection: Arc<RwLock<Option<Arc<GrpcConnection>>>>,
    namespace: String,
    app_name: String,
    labels: HashMap<String, String>,
    timeout_ms: u64,
    retry_times: u32,
}

impl RpcClient {
    /// Create a new RPC client
    pub fn new(server_addresses: Vec<String>) -> Result<Self> {
        let server_list = Arc::new(ServerListManager::new(server_addresses)?);

        Ok(Self {
            server_list,
            connection: Arc::new(RwLock::new(None)),
            namespace: String::new(),
            app_name: String::new(),
            labels: HashMap::new(),
            timeout_ms: DEFAULT_TIMEOUT_MS,
            retry_times: 3,
        })
    }

    /// Set namespace
    pub fn with_namespace(mut self, namespace: &str) -> Self {
        self.namespace = namespace.to_string();
        self
    }

    /// Set app name
    pub fn with_app_name(mut self, app_name: &str) -> Self {
        self.app_name = app_name.to_string();
        self
    }

    /// Set labels
    pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
        self.labels = labels;
        self
    }

    /// Set timeout
    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
        self.timeout_ms = timeout_ms;
        self
    }

    /// Set retry times
    pub fn with_retry(mut self, retry_times: u32) -> Self {
        self.retry_times = retry_times;
        self
    }

    /// Start the client and establish connection
    pub async fn start(&self) -> Result<()> {
        let server = self.server_list.current_server().clone();

        let connection = GrpcConnection::new(server)
            .with_namespace(&self.namespace)
            .with_app_name(&self.app_name)
            .with_labels(self.labels.clone())
            .with_timeout(self.timeout_ms);

        connection.connect().await?;

        *self.connection.write() = Some(Arc::new(connection));

        Ok(())
    }

    /// Stop the client
    pub async fn stop(&self) {
        let conn = self.connection.write().take();
        if let Some(conn) = conn {
            conn.disconnect().await;
        }
    }

    /// Check if connected
    pub fn is_connected(&self) -> bool {
        self.connection
            .read()
            .as_ref()
            .map(|c| c.is_connected())
            .unwrap_or(false)
    }

    /// Get connection ID
    pub fn connection_id(&self) -> Option<String> {
        self.connection.read().as_ref().map(|c| c.connection_id())
    }

    /// Send request with retry
    pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
    where
        Req: RequestTrait + serde::Serialize + Clone,
        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
    {
        let connection = self
            .connection
            .read()
            .clone()
            .ok_or(BatataError::ClientNotStarted)?;

        let mut last_error = BatataError::NoAvailableServer;

        for attempt in 0..=self.retry_times {
            if attempt > 0 {
                debug!("Retry attempt {} for request", attempt);
            }

            match connection.request::<Req, Resp>(request).await {
                Ok(resp) => return Ok(resp),
                Err(e) => {
                    if e.is_retryable() && attempt < self.retry_times {
                        warn!("Request failed, will retry: {}", e);
                        last_error = e;
                        tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
                        continue;
                    }
                    return Err(e);
                }
            }
        }

        Err(last_error)
    }

    /// Send request without retry (for fire-and-forget)
    pub async fn send(&self, payload: Payload) -> Result<()> {
        let connection = self
            .connection
            .read()
            .clone()
            .ok_or(BatataError::ClientNotStarted)?;

        connection.send(payload).await
    }

    /// Get the current connection
    pub fn connection(&self) -> Option<Arc<GrpcConnection>> {
        self.connection.read().clone()
    }

    /// Simple unary request without connection state
    pub async fn unary_request<Req, Resp>(&self, request: &Req) -> Result<Resp>
    where
        Req: RequestTrait + serde::Serialize,
        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
    {
        let server = self.server_list.current_server();
        let endpoint = server.grpc_endpoint();

        let channel = Channel::from_shared(endpoint)
            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
            .connect_timeout(Duration::from_millis(self.timeout_ms))
            .connect()
            .await?;

        let mut client = GrpcRequestClient::new(channel);

        let client_ip = get_local_ip();
        let payload = request.to_payload(&client_ip);

        let response = timeout(
            Duration::from_millis(self.timeout_ms),
            client.request(payload),
        )
        .await
        .map_err(|_| BatataError::Timeout {
            timeout_ms: self.timeout_ms,
        })??;

        let payload = response.into_inner();

        let resp: Resp = payload
            .body
            .as_ref()
            .and_then(|body| serde_json::from_slice(&body.value).ok())
            .unwrap_or_default();

        if !resp.is_success() {
            return Err(BatataError::server_error(resp.error_code(), resp.message()));
        }

        Ok(resp)
    }
}

/// Get local IP address
fn get_local_ip() -> String {
    if let Ok(addrs) = if_addrs::get_if_addrs() {
        for iface in addrs {
            if !iface.is_loopback()
                && let std::net::IpAddr::V4(ipv4) = iface.ip()
            {
                return ipv4.to_string();
            }
        }
    }
    "127.0.0.1".to_string()
}