Skip to main content

conjure_runtime/service/
http_error.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::errors::{RemoteError, ThrottledError, UnavailableError};
15use crate::service::{Layer, Service};
16use crate::{builder, Builder, ServerQos, ServiceError};
17use bytes::{BufMut, BytesMut};
18use conjure_error::{Error, ErrorType, Internal, SerializableError};
19use conjure_http::client::{ConjureRuntime, JsonEncoding};
20use futures::StreamExt;
21use http::header::RETRY_AFTER;
22use http::{Request, Response, StatusCode};
23use http_body::Body;
24use http_body_util::BodyExt;
25use serde::Deserialize;
26use std::error;
27use std::pin::pin;
28use std::sync::Arc;
29use std::time::Duration;
30use witchcraft_log::{debug, info};
31
32/// A layer which maps raw HTTP responses into Conjure `Error`s.
33///
34/// If `server_qos` is `ServerQos::Propagate429And503ToCaller`, 429 and 503 responses will be turned into Conjure
35/// "throttle" and "service unavailable" errors respectively. Otherwise, they run into service errors. In both cases,
36/// the error's cause will be the `ThrottledError` and `UnavailableError` types respectvely. If a `Retry-After` header
37/// is present on a 429 response it will be included in the error.
38///
39/// If `service_error` is `ServiceError::PropagateToCaller`, Conjure errors returned by the server will be propagated,
40/// with the new `Error` inheriting the incoming error's code, name, instance ID, and parameters. Otherwise it will be
41/// treated as a generic internal error. In both cases, the cause will be a `RemoteError`.
42pub struct HttpErrorLayer {
43    server_qos: ServerQos,
44    service_error: ServiceError,
45    conjure_runtime: Arc<ConjureRuntime>,
46}
47
48impl HttpErrorLayer {
49    pub fn new(builder: &Builder<builder::Complete>) -> HttpErrorLayer {
50        HttpErrorLayer {
51            server_qos: builder.get_server_qos(),
52            service_error: builder.get_service_error(),
53            conjure_runtime: builder.get_conjure_runtime().clone(),
54        }
55    }
56}
57
58impl<S> Layer<S> for HttpErrorLayer {
59    type Service = HttpErrorService<S>;
60
61    fn layer(self, inner: S) -> Self::Service {
62        HttpErrorService {
63            inner,
64            server_qos: self.server_qos,
65            service_error: self.service_error,
66            conjure_runtime: self.conjure_runtime,
67        }
68    }
69}
70
71pub struct HttpErrorService<S> {
72    inner: S,
73    server_qos: ServerQos,
74    service_error: ServiceError,
75    conjure_runtime: Arc<ConjureRuntime>,
76}
77
78impl<S, B1, B2> Service<Request<B1>> for HttpErrorService<S>
79where
80    S: Service<Request<B1>, Response = Response<B2>, Error = Error>,
81    B2: Body,
82    B2::Error: Into<Box<dyn error::Error + Sync + Send>>,
83{
84    type Response = Response<B2>;
85    type Error = Error;
86
87    async fn call(&self, req: Request<B1>) -> Result<Self::Response, Self::Error> {
88        let response = self.inner.call(req).await?;
89
90        if response.status().is_success() {
91            return Ok(response);
92        }
93
94        match response.status() {
95            StatusCode::TOO_MANY_REQUESTS => {
96                let retry_after = response
97                    .headers()
98                    .get(RETRY_AFTER)
99                    .and_then(|h| h.to_str().ok())
100                    .and_then(|s| s.parse().ok())
101                    .map(Duration::from_secs);
102                let error = ThrottledError { retry_after };
103
104                let e = match self.server_qos {
105                    ServerQos::AutomaticRetry => Error::internal_safe(error),
106                    ServerQos::Propagate429And503ToCaller => match retry_after {
107                        Some(retry_after) => Error::throttle_for_safe(error, retry_after),
108                        None => Error::throttle_safe(error),
109                    },
110                };
111
112                Err(e)
113            }
114            StatusCode::SERVICE_UNAVAILABLE => {
115                let error = UnavailableError(());
116
117                let e = match self.server_qos {
118                    ServerQos::AutomaticRetry => Error::internal_safe(error),
119                    ServerQos::Propagate429And503ToCaller => Error::unavailable_safe(error),
120                };
121
122                Err(e)
123            }
124            _ => {
125                let (parts, body) = response.into_parts();
126                let mut stream = pin!(body.into_data_stream());
127
128                let mut body = BytesMut::new();
129                while let Some(chunk) = stream.next().await {
130                    match chunk {
131                        Ok(chunk) => {
132                            body.put(chunk);
133                            if body.len() > 500 * 1024 {
134                                break;
135                            }
136                        }
137                        Err(e) => {
138                            info!("error reading response body", error: Error::internal(e));
139                            break;
140                        }
141                    }
142                }
143
144                let error_encoding = self
145                    .conjure_runtime
146                    .response_body_encoding(&parts.headers)
147                    .unwrap_or_else(|e| {
148                        // Conjure servers always use JSON for errors, so we can run into this case if e.g. someone
149                        // configures a Smile-only runtime.
150                        debug!("falling back to json deserializer for error body", error: e);
151                        &JsonEncoding
152                    });
153
154                let error = match SerializableError::deserialize(
155                    error_encoding.deserializer(&body).deserializer(),
156                ) {
157                    Ok(error) => Some(error),
158                    Err(e) => {
159                        debug!("failed to deserialize response body as a Conjure error", error: Error::internal(e));
160                        None
161                    }
162                };
163
164                let error = RemoteError {
165                    status: parts.status,
166                    error,
167                };
168                let log_body = error.error.is_none();
169
170                let mut error = match (&error.error, self.service_error) {
171                    (Some(e), ServiceError::PropagateToCaller) => {
172                        let e = e.clone();
173                        Error::propagated_service_safe(error, e)
174                    }
175                    (Some(e), ServiceError::WrapInNewError) => {
176                        let instance_id = e.error_instance_id();
177                        Error::service_safe(error, Internal::new().with_instance_id(instance_id))
178                    }
179                    (None, _) => Error::internal_safe(error),
180                };
181
182                if log_body {
183                    error = error.with_unsafe_param("body", String::from_utf8_lossy(&body));
184                }
185
186                Err(error)
187            }
188        }
189    }
190}
191
192#[cfg(test)]
193mod test {
194    use super::*;
195    use crate::service;
196    use bytes::Bytes;
197    use conjure_error::{ErrorCode, ErrorKind, SerializableError};
198    use conjure_object::Uuid;
199    use conjure_serde::{json, smile};
200    use http::header::CONTENT_TYPE;
201    use http_body_util::{Empty, Full};
202
203    #[tokio::test]
204    async fn success_is_ok() {
205        let service =
206            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
207                Ok(Response::builder()
208                    .status(StatusCode::OK)
209                    .body(Empty::<Bytes>::new())
210                    .unwrap())
211            }));
212
213        let request = Request::new(());
214        let out = service.call(request).await.unwrap();
215        assert_eq!(out.status(), StatusCode::OK);
216    }
217
218    #[tokio::test]
219    async fn default_throttle_handling() {
220        let service =
221            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
222                Ok(Response::builder()
223                    .status(StatusCode::TOO_MANY_REQUESTS)
224                    .header(RETRY_AFTER, "100")
225                    .body(Empty::<Bytes>::new())
226                    .unwrap())
227            }));
228
229        let request = Request::new(());
230        let error = service.call(request).await.err().unwrap();
231        match error.kind() {
232            ErrorKind::Service(_) => {}
233            _ => panic!("expected a service error"),
234        }
235        let cause = error.cause().downcast_ref::<ThrottledError>().unwrap();
236        assert_eq!(cause.retry_after, Some(Duration::from_secs(100)));
237    }
238
239    #[tokio::test]
240    async fn propagated_throttle_handling() {
241        let service = HttpErrorLayer::new(
242            &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
243        )
244        .layer(service::service_fn(|_| async move {
245            Ok(Response::builder()
246                .status(StatusCode::TOO_MANY_REQUESTS)
247                .header(RETRY_AFTER, "100")
248                .body(Empty::<Bytes>::new())
249                .unwrap())
250        }));
251
252        let request = Request::new(());
253        let error = service.call(request).await.err().unwrap();
254        let throttle = match error.kind() {
255            ErrorKind::Throttle(throttle) => throttle,
256            _ => panic!("expected a service error"),
257        };
258        assert_eq!(throttle.duration(), Some(Duration::from_secs(100)));
259    }
260
261    #[tokio::test]
262    async fn default_unavailable_handling() {
263        let service =
264            HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
265                Ok(Response::builder()
266                    .status(StatusCode::SERVICE_UNAVAILABLE)
267                    .body(Empty::<Bytes>::new())
268                    .unwrap())
269            }));
270
271        let request = Request::new(());
272        let error = service.call(request).await.err().unwrap();
273        match error.kind() {
274            ErrorKind::Service(_) => {}
275            _ => panic!("expected a service error"),
276        }
277        error.cause().downcast_ref::<UnavailableError>().unwrap();
278    }
279
280    #[tokio::test]
281    async fn propagated_unavailable_handling() {
282        let service = HttpErrorLayer::new(
283            &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
284        )
285        .layer(service::service_fn(|_| async move {
286            Ok(Response::builder()
287                .status(StatusCode::SERVICE_UNAVAILABLE)
288                .body(Empty::<Bytes>::new())
289                .unwrap())
290        }));
291
292        let request = Request::new(());
293        let error = service.call(request).await.err().unwrap();
294        match error.kind() {
295            ErrorKind::Unavailable(_) => {}
296            _ => panic!("expected a service error"),
297        }
298    }
299
300    #[tokio::test]
301    async fn default_service_handling() {
302        let service_error = SerializableError::builder()
303            .error_code(ErrorCode::Conflict)
304            .error_name("Default:Conflict")
305            .error_instance_id(Uuid::nil())
306            .build();
307
308        let service = HttpErrorLayer::new(&Builder::for_test()).layer({
309            let service_error = service_error.clone();
310            service::service_fn(move |_| {
311                let json = json::to_vec(&service_error).unwrap();
312                async move {
313                    Ok(Response::builder()
314                        .status(StatusCode::CONFLICT)
315                        .header(CONTENT_TYPE, "application/json")
316                        .body(Full::new(Bytes::from(json)))
317                        .unwrap())
318                }
319            })
320        });
321
322        let request = Request::new(());
323        let error = service.call(request).await.err().unwrap();
324        let service = match error.kind() {
325            ErrorKind::Service(service) => service,
326            _ => panic!("expected a service error"),
327        };
328        assert_eq!(*service.error_code(), ErrorCode::Internal);
329        assert_eq!(
330            service.error_instance_id(),
331            service_error.error_instance_id()
332        );
333
334        let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
335        assert_eq!(remote_error.error(), Some(&service_error));
336    }
337
338    #[tokio::test]
339    async fn propagated_service_handling() {
340        let service_error = SerializableError::builder()
341            .error_code(ErrorCode::Conflict)
342            .error_name("Default:Conflict")
343            .error_instance_id(Uuid::nil())
344            .build();
345
346        let service = HttpErrorLayer::new(
347            &Builder::for_test().service_error(ServiceError::PropagateToCaller),
348        )
349        .layer({
350            let service_error = service_error.clone();
351            service::service_fn(move |_| {
352                let json = json::to_vec(&service_error).unwrap();
353                async move {
354                    Ok(Response::builder()
355                        .status(StatusCode::CONFLICT)
356                        .header(CONTENT_TYPE, "application/json")
357                        .body(Full::new(Bytes::from(json)))
358                        .unwrap())
359                }
360            })
361        });
362
363        let request = Request::new(());
364        let error = service.call(request).await.err().unwrap();
365        let service = match error.kind() {
366            ErrorKind::Service(service) => service,
367            _ => panic!("expected a service error"),
368        };
369        assert_eq!(service_error, *service);
370
371        let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
372        assert_eq!(remote_error.error(), Some(&service_error));
373    }
374
375    #[tokio::test]
376    async fn smile_service_handling() {
377        let service_error = SerializableError::builder()
378            .error_code(ErrorCode::Conflict)
379            .error_name("Default:Conflict")
380            .error_instance_id(Uuid::nil())
381            .build();
382
383        let service = HttpErrorLayer::new(&Builder::for_test()).layer({
384            let service_error = service_error.clone();
385            service::service_fn(move |_| {
386                let smile = smile::to_vec(&service_error).unwrap();
387                async move {
388                    Ok(Response::builder()
389                        .status(StatusCode::CONFLICT)
390                        .header(CONTENT_TYPE, "application/x-jackson-smile")
391                        .body(Full::new(Bytes::from(smile)))
392                        .unwrap())
393                }
394            })
395        });
396
397        let request = Request::new(());
398        let error = service.call(request).await.err().unwrap();
399        let service = match error.kind() {
400            ErrorKind::Service(service) => service,
401            _ => panic!("expected a service error"),
402        };
403        assert_eq!(*service.error_code(), ErrorCode::Internal);
404        assert_eq!(
405            service.error_instance_id(),
406            service_error.error_instance_id()
407        );
408
409        let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
410        assert_eq!(remote_error.error(), Some(&service_error));
411    }
412}