Skip to main content

rio_rs/client/
tower_services.rs

1use std::marker::PhantomData;
2use std::task::Poll;
3use std::time::Duration;
4
5use futures::future::BoxFuture;
6use futures::{pin_mut, FutureExt, SinkExt, StreamExt};
7use log::{debug, error, info, warn};
8use serde::de::DeserializeOwned;
9use tower::Service as TowerService;
10
11use crate::cluster::storage::MembershipStorage;
12use crate::protocol::RequestError;
13use crate::protocol::{ClientError, RequestEnvelope, ResponseEnvelope, ResponseError};
14
15use super::Client;
16
17/// Requests have only a single response from the server (for streaming back result, see [crate::client::Client::subscribe])
18///
19/// This contains the [Client] because it does
20///
21/// - `'a` is the lifetime for the tower service's box future
22/// - `S` is the MembershipStorage for the internal client
23/// - `E` is the generic for the RequestError
24#[derive(Clone)]
25pub struct Request<'a, S, E> {
26    client: Client<S>,
27    _lifetime_marker: PhantomData<&'a ()>,
28    _error_marker: PhantomData<E>,
29}
30
31impl<'a, S, E> Request<'a, S, E>
32where
33    S: MembershipStorage,
34{
35    pub fn new(client: Client<S>) -> Self {
36        Request {
37            client,
38            _lifetime_marker: PhantomData,
39            _error_marker: PhantomData,
40        }
41    }
42}
43
44impl<'a, S, E: std::error::Error + DeserializeOwned> TowerService<RequestEnvelope>
45    for Request<'a, S, E>
46where
47    S: MembershipStorage + 'static, // TODO remove 'static
48{
49    type Response = Vec<u8>;
50    type Error = RequestError<E>;
51    type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
52
53    /// Waits for members to be available
54    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
55        let fetch_active_servers = self.client.fetch_active_servers();
56        pin_mut!(fetch_active_servers);
57        fetch_active_servers
58            .poll_unpin(cx)
59            .map_ok(|_| ())
60            .map_err(|e| e.into())
61    }
62
63    /// TODO
64    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
65        let mut client = self.client.clone();
66        Box::pin(async move {
67            let mut stream = client
68                .service_object_stream(&req.handler_type, &req.handler_id)
69                .await?;
70
71            let ser_request = bincode::serialize(&req)
72                .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
73
74            stream.send(ser_request.into()).await?;
75            match stream.next().await {
76                Some(Ok(frame)) => {
77                    let message: ResponseEnvelope = bincode::deserialize(&frame)
78                        .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
79                    Ok(message.body?)
80                }
81                Some(Err(e)) => Err(RequestError::ClientError(ClientError::IoError(
82                    e.to_string(),
83                ))),
84                // When there are no more items on the stream, it means the TCP stream was
85                // disconnected
86                None => Err(RequestError::ClientError(ClientError::Disconnect)),
87            }
88        })
89    }
90}
91
92/// This type wraps a [Request], and it retries its call under some conditions:
93///
94/// - When the object is not on the cached/expected placement
95/// - When the object is not yet allocated
96pub struct RequestRedirect<'a, S, E>
97where
98    S: MembershipStorage,
99    Request<'a, S, E>: Clone,
100{
101    inner: Request<'a, S, E>,
102}
103
104impl<'a, S, E> RequestRedirect<'a, S, E>
105where
106    S: MembershipStorage,
107    Request<'a, S, E>: Clone,
108{
109    pub fn new(inner: Request<'a, S, E>) -> Self {
110        RequestRedirect { inner }
111    }
112}
113
114impl<'a, S, E> TowerService<RequestEnvelope> for RequestRedirect<'a, S, E>
115where
116    E: std::error::Error + DeserializeOwned + Send + Sync + 'a,
117    S: MembershipStorage + 'static,
118    Request<'a, S, E>: Clone,
119{
120    type Response = Vec<u8>;
121    type Error = RequestError<E>;
122    type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
123
124    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
125        self.inner.poll_ready(cx)
126    }
127
128    /// <div class="warning">There are tons of extra allocations to avoid race conditions</div>
129    ///
130    /// <div class="warning">
131    /// This method works over a clone of [Request], we need to make sure
132    /// it keeps in sync with the originial version
133    /// </div>
134    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
135        // Clones a bunch of stuff so the future returned by this
136        // function lives shorter than the actual service (as in `'b: 'a`)
137        let handler_type = req.handler_type.clone();
138        let handler_id = req.handler_id.clone();
139        let request = req.clone();
140        let mut inner_service = self.inner.clone();
141
142        // TODO move this to config
143        let retry_min_duration = Duration::from_nanos(1_000); // 0.01ms
144        let retry_max_duration = Duration::from_secs(2);
145        let max_retries = Some(20);
146        // END: TODO move this to config
147
148        let mut retry_count = 0;
149        let mut retry_duration = retry_min_duration;
150        Box::pin(async move {
151            loop {
152                // Used a cloned request, so it can be used in a loop
153                let response = inner_service.call(request.clone()).await;
154                match response {
155                    // This case happens when there is a mismatch between the client and the
156                    // servers regarding where the service object is allocated
157                    // Ps.: No need to sleep on redirect
158                    Err(RequestError::ResponseError(ResponseError::Redirect(to))) => {
159                        // Add the new address to the placement so in the next iteration
160                        // it will use the right server
161                        info!("Redirect to {}", to);
162                        inner_service
163                            .client
164                            .placement
165                            .write()
166                            .map_err(|_| ClientError::PlacementLock)?
167                            .put((handler_type.clone(), handler_id.clone()), to);
168                    }
169                    // All these errors indicate that the server we've tried is no longer available
170                    // When facing one of the errors below, we need to retry the
171                    // request so it picks up a new Server on the cluster
172                    Err(e)
173                        if matches!(
174                            e,
175                            RequestError::ResponseError(ResponseError::DeallocateServiceObject,)
176                                | RequestError::ClientError(ClientError::Disconnect)
177                                | RequestError::ClientError(ClientError::ServerNotAvailable(_))
178                                | RequestError::ClientError(ClientError::IoError(_))
179                        ) =>
180                    {
181                        // early quiting if max_retries reached
182                        if let Some(max_retries) = max_retries {
183                            if retry_count > max_retries {
184                                error!("Max retries ({}) reached: {:?}", max_retries, e);
185                                return Err(e);
186                            }
187                        }
188                        warn!("{:?}", e);
189
190                        // update retry info
191                        debug!("Retry in {:?}", retry_duration);
192                        tokio::time::sleep(retry_duration).await;
193                        retry_count += 1;
194                        retry_duration *= 2;
195                        retry_duration =
196                            retry_duration.clamp(retry_min_duration, retry_max_duration);
197
198                        // Removed the old placement, the next request
199                        // will pickup a new placement to try from
200                        warn!("Refresh the list of servers");
201                        inner_service.client.ts_active_servers_refresh = 0; // forces re-fetching the
202                                                                            // active servers
203                        inner_service
204                            .client
205                            .placement
206                            .write()
207                            .map_err(|_| ClientError::PlacementLock)?
208                            .pop(&(handler_type.clone(), handler_id.clone()));
209                    }
210                    Err(e) => {
211                        // Have a separate case to avoid spamming error logs with the application
212                        // error in binary format
213                        if let RequestError::ResponseError(ResponseError::ApplicationError(_)) = e {
214                            error!("Uncaught error ResponseError::ApplicationError(...)");
215                        } else {
216                            error!("Uncaught error {:#?}", e);
217                        }
218                        return Err(e);
219                    }
220                    // Return as is
221                    rest => return rest,
222                }
223            }
224        })
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use async_trait::async_trait;
231    use chrono::{DateTime, Utc};
232    use lru::LruCache;
233    use serde::{Deserialize, Serialize};
234    use std::{
235        num::NonZero,
236        sync::{Arc, RwLock},
237    };
238    use thiserror::Error;
239    use tower::ServiceExt;
240
241    use super::*;
242    use crate::{
243        cluster::storage::{
244            local::LocalStorage, Member, MembershipResult, MembershipStorage, MembershipUnitResult,
245        },
246        errors::MembershipError,
247    };
248
249    #[derive(Error, Debug, Serialize, Deserialize, PartialEq)]
250    enum NoopError {
251        #[error("No-op")]
252        Noop,
253    }
254
255    #[derive(Clone, Default, Debug)]
256    struct FailMembershipStorage {}
257
258    #[async_trait]
259    impl MembershipStorage for FailMembershipStorage {
260        async fn push(&self, _: Member) -> MembershipUnitResult {
261            Ok(())
262        }
263        async fn remove(&self, _: &str, _: &str) -> MembershipUnitResult {
264            Ok(())
265        }
266        async fn set_is_active(&self, _: &str, _: &str, _: bool) -> MembershipUnitResult {
267            Ok(())
268        }
269        async fn members(&self) -> MembershipResult<Vec<Member>> {
270            Err(MembershipError::Unknown("".to_string()))
271        }
272        async fn notify_failure(&self, _: &str, _: &str) -> MembershipUnitResult {
273            Ok(())
274        }
275        async fn member_failures(&self, _: &str, _: &str) -> MembershipResult<Vec<DateTime<Utc>>> {
276            Ok(vec![])
277        }
278    }
279
280    fn client() -> Client<LocalStorage> {
281        Client {
282            timeout_millis: 1000,
283            membership_storage: LocalStorage::default(),
284            active_servers: Default::default(),
285            ts_active_servers_refresh: 0,
286            streams: Arc::default(),
287            placement: Arc::new(RwLock::new(LruCache::new(NonZero::new(10).unwrap()))),
288        }
289    }
290
291    #[tokio::test]
292    async fn test_poll_ready_no_active_server() {
293        let client = client();
294        let mut request: Request<_, NoopError> = Request::new(client);
295        request.ready().await.expect("poll_ready");
296        assert!(request.client.active_servers.is_empty());
297    }
298
299    #[tokio::test]
300    async fn test_poll_ready_with_active_servers() {
301        // starts with no active servers
302        let client = client();
303        assert!(client.active_servers.is_empty());
304
305        let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
306        server.set_active(true);
307        client
308            .membership_storage
309            .push(server)
310            .await
311            .expect("add member");
312
313        // When poll_ready is called, it fetches the active servers
314        let mut request: Request<_, NoopError> = Request::new(client);
315        request.ready().await.expect("poll_ready");
316        assert_eq!(request.client.active_servers.len(), 1);
317    }
318
319    #[tokio::test]
320    async fn test_poll_ready_error() {
321        let client = Client {
322            timeout_millis: 1000,
323            membership_storage: FailMembershipStorage {},
324            active_servers: Default::default(),
325            ts_active_servers_refresh: 0,
326            streams: Arc::default(),
327            placement: Arc::new(RwLock::new(LruCache::new(NonZero::new(10).unwrap()))),
328        };
329        let mut request: Request<_, NoopError> = Request::new(client);
330        let waker = futures::task::noop_waker();
331        let mut context = std::task::Context::from_waker(&waker);
332        let poll_ready = request.poll_ready(&mut context);
333        assert_eq!(
334            poll_ready,
335            Poll::Ready(Err(RequestError::ClientError(
336                ClientError::RendevouzUnavailable
337            )))
338        );
339    }
340}