1use std::marker::PhantomData;
2use std::task::Poll;
3use std::time::Duration;
4
5use futures::future::BoxFuture;
6use futures::{pin_mut, FutureExt, SinkExt, StreamExt};
7use log::{debug, error, info, warn};
8use serde::de::DeserializeOwned;
9use tower::Service as TowerService;
10
11use crate::cluster::storage::MembershipStorage;
12use crate::protocol::RequestError;
13use crate::protocol::{ClientError, RequestEnvelope, ResponseEnvelope, ResponseError};
14
15use super::Client;
16
17#[derive(Clone)]
25pub struct Request<'a, S, E>
26where
27 S: MembershipStorage,
28{
29 client: Client<S>,
30 _lifetime_marker: PhantomData<&'a ()>,
31 _error_marker: PhantomData<E>,
32}
33
34impl<'a, S, E> Request<'a, S, E>
35where
36 S: MembershipStorage,
37{
38 pub fn new(client: Client<S>) -> Self {
39 Request {
40 client,
41 _lifetime_marker: PhantomData,
42 _error_marker: PhantomData,
43 }
44 }
45}
46
47impl<'a, S, E: std::error::Error + DeserializeOwned> TowerService<RequestEnvelope>
48 for Request<'a, S, E>
49where
50 S: MembershipStorage + 'static, {
52 type Response = Vec<u8>;
53 type Error = RequestError<E>;
54 type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
55
56 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
58 let fetch_active_servers = self.client.fetch_active_servers();
59 pin_mut!(fetch_active_servers);
60 fetch_active_servers
61 .poll_unpin(cx)
62 .map_ok(|_| ())
63 .map_err(|e| e.into())
64 }
65
66 fn call(&mut self, req: RequestEnvelope) -> Self::Future {
68 let mut client = self.client.clone();
69 Box::pin(async move {
70 let mut stream = client
71 .service_object_stream(&req.handler_type, &req.handler_id)
72 .await?;
73
74 let ser_request = bincode::serialize(&req)
75 .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
76
77 stream.send(ser_request.into()).await?;
78 match stream.next().await {
79 Some(Ok(frame)) => {
80 let message: ResponseEnvelope = bincode::deserialize(&frame)
81 .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
82 Ok(message.body?)
83 }
84 Some(Err(e)) => Err(RequestError::ClientError(ClientError::IoError(
85 e.to_string(),
86 ))),
87 None => Err(RequestError::ClientError(ClientError::Disconnect)),
90 }
91 })
92 }
93}
94
95pub struct RequestRedirect<'a, S, E>
100where
101 S: MembershipStorage,
102 Request<'a, S, E>: Clone,
103{
104 inner: Request<'a, S, E>,
105}
106
107impl<'a, S, E> RequestRedirect<'a, S, E>
108where
109 S: MembershipStorage,
110 Request<'a, S, E>: Clone,
111{
112 pub fn new(inner: Request<'a, S, E>) -> Self {
113 RequestRedirect { inner }
114 }
115}
116
117impl<'a, S, E> TowerService<RequestEnvelope> for RequestRedirect<'a, S, E>
118where
119 E: std::error::Error + DeserializeOwned + Send + Sync + 'a,
120 S: MembershipStorage + 'static,
121 Request<'a, S, E>: Clone,
122{
123 type Response = Vec<u8>;
124 type Error = RequestError<E>;
125 type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
126
127 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
128 self.inner.poll_ready(cx)
129 }
130
131 fn call(&mut self, req: RequestEnvelope) -> Self::Future {
138 let handler_type = req.handler_type.clone();
141 let handler_id = req.handler_id.clone();
142 let request = req.clone();
143 let mut inner_service = self.inner.clone();
144
145 let retry_min_duration = Duration::from_nanos(1_000); let retry_max_duration = Duration::from_secs(2);
148 let max_retries = Some(20);
149 let mut retry_count = 0;
152 let mut retry_duration = retry_min_duration.clone();
153 Box::pin(async move {
154 loop {
155 let response = inner_service.call(request.clone()).await;
157 match response {
158 Err(RequestError::ResponseError(ResponseError::Redirect(to))) => {
162 info!("Redirect to {}", to);
165 inner_service
166 .client
167 .placement
168 .write()
169 .map_err(|_| ClientError::PlacementLock)?
170 .put((handler_type.clone(), handler_id.clone()), to);
171 }
172 Err(e)
176 if matches!(
177 e,
178 RequestError::ResponseError(ResponseError::DeallocateServiceObject,)
179 | RequestError::ClientError(ClientError::Disconnect)
180 | RequestError::ClientError(ClientError::ServerNotAvailable(_))
181 | RequestError::ClientError(ClientError::IoError(_))
182 ) =>
183 {
184 if let Some(max_retries) = max_retries {
186 if retry_count > max_retries {
187 error!("Max retries ({}) reached: {:?}", max_retries, e);
188 return Err(e);
189 }
190 }
191 warn!("{:?}", e);
192
193 debug!("Retry in {:?}", retry_duration);
195 tokio::time::sleep(retry_duration).await;
196 retry_count += 1;
197 retry_duration *= 2;
198 retry_duration =
199 retry_duration.clamp(retry_min_duration, retry_max_duration);
200
201 warn!("Refresh the list of servers");
204 inner_service.client.ts_active_servers_refresh = 0; inner_service
207 .client
208 .placement
209 .write()
210 .map_err(|_| ClientError::PlacementLock)?
211 .pop(&(handler_type.clone(), handler_id.clone()));
212 }
213 Err(e) => {
214 if let RequestError::ResponseError(ResponseError::ApplicationError(_)) = e {
217 error!("Uncaught error ResponseError::ApplicationError(...)");
218 } else {
219 error!("Uncaught error {:#?}", e);
220 }
221 return Err(e);
222 }
223 rest => return rest,
225 }
226 }
227 })
228 }
229}
230
231#[cfg(test)]
232mod test {
233 use async_trait::async_trait;
234 use chrono::{DateTime, Utc};
235 use lru::LruCache;
236 use serde::{Deserialize, Serialize};
237 use std::sync::{Arc, RwLock};
238 use thiserror::Error;
239 use tower::ServiceExt;
240
241 use super::*;
242 use crate::{
243 cluster::storage::{
244 local::LocalStorage, Member, MembershipResult, MembershipStorage, MembershipUnitResult,
245 },
246 errors::MembershipError,
247 };
248
249 #[derive(Error, Debug, Serialize, Deserialize, PartialEq)]
250 enum NoopError {
251 #[error("No-op")]
252 Noop,
253 }
254
255 #[derive(Clone, Default)]
256 struct FailMembershipStorage {}
257
258 #[async_trait]
259 impl MembershipStorage for FailMembershipStorage {
260 async fn push(&self, _: Member) -> MembershipUnitResult {
261 Ok(())
262 }
263 async fn remove(&self, _: &str, _: &str) -> MembershipUnitResult {
264 Ok(())
265 }
266 async fn set_is_active(&self, _: &str, _: &str, _: bool) -> MembershipUnitResult {
267 Ok(())
268 }
269 async fn members(&self) -> MembershipResult<Vec<Member>> {
270 Err(MembershipError::Unknown("".to_string()))
271 }
272 async fn notify_failure(&self, _: &str, _: &str) -> MembershipUnitResult {
273 Ok(())
274 }
275 async fn member_failures(&self, _: &str, _: &str) -> MembershipResult<Vec<DateTime<Utc>>> {
276 Ok(vec![])
277 }
278 }
279
280 fn client() -> Client<LocalStorage> {
281 Client {
282 timeout_millis: 1000,
283 membership_storage: LocalStorage::default(),
284 active_servers: Default::default(),
285 ts_active_servers_refresh: 0,
286 streams: Arc::default(),
287 placement: Arc::new(RwLock::new(LruCache::new(10))),
288 }
289 }
290
291 #[tokio::test]
292 async fn test_poll_ready_no_active_server() {
293 let client = client();
294 let mut request: Request<_, NoopError> = Request::new(client);
295 request.ready().await.expect("poll_ready");
296 assert!(request.client.active_servers.is_empty());
297 }
298
299 #[tokio::test]
300 async fn test_poll_ready_with_active_servers() {
301 let client = client();
303 assert!(client.active_servers.is_empty());
304
305 let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
306 server.set_active(true);
307 client
308 .membership_storage
309 .push(server)
310 .await
311 .expect("add member");
312
313 let mut request: Request<_, NoopError> = Request::new(client);
315 request.ready().await.expect("poll_ready");
316 assert_eq!(request.client.active_servers.len(), 1);
317 }
318
319 #[tokio::test]
320 async fn test_poll_ready_error() {
321 let client = Client {
322 timeout_millis: 1000,
323 membership_storage: FailMembershipStorage {},
324 active_servers: Default::default(),
325 ts_active_servers_refresh: 0,
326 streams: Arc::default(),
327 placement: Arc::new(RwLock::new(LruCache::new(10))),
328 };
329 let mut request: Request<_, NoopError> = Request::new(client);
330 let waker = futures::task::noop_waker();
331 let mut context = std::task::Context::from_waker(&waker);
332 let poll_ready = request.poll_ready(&mut context);
333 assert_eq!(
334 poll_ready,
335 Poll::Ready(Err(RequestError::ClientError(
336 ClientError::RendevouzUnavailable
337 )))
338 );
339 }
340}