Skip to main content

orlando_cluster/
connection_pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::error::ClusterError;
7use crate::proto::cluster_gateway_client::ClusterGatewayClient;
8use crate::proto::grain_transport_client::GrainTransportClient;
9use crate::proto::membership_client::MembershipClient;
10
11/// Timeout for establishing new gRPC connections.
12const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
13
14pub struct ConnectionPool {
15    transports: DashMap<String, GrainTransportClient<tonic::transport::Channel>>,
16    memberships: DashMap<String, MembershipClient<tonic::transport::Channel>>,
17    gateways: DashMap<String, ClusterGatewayClient<tonic::transport::Channel>>,
18    tls_config: Option<tonic::transport::ClientTlsConfig>,
19    auth_token: Option<String>,
20}
21
22impl Default for ConnectionPool {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl ConnectionPool {
29    pub fn new() -> Self {
30        Self {
31            transports: DashMap::new(),
32            memberships: DashMap::new(),
33            gateways: DashMap::new(),
34            tls_config: None,
35            auth_token: None,
36        }
37    }
38
39    pub fn with_tls(tls: tonic::transport::ClientTlsConfig) -> Self {
40        Self {
41            tls_config: Some(tls),
42            ..Self::new()
43        }
44    }
45
46    pub fn with_tls_and_auth(tls: tonic::transport::ClientTlsConfig, token: String) -> Self {
47        Self {
48            tls_config: Some(tls),
49            auth_token: Some(token),
50            ..Self::new()
51        }
52    }
53
54    pub fn with_auth(token: String) -> Self {
55        Self {
56            auth_token: Some(token),
57            ..Self::new()
58        }
59    }
60
61    /// Wrap a value in a `tonic::Request` with the auth token attached (if configured).
62    pub fn authorized_request<T>(&self, inner: T) -> tonic::Request<T> {
63        let mut request = tonic::Request::new(inner);
64        if let Some(ref token) = self.auth_token
65            && let Ok(value) = token.parse()
66        {
67            request.metadata_mut().insert("authorization", value);
68        }
69        request
70    }
71
72    async fn connect_channel(&self, endpoint: &str) -> Result<tonic::transport::Channel, ClusterError> {
73        let uri = if self.tls_config.is_some() {
74            format!("https://{}", endpoint)
75        } else {
76            format!("http://{}", endpoint)
77        };
78        let mut ep = tonic::transport::Endpoint::from_shared(uri)
79            .map_err(|e| ClusterError::Transport(e.to_string()))?
80            .connect_timeout(CONNECT_TIMEOUT);
81
82        if let Some(ref tls) = self.tls_config {
83            ep = ep
84                .tls_config(tls.clone())
85                .map_err(|e| ClusterError::Transport(e.to_string()))?;
86        }
87
88        ep.connect()
89            .await
90            .map_err(|e| ClusterError::Transport(e.to_string()))
91    }
92
93    pub async fn get_transport(
94        self: &Arc<Self>,
95        endpoint: &str,
96    ) -> Result<GrainTransportClient<tonic::transport::Channel>, ClusterError> {
97        if let Some(client) = self.transports.get(endpoint) {
98            return Ok(client.clone());
99        }
100
101        let channel = self.connect_channel(endpoint).await?;
102        let client = GrainTransportClient::new(channel);
103
104        self.transports
105            .insert(endpoint.to_string(), client.clone());
106        Ok(client)
107    }
108
109    pub async fn get_membership(
110        self: &Arc<Self>,
111        endpoint: &str,
112    ) -> Result<MembershipClient<tonic::transport::Channel>, ClusterError> {
113        if let Some(client) = self.memberships.get(endpoint) {
114            return Ok(client.clone());
115        }
116
117        let channel = self.connect_channel(endpoint).await?;
118        let client = MembershipClient::new(channel);
119
120        self.memberships
121            .insert(endpoint.to_string(), client.clone());
122        Ok(client)
123    }
124
125    pub async fn get_gateway(
126        self: &Arc<Self>,
127        endpoint: &str,
128    ) -> Result<ClusterGatewayClient<tonic::transport::Channel>, ClusterError> {
129        if let Some(client) = self.gateways.get(endpoint) {
130            return Ok(client.clone());
131        }
132
133        let channel = self.connect_channel(endpoint).await?;
134        let client = ClusterGatewayClient::new(channel);
135
136        self.gateways
137            .insert(endpoint.to_string(), client.clone());
138        Ok(client)
139    }
140
141    /// Remove cached connections for an endpoint.
142    ///
143    /// Called automatically when SWIM declares a member dead. Callers should
144    /// also call this after persistent connection errors to force reconnection.
145    ///
146    /// Note: tonic `Channel` handles reconnection internally for transient
147    /// failures — this method is for permanent removals (dead silos) or
148    /// forcing a fresh connection after repeated errors.
149    pub fn remove(&self, endpoint: &str) {
150        self.transports.remove(endpoint);
151        self.memberships.remove(endpoint);
152        self.gateways.remove(endpoint);
153    }
154}