cronback_lib/
grpc_client_provider.rs

1use std::collections::HashMap;
2use std::str::FromStr;
3use std::sync::RwLock;
4
5use async_trait::async_trait;
6use derive_more::{Deref, DerefMut};
7use thiserror::Error;
8use tonic::transport::{Channel, Endpoint};
9
10use crate::config::MainConfig;
11use crate::model::ValidShardedId;
12use crate::prelude::{GrpcRequestInterceptor, Shard};
13use crate::service::ServiceContext;
14use crate::types::{ProjectId, RequestId};
15
16#[derive(Debug, Error)]
17pub enum GrpcClientError {
18    #[error(transparent)]
19    Connect(#[from] tonic::transport::Error),
20    #[error("Internal data routing error: {0}")]
21    Routing(String),
22    #[error("Malformed grpc endpoint address: {0}")]
23    BadAddress(String),
24}
25
26// Wraps a raw gRPC client with the project ID and request ID, allows users to
27// access project_id and request_id at any time.
28//
29// Deref/DerefMut allow ScopedGrpcClient to be used as a T (smart pointer-like)
30#[derive(Deref, DerefMut)]
31pub struct ScopedGrpcClient<T> {
32    pub project_id: ValidShardedId<ProjectId>,
33    pub request_id: RequestId,
34    #[deref]
35    #[deref_mut]
36    inner: T,
37}
38
39impl<T> ScopedGrpcClient<T> {
40    pub fn new(
41        project_id: ValidShardedId<ProjectId>,
42        request_id: RequestId,
43        inner: T,
44    ) -> Self {
45        Self {
46            project_id,
47            request_id,
48            inner,
49        }
50    }
51}
52
53#[async_trait]
54pub trait GrpcClientType: Sync + Send {
55    type RawGrpcClient;
56
57    fn create_scoped_client(
58        project_id: ValidShardedId<ProjectId>,
59        request_id: RequestId,
60        channel: tonic::transport::Channel,
61        interceptor: GrpcRequestInterceptor,
62    ) -> Self;
63
64    fn get_mut(&mut self) -> &mut ScopedGrpcClient<Self::RawGrpcClient>;
65
66    // Concrete default implementations
67    async fn create_channel(
68        address: &str,
69    ) -> Result<tonic::transport::Channel, GrpcClientError> {
70        let channel = Endpoint::from_str(address)
71            .map_err(|_| GrpcClientError::BadAddress(address.to_string()))?
72            .connect()
73            .await?;
74        Ok(channel)
75    }
76
77    fn address_map(config: &MainConfig) -> &HashMap<u64, String>;
78
79    fn get_address(
80        config: &MainConfig,
81        _project_id: &ValidShardedId<ProjectId>,
82    ) -> Result<String, GrpcClientError> {
83        // For now, we'll assume everything is on Cell 0
84        // TODO: support multiple cells
85        let shard = Shard(0);
86
87        let address =
88            Self::address_map(config).get(&shard.0).ok_or_else(|| {
89                GrpcClientError::Routing(format!(
90                    "No endpoint was found for shard {shard} in config (grpc \
91                     client type {:?})",
92                    std::any::type_name::<Self>(),
93                ))
94            })?;
95        Ok(address.clone())
96    }
97}
98
99#[async_trait]
100pub trait GrpcClientFactory: Send + Sync {
101    type ClientType;
102    async fn get_client(
103        &self,
104        request_id: &RequestId,
105        project_id: &ValidShardedId<ProjectId>,
106    ) -> Result<Self::ClientType, GrpcClientError>;
107}
108
109// A concrete channel-caching implementation of the GrpcClientFactory used in
110// production.
111pub struct GrpcClientProvider<T> {
112    service_context: ServiceContext,
113    channel_cache: RwLock<HashMap<String, Channel>>,
114    phantom: std::marker::PhantomData<T>,
115}
116
117impl<T: GrpcClientType> GrpcClientProvider<T> {
118    pub fn new(service_context: ServiceContext) -> Self {
119        Self {
120            service_context,
121            channel_cache: Default::default(),
122            phantom: Default::default(),
123        }
124    }
125}
126
127#[async_trait]
128impl<T: GrpcClientType> GrpcClientFactory for GrpcClientProvider<T> {
129    type ClientType = T;
130
131    async fn get_client(
132        &self,
133        request_id: &RequestId,
134        project_id: &ValidShardedId<ProjectId>,
135    ) -> Result<Self::ClientType, GrpcClientError> {
136        // resolve shard -> cell
137        let address = T::get_address(
138            &self.service_context.get_config().main,
139            project_id,
140        )?;
141
142        let mut channel = None;
143        {
144            let cache = self.channel_cache.read().unwrap();
145            if let Some(ch) = cache.get(&address) {
146                channel = Some(ch.clone());
147            }
148        }
149        if channel.is_none() {
150            // We attempt to create a new channel anyway because we don't want
151            // to block the write lock during connection.
152            let temp_new_ch = T::create_channel(&address).await?;
153            {
154                // Only upgrade to a write lock if we need to create a new
155                let mut cache = self.channel_cache.write().unwrap();
156                // check again, someone might have already created the channel
157                if let Some(ch) = cache.get(&address) {
158                    channel = Some(ch.clone());
159                    // temp_new_ch dropped here
160                } else {
161                    cache.insert(address.clone(), temp_new_ch.clone());
162                    channel = Some(temp_new_ch);
163                }
164            }
165        }
166
167        assert!(channel.is_some());
168
169        let interceptor = GrpcRequestInterceptor {
170            project_id: Some(project_id.clone()),
171            request_id: Some(request_id.clone()),
172        };
173
174        Ok(T::create_scoped_client(
175            project_id.clone(),
176            request_id.clone(),
177            channel.unwrap(),
178            interceptor,
179        ))
180    }
181}
182
183pub mod test_helpers {
184    use std::sync::Arc;
185
186    use hyper::Uri;
187    use tempfile::TempPath;
188    use tokio::net::UnixStream;
189    use tower::service_fn;
190
191    use super::*;
192    // An implementation of the GrpcClientFactory used in tests that uses unix
193    // domain socket.
194    pub struct TestGrpcClientProvider<T> {
195        cell_to_socket_path: HashMap<u16, Arc<TempPath>>,
196        channel_cache: RwLock<HashMap<u16, Channel>>,
197        phantom: std::marker::PhantomData<T>,
198    }
199
200    impl<T: GrpcClientType> TestGrpcClientProvider<T> {
201        pub fn new(cell_to_socket_path: HashMap<u16, Arc<TempPath>>) -> Self {
202            Self {
203                cell_to_socket_path,
204                channel_cache: Default::default(),
205                phantom: Default::default(),
206            }
207        }
208
209        pub fn new_single_shard(socket_path: Arc<TempPath>) -> Self {
210            let mut cell_to_socket_path = HashMap::with_capacity(1);
211            cell_to_socket_path.insert(0, socket_path);
212
213            Self {
214                cell_to_socket_path,
215                channel_cache: Default::default(),
216                phantom: Default::default(),
217            }
218        }
219    }
220
221    #[async_trait]
222    impl<T: GrpcClientType> GrpcClientFactory for TestGrpcClientProvider<T> {
223        type ClientType = T;
224
225        async fn get_client(
226            &self,
227            request_id: &RequestId,
228            project_id: &ValidShardedId<ProjectId>,
229        ) -> Result<Self::ClientType, GrpcClientError> {
230            // do we have a channel in cache?
231            // TODO: support multiple cells
232            let _shard = Shard(0);
233            let cell: u16 = 0;
234
235            let socket = self
236                .cell_to_socket_path
237                .get(&cell)
238                .expect("Cell not found!")
239                .clone();
240
241            let mut channel = None;
242            {
243                let cache = self.channel_cache.read().unwrap();
244                if let Some(ch) = cache.get(&cell) {
245                    channel = Some(ch.clone());
246                }
247            }
248            if channel.is_none() {
249                // We attempt to create a new channel anyway because we don't
250                // want to block the write lock during
251                // connection.
252                // Connect to the server over a Unix socket. The URL will be
253                // ignored.
254                let temp_new_ch = Endpoint::try_from("http://example.url")
255                    .unwrap()
256                    .connect_with_connector(service_fn(move |_: Uri| {
257                        let socket = Arc::clone(&socket);
258                        async move { UnixStream::connect(&*socket).await }
259                    }))
260                    .await?;
261                {
262                    // Only upgrade to a write lock if we need to create a new
263                    let mut cache = self.channel_cache.write().unwrap();
264                    // check again, someone might have already created the
265                    // channel
266                    if let Some(ch) = cache.get(&cell) {
267                        channel = Some(ch.clone());
268                        // temp_new_ch dropped here
269                    } else {
270                        cache.insert(cell, temp_new_ch.clone());
271                        channel = Some(temp_new_ch);
272                    }
273                }
274            }
275
276            assert!(channel.is_some());
277
278            let interceptor = GrpcRequestInterceptor {
279                project_id: Some(project_id.clone()),
280                request_id: Some(request_id.clone()),
281            };
282
283            Ok(T::create_scoped_client(
284                project_id.clone(),
285                request_id.clone(),
286                channel.unwrap(),
287                interceptor,
288            ))
289        }
290    }
291}