allstak/integrations/
reqwest.rs1use std::time::Instant;
37
38use http::Extensions;
39use reqwest::{Request, Response};
40use reqwest_middleware::{Error, Middleware, Next, Result};
41
42use crate::hub::Hub;
43use crate::performance::Span;
44use crate::propagation;
45use crate::protocol::HttpRequestRecord;
46use crate::util;
47
48#[derive(Clone, Debug)]
53pub struct AllstakHttpMiddleware {
54 start_span: bool,
55 inject_headers: bool,
56 record_request: bool,
57 operation: &'static str,
58}
59
60impl Default for AllstakHttpMiddleware {
61 fn default() -> Self {
62 AllstakHttpMiddleware {
63 start_span: true,
64 inject_headers: true,
65 record_request: true,
66 operation: "http.client",
67 }
68 }
69}
70
71impl AllstakHttpMiddleware {
72 pub fn new() -> Self {
74 AllstakHttpMiddleware::default()
75 }
76
77 pub fn enable_span(mut self, enable: bool) -> Self {
79 self.start_span = enable;
80 self
81 }
82
83 pub fn enable_header_injection(mut self, enable: bool) -> Self {
85 self.inject_headers = enable;
86 self
87 }
88
89 pub fn enable_request_record(mut self, enable: bool) -> Self {
91 self.record_request = enable;
92 self
93 }
94}
95
96#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
97#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
98impl Middleware for AllstakHttpMiddleware {
99 async fn handle(
100 &self,
101 mut req: Request,
102 extensions: &mut Extensions,
103 next: Next<'_>,
104 ) -> Result<Response> {
105 let hub = Hub::current();
106
107 let mut ctx = hub.current_trace_context();
111 if ctx.trace_id.is_none() {
112 ctx.trace_id = Some(util::new_trace_id());
113 }
114
115 let method = req.method().to_string();
116 let url = req.url().clone();
117 let host = url.host_str().unwrap_or("").to_string();
118 let path = url.path().to_string();
119
120 let mut span = if self.start_span {
122 Some(Span::continued(
123 self.operation,
124 format!("{method} {host}{path}"),
125 ctx.trace_id.clone(),
126 ctx.parent_span_id.clone(),
127 ))
128 } else {
129 None
130 };
131 let span_id = span.as_ref().map(|s| s.span_id().to_string());
132
133 if self.inject_headers {
136 let headers = req.headers_mut();
137 propagation::inject(&ctx, span_id.as_deref(), |name, value| {
138 if let (Ok(hn), Ok(hv)) = (
139 http::HeaderName::from_bytes(name.as_bytes()),
140 http::HeaderValue::from_str(value),
141 ) {
142 headers.insert(hn, hv);
143 }
144 });
145 }
146
147 let started = Instant::now();
148 let result = next.run(req, extensions).await;
149 let duration_ms = started.elapsed().as_millis() as u64;
150
151 let status_code = match &result {
154 Ok(resp) => resp.status().as_u16(),
155 Err(Error::Reqwest(e)) => e.status().map(|s| s.as_u16()).unwrap_or(0),
156 Err(_) => 0,
157 };
158
159 if let Some(span) = span.as_mut() {
160 if status_code >= 500 || status_code == 0 {
161 span.set_status("internal_error");
162 } else {
163 span.set_status("ok");
164 }
165 span.set_tag("http.method", method.clone());
166 span.set_tag("http.host", host.clone());
167 span.set_tag("http.status_code", status_code.to_string());
168 }
169 if let Some(span) = span.take() {
171 span.finish();
172 }
173
174 if self.record_request {
175 let record = HttpRequestRecord {
176 trace_id: ctx.trace_id.clone(),
177 request_id: ctx.request_id.clone(),
178 direction: "outbound".to_string(),
179 method,
180 host,
181 path,
182 status_code,
183 duration_ms,
184 request_size: None,
185 response_size: None,
186 user_id: None,
187 error_fingerprint: None,
188 timestamp: util::now_iso8601(),
189 };
190 if let Some(client) = hub.client() {
191 client.capture_http_request(record);
192 }
193 }
194
195 result
196 }
197}
198
199pub fn instrumented_client() -> reqwest_middleware::ClientWithMiddleware {
202 instrumented_client_from(reqwest::Client::new())
203}
204
205pub fn instrumented_client_from(
207 client: reqwest::Client,
208) -> reqwest_middleware::ClientWithMiddleware {
209 reqwest_middleware::ClientBuilder::new(client)
210 .with(AllstakHttpMiddleware::new())
211 .build()
212}