Skip to main content

rio_rs/client/
mod.rs

1//! Talk to a rio-rs server
2//!
3//! Provides a client to interact with a cluster for both request/response and pub/sub
4//!
5//! There is a pooled client.
6//! The client also does proper placement lookups and controls its own
7//! caching strategy
8
9mod builder;
10mod pool;
11pub mod tower_services;
12
13use async_stream::stream;
14pub use builder::ClientBuilder;
15pub use pool::ClientConnectionManager;
16pub use pool::Pool;
17pub use pool::PooledConnection;
18
19use dashmap::mapref::one::RefMut;
20use dashmap::DashMap;
21use futures::SinkExt;
22use futures::{Stream, StreamExt};
23use lru::LruCache;
24use rand::rng;
25use rand::seq::IndexedRandom;
26use serde::de::DeserializeOwned;
27use serde::Serialize;
28use std::collections::HashSet;
29use std::marker::PhantomData;
30use std::num::NonZeroUsize;
31use std::sync::{Arc, RwLock};
32use tokio::net::TcpStream;
33use tokio::time::timeout;
34use tokio_util::codec::{Framed, LengthDelimitedCodec};
35use tower::Service as TowerService;
36
37use crate::cluster::storage::MembershipStorage;
38use crate::protocol::pubsub::{SubscriptionRequest, SubscriptionResponse};
39use crate::protocol::{ClientError, RequestEnvelope, RequestError, ResponseError};
40use crate::registry::IdentifiableType;
41
42pub const DEFAULT_TIMEOUT_MILLIS: u64 = 500;
43
44/// Client struct to interact with a cluster for requests and subscriptions
45///
46/// S is the MembershipStorage implementation to fetch the cluster members
47#[derive(Clone)]
48pub struct Client<S> {
49    timeout_millis: u64,
50
51    /// Membership view used for Server's service discovery
52    membership_storage: S,
53
54    /// List of servers that are accepting requests
55    active_servers: HashSet<String>,
56
57    /// Timestamp of the last time self.active_servers was refresh
58    ts_active_servers_refresh: u64,
59
60    /// Framed TCP Stream mapped by ip+port address
61    streams: Arc<DashMap<String, Framed<TcpStream, LengthDelimitedCodec>>>,
62
63    /// Cached location of objects previously used by  the client
64    placement: Arc<RwLock<LruCache<(String, String), String>>>,
65}
66
67/// Stream of subscription messages. This is used for pub/sub.
68pub struct SubscriptionStream<T>
69where
70    T: DeserializeOwned,
71{
72    // TODO make this over an impl G instead of Framed
73    pub tcp_stream: Framed<TcpStream, LengthDelimitedCodec>,
74    _phantom: PhantomData<T>,
75}
76
77impl<T> SubscriptionStream<T>
78where
79    T: DeserializeOwned,
80{
81    pub fn new(tcp_stream: Framed<TcpStream, LengthDelimitedCodec>) -> Self {
82        SubscriptionStream {
83            tcp_stream,
84            _phantom: PhantomData {},
85        }
86    }
87}
88
89/// <div class="warning">
90/// Remove Unpin
91/// </div>
92impl<T> Stream for SubscriptionStream<T>
93where
94    T: DeserializeOwned + std::marker::Unpin + std::fmt::Debug,
95{
96    type Item = Result<T, ResponseError>;
97    fn poll_next(
98        self: std::pin::Pin<&mut Self>,
99        cx: &mut std::task::Context<'_>,
100    ) -> std::task::Poll<Option<Self::Item>> {
101        let self_mut = self.get_mut();
102        self_mut.tcp_stream.poll_next_unpin(cx).map(|maybe_bytes| {
103            let bytes_result = maybe_bytes?;
104            let bytes_message = match bytes_result {
105                Ok(bytes_message) => bytes_message,
106                Err(err) => {
107                    return Some(Err(ResponseError::DeseralizationError(err.to_string())));
108                }
109            };
110
111            let sub_response: SubscriptionResponse = match bincode::deserialize(&bytes_message) {
112                Ok(sub_response) => sub_response,
113                Err(err) => return Some(Err(ResponseError::DeseralizationError(err.to_string()))),
114            };
115
116            let final_message = match sub_response.body {
117                Ok(v) => {
118                    let response: Result<T, _> = bincode::deserialize(&v)
119                        .map_err(|e| ResponseError::DeseralizationError(e.to_string()));
120                    response
121                }
122                Err(err) => Err(err),
123            };
124            Some(final_message)
125        })
126    }
127}
128
129type ClientResult<T> = Result<T, ClientError>;
130
131impl<S> Client<S>
132where
133    S: 'static + MembershipStorage,
134{
135    /// Create a new Client from a MembershipStorage
136    pub fn new(members_storage: S) -> Self {
137        let lru_limit = NonZeroUsize::new(1_000).expect("LruCache limit must be greater than 0");
138
139        Client {
140            membership_storage: members_storage,
141            timeout_millis: DEFAULT_TIMEOUT_MILLIS,
142            active_servers: Default::default(),
143            ts_active_servers_refresh: 0,
144            streams: Arc::default(),
145            placement: Arc::new(RwLock::new(LruCache::new(lru_limit))),
146        }
147    }
148
149    /// Fetch a list of active servers if it hasn't done yet, or if the current list is too old
150    ///
151    /// Note that this is not an incremental operation. It will replace all the current active servers
152    /// cached on the client
153    async fn fetch_active_servers(&mut self) -> ClientResult<()> {
154        // if there are active servers and the refresh time stamp has changed
155        // We assume the cache is good
156        if !self.active_servers.is_empty() && self.ts_active_servers_refresh > 0 {
157            return Ok(());
158        }
159
160        let active_servers: HashSet<String> = self
161            .membership_storage
162            .active_members()
163            .await
164            .map_err(|_| ClientError::RendevouzUnavailable)?
165            .iter()
166            .map(|member| member.address())
167            .collect();
168
169        self.active_servers = active_servers;
170        self.ts_active_servers_refresh = 1;
171        Ok(())
172    }
173
174    async fn ensure_stream_exists(&mut self, address: &str) -> ClientResult<()> {
175        self.fetch_active_servers().await?;
176
177        // We start this method fetching the active servers, so if there are no active servers we
178        // fail
179        if self.active_servers.is_empty() {
180            return Err(ClientError::NoServersAvailable);
181        }
182
183        // If we do have items but the asked address is not there, the active_servers might be
184        // outdated and it will reset the refresh time and fetch it again
185        if !self.active_servers.contains(address) {
186            self.ts_active_servers_refresh = 0;
187            self.fetch_active_servers().await?;
188        }
189
190        // After fetch and re-fetch, if the asked address is not on the list, it means the caller
191        // is outdated
192        if !self.active_servers.contains(address) {
193            return Err(ClientError::ServerNotAvailable(address.to_string()));
194        }
195
196        // If there are no stream for the address, create a new one
197        // This is on a nested block so it controlls the guards in `self.stream`
198        if self.streams.get(address).is_none() {
199            let stream = TcpStream::connect(&address)
200                .await
201                .map_err(|_| ClientError::Disconnect)?;
202            let stream = Framed::new(stream, LengthDelimitedCodec::new());
203            self.streams.insert(address.to_string(), stream);
204        };
205        Ok(())
206    }
207
208    /// Get an existing connection to server `address` or create a new one
209    ///
210    /// If the address is not one of the known online servers, it will fetch
211    /// the list of active servers again
212    async fn server_stream(
213        &mut self,
214        address: &String,
215    ) -> ClientResult<RefMut<'_, String, Framed<TcpStream, LengthDelimitedCodec>>> {
216        self.ensure_stream_exists(address).await?;
217        self.streams
218            .get_mut(address)
219            .ok_or(ClientError::Connectivity)
220    }
221
222    /// Same as [Self::server_stream], but it pops from the stream cache
223    async fn pop_server_stream(
224        &mut self,
225        address: &String,
226    ) -> ClientResult<Framed<TcpStream, LengthDelimitedCodec>> {
227        self.ensure_stream_exists(address).await?;
228        self.streams
229            .remove(address)
230            .map(|(_, v)| v)
231            .ok_or(ClientError::Connectivity)
232    }
233
234    /// Returns the address for a given service object
235    async fn get_service_object_address(
236        &mut self,
237        service_object_type: impl ToString,
238        service_object_id: impl ToString,
239    ) -> ClientResult<String> {
240        self.fetch_active_servers().await?;
241        let object_id = (
242            service_object_id.to_string(),
243            service_object_type.to_string(),
244        );
245        let address = {
246            let mut placement_guard = self
247                .placement
248                .write()
249                .map_err(|_| ClientError::PlacementLock)?;
250
251            let cached_address = placement_guard.get(&object_id);
252            match cached_address {
253                Some(address) => address.clone(),
254                None => {
255                    // If there is no address associated with this service,
256                    // it will pick one at random (allowing the server to 'correct' it)
257                    let mut rng = rng();
258                    let servers: Vec<String> = self.active_servers.iter().cloned().collect();
259                    let random_server = servers
260                        .choose(&mut rng)
261                        .ok_or(ClientError::NoServersAvailable)?;
262                    random_server.clone()
263                }
264            }
265        };
266        Ok(address)
267    }
268
269    /// Returns a stream to the server that a given ServiceObject might be allocated into
270    async fn service_object_stream(
271        &mut self,
272        service_object_type: impl ToString,
273        service_object_id: impl ToString,
274    ) -> ClientResult<RefMut<'_, String, Framed<TcpStream, LengthDelimitedCodec>>> {
275        self.fetch_active_servers().await?;
276        let address = self
277            .get_service_object_address(service_object_type, service_object_id)
278            .await?;
279        self.server_stream(&address).await
280    }
281
282    /// Send a request to the cluster transparently (the caller doesn't need to know where the
283    /// object is placed)
284    ///
285    /// <div class="warning">
286    /// <b>TODO</b>
287    ///
288    /// When the cached or selected server are not available, it needs to refresh all the
289    /// cache and try a different server, this process needs to repeat until it finds a new
290    /// available server
291    /// </div>
292    pub async fn send<T, E>(
293        &mut self,
294        handler_type: impl AsRef<str>,
295        handler_id: impl AsRef<str>,
296        payload: &(impl Serialize + IdentifiableType + Send + Sync),
297    ) -> Result<T, RequestError<E>>
298    where
299        T: DeserializeOwned,
300        E: std::error::Error + DeserializeOwned + Clone + Send + Sync,
301    {
302        // TODO move fetch_active_servers into poll_ready self.ready().await?;
303        self.fetch_active_servers().await?;
304
305        let handler_type = handler_type.as_ref().to_string();
306        let handler_id = handler_id.as_ref().to_string();
307        let ser_payload = bincode::serialize(&payload)
308            .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
309        let message_type = payload.instance_type_id().to_string();
310
311        let request = RequestEnvelope::new(
312            handler_type.clone(),
313            handler_id.clone(),
314            message_type.clone(),
315            ser_payload.clone(),
316        );
317        let tower_svc = tower_services::Request::new(self.clone());
318        let mut tower_svc = tower_services::RequestRedirect::new(tower_svc);
319        let response = tower_svc.call(request).await;
320        response.and_then(|x| {
321            let body: T = bincode::deserialize(&x)
322                .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
323            Ok(body)
324        })
325    }
326
327    /// Same as [Self::send], but it uses the [RequestEnvelope] ready for serialization
328    pub async fn send_request<E: std::error::Error + DeserializeOwned + Clone + Send + Sync>(
329        &mut self,
330        request: RequestEnvelope,
331    ) -> Result<Vec<u8>, RequestError<E>> {
332        // TODO move fetch_active_servers into poll_ready self.ready().await?;
333        self.fetch_active_servers().await?;
334
335        let tower_svc = tower_services::Request::new(self.clone());
336        let mut tower_svc = tower_services::RequestRedirect::new(tower_svc);
337        let response = tower_svc.call(request).await?;
338        Ok(response)
339    }
340
341    async fn _subscribe<'a, T>(
342        &'a mut self,
343        handler_type: &str,
344        handler_id: &str,
345        address: &str,
346    ) -> SubscriptionStream<T>
347    where
348        Self: 'a,
349        T: DeserializeOwned + std::marker::Unpin + 'a + std::fmt::Debug,
350    {
351        let mut svc_stream = self.pop_server_stream(&address.to_string()).await.unwrap();
352        let req = SubscriptionRequest {
353            handler_type: handler_type.to_string(),
354            handler_id: handler_id.to_string(),
355        };
356        let ser_request = bincode::serialize(&req).unwrap();
357        svc_stream.send(ser_request.into()).await.unwrap();
358        SubscriptionStream::<T>::new(svc_stream)
359    }
360
361    /// Subscribe to events from a service object
362    ///
363    /// <div class="warning">
364    /// <b>TODO</b>
365    ///
366    /// - [x] Returns async iter
367    /// - [x] Handle redirects
368    /// - [ ] Move this logic into a tower service
369    /// - [ ] Support moving service object (after you connect to a node and the handler you are listening to moves to some other node)
370    /// - [x] Use dedicated connection
371    ///
372    /// </div>
373    pub async fn subscribe<'a, T>(
374        &'a mut self,
375        handler_type: impl AsRef<str>,
376        handler_id: impl AsRef<str>,
377    ) -> Result<impl Stream<Item = Result<T, ResponseError>> + 'a, ClientError>
378    where
379        Self: 'a,
380        T: DeserializeOwned + std::marker::Unpin + 'a + std::fmt::Debug,
381    {
382        let handler_type = handler_type.as_ref().to_string();
383        let handler_id = handler_id.as_ref().to_string();
384        let mut address = self
385            .get_service_object_address(&handler_type, &handler_id)
386            .await?;
387
388        let stream = stream! {
389            loop {
390                let mut subscription_stream = self._subscribe(&handler_type, &handler_id, &address).await;
391                while let Some(v) = subscription_stream.next().await {
392                    if let Err(ResponseError::Redirect(to)) = v {
393                        address = to;
394                        break;
395                    }
396                    yield v;
397                }
398            }
399        };
400        Ok(stream)
401    }
402
403    /// Connects to a the first server of the MembershipStorage
404    ///
405    /// This is used mostly by the PeerToPeerClusterProvider to check whether
406    /// a set of servers is reacheable and alive
407    pub async fn ping(&mut self) -> Result<(), ClientError> {
408        let servers = self
409            .membership_storage
410            .members()
411            .await
412            .map_err(|_| ClientError::Connectivity)?;
413        let server = servers.first().ok_or(ClientError::NoServersAvailable)?;
414
415        async fn conn(address: &str) -> Result<(), ClientError> {
416            TcpStream::connect(&address)
417                .await
418                .map(|_stream| Ok(()))
419                .map_err(|_e| ClientError::Connectivity)?
420        }
421
422        match timeout(
423            std::time::Duration::from_millis(self.timeout_millis),
424            conn(&server.address()),
425        )
426        .await
427        {
428            Ok(x) => x,
429            Err(_elapsed) => Err(ClientError::Connectivity),
430        }
431    }
432}
433
434#[cfg(test)]
435mod test {
436    use super::*;
437    use crate::cluster::storage::{local::LocalStorage, Member, MembershipStorage};
438
439    fn client() -> Client<LocalStorage> {
440        Client {
441            timeout_millis: 1000,
442            membership_storage: LocalStorage::default(),
443            active_servers: Default::default(),
444            ts_active_servers_refresh: 0,
445            streams: Arc::default(),
446            placement: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(10).unwrap()))),
447        }
448    }
449
450    async fn client_with_members() -> Client<LocalStorage> {
451        let client = client();
452        let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
453        server.set_active(true);
454        client
455            .membership_storage
456            .push(server)
457            .await
458            .expect("add member");
459        client
460    }
461
462    #[tokio::test]
463    async fn test_server_stream_no_servers_available_error() {
464        let mut client = client();
465        let stream_err = client
466            .server_stream(&"0.0.0.0:6000".to_string())
467            .await
468            .unwrap_err();
469        assert_eq!(stream_err, ClientError::NoServersAvailable);
470    }
471
472    #[tokio::test]
473    async fn test_server_stream_server_not_available_error() {
474        let mut client = client_with_members().await;
475        let stream_err = client
476            .server_stream(&"0.0.0.0:6000".to_string())
477            .await
478            .unwrap_err();
479        assert_eq!(
480            stream_err,
481            ClientError::ServerNotAvailable("0.0.0.0:6000".to_string())
482        );
483    }
484
485    #[tokio::test]
486    async fn test_server_stream_cant_connect_to_server() {
487        let mut client = client_with_members().await;
488        let stream = client.server_stream(&"0.0.0.0:1234".to_string()).await;
489
490        // TODO
491        //  this test used to match against ClientError::Unknown,
492        //  I don't recall why, so I need to investigate wether it was
493        //  broken before or it is broken now
494        assert!(matches!(stream, Err(ClientError::Disconnect)));
495    }
496
497    #[tokio::test]
498    async fn test_service_clone() {
499        let client = client_with_members().await;
500        let _ = client.clone();
501    }
502}