danube_client/
connection_manager.rs1use crate::errors::{DanubeError, Result};
2
3use std::{
4 collections::{hash_map::Entry, HashMap},
5 sync::Arc,
6};
7use tokio::sync::Mutex;
8use tonic::transport::{Channel, ClientTlsConfig, Uri};
9use tracing::info;
10
11#[derive(Debug, Clone, Eq, Hash, PartialEq)]
13pub struct BrokerAddress {
14 pub connect_url: Uri,
16 pub broker_url: Uri,
18 pub proxy: bool,
20}
21
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24enum ConnectionStatus {
25 Connected(Arc<RpcConnection>),
26 Disconnected,
27}
28
29pub type TokenSupplier = Arc<dyn Fn() -> String + Send + Sync>;
33
34#[derive(Clone, Default)]
35pub(crate) struct ConnectionOptions {
36 pub(crate) tls_config: Option<ClientTlsConfig>,
37 pub(crate) token: Option<String>,
38 pub(crate) token_supplier: Option<TokenSupplier>,
39 pub(crate) internal_broker: Option<String>,
40 pub(crate) use_tls: bool,
41}
42
43impl ConnectionOptions {
44 pub(crate) fn resolve_token(&self) -> Option<String> {
47 if let Some(ref supplier) = self.token_supplier {
48 Some(supplier())
49 } else {
50 self.token.clone()
51 }
52 }
53}
54
55impl std::fmt::Debug for ConnectionOptions {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("ConnectionOptions")
58 .field("tls_config", &self.tls_config)
59 .field("token", &self.token.as_ref().map(|_| "<redacted>"))
60 .field("token_supplier", &self.token_supplier.as_ref().map(|_| "<supplier>"))
61 .field("internal_broker", &self.internal_broker)
62 .field("use_tls", &self.use_tls)
63 .finish()
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct ConnectionManager {
69 connections: Arc<Mutex<HashMap<BrokerAddress, ConnectionStatus>>>,
70 pub(crate) connection_options: ConnectionOptions,
71}
72
73impl ConnectionManager {
74 pub(crate) fn new(connection_options: ConnectionOptions) -> Self {
75 ConnectionManager {
76 connections: Arc::new(Mutex::new(HashMap::new())),
77 connection_options,
78 }
79 }
80
81 pub(crate) async fn get_connection(
82 &self,
83 broker_url: &Uri,
84 connect_url: &Uri,
85 ) -> Result<Arc<RpcConnection>> {
86 let proxy = broker_url != connect_url;
87 let broker = BrokerAddress {
88 connect_url: connect_url.clone(),
89 broker_url: broker_url.clone(),
90 proxy,
91 };
92
93 let mut cnx = self.connections.lock().await;
94
95 match cnx.entry(broker) {
96 Entry::Occupied(mut occupied_entry) => match occupied_entry.get() {
97 ConnectionStatus::Connected(rpc_cnx) => Ok(rpc_cnx.clone()),
98 ConnectionStatus::Disconnected => {
99 let new_rpc_cnx =
100 new_rpc_connection(&self.connection_options, connect_url).await?;
101 let rpc_cnx = Arc::new(new_rpc_cnx);
102 *occupied_entry.get_mut() = ConnectionStatus::Connected(rpc_cnx.clone());
103 Ok(rpc_cnx)
104 }
105 },
106 Entry::Vacant(vacant_entry) => {
107 let new_rpc_cnx = new_rpc_connection(&self.connection_options, connect_url).await?;
108 let rpc_cnx = Arc::new(new_rpc_cnx);
109 vacant_entry.insert(ConnectionStatus::Connected(rpc_cnx.clone()));
110 Ok(rpc_cnx)
111 }
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
117pub(crate) struct RpcConnection {
118 pub(crate) grpc_cnx: Channel,
119}
120
121pub(crate) async fn new_rpc_connection(
122 cnx_options: &ConnectionOptions,
123 connect_url: &Uri,
124) -> Result<RpcConnection> {
125 info!("Establishing new RPC connection to {}", connect_url);
126
127 let channel = match cnx_options.use_tls {
128 false => {
129 Channel::from_shared(connect_url.to_string())?
131 .connect()
132 .await?
133 }
134 true => {
135 let tls_config = cnx_options.tls_config.as_ref().ok_or_else(|| {
137 DanubeError::Unrecoverable(
138 "TLS is enabled but no TLS config provided. Use with_tls() or with_mtls() before enabling TLS".to_string(),
139 )
140 })?;
141
142 Channel::from_shared(connect_url.to_string())?
143 .tls_config(tls_config.clone())?
144 .connect()
145 .await?
146 }
147 };
148
149 Ok(RpcConnection { grpc_cnx: channel })
150}