conjure_runtime/service/
http_error.rs1use crate::errors::{RemoteError, ThrottledError, UnavailableError};
15use crate::service::{Layer, Service};
16use crate::{builder, Builder, ServerQos, ServiceError};
17use bytes::{BufMut, BytesMut};
18use conjure_error::{Error, ErrorType, Internal};
19use conjure_serde::json;
20use futures::StreamExt;
21use http::header::RETRY_AFTER;
22use http::{Request, Response, StatusCode};
23use http_body::Body;
24use http_body_util::BodyExt;
25use std::error;
26use std::pin::pin;
27use std::time::Duration;
28use witchcraft_log::info;
29
30pub struct HttpErrorLayer {
41 server_qos: ServerQos,
42 service_error: ServiceError,
43}
44
45impl HttpErrorLayer {
46 pub fn new(builder: &Builder<builder::Complete>) -> HttpErrorLayer {
47 HttpErrorLayer {
48 server_qos: builder.get_server_qos(),
49 service_error: builder.get_service_error(),
50 }
51 }
52}
53
54impl<S> Layer<S> for HttpErrorLayer {
55 type Service = HttpErrorService<S>;
56
57 fn layer(self, inner: S) -> Self::Service {
58 HttpErrorService {
59 inner,
60 server_qos: self.server_qos,
61 service_error: self.service_error,
62 }
63 }
64}
65
66pub struct HttpErrorService<S> {
67 inner: S,
68 server_qos: ServerQos,
69 service_error: ServiceError,
70}
71
72impl<S, B1, B2> Service<Request<B1>> for HttpErrorService<S>
73where
74 S: Service<Request<B1>, Response = Response<B2>, Error = Error>,
75 B2: Body,
76 B2::Error: Into<Box<dyn error::Error + Sync + Send>>,
77{
78 type Response = Response<B2>;
79 type Error = Error;
80
81 async fn call(&self, req: Request<B1>) -> Result<Self::Response, Self::Error> {
82 let response = self.inner.call(req).await?;
83
84 if response.status().is_success() {
85 return Ok(response);
86 }
87
88 match response.status() {
89 StatusCode::TOO_MANY_REQUESTS => {
90 let retry_after = response
91 .headers()
92 .get(RETRY_AFTER)
93 .and_then(|h| h.to_str().ok())
94 .and_then(|s| s.parse().ok())
95 .map(Duration::from_secs);
96 let error = ThrottledError { retry_after };
97
98 let e = match self.server_qos {
99 ServerQos::AutomaticRetry => Error::internal_safe(error),
100 ServerQos::Propagate429And503ToCaller => match retry_after {
101 Some(retry_after) => Error::throttle_for_safe(error, retry_after),
102 None => Error::throttle_safe(error),
103 },
104 };
105
106 Err(e)
107 }
108 StatusCode::SERVICE_UNAVAILABLE => {
109 let error = UnavailableError(());
110
111 let e = match self.server_qos {
112 ServerQos::AutomaticRetry => Error::internal_safe(error),
113 ServerQos::Propagate429And503ToCaller => Error::unavailable_safe(error),
114 };
115
116 Err(e)
117 }
118 _ => {
119 let (parts, body) = response.into_parts();
120 let mut stream = pin!(body.into_data_stream());
121
122 let mut body = BytesMut::new();
123 while let Some(chunk) = stream.next().await {
124 match chunk {
125 Ok(chunk) => {
126 body.put(chunk);
127 if body.len() > 500 * 1024 {
128 break;
129 }
130 }
131 Err(e) => {
132 info!("error reading response body", error: Error::internal(e));
133 break;
134 }
135 }
136 }
137
138 let error = RemoteError {
139 status: parts.status,
140 error: json::client_from_slice(&body).ok(),
141 };
142 let log_body = error.error.is_none();
143
144 let mut error = match (&error.error, self.service_error) {
145 (Some(e), ServiceError::PropagateToCaller) => {
146 let e = e.clone();
147 Error::propagated_service_safe(error, e)
148 }
149 (Some(e), ServiceError::WrapInNewError) => {
150 let instance_id = e.error_instance_id();
151 Error::service_safe(error, Internal::new().with_instance_id(instance_id))
152 }
153 (None, _) => Error::internal_safe(error),
154 };
155
156 if log_body {
157 error = error.with_unsafe_param("body", String::from_utf8_lossy(&body));
158 }
159
160 Err(error)
161 }
162 }
163 }
164}
165
166#[cfg(test)]
167mod test {
168 use super::*;
169 use crate::service;
170 use bytes::Bytes;
171 use conjure_error::{ErrorCode, ErrorKind, SerializableError};
172 use conjure_object::Uuid;
173 use http::header::CONTENT_TYPE;
174 use http_body_util::{Empty, Full};
175
176 #[tokio::test]
177 async fn success_is_ok() {
178 let service =
179 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
180 Ok(Response::builder()
181 .status(StatusCode::OK)
182 .body(Empty::<Bytes>::new())
183 .unwrap())
184 }));
185
186 let request = Request::new(());
187 let out = service.call(request).await.unwrap();
188 assert_eq!(out.status(), StatusCode::OK);
189 }
190
191 #[tokio::test]
192 async fn default_throttle_handling() {
193 let service =
194 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
195 Ok(Response::builder()
196 .status(StatusCode::TOO_MANY_REQUESTS)
197 .header(RETRY_AFTER, "100")
198 .body(Empty::<Bytes>::new())
199 .unwrap())
200 }));
201
202 let request = Request::new(());
203 let error = service.call(request).await.err().unwrap();
204 match error.kind() {
205 ErrorKind::Service(_) => {}
206 _ => panic!("expected a service error"),
207 }
208 let cause = error.cause().downcast_ref::<ThrottledError>().unwrap();
209 assert_eq!(cause.retry_after, Some(Duration::from_secs(100)));
210 }
211
212 #[tokio::test]
213 async fn propagated_throttle_handling() {
214 let service = HttpErrorLayer::new(
215 &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
216 )
217 .layer(service::service_fn(|_| async move {
218 Ok(Response::builder()
219 .status(StatusCode::TOO_MANY_REQUESTS)
220 .header(RETRY_AFTER, "100")
221 .body(Empty::<Bytes>::new())
222 .unwrap())
223 }));
224
225 let request = Request::new(());
226 let error = service.call(request).await.err().unwrap();
227 let throttle = match error.kind() {
228 ErrorKind::Throttle(throttle) => throttle,
229 _ => panic!("expected a service error"),
230 };
231 assert_eq!(throttle.duration(), Some(Duration::from_secs(100)));
232 }
233
234 #[tokio::test]
235 async fn default_unavailable_handling() {
236 let service =
237 HttpErrorLayer::new(&Builder::for_test()).layer(service::service_fn(|_| async move {
238 Ok(Response::builder()
239 .status(StatusCode::SERVICE_UNAVAILABLE)
240 .body(Empty::<Bytes>::new())
241 .unwrap())
242 }));
243
244 let request = Request::new(());
245 let error = service.call(request).await.err().unwrap();
246 match error.kind() {
247 ErrorKind::Service(_) => {}
248 _ => panic!("expected a service error"),
249 }
250 error.cause().downcast_ref::<UnavailableError>().unwrap();
251 }
252
253 #[tokio::test]
254 async fn propagated_unavailable_handling() {
255 let service = HttpErrorLayer::new(
256 &Builder::for_test().server_qos(ServerQos::Propagate429And503ToCaller),
257 )
258 .layer(service::service_fn(|_| async move {
259 Ok(Response::builder()
260 .status(StatusCode::SERVICE_UNAVAILABLE)
261 .body(Empty::<Bytes>::new())
262 .unwrap())
263 }));
264
265 let request = Request::new(());
266 let error = service.call(request).await.err().unwrap();
267 match error.kind() {
268 ErrorKind::Unavailable(_) => {}
269 _ => panic!("expected a service error"),
270 }
271 }
272
273 #[tokio::test]
274 async fn default_service_handling() {
275 let service_error = SerializableError::builder()
276 .error_code(ErrorCode::Conflict)
277 .error_name("Default:Conflict")
278 .error_instance_id(Uuid::nil())
279 .build();
280
281 let service = HttpErrorLayer::new(&Builder::for_test()).layer({
282 let service_error = service_error.clone();
283 service::service_fn(move |_| {
284 let json = json::to_vec(&service_error).unwrap();
285 async move {
286 Ok(Response::builder()
287 .status(StatusCode::CONFLICT)
288 .header(CONTENT_TYPE, "application/json")
289 .body(Full::new(Bytes::from(json)))
290 .unwrap())
291 }
292 })
293 });
294
295 let request = Request::new(());
296 let error = service.call(request).await.err().unwrap();
297 let service = match error.kind() {
298 ErrorKind::Service(service) => service,
299 _ => panic!("expected a service error"),
300 };
301 assert_eq!(*service.error_code(), ErrorCode::Internal);
302 assert_eq!(
303 service.error_instance_id(),
304 service_error.error_instance_id()
305 );
306
307 let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
308 assert_eq!(remote_error.error(), Some(&service_error));
309 }
310
311 #[tokio::test]
312 async fn propagated_service_handling() {
313 let service_error = SerializableError::builder()
314 .error_code(ErrorCode::Conflict)
315 .error_name("Default:Conflict")
316 .error_instance_id(Uuid::nil())
317 .build();
318
319 let service = HttpErrorLayer::new(
320 &Builder::for_test().service_error(ServiceError::PropagateToCaller),
321 )
322 .layer({
323 let service_error = service_error.clone();
324 service::service_fn(move |_| {
325 let json = json::to_vec(&service_error).unwrap();
326 async move {
327 Ok(Response::builder()
328 .status(StatusCode::CONFLICT)
329 .header(CONTENT_TYPE, "application/json")
330 .body(Full::new(Bytes::from(json)))
331 .unwrap())
332 }
333 })
334 });
335
336 let request = Request::new(());
337 let error = service.call(request).await.err().unwrap();
338 let service = match error.kind() {
339 ErrorKind::Service(service) => service,
340 _ => panic!("expected a service error"),
341 };
342 assert_eq!(service_error, *service);
343
344 let remote_error = error.cause().downcast_ref::<RemoteError>().unwrap();
345 assert_eq!(remote_error.error(), Some(&service_error));
346 }
347}