1use crate::errors::{RemoteError, ThrottledError, UnavailableError};
15use crate::service::{Layer, Service};
16use crate::{builder, Builder, ServerQos, ServiceError};
17use bytes::{BufMut, BytesMut};
18use conjure_error::{Error, ErrorType, Internal, SerializableError};
19use conjure_http::client::{ConjureRuntime, JsonEncoding};
20use futures::StreamExt;
21use http::header::RETRY_AFTER;
22use http::{Request, Response, StatusCode};
23use http_body::Body;
24use http_body_util::BodyExt;
25use serde::Deserialize;
26use std::error;
27use std::pin::pin;
28use std::sync::Arc;
29use std::time::Duration;
30use witchcraft_log::{debug, info};
31
32pub struct HttpErrorLayer {
43 server_qos: ServerQos,
44 service_error: ServiceError,
45 conjure_runtime: Arc<ConjureRuntime>,
46}
47
48impl HttpErrorLayer {
49 pub fn new(builder: &Builder<builder::Complete>) -> HttpErrorLayer {
50 HttpErrorLayer {
51 server_qos: builder.get_server_qos(),
52 service_error: builder.get_service_error(),
53 conjure_runtime: builder.get_conjure_runtime().clone(),
54 }
55 }
56}
57
58impl<S> Layer<S> for HttpErrorLayer {
59 type Service = HttpErrorService<S>;
60
61 fn layer(self, inner: S) -> Self::Service {
62 HttpErrorService {
63 inner,
64 server_qos: self.server_qos,
65 service_error: self.service_error,
66 conjure_runtime: self.conjure_runtime,
67 }
68 }
69}
70
71pub struct HttpErrorService<S> {
72 inner: S,
73 server_qos: ServerQos,
74 service_error: ServiceError,
75 conjure_runtime: Arc<ConjureRuntime>,
76}
77
78impl<S, B1, B2> Service<Request<B1>> for HttpErrorService<S>
79where
80 S: Service<Request<B1>, Response = Response<B2>, Error = Error>,
81 B2: Body,
82 B2::Error: Into<Box<dyn error::Error + Sync + Send>>,
83{
84 type Response = Response<B2>;
85 type Error = Error;
86
87 async fn call(&self, req: Request<B1>) -> Result<Self::Response, Self::Error> {
88 let response = self.inner.call(req).await?;
89
90 if response.status().is_success() {
91 return Ok(response);
92 }
93
94 match response.status() {
95 StatusCode::TOO_MANY_REQUESTS => {
96 let retry_after = response
97 .headers()
98 .get(RETRY_AFTER)
99 .and_then(|h| h.to_str().ok())
100 .and_then(|s| s.parse().ok())
101 .map(Duration::from_secs);
102 let error = ThrottledError { retry_after };
103
104 let e = match self.server_qos {
105 ServerQos::AutomaticRetry => Error::internal_safe(error),
106 ServerQos::Propagate429And503ToCaller => match retry_after {
107 Some(retry_after) => Error::throttle_for_safe(error, retry_after),
108 None => Error::throttle_safe(error),
109 },
110 };
111
112 Err(e)
113 }
114 StatusCode::SERVICE_UNAVAILABLE => {
115 let error = UnavailableError(());
116
117 let e = match self.server_qos {
118 ServerQos::AutomaticRetry => Error::internal_safe(error),
119 ServerQos::Propagate429And503ToCaller => Error::unavailable_safe(error),
120 };
121
122 Err(e)
123 }
124 _ => {
125 let (parts, body) = response.into_parts();
126 let mut stream = pin!(body.into_data_stream());
127
128 let mut body = BytesMut::new();
129 while let Some(chunk) = stream.next().await {
130 match chunk {
131 Ok(chunk) => {
132 body.put(chunk);
133 if body.len() > 500 * 1024 {
134 break;
135 }
136 }
137 Err(e) => {
138 info!("error reading response body", error: Error::internal(e));
139 break;
140 }
141 }
142 }
143
144 let error_encoding = self
145 .conjure_runtime
146 .response_body_encoding(&parts.headers)
147 .unwrap_or_else(|e| {
148 debug!("falling back to json deserializer for error body", error: e);
151 &JsonEncoding
152 });
153
154 let error = match SerializableError::deserialize(
155 error_encoding.deserializer(&body).deserializer(),
156 ) {
157 Ok(error) => Some(error),
158 Err(e) => {
159 debug!("failed to deserialize response body as a Conjure error", error: Error::internal(e));
160 None
161 }
162 };
163
164 let error = RemoteError {
165 status: parts.status,
166 error,
167 };
168 let log_body = error.error.is_none();
169
170 let mut error = match (&error.error, self.service_error) {
171 (Some(e), ServiceError::PropagateToCaller) => {
172 let e = e.clone();
173 Error::propagated_service_safe(error, e)
174 }
175 (Some(e), ServiceError::WrapInNewError) => {
176 let instance_id = e.error_instance_id();
177 Error::service_safe(error, Internal::new().with_instance_id(instance_id))
178 }
179 (None, _) => Error::internal_safe(error),
180 };
181
182 if log_body {
183 error = error.with_unsafe_param("body", String::from_utf8_lossy(&body));
184 }
185
186 Err(error)
187 }
188 }
189 }
190}
191
192#[cfg(test)]
193mod test {
194 use super::*;
195 use crate::service;
196 use bytes::Bytes;
197 use conjure_error::{ErrorCode, ErrorKind, SerializableError};
198 use conjure_object::Uuid;
199 use conjure_serde::{json, smile};
200 use http::header::CONTENT_TYPE;
201 use http_body_util::{Empty, Full};
202
203 #[tokio::test]
204 async fn success_is_ok() {
205 let service =
206 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
207 Ok(Response::builder()
208 .status(StatusCode::OK)
209 .body(Empty::<Bytes>::new())
210 .unwrap())
211 }));
212
213 let request = Request::new(());
214 let out = service.call(request).await.unwrap();
215 assert_eq!(out.status(), StatusCode::OK);
216 }
217
218 #[tokio::test]
219 async fn default_throttle_handling() {
220 let service =
221 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
222 Ok(Response::builder()
223 .status(StatusCode::TOO_MANY_REQUESTS)
224 .header(RETRY_AFTER, "100")
225 .body(Empty::<Bytes>::new())
226 .unwrap())
227 }));
228
229 let request = Request::new(());
230 let error = service.call(request).await.err().unwrap();
231 match error.kind() {
232 ErrorKind::Service(_) => {}
233 _ => panic!("expected a service error"),
234 }
235 let cause = error.cause().downcast_ref::<ThrottledError>().unwrap();
236 assert_eq!(cause.retry_after, Some(Duration::from_secs(100)));
237 }
238
239 #[tokio::test]
240 async fn propagated_throttle_handling() {
241 let service = HttpErrorLayer::new(
242 &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
243 )
244 .layer(service::service_fn(|_| async move {
245 Ok(Response::builder()
246 .status(StatusCode::TOO_MANY_REQUESTS)
247 .header(RETRY_AFTER, "100")
248 .body(Empty::<Bytes>::new())
249 .unwrap())
250 }));
251
252 let request = Request::new(());
253 let error = service.call(request).await.err().unwrap();
254 let throttle = match error.kind() {
255 ErrorKind::Throttle(throttle) => throttle,
256 _ => panic!("expected a service error"),
257 };
258 assert_eq!(throttle.duration(), Some(Duration::from_secs(100)));
259 }
260
261 #[tokio::test]
262 async fn default_unavailable_handling() {
263 let service =
264 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
265 Ok(Response::builder()
266 .status(StatusCode::SERVICE_UNAVAILABLE)
267 .body(Empty::<Bytes>::new())
268 .unwrap())
269 }));
270
271 let request = Request::new(());
272 let error = service.call(request).await.err().unwrap();
273 match error.kind() {
274 ErrorKind::Service(_) => {}
275 _ => panic!("expected a service error"),
276 }
277 error.cause().downcast_ref::<UnavailableError>().unwrap();
278 }
279
280 #[tokio::test]
281 async fn propagated_unavailable_handling() {
282 let service = HttpErrorLayer::new(
283 &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
284 )
285 .layer(service::service_fn(|_| async move {
286 Ok(Response::builder()
287 .status(StatusCode::SERVICE_UNAVAILABLE)
288 .body(Empty::<Bytes>::new())
289 .unwrap())
290 }));
291
292 let request = Request::new(());
293 let error = service.call(request).await.err().unwrap();
294 match error.kind() {
295 ErrorKind::Unavailable(_) => {}
296 _ => panic!("expected a service error"),
297 }
298 }
299
300 #[tokio::test]
301 async fn default_service_handling() {
302 let service_error = SerializableError::builder()
303 .error_code(ErrorCode::Conflict)
304 .error_name("Default:Conflict")
305 .error_instance_id(Uuid::nil())
306 .build();
307
308 let service = HttpErrorLayer::new(&Builder::for_test()).layer({
309 let service_error = service_error.clone();
310 service::service_fn(move |_| {
311 let json = json::to_vec(&service_error).unwrap();
312 async move {
313 Ok(Response::builder()
314 .status(StatusCode::CONFLICT)
315 .header(CONTENT_TYPE, "application/json")
316 .body(Full::new(Bytes::from(json)))
317 .unwrap())
318 }
319 })
320 });
321
322 let request = Request::new(());
323 let error = service.call(request).await.err().unwrap();
324 let service = match error.kind() {
325 ErrorKind::Service(service) => service,
326 _ => panic!("expected a service error"),
327 };
328 assert_eq!(*service.error_code(), ErrorCode::Internal);
329 assert_eq!(
330 service.error_instance_id(),
331 service_error.error_instance_id()
332 );
333
334 let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
335 assert_eq!(remote_error.error(), Some(&service_error));
336 }
337
338 #[tokio::test]
339 async fn propagated_service_handling() {
340 let service_error = SerializableError::builder()
341 .error_code(ErrorCode::Conflict)
342 .error_name("Default:Conflict")
343 .error_instance_id(Uuid::nil())
344 .build();
345
346 let service = HttpErrorLayer::new(
347 &Builder::for_test().service_error(ServiceError::PropagateToCaller),
348 )
349 .layer({
350 let service_error = service_error.clone();
351 service::service_fn(move |_| {
352 let json = json::to_vec(&service_error).unwrap();
353 async move {
354 Ok(Response::builder()
355 .status(StatusCode::CONFLICT)
356 .header(CONTENT_TYPE, "application/json")
357 .body(Full::new(Bytes::from(json)))
358 .unwrap())
359 }
360 })
361 });
362
363 let request = Request::new(());
364 let error = service.call(request).await.err().unwrap();
365 let service = match error.kind() {
366 ErrorKind::Service(service) => service,
367 _ => panic!("expected a service error"),
368 };
369 assert_eq!(service_error, *service);
370
371 let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
372 assert_eq!(remote_error.error(), Some(&service_error));
373 }
374
375 #[tokio::test]
376 async fn smile_service_handling() {
377 let service_error = SerializableError::builder()
378 .error_code(ErrorCode::Conflict)
379 .error_name("Default:Conflict")
380 .error_instance_id(Uuid::nil())
381 .build();
382
383 let service = HttpErrorLayer::new(&Builder::for_test()).layer({
384 let service_error = service_error.clone();
385 service::service_fn(move |_| {
386 let smile = smile::to_vec(&service_error).unwrap();
387 async move {
388 Ok(Response::builder()
389 .status(StatusCode::CONFLICT)
390 .header(CONTENT_TYPE, "application/x-jackson-smile")
391 .body(Full::new(Bytes::from(smile)))
392 .unwrap())
393 }
394 })
395 });
396
397 let request = Request::new(());
398 let error = service.call(request).await.err().unwrap();
399 let service = match error.kind() {
400 ErrorKind::Service(service) => service,
401 _ => panic!("expected a service error"),
402 };
403 assert_eq!(*service.error_code(), ErrorCode::Internal);
404 assert_eq!(
405 service.error_instance_id(),
406 service_error.error_instance_id()
407 );
408
409 let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
410 assert_eq!(remote_error.error(), Some(&service_error));
411 }
412}