orlando_cluster/
connection_pool.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::error::ClusterError;
7use crate::proto::cluster_gateway_client::ClusterGatewayClient;
8use crate::proto::grain_transport_client::GrainTransportClient;
9use crate::proto::membership_client::MembershipClient;
10
11const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
13
14pub struct ConnectionPool {
15 transports: DashMap<String, GrainTransportClient<tonic::transport::Channel>>,
16 memberships: DashMap<String, MembershipClient<tonic::transport::Channel>>,
17 gateways: DashMap<String, ClusterGatewayClient<tonic::transport::Channel>>,
18 tls_config: Option<tonic::transport::ClientTlsConfig>,
19 auth_token: Option<String>,
20}
21
22impl Default for ConnectionPool {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl ConnectionPool {
29 pub fn new() -> Self {
30 Self {
31 transports: DashMap::new(),
32 memberships: DashMap::new(),
33 gateways: DashMap::new(),
34 tls_config: None,
35 auth_token: None,
36 }
37 }
38
39 pub fn with_tls(tls: tonic::transport::ClientTlsConfig) -> Self {
40 Self {
41 tls_config: Some(tls),
42 ..Self::new()
43 }
44 }
45
46 pub fn with_tls_and_auth(tls: tonic::transport::ClientTlsConfig, token: String) -> Self {
47 Self {
48 tls_config: Some(tls),
49 auth_token: Some(token),
50 ..Self::new()
51 }
52 }
53
54 pub fn with_auth(token: String) -> Self {
55 Self {
56 auth_token: Some(token),
57 ..Self::new()
58 }
59 }
60
61 pub fn authorized_request<T>(&self, inner: T) -> tonic::Request<T> {
63 let mut request = tonic::Request::new(inner);
64 if let Some(ref token) = self.auth_token
65 && let Ok(value) = token.parse()
66 {
67 request.metadata_mut().insert("authorization", value);
68 }
69 request
70 }
71
72 async fn connect_channel(&self, endpoint: &str) -> Result<tonic::transport::Channel, ClusterError> {
73 let uri = if self.tls_config.is_some() {
74 format!("https://{}", endpoint)
75 } else {
76 format!("http://{}", endpoint)
77 };
78 let mut ep = tonic::transport::Endpoint::from_shared(uri)
79 .map_err(|e| ClusterError::Transport(e.to_string()))?
80 .connect_timeout(CONNECT_TIMEOUT);
81
82 if let Some(ref tls) = self.tls_config {
83 ep = ep
84 .tls_config(tls.clone())
85 .map_err(|e| ClusterError::Transport(e.to_string()))?;
86 }
87
88 ep.connect()
89 .await
90 .map_err(|e| ClusterError::Transport(e.to_string()))
91 }
92
93 pub async fn get_transport(
94 self: &Arc<Self>,
95 endpoint: &str,
96 ) -> Result<GrainTransportClient<tonic::transport::Channel>, ClusterError> {
97 if let Some(client) = self.transports.get(endpoint) {
98 return Ok(client.clone());
99 }
100
101 let channel = self.connect_channel(endpoint).await?;
102 let client = GrainTransportClient::new(channel);
103
104 self.transports
105 .insert(endpoint.to_string(), client.clone());
106 Ok(client)
107 }
108
109 pub async fn get_membership(
110 self: &Arc<Self>,
111 endpoint: &str,
112 ) -> Result<MembershipClient<tonic::transport::Channel>, ClusterError> {
113 if let Some(client) = self.memberships.get(endpoint) {
114 return Ok(client.clone());
115 }
116
117 let channel = self.connect_channel(endpoint).await?;
118 let client = MembershipClient::new(channel);
119
120 self.memberships
121 .insert(endpoint.to_string(), client.clone());
122 Ok(client)
123 }
124
125 pub async fn get_gateway(
126 self: &Arc<Self>,
127 endpoint: &str,
128 ) -> Result<ClusterGatewayClient<tonic::transport::Channel>, ClusterError> {
129 if let Some(client) = self.gateways.get(endpoint) {
130 return Ok(client.clone());
131 }
132
133 let channel = self.connect_channel(endpoint).await?;
134 let client = ClusterGatewayClient::new(channel);
135
136 self.gateways
137 .insert(endpoint.to_string(), client.clone());
138 Ok(client)
139 }
140
141 pub fn remove(&self, endpoint: &str) {
150 self.transports.remove(endpoint);
151 self.memberships.remove(endpoint);
152 self.gateways.remove(endpoint);
153 }
154}