cronback_lib/
grpc_client_provider.rs1use std::collections::HashMap;
2use std::str::FromStr;
3use std::sync::RwLock;
4
5use async_trait::async_trait;
6use derive_more::{Deref, DerefMut};
7use thiserror::Error;
8use tonic::transport::{Channel, Endpoint};
9
10use crate::config::MainConfig;
11use crate::model::ValidShardedId;
12use crate::prelude::{GrpcRequestInterceptor, Shard};
13use crate::service::ServiceContext;
14use crate::types::{ProjectId, RequestId};
15
16#[derive(Debug, Error)]
17pub enum GrpcClientError {
18 #[error(transparent)]
19 Connect(#[from] tonic::transport::Error),
20 #[error("Internal data routing error: {0}")]
21 Routing(String),
22 #[error("Malformed grpc endpoint address: {0}")]
23 BadAddress(String),
24}
25
26#[derive(Deref, DerefMut)]
31pub struct ScopedGrpcClient<T> {
32 pub project_id: ValidShardedId<ProjectId>,
33 pub request_id: RequestId,
34 #[deref]
35 #[deref_mut]
36 inner: T,
37}
38
39impl<T> ScopedGrpcClient<T> {
40 pub fn new(
41 project_id: ValidShardedId<ProjectId>,
42 request_id: RequestId,
43 inner: T,
44 ) -> Self {
45 Self {
46 project_id,
47 request_id,
48 inner,
49 }
50 }
51}
52
53#[async_trait]
54pub trait GrpcClientType: Sync + Send {
55 type RawGrpcClient;
56
57 fn create_scoped_client(
58 project_id: ValidShardedId<ProjectId>,
59 request_id: RequestId,
60 channel: tonic::transport::Channel,
61 interceptor: GrpcRequestInterceptor,
62 ) -> Self;
63
64 fn get_mut(&mut self) -> &mut ScopedGrpcClient<Self::RawGrpcClient>;
65
66 async fn create_channel(
68 address: &str,
69 ) -> Result<tonic::transport::Channel, GrpcClientError> {
70 let channel = Endpoint::from_str(address)
71 .map_err(|_| GrpcClientError::BadAddress(address.to_string()))?
72 .connect()
73 .await?;
74 Ok(channel)
75 }
76
77 fn address_map(config: &MainConfig) -> &HashMap<u64, String>;
78
79 fn get_address(
80 config: &MainConfig,
81 _project_id: &ValidShardedId<ProjectId>,
82 ) -> Result<String, GrpcClientError> {
83 let shard = Shard(0);
86
87 let address =
88 Self::address_map(config).get(&shard.0).ok_or_else(|| {
89 GrpcClientError::Routing(format!(
90 "No endpoint was found for shard {shard} in config (grpc \
91 client type {:?})",
92 std::any::type_name::<Self>(),
93 ))
94 })?;
95 Ok(address.clone())
96 }
97}
98
99#[async_trait]
100pub trait GrpcClientFactory: Send + Sync {
101 type ClientType;
102 async fn get_client(
103 &self,
104 request_id: &RequestId,
105 project_id: &ValidShardedId<ProjectId>,
106 ) -> Result<Self::ClientType, GrpcClientError>;
107}
108
109pub struct GrpcClientProvider<T> {
112 service_context: ServiceContext,
113 channel_cache: RwLock<HashMap<String, Channel>>,
114 phantom: std::marker::PhantomData<T>,
115}
116
117impl<T: GrpcClientType> GrpcClientProvider<T> {
118 pub fn new(service_context: ServiceContext) -> Self {
119 Self {
120 service_context,
121 channel_cache: Default::default(),
122 phantom: Default::default(),
123 }
124 }
125}
126
127#[async_trait]
128impl<T: GrpcClientType> GrpcClientFactory for GrpcClientProvider<T> {
129 type ClientType = T;
130
131 async fn get_client(
132 &self,
133 request_id: &RequestId,
134 project_id: &ValidShardedId<ProjectId>,
135 ) -> Result<Self::ClientType, GrpcClientError> {
136 let address = T::get_address(
138 &self.service_context.get_config().main,
139 project_id,
140 )?;
141
142 let mut channel = None;
143 {
144 let cache = self.channel_cache.read().unwrap();
145 if let Some(ch) = cache.get(&address) {
146 channel = Some(ch.clone());
147 }
148 }
149 if channel.is_none() {
150 let temp_new_ch = T::create_channel(&address).await?;
153 {
154 let mut cache = self.channel_cache.write().unwrap();
156 if let Some(ch) = cache.get(&address) {
158 channel = Some(ch.clone());
159 } else {
161 cache.insert(address.clone(), temp_new_ch.clone());
162 channel = Some(temp_new_ch);
163 }
164 }
165 }
166
167 assert!(channel.is_some());
168
169 let interceptor = GrpcRequestInterceptor {
170 project_id: Some(project_id.clone()),
171 request_id: Some(request_id.clone()),
172 };
173
174 Ok(T::create_scoped_client(
175 project_id.clone(),
176 request_id.clone(),
177 channel.unwrap(),
178 interceptor,
179 ))
180 }
181}
182
183pub mod test_helpers {
184 use std::sync::Arc;
185
186 use hyper::Uri;
187 use tempfile::TempPath;
188 use tokio::net::UnixStream;
189 use tower::service_fn;
190
191 use super::*;
192 pub struct TestGrpcClientProvider<T> {
195 cell_to_socket_path: HashMap<u16, Arc<TempPath>>,
196 channel_cache: RwLock<HashMap<u16, Channel>>,
197 phantom: std::marker::PhantomData<T>,
198 }
199
200 impl<T: GrpcClientType> TestGrpcClientProvider<T> {
201 pub fn new(cell_to_socket_path: HashMap<u16, Arc<TempPath>>) -> Self {
202 Self {
203 cell_to_socket_path,
204 channel_cache: Default::default(),
205 phantom: Default::default(),
206 }
207 }
208
209 pub fn new_single_shard(socket_path: Arc<TempPath>) -> Self {
210 let mut cell_to_socket_path = HashMap::with_capacity(1);
211 cell_to_socket_path.insert(0, socket_path);
212
213 Self {
214 cell_to_socket_path,
215 channel_cache: Default::default(),
216 phantom: Default::default(),
217 }
218 }
219 }
220
221 #[async_trait]
222 impl<T: GrpcClientType> GrpcClientFactory for TestGrpcClientProvider<T> {
223 type ClientType = T;
224
225 async fn get_client(
226 &self,
227 request_id: &RequestId,
228 project_id: &ValidShardedId<ProjectId>,
229 ) -> Result<Self::ClientType, GrpcClientError> {
230 let _shard = Shard(0);
233 let cell: u16 = 0;
234
235 let socket = self
236 .cell_to_socket_path
237 .get(&cell)
238 .expect("Cell not found!")
239 .clone();
240
241 let mut channel = None;
242 {
243 let cache = self.channel_cache.read().unwrap();
244 if let Some(ch) = cache.get(&cell) {
245 channel = Some(ch.clone());
246 }
247 }
248 if channel.is_none() {
249 let temp_new_ch = Endpoint::try_from("http://example.url")
255 .unwrap()
256 .connect_with_connector(service_fn(move |_: Uri| {
257 let socket = Arc::clone(&socket);
258 async move { UnixStream::connect(&*socket).await }
259 }))
260 .await?;
261 {
262 let mut cache = self.channel_cache.write().unwrap();
264 if let Some(ch) = cache.get(&cell) {
267 channel = Some(ch.clone());
268 } else {
270 cache.insert(cell, temp_new_ch.clone());
271 channel = Some(temp_new_ch);
272 }
273 }
274 }
275
276 assert!(channel.is_some());
277
278 let interceptor = GrpcRequestInterceptor {
279 project_id: Some(project_id.clone()),
280 request_id: Some(request_id.clone()),
281 };
282
283 Ok(T::create_scoped_client(
284 project_id.clone(),
285 request_id.clone(),
286 channel.unwrap(),
287 interceptor,
288 ))
289 }
290 }
291}