Skip to main content

danube_client/
connection_manager.rs

1use crate::errors::{DanubeError, Result};
2
3use std::{
4    collections::{hash_map::Entry, HashMap},
5    sync::Arc,
6};
7use tokio::sync::Mutex;
8use tonic::transport::{Channel, ClientTlsConfig, Uri};
9use tracing::info;
10
11/// holds connection information for a broker
12#[derive(Debug, Clone, Eq, Hash, PartialEq)]
13pub struct BrokerAddress {
14    /// URL we're using for connection (can be the proxy's URL)
15    pub connect_url: Uri,
16    /// Danube URL for the broker we're actually contacting
17    pub broker_url: Uri,
18    /// true if the connection is through a proxy
19    pub proxy: bool,
20}
21
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24enum ConnectionStatus {
25    Connected(Arc<RpcConnection>),
26    Disconnected,
27}
28
29/// A function that returns a token string, called on every request.
30/// This enables dynamic token refresh (e.g., reading from a file that is
31/// periodically updated by infrastructure like K8s projected volumes).
32pub type TokenSupplier = Arc<dyn Fn() -> String + Send + Sync>;
33
34#[derive(Clone, Default)]
35pub(crate) struct ConnectionOptions {
36    pub(crate) tls_config: Option<ClientTlsConfig>,
37    pub(crate) token: Option<String>,
38    pub(crate) token_supplier: Option<TokenSupplier>,
39    pub(crate) internal_broker: Option<String>,
40    pub(crate) use_tls: bool,
41}
42
43impl ConnectionOptions {
44    /// Resolve the current token. If a supplier is set, calls it to get a fresh
45    /// token (enabling runtime rotation). Otherwise falls back to the static token.
46    pub(crate) fn resolve_token(&self) -> Option<String> {
47        if let Some(ref supplier) = self.token_supplier {
48            Some(supplier())
49        } else {
50            self.token.clone()
51        }
52    }
53}
54
55impl std::fmt::Debug for ConnectionOptions {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("ConnectionOptions")
58            .field("tls_config", &self.tls_config)
59            .field("token", &self.token.as_ref().map(|_| "<redacted>"))
60            .field("token_supplier", &self.token_supplier.as_ref().map(|_| "<supplier>"))
61            .field("internal_broker", &self.internal_broker)
62            .field("use_tls", &self.use_tls)
63            .finish()
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct ConnectionManager {
69    connections: Arc<Mutex<HashMap<BrokerAddress, ConnectionStatus>>>,
70    pub(crate) connection_options: ConnectionOptions,
71}
72
73impl ConnectionManager {
74    pub(crate) fn new(connection_options: ConnectionOptions) -> Self {
75        ConnectionManager {
76            connections: Arc::new(Mutex::new(HashMap::new())),
77            connection_options,
78        }
79    }
80
81    pub(crate) async fn get_connection(
82        &self,
83        broker_url: &Uri,
84        connect_url: &Uri,
85    ) -> Result<Arc<RpcConnection>> {
86        let proxy = broker_url != connect_url;
87        let broker = BrokerAddress {
88            connect_url: connect_url.clone(),
89            broker_url: broker_url.clone(),
90            proxy,
91        };
92
93        let mut cnx = self.connections.lock().await;
94
95        match cnx.entry(broker) {
96            Entry::Occupied(mut occupied_entry) => match occupied_entry.get() {
97                ConnectionStatus::Connected(rpc_cnx) => Ok(rpc_cnx.clone()),
98                ConnectionStatus::Disconnected => {
99                    let new_rpc_cnx =
100                        new_rpc_connection(&self.connection_options, connect_url).await?;
101                    let rpc_cnx = Arc::new(new_rpc_cnx);
102                    *occupied_entry.get_mut() = ConnectionStatus::Connected(rpc_cnx.clone());
103                    Ok(rpc_cnx)
104                }
105            },
106            Entry::Vacant(vacant_entry) => {
107                let new_rpc_cnx = new_rpc_connection(&self.connection_options, connect_url).await?;
108                let rpc_cnx = Arc::new(new_rpc_cnx);
109                vacant_entry.insert(ConnectionStatus::Connected(rpc_cnx.clone()));
110                Ok(rpc_cnx)
111            }
112        }
113    }
114}
115
116#[derive(Debug, Clone)]
117pub(crate) struct RpcConnection {
118    pub(crate) grpc_cnx: Channel,
119}
120
121pub(crate) async fn new_rpc_connection(
122    cnx_options: &ConnectionOptions,
123    connect_url: &Uri,
124) -> Result<RpcConnection> {
125    info!("Establishing new RPC connection to {}", connect_url);
126
127    let channel = match cnx_options.use_tls {
128        false => {
129            // Plain TCP connection
130            Channel::from_shared(connect_url.to_string())?
131                .connect()
132                .await?
133        }
134        true => {
135            // TLS is enabled, tls_config must be present
136            let tls_config = cnx_options.tls_config.as_ref().ok_or_else(|| {
137                DanubeError::Unrecoverable(
138                    "TLS is enabled but no TLS config provided. Use with_tls() or with_mtls() before enabling TLS".to_string(),
139                )
140            })?;
141
142            Channel::from_shared(connect_url.to_string())?
143                .tls_config(tls_config.clone())?
144                .connect()
145                .await?
146        }
147    };
148
149    Ok(RpcConnection { grpc_cnx: channel })
150}