batata-client 0.0.2

Rust client for Batata/Nacos service discovery and configuration management
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use futures::StreamExt;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tonic::transport::Channel;
use tracing::{debug, error, info};

use crate::api::remote::{
    ConnectionSetupRequest, RequestTrait, ResponseTrait, ServerCheckRequest, ServerCheckResponse,
};
use crate::api::{
    bi_request_stream_client::BiRequestStreamClient, request_client::RequestClient, Payload,
};
use crate::common::{DEFAULT_HEARTBEAT_INTERVAL_MS, DEFAULT_TIMEOUT_MS, LABEL_APP_NAME, LABEL_SOURCE, LABEL_SOURCE_SDK};
use crate::error::{BatataError, Result};
use crate::remote::ServerAddress;

/// Connection state
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ConnectionState {
    Disconnected,
    Connecting,
    Connected,
    Reconnecting,
}

/// Server push message handler callback
pub type ServerPushHandler = Arc<dyn Fn(Payload) -> Option<Payload> + Send + Sync>;

/// gRPC connection manager for bidirectional streaming
pub struct GrpcConnection {
    server_address: ServerAddress,
    connection_id: Arc<RwLock<String>>,
    state: Arc<RwLock<ConnectionState>>,
    client_ip: String,
    namespace: String,
    app_name: String,
    labels: HashMap<String, String>,

    // Channel sender for outgoing messages
    sender: Arc<RwLock<Option<mpsc::Sender<Payload>>>>,

    // Handler for server push messages
    push_handlers: Arc<RwLock<HashMap<String, ServerPushHandler>>>,

    // Timeout settings
    timeout_ms: u64,
    #[allow(dead_code)]
    heartbeat_interval_ms: u64,
}

impl GrpcConnection {
    /// Create a new gRPC connection
    pub fn new(server_address: ServerAddress) -> Self {
        Self {
            server_address,
            connection_id: Arc::new(RwLock::new(String::new())),
            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
            client_ip: get_local_ip(),
            namespace: String::new(),
            app_name: String::new(),
            labels: HashMap::new(),
            sender: Arc::new(RwLock::new(None)),
            push_handlers: Arc::new(RwLock::new(HashMap::new())),
            timeout_ms: DEFAULT_TIMEOUT_MS,
            heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
        }
    }

    /// Set namespace
    pub fn with_namespace(mut self, namespace: &str) -> Self {
        self.namespace = namespace.to_string();
        self
    }

    /// Set app name
    pub fn with_app_name(mut self, app_name: &str) -> Self {
        self.app_name = app_name.to_string();
        self
    }

