danube-client 0.12.0

The async client for Danube Messaging Broker platform
Documentation
use crate::errors::{DanubeError, Result};

use std::{
    collections::{hash_map::Entry, HashMap},
    sync::Arc,
};
use tokio::sync::Mutex;
use tonic::transport::{Channel, ClientTlsConfig, Uri};
use tracing::info;

/// holds connection information for a broker
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct BrokerAddress {
    /// URL we're using for connection (can be the proxy's URL)
    pub connect_url: Uri,
    /// Danube URL for the broker we're actually contacting
    pub broker_url: Uri,
    /// true if the connection is through a proxy
    pub proxy: bool,
}

#[derive(Debug, Clone)]
#[allow(dead_code)]
enum ConnectionStatus {
    Connected(Arc<RpcConnection>),
    Disconnected,
}

/// A function that returns a token string, called on every request.
/// This enables dynamic token refresh (e.g., reading from a file that is
/// periodically updated by infrastructure like K8s projected volumes).
pub type TokenSupplier = Arc<dyn Fn() -> String + Send + Sync>;

#[derive(Clone, Default)]
pub(crate) struct ConnectionOptions {
    pub(crate) tls_config: Option<ClientTlsConfig>,
    pub(crate) token: Option<String>,
    pub(crate) token_supplier: Option<TokenSupplier>,
    pub(crate) internal_broker: Option<String>,
    pub(crate) use_tls: bool,
}

impl ConnectionOptions {
    /// Resolve the current token. If a supplier is set, calls it to get a fresh
    /// token (enabling runtime rotation). Otherwise falls back to the static token.
    pub(crate) fn resolve_token(&self) -> Option<String> {
        if let Some(ref supplier) = self.token_supplier {
            Some(supplier())
        } else {
            self.token.clone()
        }
    }
}

impl std::fmt::Debug for ConnectionOptions {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ConnectionOptions")
            .field("tls_config", &self.tls_config)
            .field("token", &self.token.as_ref().map(|_| "<redacted>"))
            .field("token_supplier", &self.token_supplier.as_ref().map(|_| "<supplier>"))
            .field("internal_broker", &self.internal_broker)
            .field("use_tls", &self.use_tls)
            .finish()
    }
}

#[derive(Debug, Clone)]
pub struct ConnectionManager {
    connections: Arc<Mutex<HashMap<BrokerAddress, ConnectionStatus>>>,
    pub(crate) connection_options: ConnectionOptions,
}

impl ConnectionManager {
    pub(crate) fn new(connection_options: ConnectionOptions) -> Self {
        ConnectionManager {
            connections: Arc::new(Mutex::new(HashMap::new())),
            connection_options,
        }
    }

    pub(crate) async fn get_connection(
        &self,
        broker_url: &Uri,
        connect_url: &Uri,
    ) -> Result<Arc<RpcConnection>> {
        let proxy = broker_url != connect_url;
        let broker = BrokerAddress {
            connect_url: connect_url.clone(),
            broker_url: broker_url.clone(),
            proxy,
        };

        let mut cnx = self.connections.lock().await;

        match cnx.entry(broker) {
            Entry::Occupied(mut occupied_entry) => match occupied_entry.get() {
                ConnectionStatus::Connected(rpc_cnx) => Ok(rpc_cnx.clone()),
                ConnectionStatus::Disconnected => {
                    let new_rpc_cnx =
                        new_rpc_connection(&self.connection_options, connect_url).await?;
                    let rpc_cnx = Arc::new(new_rpc_cnx);
                    *occupied_entry.get_mut() = ConnectionStatus::Connected(rpc_cnx.clone());
                    Ok(rpc_cnx)
                }
            },
            Entry::Vacant(vacant_entry) => {
                let new_rpc_cnx = new_rpc_connection(&self.connection_options, connect_url).await?;
                let rpc_cnx = Arc::new(new_rpc_cnx);
                vacant_entry.insert(ConnectionStatus::Connected(rpc_cnx.clone()));
                Ok(rpc_cnx)
            }
        }
    }
}

#[derive(Debug, Clone)]
pub(crate) struct RpcConnection {
    pub(crate) grpc_cnx: Channel,
}

pub(crate) async fn new_rpc_connection(
    cnx_options: &ConnectionOptions,
    connect_url: &Uri,
) -> Result<RpcConnection> {
    info!("Establishing new RPC connection to {}", connect_url);

    let channel = match cnx_options.use_tls {
        false => {
            // Plain TCP connection
            Channel::from_shared(connect_url.to_string())?
                .connect()
                .await?
        }
        true => {
            // TLS is enabled, tls_config must be present
            let tls_config = cnx_options.tls_config.as_ref().ok_or_else(|| {
                DanubeError::Unrecoverable(
                    "TLS is enabled but no TLS config provided. Use with_tls() or with_mtls() before enabling TLS".to_string(),
                )
            })?;

            Channel::from_shared(connect_url.to_string())?
                .tls_config(tls_config.clone())?
                .connect()
                .await?
        }
    };

    Ok(RpcConnection { grpc_cnx: channel })
}