1mod builder;
10mod pool;
11pub mod tower_services;
12
13use async_stream::stream;
14pub use builder::ClientBuilder;
15pub use pool::ClientConnectionManager;
16pub use pool::Pool;
17pub use pool::PooledConnection;
18
19use dashmap::mapref::one::RefMut;
20use dashmap::DashMap;
21use futures::SinkExt;
22use futures::{Stream, StreamExt};
23use lru::LruCache;
24use rand::rng;
25use rand::seq::IndexedRandom;
26use serde::de::DeserializeOwned;
27use serde::Serialize;
28use std::collections::HashSet;
29use std::marker::PhantomData;
30use std::num::NonZeroUsize;
31use std::sync::{Arc, RwLock};
32use tokio::net::TcpStream;
33use tokio::time::timeout;
34use tokio_util::codec::{Framed, LengthDelimitedCodec};
35use tower::Service as TowerService;
36
37use crate::cluster::storage::MembershipStorage;
38use crate::protocol::pubsub::{SubscriptionRequest, SubscriptionResponse};
39use crate::protocol::{ClientError, RequestEnvelope, RequestError, ResponseError};
40use crate::registry::IdentifiableType;
41
42pub const DEFAULT_TIMEOUT_MILLIS: u64 = 500;
43
44#[derive(Clone)]
48pub struct Client<S> {
49 timeout_millis: u64,
50
51 membership_storage: S,
53
54 active_servers: HashSet<String>,
56
57 ts_active_servers_refresh: u64,
59
60 streams: Arc<DashMap<String, Framed<TcpStream, LengthDelimitedCodec>>>,
62
63 placement: Arc<RwLock<LruCache<(String, String), String>>>,
65}
66
67pub struct SubscriptionStream<T>
69where
70 T: DeserializeOwned,
71{
72 pub tcp_stream: Framed<TcpStream, LengthDelimitedCodec>,
74 _phantom: PhantomData<T>,
75}
76
77impl<T> SubscriptionStream<T>
78where
79 T: DeserializeOwned,
80{
81 pub fn new(tcp_stream: Framed<TcpStream, LengthDelimitedCodec>) -> Self {
82 SubscriptionStream {
83 tcp_stream,
84 _phantom: PhantomData {},
85 }
86 }
87}
88
89impl<T> Stream for SubscriptionStream<T>
93where
94 T: DeserializeOwned + std::marker::Unpin + std::fmt::Debug,
95{
96 type Item = Result<T, ResponseError>;
97 fn poll_next(
98 self: std::pin::Pin<&mut Self>,
99 cx: &mut std::task::Context<'_>,
100 ) -> std::task::Poll<Option<Self::Item>> {
101 let self_mut = self.get_mut();
102 self_mut.tcp_stream.poll_next_unpin(cx).map(|maybe_bytes| {
103 let bytes_result = maybe_bytes?;
104 let bytes_message = match bytes_result {
105 Ok(bytes_message) => bytes_message,
106 Err(err) => {
107 return Some(Err(ResponseError::DeseralizationError(err.to_string())));
108 }
109 };
110
111 let sub_response: SubscriptionResponse = match bincode::deserialize(&bytes_message) {
112 Ok(sub_response) => sub_response,
113 Err(err) => return Some(Err(ResponseError::DeseralizationError(err.to_string()))),
114 };
115
116 let final_message = match sub_response.body {
117 Ok(v) => {
118 let response: Result<T, _> = bincode::deserialize(&v)
119 .map_err(|e| ResponseError::DeseralizationError(e.to_string()));
120 response
121 }
122 Err(err) => Err(err),
123 };
124 Some(final_message)
125 })
126 }
127}
128
129type ClientResult<T> = Result<T, ClientError>;
130
131impl<S> Client<S>
132where
133 S: 'static + MembershipStorage,
134{
135 pub fn new(members_storage: S) -> Self {
137 let lru_limit = NonZeroUsize::new(1_000).expect("LruCache limit must be greater than 0");
138
139 Client {
140 membership_storage: members_storage,
141 timeout_millis: DEFAULT_TIMEOUT_MILLIS,
142 active_servers: Default::default(),
143 ts_active_servers_refresh: 0,
144 streams: Arc::default(),
145 placement: Arc::new(RwLock::new(LruCache::new(lru_limit))),
146 }
147 }
148
149 async fn fetch_active_servers(&mut self) -> ClientResult<()> {
154 if !self.active_servers.is_empty() && self.ts_active_servers_refresh > 0 {
157 return Ok(());
158 }
159
160 let active_servers: HashSet<String> = self
161 .membership_storage
162 .active_members()
163 .await
164 .map_err(|_| ClientError::RendevouzUnavailable)?
165 .iter()
166 .map(|member| member.address())
167 .collect();
168
169 self.active_servers = active_servers;
170 self.ts_active_servers_refresh = 1;
171 Ok(())
172 }
173
174 async fn ensure_stream_exists(&mut self, address: &str) -> ClientResult<()> {
175 self.fetch_active_servers().await?;
176
177 if self.active_servers.is_empty() {
180 return Err(ClientError::NoServersAvailable);
181 }
182
183 if !self.active_servers.contains(address) {
186 self.ts_active_servers_refresh = 0;
187 self.fetch_active_servers().await?;
188 }
189
190 if !self.active_servers.contains(address) {
193 return Err(ClientError::ServerNotAvailable(address.to_string()));
194 }
195
196 if self.streams.get(address).is_none() {
199 let stream = TcpStream::connect(&address)
200 .await
201 .map_err(|_| ClientError::Disconnect)?;
202 let stream = Framed::new(stream, LengthDelimitedCodec::new());
203 self.streams.insert(address.to_string(), stream);
204 };
205 Ok(())
206 }
207
208 async fn server_stream(
213 &mut self,
214 address: &String,
215 ) -> ClientResult<RefMut<'_, String, Framed<TcpStream, LengthDelimitedCodec>>> {
216 self.ensure_stream_exists(address).await?;
217 self.streams
218 .get_mut(address)
219 .ok_or(ClientError::Connectivity)
220 }
221
222 async fn pop_server_stream(
224 &mut self,
225 address: &String,
226 ) -> ClientResult<Framed<TcpStream, LengthDelimitedCodec>> {
227 self.ensure_stream_exists(address).await?;
228 self.streams
229 .remove(address)
230 .map(|(_, v)| v)
231 .ok_or(ClientError::Connectivity)
232 }
233
234 async fn get_service_object_address(
236 &mut self,
237 service_object_type: impl ToString,
238 service_object_id: impl ToString,
239 ) -> ClientResult<String> {
240 self.fetch_active_servers().await?;
241 let object_id = (
242 service_object_id.to_string(),
243 service_object_type.to_string(),
244 );
245 let address = {
246 let mut placement_guard = self
247 .placement
248 .write()
249 .map_err(|_| ClientError::PlacementLock)?;
250
251 let cached_address = placement_guard.get(&object_id);
252 match cached_address {
253 Some(address) => address.clone(),
254 None => {
255 let mut rng = rng();
258 let servers: Vec<String> = self.active_servers.iter().cloned().collect();
259 let random_server = servers
260 .choose(&mut rng)
261 .ok_or(ClientError::NoServersAvailable)?;
262 random_server.clone()
263 }
264 }
265 };
266 Ok(address)
267 }
268
269 async fn service_object_stream(
271 &mut self,
272 service_object_type: impl ToString,
273 service_object_id: impl ToString,
274 ) -> ClientResult<RefMut<'_, String, Framed<TcpStream, LengthDelimitedCodec>>> {
275 self.fetch_active_servers().await?;
276 let address = self
277 .get_service_object_address(service_object_type, service_object_id)
278 .await?;
279 self.server_stream(&address).await
280 }
281
282 pub async fn send<T, E>(
293 &mut self,
294 handler_type: impl AsRef<str>,
295 handler_id: impl AsRef<str>,
296 payload: &(impl Serialize + IdentifiableType + Send + Sync),
297 ) -> Result<T, RequestError<E>>
298 where
299 T: DeserializeOwned,
300 E: std::error::Error + DeserializeOwned + Clone + Send + Sync,
301 {
302 self.fetch_active_servers().await?;
304
305 let handler_type = handler_type.as_ref().to_string();
306 let handler_id = handler_id.as_ref().to_string();
307 let ser_payload = bincode::serialize(&payload)
308 .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
309 let message_type = payload.instance_type_id().to_string();
310
311 let request = RequestEnvelope::new(
312 handler_type.clone(),
313 handler_id.clone(),
314 message_type.clone(),
315 ser_payload.clone(),
316 );
317 let tower_svc = tower_services::Request::new(self.clone());
318 let mut tower_svc = tower_services::RequestRedirect::new(tower_svc);
319 let response = tower_svc.call(request).await;
320 response.and_then(|x| {
321 let body: T = bincode::deserialize(&x)
322 .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
323 Ok(body)
324 })
325 }
326
327 pub async fn send_request<E: std::error::Error + DeserializeOwned + Clone + Send + Sync>(
329 &mut self,
330 request: RequestEnvelope,
331 ) -> Result<Vec<u8>, RequestError<E>> {
332 self.fetch_active_servers().await?;
334
335 let tower_svc = tower_services::Request::new(self.clone());
336 let mut tower_svc = tower_services::RequestRedirect::new(tower_svc);
337 let response = tower_svc.call(request).await?;
338 Ok(response)
339 }
340
341 async fn _subscribe<'a, T>(
342 &'a mut self,
343 handler_type: &str,
344 handler_id: &str,
345 address: &str,
346 ) -> SubscriptionStream<T>
347 where
348 Self: 'a,
349 T: DeserializeOwned + std::marker::Unpin + 'a + std::fmt::Debug,
350 {
351 let mut svc_stream = self.pop_server_stream(&address.to_string()).await.unwrap();
352 let req = SubscriptionRequest {
353 handler_type: handler_type.to_string(),
354 handler_id: handler_id.to_string(),
355 };
356 let ser_request = bincode::serialize(&req).unwrap();
357 svc_stream.send(ser_request.into()).await.unwrap();
358 SubscriptionStream::<T>::new(svc_stream)
359 }
360
361 pub async fn subscribe<'a, T>(
374 &'a mut self,
375 handler_type: impl AsRef<str>,
376 handler_id: impl AsRef<str>,
377 ) -> Result<impl Stream<Item = Result<T, ResponseError>> + 'a, ClientError>
378 where
379 Self: 'a,
380 T: DeserializeOwned + std::marker::Unpin + 'a + std::fmt::Debug,
381 {
382 let handler_type = handler_type.as_ref().to_string();
383 let handler_id = handler_id.as_ref().to_string();
384 let mut address = self
385 .get_service_object_address(&handler_type, &handler_id)
386 .await?;
387
388 let stream = stream! {
389 loop {
390 let mut subscription_stream = self._subscribe(&handler_type, &handler_id, &address).await;
391 while let Some(v) = subscription_stream.next().await {
392 if let Err(ResponseError::Redirect(to)) = v {
393 address = to;
394 break;
395 }
396 yield v;
397 }
398 }
399 };
400 Ok(stream)
401 }
402
403 pub async fn ping(&mut self) -> Result<(), ClientError> {
408 let servers = self
409 .membership_storage
410 .members()
411 .await
412 .map_err(|_| ClientError::Connectivity)?;
413 let server = servers.first().ok_or(ClientError::NoServersAvailable)?;
414
415 async fn conn(address: &str) -> Result<(), ClientError> {
416 TcpStream::connect(&address)
417 .await
418 .map(|_stream| Ok(()))
419 .map_err(|_e| ClientError::Connectivity)?
420 }
421
422 match timeout(
423 std::time::Duration::from_millis(self.timeout_millis),
424 conn(&server.address()),
425 )
426 .await
427 {
428 Ok(x) => x,
429 Err(_elapsed) => Err(ClientError::Connectivity),
430 }
431 }
432}
433
434#[cfg(test)]
435mod test {
436 use super::*;
437 use crate::cluster::storage::{local::LocalStorage, Member, MembershipStorage};
438
439 fn client() -> Client<LocalStorage> {
440 Client {
441 timeout_millis: 1000,
442 membership_storage: LocalStorage::default(),
443 active_servers: Default::default(),
444 ts_active_servers_refresh: 0,
445 streams: Arc::default(),
446 placement: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(10).unwrap()))),
447 }
448 }
449
450 async fn client_with_members() -> Client<LocalStorage> {
451 let client = client();
452 let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
453 server.set_active(true);
454 client
455 .membership_storage
456 .push(server)
457 .await
458 .expect("add member");
459 client
460 }
461
462 #[tokio::test]
463 async fn test_server_stream_no_servers_available_error() {
464 let mut client = client();
465 let stream_err = client
466 .server_stream(&"0.0.0.0:6000".to_string())
467 .await
468 .unwrap_err();
469 assert_eq!(stream_err, ClientError::NoServersAvailable);
470 }
471
472 #[tokio::test]
473 async fn test_server_stream_server_not_available_error() {
474 let mut client = client_with_members().await;
475 let stream_err = client
476 .server_stream(&"0.0.0.0:6000".to_string())
477 .await
478 .unwrap_err();
479 assert_eq!(
480 stream_err,
481 ClientError::ServerNotAvailable("0.0.0.0:6000".to_string())
482 );
483 }
484
485 #[tokio::test]
486 async fn test_server_stream_cant_connect_to_server() {
487 let mut client = client_with_members().await;
488 let stream = client.server_stream(&"0.0.0.0:1234".to_string()).await;
489
490 assert!(matches!(stream, Err(ClientError::Disconnect)));
495 }
496
497 #[tokio::test]
498 async fn test_service_clone() {
499 let client = client_with_members().await;
500 let _ = client.clone();
501 }
502}