    /// Set labels
    pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
        self.labels = labels;
        self
    }

    /// Set timeout
    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
        self.timeout_ms = timeout_ms;
        self
    }

    /// Get connection ID
    pub fn connection_id(&self) -> String {
        self.connection_id.read().clone()
    }

    /// Get current state
    pub fn state(&self) -> ConnectionState {
        *self.state.read()
    }

    /// Check if connected
    pub fn is_connected(&self) -> bool {
        *self.state.read() == ConnectionState::Connected
    }

    /// Register a push handler for a specific message type
    pub fn register_push_handler(&self, message_type: &str, handler: ServerPushHandler) {
        self.push_handlers
            .write()
            .insert(message_type.to_string(), handler);
    }

    /// Connect to server
    pub async fn connect(&self) -> Result<()> {
        {
            let mut state = self.state.write();
            if *state == ConnectionState::Connected {
                return Ok(());
            }
            *state = ConnectionState::Connecting;
        }

        let endpoint = self.server_address.grpc_endpoint();
        info!("Connecting to server: {}", endpoint);

        // Create channel
        let channel = Channel::from_shared(endpoint.clone())
            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
            .connect_timeout(Duration::from_millis(self.timeout_ms))
            .connect()
            .await?;

        // First, do server check using unary RPC
        let connection_id = self.server_check(&channel).await?;
        *self.connection_id.write() = connection_id.clone();

        // Create bidirectional stream
        let mut bi_client = BiRequestStreamClient::new(channel.clone());

        // Create channel for sending messages
        let (tx, rx) = mpsc::channel::<Payload>(100);
        *self.sender.write() = Some(tx.clone());

        // Convert receiver to stream
        let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);

        // Start bidirectional stream
        let response = bi_client.request_bi_stream(outbound).await?;
        let mut inbound = response.into_inner();

        // Send connection setup request
        self.send_connection_setup(&tx).await?;

        // Spawn task to handle incoming messages
        let push_handlers = self.push_handlers.clone();
        let state = self.state.clone();
        let tx_clone = tx.clone();

        tokio::spawn(async move {
            while let Some(result) = inbound.next().await {
                match result {
                    Ok(payload) => {
                        if let Some(metadata) = &payload.metadata {
                            let msg_type = &metadata.r#type;
                            debug!("Received message type: {}", msg_type);

                            // Handle server push - clone handler before await
                            let handler_opt = {
                                let handlers = push_handlers.read();
                                handlers.get(msg_type).cloned()
                            };

                            if let Some(handler) = handler_opt
                                && let Some(response) = handler(payload)
                                && let Err(e) = tx_clone.send(response).await
                            {
                                error!("Failed to send response: {}", e);
                            }
                        }
                    }
                    Err(e) => {
                        error!("Stream error: {}", e);
                        *state.write() = ConnectionState::Disconnected;
                        break;
                    }
                }
            }
        });

        // Update state
        *self.state.write() = ConnectionState::Connected;
        info!(
            "Connected to server, connection_id: {}",
            self.connection_id()
        );

        Ok(())
    }

    /// Disconnect from server
    pub async fn disconnect(&self) {
        *self.state.write() = ConnectionState::Disconnected;
        *self.sender.write() = None;
        info!("Disconnected from server");
    }

    /// Send a request and wait for response
    pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
    where
        Req: RequestTrait + serde::Serialize,
        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
    {
        if !self.is_connected() {
            return Err(BatataError::ClientNotStarted);
        }

        let payload = request.to_payload(&self.client_ip);

        // Use unary request
        let channel = Channel::from_shared(self.server_address.grpc_endpoint())
            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
            .connect_timeout(Duration::from_millis(self.timeout_ms))
            .connect()
            .await?;

        let mut client = RequestClient::new(channel);

        let response = timeout(
            Duration::from_millis(self.timeout_ms),
            client.request(payload),
        )
        .await
        .map_err(|_| BatataError::Timeout {
            timeout_ms: self.timeout_ms,
        })??;

        let payload = response.into_inner();

        // Deserialize response
        let resp: Resp = payload
            .body
            .as_ref()
            .and_then(|body| serde_json::from_slice(&body.value).ok())
            .unwrap_or_default();

        if !resp.is_success() {
            return Err(BatataError::server_error(resp.error_code(), resp.message()));
        }

        Ok(resp)
    }

    /// Send a message through the stream (no response expected)
    pub async fn send(&self, payload: Payload) -> Result<()> {
        let sender = self.sender.read().clone();
        let sender = sender.ok_or(BatataError::ClientNotStarted)?;

        sender
            .send(payload)
            .await
            .map_err(|e| BatataError::connection_error(format!("Failed to send: {}", e)))
    }

    /// Server check using unary RPC
    async fn server_check(&self, channel: &Channel) -> Result<String> {
        let mut client = RequestClient::new(channel.clone());

        let request = ServerCheckRequest::new();
        let payload = request.to_payload(&self.client_ip);

        let response = timeout(
            Duration::from_millis(self.timeout_ms),
            client.request(payload),
        )
        .await
        .map_err(|_| BatataError::Timeout {
            timeout_ms: self.timeout_ms,
        })??;

        let payload = response.into_inner();
        let resp: ServerCheckResponse = payload
            .body
            .as_ref()
            .and_then(|body| serde_json::from_slice(&body.value).ok())
            .unwrap_or_default();

        if !resp.is_success() {
            return Err(BatataError::server_error(resp.error_code(), resp.message()));
        }

        Ok(resp.connection_id)
    }

    /// Send connection setup request
    async fn send_connection_setup(&self, sender: &mpsc::Sender<Payload>) -> Result<()> {
        let mut labels = self.labels.clone();
        labels.insert(LABEL_SOURCE.to_string(), LABEL_SOURCE_SDK.to_string());
        if !self.app_name.is_empty() {
            labels.insert(LABEL_APP_NAME.to_string(), self.app_name.clone());
        }

        let request = ConnectionSetupRequest::new()
            .with_labels(labels)
            .with_tenant(self.namespace.clone());

        let payload = request.to_payload(&self.client_ip);

        sender
            .send(payload)
            .await
            .map_err(|e| BatataError::connection_error(format!("Failed to send setup: {}", e)))
    }
}

/// Get local IP address
fn get_local_ip() -> String {
    if let Ok(addrs) = if_addrs::get_if_addrs() {
        for iface in addrs {
            if !iface.is_loopback()
                && let std::net::IpAddr::V4(ipv4) = iface.ip()
            {
                return ipv4.to_string();
            }
        }
    }
    "127.0.0.1".to_string()
}