conjure_runtime/service/
http_error.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::errors::{RemoteError, ThrottledError, UnavailableError};
15use crate::service::{Layer, Service};
16use crate::{builder, Builder, ServerQos, ServiceError};
17use bytes::{BufMut, BytesMut};
18use conjure_error::{Error, ErrorType, Internal};
19use conjure_serde::json;
20use futures::StreamExt;
21use http::header::RETRY_AFTER;
22use http::{Request, Response, StatusCode};
23use http_body::Body;
24use http_body_util::BodyExt;
25use std::error;
26use std::pin::pin;
27use std::time::Duration;
28use witchcraft_log::info;
29
30/// A layer which maps raw HTTP responses into Conjure `Error`s.
31///
32/// If `server_qos` is `ServerQos::Propagate429And503ToCaller`, 429 and 503 responses will be turned into Conjure
33/// "throttle" and "service unavailable" errors respectively. Otherwise, they run into service errors. In both cases,
34/// the error's cause will be the `ThrottledError` and `UnavailableError` types respectvely. If a `Retry-After` header
35/// is present on a 429 response it will be included in the error.
36///
37/// If `service_error` is `ServiceError::PropagateToCaller`, Conjure errors returned by the server will be propagated,
38/// with the new `Error` inheriting the incoming error's code, name, instance ID, and parameters. Otherwise it will be
39/// treated as a generic internal error. In both cases, the cause will be a `RemoteError`.
40pub struct HttpErrorLayer {
41    server_qos: ServerQos,
42    service_error: ServiceError,
43}
44
45impl HttpErrorLayer {
46    pub fn new(builder: &Builder<builder::Complete>) -> HttpErrorLayer {
47        HttpErrorLayer {
48            server_qos: builder.get_server_qos(),
49            service_error: builder.get_service_error(),
50        }
51    }
52}
53
54impl<S> Layer<S> for HttpErrorLayer {
55    type Service = HttpErrorService<S>;
56
57    fn layer(self, inner: S) -> Self::Service {
58        HttpErrorService {
59            inner,
60            server_qos: self.server_qos,
61            service_error: self.service_error,
62        }
63    }
64}
65
66pub struct HttpErrorService<S> {
67    inner: S,
68    server_qos: ServerQos,
69    service_error: ServiceError,
70}
71
72impl<S, B1, B2> Service<Request<B1>> for HttpErrorService<S>
73where
74    S: Service<Request<B1>, Response = Response<B2>, Error = Error>,
75    B2: Body,
76    B2::Error: Into<Box<dyn error::Error + Sync + Send>>,
77{
78    type Response = Response<B2>;
79    type Error = Error;
80
81    async fn call(&self, req: Request<B1>) -> Result<Self::Response, Self::Error> {
82        let response = self.inner.call(req).await?;
83
84        if response.status().is_success() {
85            return Ok(response);
86        }
87
88        match response.status() {
89            StatusCode::TOO_MANY_REQUESTS => {
90                let retry_after = response
91                    .headers()
92                    .get(RETRY_AFTER)
93                    .and_then(|h| h.to_str().ok())
94                    .and_then(|s| s.parse().ok())
95                    .map(Duration::from_secs);
96                let error = ThrottledError { retry_after };
97
98                let e = match self.server_qos {
99                    ServerQos::AutomaticRetry => Error::internal_safe(error),
100                    ServerQos::Propagate429And503ToCaller => match retry_after {
101                        Some(retry_after) => Error::throttle_for_safe(error, retry_after),
102                        None => Error::throttle_safe(error),
103                    },
104                };
105
106                Err(e)
107            }
108            StatusCode::SERVICE_UNAVAILABLE => {
109                let error = UnavailableError(());
110
111                let e = match self.server_qos {
112                    ServerQos::AutomaticRetry => Error::internal_safe(error),
113                    ServerQos::Propagate429And503ToCaller => Error::unavailable_safe(error),
114                };
115
116                Err(e)
117            }
118            _ => {
119                let (parts, body) = response.into_parts();
120                let mut stream = pin!(body.into_data_stream());
121
122                let mut body = BytesMut::new();
123                while let Some(chunk) = stream.next().await {
124                    match chunk {
125                        Ok(chunk) => {
126                            body.put(chunk);
127                            if body.len() > 500 * 1024 {
128                                break;
129                            }
130                        }
131                        Err(e) => {
132                            info!("error reading response body", error: Error::internal(e));
133                            break;
134                        }
135                    }
136                }
137
138                let error = RemoteError {
139                    status: parts.status,
140                    error: json::client_from_slice(&body).ok(),
141                };
142                let log_body = error.error.is_none();
143
144                let mut error = match (&error.error, self.service_error) {
145                    (Some(e), ServiceError::PropagateToCaller) => {
146                        let e = e.clone();
147                        Error::propagated_service_safe(error, e)
148                    }
149                    (Some(e), ServiceError::WrapInNewError) => {
150                        let instance_id = e.error_instance_id();
151                        Error::service_safe(error, Internal::new().with_instance_id(instance_id))
152                    }
153                    (None, _) => Error::internal_safe(error),
154                };
155
156                if log_body {
157                    error = error.with_unsafe_param("body", String::from_utf8_lossy(&body));
158                }
159
160                Err(error)
161            }
162        }
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use super::*;
169    use crate::service;
170    use bytes::Bytes;
171    use conjure_error::{ErrorCode, ErrorKind, SerializableError};
172    use conjure_object::Uuid;
173    use http::header::CONTENT_TYPE;
174    use http_body_util::{Empty, Full};
175
176    #[tokio::test]
177    async fn success_is_ok() {
178        let service =
179            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
180                Ok(Response::builder()
181                    .status(StatusCode::OK)
182                    .body(Empty::<Bytes>::new())
183                    .unwrap())
184            }));
185
186        let request = Request::new(());
187        let out = service.call(request).await.unwrap();
188        assert_eq!(out.status(), StatusCode::OK);
189    }
190
191    #[tokio::test]
192    async fn default_throttle_handling() {
193        let service =
194            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
195                Ok(Response::builder()
196                    .status(StatusCode::TOO_MANY_REQUESTS)
197                    .header(RETRY_AFTER, "100")
198                    .body(Empty::<Bytes>::new())
199                    .unwrap())
200            }));
201
202        let request = Request::new(());
203        let error = service.call(request).await.err().unwrap();
204        match error.kind() {
205            ErrorKind::Service(_) => {}
206            _ => panic!("expected a service error"),
207        }
208        let cause = error.cause().downcast_ref::<ThrottledError>().unwrap();
209        assert_eq!(cause.retry_after, Some(Duration::from_secs(100)));
210    }
211
212    #[tokio::test]
213    async fn propagated_throttle_handling() {
214        let service = HttpErrorLayer::new(
215            &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
216        )
217        .layer(service::service_fn(|_| async move {
218            Ok(Response::builder()
219                .status(StatusCode::TOO_MANY_REQUESTS)
220                .header(RETRY_AFTER, "100")
221                .body(Empty::<Bytes>::new())
222                .unwrap())
223        }));
224
225        let request = Request::new(());
226        let error = service.call(request).await.err().unwrap();
227        let throttle = match error.kind() {
228            ErrorKind::Throttle(throttle) => throttle,
229            _ => panic!("expected a service error"),
230        };
231        assert_eq!(throttle.duration(), Some(Duration::from_secs(100)));
232    }
233
234    #[tokio::test]
235    async fn default_unavailable_handling() {
236        let service =
237            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
238                Ok(Response::builder()
239                    .status(StatusCode::SERVICE_UNAVAILABLE)
240                    .body(Empty::<Bytes>::new())
241                    .unwrap())
242            }));
243
244        let request = Request::new(());
245        let error = service.call(request).await.err().unwrap();
246        match error.kind() {
247            ErrorKind::Service(_) => {}
248            _ => panic!("expected a service error"),
249        }
250        error.cause().downcast_ref::<UnavailableError>().unwrap();
251    }
252
253    #[tokio::test]
254    async fn propagated_unavailable_handling() {
255        let service = HttpErrorLayer::new(
256            &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
257        )
258        .layer(service::service_fn(|_| async move {
259            Ok(Response::builder()
260                .status(StatusCode::SERVICE_UNAVAILABLE)
261                .body(Empty::<Bytes>::new())
262                .unwrap())
263        }));
264
265        let request = Request::new(());
266        let error = service.call(request).await.err().unwrap();
267        match error.kind() {
268            ErrorKind::Unavailable(_) => {}
269            _ => panic!("expected a service error"),
270        }
271    }
272
273    #[tokio::test]
274    async fn default_service_handling() {
275        let service_error = SerializableError::builder()
276            .error_code(ErrorCode::Conflict)
277            .error_name("Default:Conflict")
278            .error_instance_id(Uuid::nil())
279            .build();
280
281        let service = HttpErrorLayer::new(&Builder::for_test()).layer({
282            let service_error = service_error.clone();
283            service::service_fn(move |_| {
284                let json = json::to_vec(&service_error).unwrap();
285                async move {
286                    Ok(Response::builder()
287                        .status(StatusCode::CONFLICT)
288                        .header(CONTENT_TYPE, "application/json")
289                        .body(Full::new(Bytes::from(json)))
290                        .unwrap())
291                }
292            })
293        });
294
295        let request = Request::new(());
296        let error = service.call(request).await.err().unwrap();
297        let service = match error.kind() {
298            ErrorKind::Service(service) => service,
299            _ => panic!("expected a service error"),
300        };
301        assert_eq!(*service.error_code(), ErrorCode::Internal);
302        assert_eq!(
303            service.error_instance_id(),
304            service_error.error_instance_id()
305        );
306
307        let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
308        assert_eq!(remote_error.error(), Some(&service_error));
309    }
310
311    #[tokio::test]
312    async fn propagated_service_handling() {
313        let service_error = SerializableError::builder()
314            .error_code(ErrorCode::Conflict)
315            .error_name("Default:Conflict")
316            .error_instance_id(Uuid::nil())
317            .build();
318
319        let service = HttpErrorLayer::new(
320            &Builder::for_test().service_error(ServiceError::PropagateToCaller),
321        )
322        .layer({
323            let service_error = service_error.clone();
324            service::service_fn(move |_| {
325                let json = json::to_vec(&service_error).unwrap();
326                async move {
327                    Ok(Response::builder()
328                        .status(StatusCode::CONFLICT)
329                        .header(CONTENT_TYPE, "application/json")
330                        .body(Full::new(Bytes::from(json)))
331                        .unwrap())
332                }
333            })
334        });
335
336        let request = Request::new(());
337        let error = service.call(request).await.err().unwrap();
338        let service = match error.kind() {
339            ErrorKind::Service(service) => service,
340            _ => panic!("expected a service error"),
341        };
342        assert_eq!(service_error, *service);
343
344        let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
345        assert_eq!(remote_error.error(), Some(&service_error));
346    }
347}