1use std::marker::PhantomData;
2use std::task::Poll;
3use std::time::Duration;
4
5use futures::future::BoxFuture;
6use futures::{pin_mut, FutureExt, SinkExt, StreamExt};
7use log::{debug, error, info, warn};
8use serde::de::DeserializeOwned;
9use tower::Service as TowerService;
10
11use crate::cluster::storage::MembershipStorage;
12use crate::protocol::RequestError;
13use crate::protocol::{ClientError, RequestEnvelope, ResponseEnvelope, ResponseError};
14
15use super::Client;
16
17#[derive(Clone)]
25pub struct Request<'a, S, E> {
26 client: Client<S>,
27 _lifetime_marker: PhantomData<&'a ()>,
28 _error_marker: PhantomData<E>,
29}
30
31impl<'a, S, E> Request<'a, S, E>
32where
33 S: MembershipStorage,
34{
35 pub fn new(client: Client<S>) -> Self {
36 Request {
37 client,
38 _lifetime_marker: PhantomData,
39 _error_marker: PhantomData,
40 }
41 }
42}
43
44impl<'a, S, E: std::error::Error + DeserializeOwned> TowerService<RequestEnvelope>
45 for Request<'a, S, E>
46where
47 S: MembershipStorage + 'static, {
49 type Response = Vec<u8>;
50 type Error = RequestError<E>;
51 type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
52
53 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
55 let fetch_active_servers = self.client.fetch_active_servers();
56 pin_mut!(fetch_active_servers);
57 fetch_active_servers
58 .poll_unpin(cx)
59 .map_ok(|_| ())
60 .map_err(|e| e.into())
61 }
62
63 fn call(&mut self, req: RequestEnvelope) -> Self::Future {
65 let mut client = self.client.clone();
66 Box::pin(async move {
67 let mut stream = client
68 .service_object_stream(&req.handler_type, &req.handler_id)
69 .await?;
70
71 let ser_request = bincode::serialize(&req)
72 .map_err(|e| ClientError::SeralizationError(e.to_string()))?;
73
74 stream.send(ser_request.into()).await?;
75 match stream.next().await {
76 Some(Ok(frame)) => {
77 let message: ResponseEnvelope = bincode::deserialize(&frame)
78 .map_err(|e| ClientError::DeseralizationError(e.to_string()))?;
79 Ok(message.body?)
80 }
81 Some(Err(e)) => Err(RequestError::ClientError(ClientError::IoError(
82 e.to_string(),
83 ))),
84 None => Err(RequestError::ClientError(ClientError::Disconnect)),
87 }
88 })
89 }
90}
91
92pub struct RequestRedirect<'a, S, E>
97where
98 S: MembershipStorage,
99 Request<'a, S, E>: Clone,
100{
101 inner: Request<'a, S, E>,
102}
103
104impl<'a, S, E> RequestRedirect<'a, S, E>
105where
106 S: MembershipStorage,
107 Request<'a, S, E>: Clone,
108{
109 pub fn new(inner: Request<'a, S, E>) -> Self {
110 RequestRedirect { inner }
111 }
112}
113
114impl<'a, S, E> TowerService<RequestEnvelope> for RequestRedirect<'a, S, E>
115where
116 E: std::error::Error + DeserializeOwned + Send + Sync + 'a,
117 S: MembershipStorage + 'static,
118 Request<'a, S, E>: Clone,
119{
120 type Response = Vec<u8>;
121 type Error = RequestError<E>;
122 type Future = BoxFuture<'a, Result<Self::Response, Self::Error>>;
123
124 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
125 self.inner.poll_ready(cx)
126 }
127
128 fn call(&mut self, req: RequestEnvelope) -> Self::Future {
135 let handler_type = req.handler_type.clone();
138 let handler_id = req.handler_id.clone();
139 let request = req.clone();
140 let mut inner_service = self.inner.clone();
141
142 let retry_min_duration = Duration::from_nanos(1_000); let retry_max_duration = Duration::from_secs(2);
145 let max_retries = Some(20);
146 let mut retry_count = 0;
149 let mut retry_duration = retry_min_duration;
150 Box::pin(async move {
151 loop {
152 let response = inner_service.call(request.clone()).await;
154 match response {
155 Err(RequestError::ResponseError(ResponseError::Redirect(to))) => {
159 info!("Redirect to {}", to);
162 inner_service
163 .client
164 .placement
165 .write()
166 .map_err(|_| ClientError::PlacementLock)?
167 .put((handler_type.clone(), handler_id.clone()), to);
168 }
169 Err(e)
173 if matches!(
174 e,
175 RequestError::ResponseError(ResponseError::DeallocateServiceObject,)
176 | RequestError::ClientError(ClientError::Disconnect)
177 | RequestError::ClientError(ClientError::ServerNotAvailable(_))
178 | RequestError::ClientError(ClientError::IoError(_))
179 ) =>
180 {
181 if let Some(max_retries) = max_retries {
183 if retry_count > max_retries {
184 error!("Max retries ({}) reached: {:?}", max_retries, e);
185 return Err(e);
186 }
187 }
188 warn!("{:?}", e);
189
190 debug!("Retry in {:?}", retry_duration);
192 tokio::time::sleep(retry_duration).await;
193 retry_count += 1;
194 retry_duration *= 2;
195 retry_duration =
196 retry_duration.clamp(retry_min_duration, retry_max_duration);
197
198 warn!("Refresh the list of servers");
201 inner_service.client.ts_active_servers_refresh = 0; inner_service
204 .client
205 .placement
206 .write()
207 .map_err(|_| ClientError::PlacementLock)?
208 .pop(&(handler_type.clone(), handler_id.clone()));
209 }
210 Err(e) => {
211 if let RequestError::ResponseError(ResponseError::ApplicationError(_)) = e {
214 error!("Uncaught error ResponseError::ApplicationError(...)");
215 } else {
216 error!("Uncaught error {:#?}", e);
217 }
218 return Err(e);
219 }
220 rest => return rest,
222 }
223 }
224 })
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use async_trait::async_trait;
231 use chrono::{DateTime, Utc};
232 use lru::LruCache;
233 use serde::{Deserialize, Serialize};
234 use std::{
235 num::NonZero,
236 sync::{Arc, RwLock},
237 };
238 use thiserror::Error;
239 use tower::ServiceExt;
240
241 use super::*;
242 use crate::{
243 cluster::storage::{
244 local::LocalStorage, Member, MembershipResult, MembershipStorage, MembershipUnitResult,
245 },
246 errors::MembershipError,
247 };
248
249 #[derive(Error, Debug, Serialize, Deserialize, PartialEq)]
250 enum NoopError {
251 #[error("No-op")]
252 Noop,
253 }
254
255 #[derive(Clone, Default, Debug)]
256 struct FailMembershipStorage {}
257
258 #[async_trait]
259 impl MembershipStorage for FailMembershipStorage {
260 async fn push(&self, _: Member) -> MembershipUnitResult {
261 Ok(())
262 }
263 async fn remove(&self, _: &str, _: &str) -> MembershipUnitResult {
264 Ok(())
265 }
266 async fn set_is_active(&self, _: &str, _: &str, _: bool) -> MembershipUnitResult {
267 Ok(())
268 }
269 async fn members(&self) -> MembershipResult<Vec<Member>> {
270 Err(MembershipError::Unknown("".to_string()))
271 }
272 async fn notify_failure(&self, _: &str, _: &str) -> MembershipUnitResult {
273 Ok(())
274 }
275 async fn member_failures(&self, _: &str, _: &str) -> MembershipResult<Vec<DateTime<Utc>>> {
276 Ok(vec![])
277 }
278 }
279
280 fn client() -> Client<LocalStorage> {
281 Client {
282 timeout_millis: 1000,
283 membership_storage: LocalStorage::default(),
284 active_servers: Default::default(),
285 ts_active_servers_refresh: 0,
286 streams: Arc::default(),
287 placement: Arc::new(RwLock::new(LruCache::new(NonZero::new(10).unwrap()))),
288 }
289 }
290
291 #[tokio::test]
292 async fn test_poll_ready_no_active_server() {
293 let client = client();
294 let mut request: Request<_, NoopError> = Request::new(client);
295 request.ready().await.expect("poll_ready");
296 assert!(request.client.active_servers.is_empty());
297 }
298
299 #[tokio::test]
300 async fn test_poll_ready_with_active_servers() {
301 let client = client();
303 assert!(client.active_servers.is_empty());
304
305 let mut server = Member::new("0.0.0.0".to_string(), "1234".to_string());
306 server.set_active(true);
307 client
308 .membership_storage
309 .push(server)
310 .await
311 .expect("add member");
312
313 let mut request: Request<_, NoopError> = Request::new(client);
315 request.ready().await.expect("poll_ready");
316 assert_eq!(request.client.active_servers.len(), 1);
317 }
318
319 #[tokio::test]
320 async fn test_poll_ready_error() {
321 let client = Client {
322 timeout_millis: 1000,
323 membership_storage: FailMembershipStorage {},
324 active_servers: Default::default(),
325 ts_active_servers_refresh: 0,
326 streams: Arc::default(),
327 placement: Arc::new(RwLock::new(LruCache::new(NonZero::new(10).unwrap()))),
328 };
329 let mut request: Request<_, NoopError> = Request::new(client);
330 let waker = futures::task::noop_waker();
331 let mut context = std::task::Context::from_waker(&waker);
332 let poll_ready = request.poll_ready(&mut context);
333 assert_eq!(
334 poll_ready,
335 Poll::Ready(Err(RequestError::ClientError(
336 ClientError::RendevouzUnavailable
337 )))
338 );
339 }
340}