1use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Instant,
9};
10
11use bytes::BytesMut;
12use http::Request;
13use obs_proto::obs::v1::{ObsEnvelope, ObsHttpClientCompleted};
14use pin_project_lite::pin_project;
15use tower::Service;
16
17use crate::propagator::{TraceContext, W3cPropagator, fresh_span_id, fresh_trace_id, status_class};
18
19type StatusFn = Arc<dyn Fn(u16) -> &'static str + Send + Sync>;
20type RouteFn<B> = Arc<dyn Fn(&Request<B>) -> String + Send + Sync>;
21
22pub struct ObsHttpClientLayer<B = ()> {
24 propagator: Arc<W3cPropagator>,
25 target_extractor: RouteFn<B>,
26 status_classifier: StatusFn,
27}
28
29impl<B> std::fmt::Debug for ObsHttpClientLayer<B> {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("ObsHttpClientLayer").finish_non_exhaustive()
32 }
33}
34
35impl<B> Clone for ObsHttpClientLayer<B> {
36 fn clone(&self) -> Self {
37 Self {
38 propagator: Arc::clone(&self.propagator),
39 target_extractor: Arc::clone(&self.target_extractor),
40 status_classifier: Arc::clone(&self.status_classifier),
41 }
42 }
43}
44
45impl<B> ObsHttpClientLayer<B> {
46 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 propagator: Arc::new(W3cPropagator::new()),
51 target_extractor: Arc::new(|req: &Request<B>| {
52 req.uri()
53 .host()
54 .map(ToString::to_string)
55 .unwrap_or_else(|| req.uri().to_string())
56 }),
57 status_classifier: Arc::new(|s| status_class(s)),
58 }
59 }
60
61 #[must_use]
63 pub fn with_target_extractor<F>(mut self, f: F) -> Self
64 where
65 F: Fn(&Request<B>) -> String + Send + Sync + 'static,
66 {
67 self.target_extractor = Arc::new(f);
68 self
69 }
70}
71
72impl<B> Default for ObsHttpClientLayer<B> {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl<S, B> tower::Layer<S> for ObsHttpClientLayer<B>
79where
80 S: Service<Request<B>>,
81{
82 type Service = ObsHttpClientService<S, B>;
83 fn layer(&self, inner: S) -> Self::Service {
84 ObsHttpClientService {
85 inner,
86 layer: self.clone(),
87 }
88 }
89}
90
91pub struct ObsHttpClientService<S, B> {
93 inner: S,
94 layer: ObsHttpClientLayer<B>,
95}
96
97impl<S, B> std::fmt::Debug for ObsHttpClientService<S, B> {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("ObsHttpClientService")
100 .field("layer", &self.layer)
101 .finish_non_exhaustive()
102 }
103}
104
105impl<S, B> Clone for ObsHttpClientService<S, B>
106where
107 S: Clone,
108{
109 fn clone(&self) -> Self {
110 Self {
111 inner: self.inner.clone(),
112 layer: self.layer.clone(),
113 }
114 }
115}
116
117impl<S, B, ResBody> Service<Request<B>> for ObsHttpClientService<S, B>
118where
119 S: Service<Request<B>, Response = http::Response<ResBody>>,
120 S::Future: Send + 'static,
121 B: Send + 'static,
122{
123 type Response = S::Response;
124 type Error = S::Error;
125 type Future = ObsHttpClientFuture<S::Future>;
126
127 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128 self.inner.poll_ready(cx)
129 }
130
131 fn call(&mut self, mut req: Request<B>) -> Self::Future {
132 let started = Instant::now();
133 let target = (self.layer.target_extractor)(&req);
134 let method = req.method().as_str().to_string();
135 let propagator = Arc::clone(&self.layer.propagator);
136 let status_classifier = Arc::clone(&self.layer.status_classifier);
137
138 let sampled = obs_core::scope::active_sampled().unwrap_or(true);
144 let flags = if sampled { "01" } else { "00" };
145 let (ctx, parent_span_id) = match obs_core::scope::active_correlation() {
146 Some((trace_id, parent_span)) => (
147 TraceContext {
148 trace_id,
149 span_id: fresh_span_id(),
150 flags: flags.to_string(),
151 tracestate: format!("parent={parent_span}"),
152 },
153 parent_span,
154 ),
155 None => (
156 TraceContext {
157 trace_id: fresh_trace_id(),
158 span_id: fresh_span_id(),
159 flags: flags.to_string(),
160 tracestate: String::new(),
161 },
162 String::new(),
163 ),
164 };
165 propagator.inject(req.headers_mut(), &ctx);
166 let trace_id = ctx.trace_id.clone();
167 let span_id = ctx.span_id.clone();
168 emit_client_started(&target, &method, &trace_id, &parent_span_id);
169
170 ObsHttpClientFuture {
171 inner: self.inner.call(req),
172 started: Some(started),
173 target,
174 method,
175 trace_id,
176 span_id,
177 parent_span_id,
178 status_classifier,
179 }
180 }
181}
182
183pin_project! {
184 pub struct ObsHttpClientFuture<F> {
186 #[pin]
187 inner: F,
188 started: Option<Instant>,
189 target: String,
190 method: String,
191 trace_id: String,
192 span_id: String,
193 parent_span_id: String,
194 status_classifier: StatusFn,
195 }
196}
197
198impl<F, ResBody, E> Future for ObsHttpClientFuture<F>
199where
200 F: Future<Output = Result<http::Response<ResBody>, E>>,
201{
202 type Output = F::Output;
203 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
204 let this = self.project();
205 match this.inner.poll(cx) {
206 Poll::Pending => Poll::Pending,
207 Poll::Ready(out) => {
208 let started = this.started.take().unwrap_or_else(Instant::now);
209 let elapsed_ms = started.elapsed().as_millis() as u64;
210 let class = match &out {
211 Ok(resp) => (this.status_classifier)(resp.status().as_u16()),
212 Err(_) => "err",
213 };
214 emit_client_completed(
215 this.target,
216 this.method,
217 class,
218 elapsed_ms,
219 this.trace_id,
220 this.span_id,
221 this.parent_span_id,
222 );
223 Poll::Ready(out)
224 }
225 }
226 }
227}
228
229fn encode_into<M: ::buffa::Message>(msg: &M, out: &mut Vec<u8>) {
232 let mut cache = ::buffa::SizeCache::default();
233 let size = msg.compute_size(&mut cache);
234 let mut buf = BytesMut::with_capacity(size as usize);
235 msg.write_to(&mut cache, &mut buf);
236 out.clear();
237 out.extend_from_slice(&buf);
238}
239
240fn emit_client_started(target: &str, method: &str, trace_id: &str, parent_span_id: &str) {
241 let typed = obs_proto::obs::v1::ObsHttpClientStarted {
242 method: method.to_string(),
243 host: target.to_string(),
244 __buffa_unknown_fields: Default::default(),
245 };
246 let mut env = ObsEnvelope {
247 full_name: "obs.v1.ObsHttpClientStarted".to_string(),
248 tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
249 sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
250 trace_id: trace_id.to_string(),
251 parent_span_id: parent_span_id.to_string(),
252 ..Default::default()
253 };
254 encode_into(&typed, &mut env.payload);
255 env.labels.insert("host".to_string(), target.to_string());
256 env.labels.insert("method".to_string(), method.to_string());
257 obs_core::observer().emit_envelope(env);
258}
259
260fn emit_client_completed(
261 target: &str,
262 method: &str,
263 status_class: &str,
264 latency_ms: u64,
265 trace_id: &str,
266 span_id: &str,
267 parent_span_id: &str,
268) {
269 let typed = ObsHttpClientCompleted {
275 method: method.to_string(),
276 host: target.to_string(),
277 status_class: status_class.to_string(),
278 latency_ms,
279 __buffa_unknown_fields: Default::default(),
280 };
281 let mut env = ObsEnvelope {
282 full_name: "obs.v1.ObsHttpClientCompleted".to_string(),
283 tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
284 sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
285 trace_id: trace_id.to_string(),
286 span_id: span_id.to_string(),
287 parent_span_id: parent_span_id.to_string(),
288 ..Default::default()
289 };
290 encode_into(&typed, &mut env.payload);
291 env.labels.insert("host".to_string(), target.to_string());
292 env.labels.insert("method".to_string(), method.to_string());
293 env.labels
294 .insert("status_class".to_string(), status_class.to_string());
295 obs_core::observer().emit_envelope(env);
296}