1use core::time::Duration;
2
3pub use async_trait::async_trait;
4pub use http_api_client_endpoint::{http, Body, Request, Response};
5use http_api_client_endpoint::{Endpoint, RetryableEndpoint, RetryableEndpointRetry};
6
7#[async_trait]
8pub trait Client {
9 type RespondError: std::error::Error + Send + Sync + 'static;
10
11 async fn respond(&self, request: Request<Body>) -> Result<Response<Body>, Self::RespondError>;
12
13 async fn respond_endpoint<EP>(
14 &self,
15 endpoint: &EP,
16 ) -> Result<
17 EP::ParseResponseOutput,
18 ClientRespondEndpointError<
19 Self::RespondError,
20 EP::RenderRequestError,
21 EP::ParseResponseError,
22 >,
23 >
24 where
25 EP: Endpoint + Send + Sync,
26 {
27 self.respond_endpoint_with_callback(endpoint, |req| req, |_| {})
28 .await
29 }
30
31 async fn respond_endpoint_with_callback<EP, PreRCB, PostRCB>(
32 &self,
33 endpoint: &EP,
34 pre_request_callback: PreRCB,
35 post_request_callback: PostRCB,
36 ) -> Result<
37 EP::ParseResponseOutput,
38 ClientRespondEndpointError<
39 Self::RespondError,
40 EP::RenderRequestError,
41 EP::ParseResponseError,
42 >,
43 >
44 where
45 EP: Endpoint + Send + Sync,
46 PreRCB: FnMut(Request<Body>) -> Request<Body> + Send,
47 PostRCB: FnMut(&Response<Body>) + Send,
48 {
49 self.respond_dyn_endpoint_with_callback(
50 endpoint,
51 pre_request_callback,
52 post_request_callback,
53 )
54 .await
55 }
56
57 async fn respond_dyn_endpoint<RRE, PRO, PRE>(
58 &self,
59 endpoint: &(dyn Endpoint<
60 RenderRequestError = RRE,
61 ParseResponseOutput = PRO,
62 ParseResponseError = PRE,
63 > + Send
64 + Sync),
65 ) -> Result<PRO, ClientRespondEndpointError<Self::RespondError, RRE, PRE>>
66 where
67 RRE: std::error::Error + Send + Sync + 'static,
68 PRE: std::error::Error + Send + Sync + 'static,
69 {
70 self.respond_dyn_endpoint_with_callback(endpoint, |req| req, |_| {})
71 .await
72 }
73
74 async fn respond_dyn_endpoint_with_callback<RRE, PRO, PRE, PreRCB, PostRCB>(
75 &self,
76 endpoint: &(dyn Endpoint<
77 RenderRequestError = RRE,
78 ParseResponseOutput = PRO,
79 ParseResponseError = PRE,
80 > + Send
81 + Sync),
82 mut pre_request_callback: PreRCB,
83 mut post_request_callback: PostRCB,
84 ) -> Result<PRO, ClientRespondEndpointError<Self::RespondError, RRE, PRE>>
85 where
86 RRE: std::error::Error + Send + Sync + 'static,
87 PRE: std::error::Error + Send + Sync + 'static,
88 PreRCB: FnMut(Request<Body>) -> Request<Body> + Send,
89 PostRCB: FnMut(&Response<Body>) + Send,
90 {
91 let request = endpoint
92 .render_request()
93 .map_err(ClientRespondEndpointError::EndpointRenderRequestFailed)?;
94
95 let request = pre_request_callback(request);
96
97 let response = self
98 .respond(request)
99 .await
100 .map_err(ClientRespondEndpointError::RespondFailed)?;
101
102 post_request_callback(&response);
103
104 endpoint
105 .parse_response(response)
106 .map_err(ClientRespondEndpointError::EndpointParseResponseFailed)
107 }
108}
109
110#[async_trait]
111pub trait RetryableClient: Client {
112 async fn sleep(&self, dur: Duration);
113
114 async fn respond_endpoint_until_done<EP>(
115 &self,
116 endpoint: &EP,
117 ) -> Result<
118 EP::ParseResponseOutput,
119 RetryableClientRespondEndpointUntilDoneError<
120 Self::RespondError,
121 EP::RenderRequestError,
122 EP::ParseResponseError,
123 >,
124 >
125 where
126 EP: RetryableEndpoint + Send + Sync,
127 {
128 self.respond_endpoint_until_done_with_callback(endpoint, |req, _| req, |_, _| {})
129 .await
130 }
131
132 async fn respond_endpoint_until_done_with_callback<EP, PreRCB, PostRCB>(
133 &self,
134 endpoint: &EP,
135 mut pre_request_callback: PreRCB,
136 mut post_request_callback: PostRCB,
137 ) -> Result<
138 EP::ParseResponseOutput,
139 RetryableClientRespondEndpointUntilDoneError<
140 Self::RespondError,
141 EP::RenderRequestError,
142 EP::ParseResponseError,
143 >,
144 >
145 where
146 EP: RetryableEndpoint + Send + Sync,
147 PreRCB: FnMut(Request<Body>, Option<&RetryableEndpointRetry<EP::RetryReason>>) -> Request<Body>
148 + Send,
149 PostRCB: FnMut(&Response<Body>, Option<&RetryableEndpointRetry<EP::RetryReason>>) + Send,
150 {
151 let mut retry = None;
152
153 loop {
154 let request = endpoint.render_request(retry.as_ref()).map_err(
155 RetryableClientRespondEndpointUntilDoneError::EndpointRenderRequestFailed,
156 )?;
157
158 let request = pre_request_callback(request, retry.as_ref());
159
160 let response = self
161 .respond(request)
162 .await
163 .map_err(RetryableClientRespondEndpointUntilDoneError::RespondFailed)?;
164
165 post_request_callback(&response, retry.as_ref());
166
167 match endpoint.parse_response(response, retry.as_ref()).map_err(
168 RetryableClientRespondEndpointUntilDoneError::EndpointParseResponseFailed,
169 )? {
170 Ok(output) => return Ok(output),
171 Err(reason) => {
172 let x = retry.get_or_insert(RetryableEndpointRetry::new(0, reason.clone()));
173 x.count += 1;
174 x.reason = reason;
175 }
176 }
177
178 if let Some(retry) = &retry {
180 if retry.count >= endpoint.max_retry_count() {
181 return Err(RetryableClientRespondEndpointUntilDoneError::ReachedMaxRetries);
182 }
183
184 self.sleep(endpoint.next_retry_in(retry)).await;
185 }
186 }
187 }
188}
189
190#[derive(Debug)]
192pub enum ClientRespondEndpointError<RE, EPRRE, EPPRE>
193where
194 RE: std::error::Error + Send + Sync + 'static,
195 EPRRE: std::error::Error + Send + Sync + 'static,
196 EPPRE: std::error::Error + Send + Sync + 'static,
197{
198 RespondFailed(RE),
199 EndpointRenderRequestFailed(EPRRE),
200 EndpointParseResponseFailed(EPPRE),
201}
202impl<RE, EPRRE, EPPRE> core::fmt::Display for ClientRespondEndpointError<RE, EPRRE, EPPRE>
203where
204 RE: std::error::Error + Send + Sync + 'static,
205 EPRRE: std::error::Error + Send + Sync + 'static,
206 EPPRE: std::error::Error + Send + Sync + 'static,
207{
208 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
209 write!(f, "{:?}", self)
210 }
211}
212impl<RE, EPRRE, EPPRE> std::error::Error for ClientRespondEndpointError<RE, EPRRE, EPPRE>
213where
214 RE: std::error::Error + Send + Sync + 'static,
215 EPRRE: std::error::Error + Send + Sync + 'static,
216 EPPRE: std::error::Error + Send + Sync + 'static,
217{
218}
219
220#[derive(Debug)]
222pub enum RetryableClientRespondEndpointUntilDoneError<RE, EPRRE, EPPRE>
223where
224 RE: std::error::Error + Send + Sync + 'static,
225 EPRRE: std::error::Error + Send + Sync + 'static,
226 EPPRE: std::error::Error + Send + Sync + 'static,
227{
228 RespondFailed(RE),
229 EndpointRenderRequestFailed(EPRRE),
230 EndpointParseResponseFailed(EPPRE),
231 ReachedMaxRetries,
232}
233impl<RE, EPRRE, EPPRE> core::fmt::Display
234 for RetryableClientRespondEndpointUntilDoneError<RE, EPRRE, EPPRE>
235where
236 RE: std::error::Error + Send + Sync + 'static,
237 EPRRE: std::error::Error + Send + Sync + 'static,
238 EPPRE: std::error::Error + Send + Sync + 'static,
239{
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 write!(f, "{:?}", self)
242 }
243}
244impl<RE, EPRRE, EPPRE> std::error::Error
245 for RetryableClientRespondEndpointUntilDoneError<RE, EPRRE, EPPRE>
246where
247 RE: std::error::Error + Send + Sync + 'static,
248 EPRRE: std::error::Error + Send + Sync + 'static,
249 EPPRE: std::error::Error + Send + Sync + 'static,
250{
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 use std::{collections::HashMap, io::Error as IoError, panic};
258
259 use futures_executor::block_on;
260
261 #[derive(Clone)]
262 struct MyEndpoint;
263 impl Endpoint for MyEndpoint {
264 type RenderRequestError = IoError;
265
266 type ParseResponseOutput = ();
267 type ParseResponseError = IoError;
268
269 fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
270 unimplemented!()
271 }
272
273 fn parse_response(
274 &self,
275 _response: Response<Body>,
276 ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
277 unreachable!()
278 }
279 }
280
281 #[derive(Clone)]
282 struct MyClient;
283 #[async_trait]
284 impl Client for MyClient {
285 type RespondError = IoError;
286
287 async fn respond(
288 &self,
289 _request: Request<Body>,
290 ) -> Result<Response<Body>, Self::RespondError> {
291 unreachable!()
292 }
293 }
294
295 #[test]
296 fn test_respond_dyn_endpoint() {
297 let prev_hook = panic::take_hook();
298 panic::set_hook(Box::new(|_| {}));
299 let ret = panic::catch_unwind(|| {
300 block_on(async move {
301 let mut map: HashMap<
302 &'static str,
303 Box<
304 dyn Endpoint<
305 RenderRequestError = IoError,
306 ParseResponseOutput = (),
307 ParseResponseError = IoError,
308 > + Send
309 + Sync,
310 >,
311 > = HashMap::new();
312
313 let key = "x";
314 map.insert(key, Box::new(MyEndpoint));
315 let client = MyClient;
316
317 let endpoint = map.get(key).unwrap();
318 client.respond_dyn_endpoint(endpoint.as_ref()).await
319 })
320 });
321 panic::set_hook(prev_hook);
322
323 match ret {
324 Err(err) => {
325 if let Some(s) = err.downcast_ref::<&str>() {
326 assert!(s.contains("not implemented"))
327 } else {
328 panic!("{:?}", err)
329 }
330 }
331 err => panic!("{:?}", err),
332 }
333 }
334}