1use std::marker::PhantomData;
2use std::task::Poll;
3use std::time::Duration;
4
5use futures::future::BoxFuture;
6use futures::{pin_mut, FutureExt, SinkExt, StreamExt};
7use log::{debug, error, info, warn};
8use serde::de::DeserializeOwned;
9use tower::Service as TowerService;
10
11use crate::cluster::storage::MembershipStorage;
12use crate::protocol::RequestError;
13use crate::protocol::{ClientError, RequestEnvelope, ResponseEnvelope, ResponseError};
14
15use super::Client;
16
17#[derive(Clone)]
25pub struct Request<'a, S, E>
26where
27    S: MembershipStorage,
28{
29    client: Client<S>,
30    _lifetime_marker: PhantomData<&'a ()>,
31    _error_marker: PhantomData<E>,
32}
33
34impl<'a, S, E> Request<'a, S, E>
35where
36    S: MembershipStorage,
37{
38    pub fn new(client: Client<S>) -> Self {
39        Request {
40            client,
41            _lifetime_marker: PhantomData,
42            _error_marker: PhantomData,
43        }
44    }
45}
46
47impl<'a, S, E: std::error::Error + DeserializeOwned> TowerService<RequestEnvelope>
48    for Request<'a, S, E>
49where
50    S: MembershipStorage + 'static, {
52    type Response = Vec<u8>;
53    type Error = RequestError<E>;
54    type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
55
56    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
58        let fetch_active_servers = self.client.fetch_active_servers();
59        pin_mut!(fetch_active_servers);
60        fetch_active_servers
61            .poll_unpin(cx)
62            .map_ok(|_| ())
63            .map_err(|e| e.into())
64    }
65
66    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
68        let mut client = self.client.clone();
69        Box::pin(async move {
70            let mut stream = client
71                .service_object_stream(&req.handler_type, &req.handler_id)
72                .await?;
73
74            let ser_request = bincode::serialize(&req)
75                .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
76
77            stream.send(ser_request.into()).await?;
78            match stream.next().await {
79                Some(Ok(frame)) => {
80                    let message: ResponseEnvelope = bincode::deserialize(&frame)
81                        .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
82                    Ok(message.body?)
83                }
84                Some(Err(e)) => Err(RequestError::ClientError(ClientError::IoError(
85                    e.to_string(),
86                ))),
87                None => Err(RequestError::ClientError(ClientError::Disconnect)),
90            }
91        })
92    }
93}
94
95pub struct RequestRedirect<'a, S, E>
100where
101    S: MembershipStorage,
102    Request<'a, S, E>: Clone,
103{
104    inner: Request<'a, S, E>,
105}
106
107impl<'a, S, E> RequestRedirect<'a, S, E>
108where
109    S: MembershipStorage,
110    Request<'a, S, E>: Clone,
111{
112    pub fn new(inner: Request<'a, S, E>) -> Self {
113        RequestRedirect { inner }
114    }
115}
116
117impl<'a, S, E> TowerService<RequestEnvelope> for RequestRedirect<'a, S, E>
118where
119    E: std::error::Error + DeserializeOwned + Send + Sync + 'a,
120    S: MembershipStorage + 'static,
121    Request<'a, S, E>: Clone,
122{
123    type Response = Vec<u8>;
124    type Error = RequestError<E>;
125    type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
126
127    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
128        self.inner.poll_ready(cx)
129    }
130
131    fn call(&mut self, req: RequestEnvelope) -> Self::Future {
138        let handler_type = req.handler_type.clone();
141        let handler_id = req.handler_id.clone();
142        let request = req.clone();
143        let mut inner_service = self.inner.clone();
144
145        let retry_min_duration = Duration::from_nanos(1_000); let retry_max_duration = Duration::from_secs(2);
148        let max_retries = Some(20);
149        let mut retry_count = 0;
152        let mut retry_duration = retry_min_duration.clone();
153        Box::pin(async move {
154            loop {
155                let response = inner_service.call(request.clone()).await;
157                match response {
158                    Err(RequestError::ResponseError(ResponseError::Redirect(to))) => {
162                        info!("Redirect to {}", to);
165                        inner_service
166                            .client
167                            .placement
168                            .write()
169                            .map_err(|_| ClientError::PlacementLock)?
170                            .put((handler_type.clone(), handler_id.clone()), to);
171                    }
172                    Err(e)
176                        if matches!(
177                            e,
178                            RequestError::ResponseError(ResponseError::DeallocateServiceObject,)
179                                | RequestError::ClientError(ClientError::Disconnect)
180                                | RequestError::ClientError(ClientError::ServerNotAvailable(_))
181                                | RequestError::ClientError(ClientError::IoError(_))
182                        ) =>
183                    {
184                        if let Some(max_retries) = max_retries {
186                            if retry_count > max_retries {
187                                error!("Max retries ({}) reached: {:?}", max_retries, e);
188                                return Err(e);
189                            }
190                        }
191                        warn!("{:?}", e);
192
193                        debug!("Retry in {:?}", retry_duration);
195                        tokio::time::sleep(retry_duration).await;
196                        retry_count += 1;
197                        retry_duration *= 2;
198                        retry_duration =
199                            retry_duration.clamp(retry_min_duration, retry_max_duration);
200
201                        warn!("Refresh the list of servers");
204                        inner_service.client.ts_active_servers_refresh = 0; inner_service
207                            .client
208                            .placement
209                            .write()
210                            .map_err(|_| ClientError::PlacementLock)?
211                            .pop(&(handler_type.clone(), handler_id.clone()));
212                    }
213                    Err(e) => {
214                        if let RequestError::ResponseError(ResponseError::ApplicationError(_)) = e {
217                            error!("Uncaught error ResponseError::ApplicationError(...)");
218                        } else {
219                            error!("Uncaught error {:#?}", e);
220                        }
221                        return Err(e);
222                    }
223                    rest => return rest,
225                }
226            }
227        })
228    }
229}
230
231#[cfg(test)]
232mod test {
233    use async_trait::async_trait;
234    use chrono::{DateTime, Utc};
235    use lru::LruCache;
236    use serde::{Deserialize, Serialize};
237    use std::sync::{Arc, RwLock};
238    use thiserror::Error;
239    use tower::ServiceExt;
240
241    use super::*;
242    use crate::{
243        cluster::storage::{
244            local::LocalStorage, Member, MembershipResult, MembershipStorage, MembershipUnitResult,
245        },
246        errors::MembershipError,
247    };
248
249    #[derive(Error, Debug, Serialize, Deserialize, PartialEq)]
250    enum NoopError {
251        #[error("No-op")]
252        Noop,
253    }
254
255    #[derive(Clone, Default)]
256    struct FailMembershipStorage {}
257
258    #[async_trait]
259    impl MembershipStorage for FailMembershipStorage {
260        async fn push(&self, _: Member) -> MembershipUnitResult {
261            Ok(())
262        }
263        async fn remove(&self, _: &str, _: &str) -> MembershipUnitResult {
264            Ok(())
265        }
266        async fn set_is_active(&self, _: &str, _: &str, _: bool) -> MembershipUnitResult {
267            Ok(())
268        }
269        async fn members(&self) -> MembershipResult<Vec<Member>> {
270            Err(MembershipError::Unknown("".to_string()))
271        }
272        async fn notify_failure(&self, _: &str, _: &str) -> MembershipUnitResult {
273            Ok(())
274        }
275        async fn member_failures(&self, _: &str, _: &str) -> MembershipResult<Vec<DateTime<Utc>>> {
276            Ok(vec![])
277        }
278    }
279
280    fn client() -> Client<LocalStorage> {
281        Client {
282            timeout_millis: 1000,
283            membership_storage: LocalStorage::default(),
284            active_servers: Default::default(),
285            ts_active_servers_refresh: 0,
286            streams: Arc::default(),
287            placement: Arc::new(RwLock::new(LruCache::new(10))),
288        }
289    }
290
291    #[tokio::test]
292    async fn test_poll_ready_no_active_server() {
293        let client = client();
294        let mut request: Request<_, NoopError> = Request::new(client);
295        request.ready().await.expect("poll_ready");
296        assert!(request.client.active_servers.is_empty());
297    }
298
299    #[tokio::test]
300    async fn test_poll_ready_with_active_servers() {
301        let client = client();
303        assert!(client.active_servers.is_empty());
304
305        let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
306        server.set_active(true);
307        client
308            .membership_storage
309            .push(server)
310            .await
311            .expect("add member");
312
313        let mut request: Request<_, NoopError> = Request::new(client);
315        request.ready().await.expect("poll_ready");
316        assert_eq!(request.client.active_servers.len(), 1);
317    }
318
319    #[tokio::test]
320    async fn test_poll_ready_error() {
321        let client = Client {
322            timeout_millis: 1000,
323            membership_storage: FailMembershipStorage {},
324            active_servers: Default::default(),
325            ts_active_servers_refresh: 0,
326            streams: Arc::default(),
327            placement: Arc::new(RwLock::new(LruCache::new(10))),
328        };
329        let mut request: Request<_, NoopError> = Request::new(client);
330        let waker = futures::task::noop_waker();
331        let mut context = std::task::Context::from_waker(&waker);
332        let poll_ready = request.poll_ready(&mut context);
333        assert_eq!(
334            poll_ready,
335            Poll::Ready(Err(RequestError::ClientError(
336                ClientError::RendevouzUnavailable
337            )))
338        );
339    }
340}