hyper_client_pool/
transaction.rs

1use std::fmt;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use futures::prelude::*;
6use hyper::client::connect::Connect;
7use hyper::{self, client::Client, Request};
8use hyper::{Body, Response};
9
10use crate::deliverable::Deliverable;
11use raii_counter::Counter;
12use tracing::{span, trace, Instrument};
13
14/// The result of the transaction, a message sent to the
15/// deliverable.
16///
17/// This must be sent to the deliverable in any case
18/// in order to prevent data loss.
19#[derive(Debug)]
20pub enum DeliveryResult {
21    /// The delivery was dropped, unknown if it was sent or not.
22    Dropped,
23
24    /// Received a response from the external server.
25    Response {
26        response: Response<Body>,
27        body: Option<Vec<u8>>,
28        body_size: usize,
29        duration: Duration,
30    },
31
32    /// Failed to connect within the timeout limit.
33    Timeout { duration: Duration },
34
35    /// Sending a request through hyper encountered an error.
36    HyperError {
37        error: hyper::Error,
38        duration: Duration,
39    },
40}
41
42/// A container type for a [`hyper::Request`] as well as the deliverable
43/// which receives the result of the request.
44pub struct Transaction<D: Deliverable> {
45    deliverable: D,
46    request: Request<Body>,
47    requires_body: bool,
48    span_id: Option<tracing::Id>,
49}
50
51struct DeliverableDropGuard<D: Deliverable> {
52    deliverable: Option<D>,
53    span_id: Option<tracing::Id>,
54}
55
56impl<D: Deliverable> Drop for DeliverableDropGuard<D> {
57    fn drop(&mut self) {
58        self.deliverable.take().map(|deliverable| {
59            trace!(parent: self.span_id.clone(), "Dropping transaction..");
60            deliverable.complete(DeliveryResult::Dropped);
61        });
62    }
63}
64
65impl<D: Deliverable> DeliverableDropGuard<D> {
66    fn new(deliverable: D, span_id: Option<tracing::Id>) -> Self {
67        Self {
68            deliverable: Some(deliverable),
69            span_id,
70        }
71    }
72
73    fn take(mut self) -> D {
74        self.deliverable
75            .take()
76            .expect("take cannot be called more than once")
77    }
78}
79
80impl<D: Deliverable> fmt::Debug for Transaction<D> {
81    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82        write!(
83            f,
84            "Transaction {{ deliverable: (unknown), request: {:?} }}",
85            self.request
86        )
87    }
88}
89
90impl<D: Deliverable> Transaction<D> {
91    pub fn new(deliverable: D, request: Request<Body>, requires_body: bool) -> Transaction<D> {
92        Transaction {
93            deliverable,
94            request,
95            requires_body,
96            span_id: None,
97        }
98    }
99
100    /// Report tracing events for this transaction within the `tracing::Span`
101    /// with the provided ID. Most interesting of these events is the
102    /// debug-level `http_request` span, which tries to have fields provided in
103    /// the opentelemetry HTTP conventions document. This event will be reported
104    /// wether this method is called or not, but it will be much more useful if
105    /// you provide a parent span so that you can determine why the request is
106    /// being made.
107    ///
108    /// https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md
109    pub fn with_parent_span(mut self, span_id: impl Into<Option<tracing::Id>>) -> Self {
110        self.span_id = span_id.into();
111
112        self
113    }
114
115    pub(crate) fn spawn_request<C: 'static + Connect + Clone + Send + Sync>(
116        self,
117        client: Arc<Client<C>>,
118        timeout: Duration,
119        counter: Counter,
120    ) {
121        let Transaction {
122            deliverable,
123            request,
124            requires_body,
125            span_id,
126        } = self;
127
128        // Creating a span per transaction can be a non-trivial CPU cost so hide it behind a feature flag
129        const TRANSACTION_SPAN_LEVEL: tracing::Level = if cfg!(feature = "transaction-tracing") {
130            tracing::Level::INFO
131        } else {
132            tracing::Level::TRACE
133        };
134
135        let outer_span = span!(
136            parent: span_id,
137            TRANSACTION_SPAN_LEVEL,
138            "http_request",
139            otel.kind = "client",
140            http.url = %request.uri(),
141            http.host = request.uri().host().unwrap_or(""),
142            http.scheme = request.uri().scheme_str().unwrap_or(""),
143            http.method = request.method().as_str(),
144            http.flavor = ?request.version(),
145            http.status_code = tracing::field::Empty,
146            http.request_content_length = tracing::field::Empty,
147            outcome = tracing::field::Empty,
148        );
149
150        let deliverable_guard = DeliverableDropGuard::new(deliverable, outer_span.id());
151
152        let start_time = Instant::now();
153
154        let inner_span1 = outer_span.clone();
155        let inner_span2 = outer_span.clone();
156
157        let request_future = async move {
158            trace!("Sending request");
159            match client.request(request).await {
160                Ok(response) => {
161                    if requires_body {
162                        let (parts, mut body) = response.into_parts();
163                        let mut body_vec = Vec::new();
164
165                        while let Some(Ok(chunk)) = body.next().await {
166                            body_vec.extend_from_slice(&*chunk);
167                        }
168
169                        let body_size = body_vec.len();
170
171                        inner_span1.record("http.request_content_length", &body_size);
172
173                        Ok((
174                            Response::from_parts(parts, Body::empty()),
175                            Some(body_vec),
176                            body_size,
177                        ))
178                    } else {
179                        // Note that you must consume the body if you want keepalive
180                        // to take affect.
181                        let (parts, mut body) = response.into_parts();
182
183                        let mut body_len = 0;
184
185                        while let Some(Ok(chunk)) = body.next().await {
186                            body_len += chunk.len();
187                        }
188
189                        inner_span1.record("http.request_content_length", &body_len);
190
191                        Ok((Response::from_parts(parts, Body::empty()), None, body_len))
192                    }
193                }
194                Err(e) => Err(e),
195            }
196        };
197
198        tokio::spawn(
199            async move {
200                let result = tokio::time::timeout(timeout, request_future).await;
201                let duration = start_time.elapsed();
202
203                let delivery_result = match result {
204                    Ok(Ok((response, body, body_size))) => {
205                        inner_span2.record("http.status_code", &response.status().as_u16());
206                        inner_span2.record("outcome", &"http success");
207                        trace!(?response, ?duration, "Finished transaction",);
208                        DeliveryResult::Response {
209                            response,
210                            body,
211                            body_size,
212                            duration,
213                        }
214                    }
215
216                    Ok(Err(hyper_error)) => {
217                        inner_span2.record("outcome", &"http error");
218                        trace!(
219                            error = ?hyper_error,
220                            ?duration,
221                            "Transaction errored during delivery",
222                        );
223                        DeliveryResult::HyperError {
224                            error: hyper_error,
225                            duration,
226                        }
227                    }
228
229                    Err(_) => {
230                        inner_span2.record("outcome", &"timeout");
231                        trace!(
232                            ?duration,
233                            timeout_limit = ?timeout,
234                            "Transaction timed out",
235                        );
236                        DeliveryResult::Timeout { duration }
237                    }
238                };
239
240                deliverable_guard.take().complete(delivery_result);
241
242                drop(counter);
243            }
244            .instrument(outer_span),
245        );
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    extern crate tracing_subscriber;
252
253    use hyper;
254    use hyper::client::connect::HttpConnector;
255    use hyper::Request;
256    use hyper_tls::HttpsConnector;
257    use std::sync::atomic::{AtomicUsize, Ordering};
258    use std::sync::Arc;
259    use tokio::time::sleep;
260    use tracing::info;
261
262    use super::*;
263
264    #[derive(Debug, Clone)]
265    struct DeliveryCounter {
266        total_count: Arc<AtomicUsize>,
267        response_count: Arc<AtomicUsize>,
268        dropped_count: Arc<AtomicUsize>,
269        hyper_error_count: Arc<AtomicUsize>,
270        timeout_count: Arc<AtomicUsize>,
271    }
272
273    impl DeliveryCounter {
274        fn new() -> DeliveryCounter {
275            DeliveryCounter {
276                total_count: Arc::new(AtomicUsize::new(0)),
277                response_count: Arc::new(AtomicUsize::new(0)),
278                dropped_count: Arc::new(AtomicUsize::new(0)),
279                hyper_error_count: Arc::new(AtomicUsize::new(0)),
280                timeout_count: Arc::new(AtomicUsize::new(0)),
281            }
282        }
283
284        fn timeout_count(&self) -> usize {
285            self.timeout_count.load(Ordering::Acquire)
286        }
287
288        fn total_count(&self) -> usize {
289            self.total_count.load(Ordering::Acquire)
290        }
291
292        fn response_count(&self) -> usize {
293            self.response_count.load(Ordering::Acquire)
294        }
295    }
296
297    impl Deliverable for DeliveryCounter {
298        fn complete(self, result: DeliveryResult) {
299            match result {
300                DeliveryResult::Response { .. } => {
301                    self.response_count.fetch_add(1, Ordering::AcqRel);
302                }
303                DeliveryResult::Dropped { .. } => {
304                    self.dropped_count.fetch_add(1, Ordering::AcqRel);
305                }
306                DeliveryResult::HyperError { .. } => {
307                    self.hyper_error_count.fetch_add(1, Ordering::AcqRel);
308                }
309                DeliveryResult::Timeout { .. } => {
310                    self.timeout_count.fetch_add(1, Ordering::AcqRel);
311                }
312            }
313
314            self.total_count.fetch_add(1, Ordering::AcqRel);
315        }
316    }
317
318    const TRANSACTION_SPAWN_COUNT: usize = 200;
319    const TIMEOUT_COUNT: usize = 50;
320
321    fn make_requests<C>(client: Client<C>, counter: &DeliveryCounter)
322    where
323        C: 'static + Connect + Clone + Send + Sync,
324    {
325        let client = Arc::new(client);
326
327        for i in 0..TRANSACTION_SPAWN_COUNT {
328            let url = if i < TIMEOUT_COUNT {
329                "https://httpbin.org/delay/4"
330            } else {
331                "https://httpbin.org/delay/0"
332            };
333
334            let transaction = Transaction::new(
335                counter.clone(),
336                Request::get(url).body(Body::empty()).unwrap(),
337                false,
338            );
339            transaction.spawn_request(Arc::clone(&client), Duration::from_secs(2), Counter::new());
340        }
341    }
342
343    fn test_hyper_client() -> Client<HttpsConnector<HttpConnector>> {
344        let connector = HttpsConnector::new();
345        Client::builder().build(connector)
346    }
347
348    #[tokio::test]
349    async fn timed_out_transactions_get_sent_to_deliverable() {
350        let _ = tracing_subscriber::fmt::try_init();
351
352        info!("test start");
353
354        let counter = DeliveryCounter::new();
355
356        let client = test_hyper_client();
357
358        make_requests(client, &counter);
359        sleep(Duration::from_secs(3)).await;
360
361        assert_ne!(counter.response_count(), TRANSACTION_SPAWN_COUNT);
362        assert_eq!(counter.timeout_count(), TIMEOUT_COUNT);
363        assert_eq!(counter.total_count(), TRANSACTION_SPAWN_COUNT);
364    }
365}