1use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Instant,
9};
10
11use bytes::BytesMut;
12use http::Request;
13use obs_core::{Observer, ScopeFrame, ScopeFrameBuilder, with_observer_task_sync};
14use obs_proto::obs::v1::{ObsEnvelope, ObsHttpRequestCompleted, ObsHttpRequestStarted};
15use pin_project_lite::pin_project;
16use tower::Service;
17
18use crate::propagator::{TraceContext, W3cPropagator, fresh_span_id, fresh_trace_id, status_class};
19
20type RouteFn<B> = Arc<dyn Fn(&Request<B>) -> String + Send + Sync>;
21type ObserverFn<B> = Arc<dyn Fn(&Request<B>) -> Option<Arc<dyn Observer>> + Send + Sync>;
22type StatusFn = Arc<dyn Fn(u16) -> &'static str + Send + Sync>;
23
24pub struct ObsHttpLayer<B = ()> {
26 route_extractor: RouteFn<B>,
27 propagator: Arc<W3cPropagator>,
28 emit_started: bool,
29 emit_metrics: bool,
30 status_classifier: StatusFn,
31 per_request_observer: Option<ObserverFn<B>>,
32}
33
34impl<B> std::fmt::Debug for ObsHttpLayer<B> {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ObsHttpLayer")
37 .field("emit_started", &self.emit_started)
38 .field("emit_metrics", &self.emit_metrics)
39 .finish_non_exhaustive()
40 }
41}
42
43impl<B> Clone for ObsHttpLayer<B> {
44 fn clone(&self) -> Self {
45 Self {
46 route_extractor: Arc::clone(&self.route_extractor),
47 propagator: Arc::clone(&self.propagator),
48 emit_started: self.emit_started,
49 emit_metrics: self.emit_metrics,
50 status_classifier: Arc::clone(&self.status_classifier),
51 per_request_observer: self.per_request_observer.clone(),
52 }
53 }
54}
55
56impl<B> ObsHttpLayer<B> {
57 #[must_use]
60 pub fn server() -> Self {
61 Self {
62 route_extractor: Arc::new(|req: &Request<B>| req.uri().path().to_string()),
63 propagator: Arc::new(W3cPropagator::new()),
64 emit_started: false,
65 emit_metrics: true,
66 status_classifier: Arc::new(|s| status_class(s)),
67 per_request_observer: None,
68 }
69 }
70
71 #[must_use]
73 pub fn with_route_extractor<F>(mut self, f: F) -> Self
74 where
75 F: Fn(&Request<B>) -> String + Send + Sync + 'static,
76 {
77 self.route_extractor = Arc::new(f);
78 self
79 }
80
81 #[must_use]
83 pub fn with_emit_started(mut self, on: bool) -> Self {
84 self.emit_started = on;
85 self
86 }
87
88 #[must_use]
91 pub fn with_emit_metrics(mut self, on: bool) -> Self {
92 self.emit_metrics = on;
93 self
94 }
95
96 #[must_use]
98 pub fn with_propagator(mut self, p: W3cPropagator) -> Self {
99 self.propagator = Arc::new(p);
100 self
101 }
102
103 #[must_use]
105 pub fn with_status_classifier<F>(mut self, f: F) -> Self
106 where
107 F: Fn(u16) -> &'static str + Send + Sync + 'static,
108 {
109 self.status_classifier = Arc::new(f);
110 self
111 }
112
113 #[must_use]
115 pub fn with_per_request_observer<F>(mut self, f: F) -> Self
116 where
117 F: Fn(&Request<B>) -> Option<Arc<dyn Observer>> + Send + Sync + 'static,
118 {
119 self.per_request_observer = Some(Arc::new(f));
120 self
121 }
122}
123
124impl<S, B> tower::Layer<S> for ObsHttpLayer<B>
125where
126 S: Service<Request<B>>,
127 S::Future: Send,
128 B: 'static,
129{
130 type Service = ObsHttpService<S, B>;
131 fn layer(&self, inner: S) -> Self::Service {
132 ObsHttpService {
133 inner,
134 layer: self.clone(),
135 }
136 }
137}
138
139pub struct ObsHttpService<S, B> {
141 inner: S,
142 layer: ObsHttpLayer<B>,
143}
144
145impl<S, B> std::fmt::Debug for ObsHttpService<S, B> {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("ObsHttpService")
148 .field("layer", &self.layer)
149 .finish_non_exhaustive()
150 }
151}
152
153impl<S, B> Clone for ObsHttpService<S, B>
154where
155 S: Clone,
156{
157 fn clone(&self) -> Self {
158 Self {
159 inner: self.inner.clone(),
160 layer: self.layer.clone(),
161 }
162 }
163}
164
165impl<S, B, ResBody> Service<Request<B>> for ObsHttpService<S, B>
166where
167 S: Service<Request<B>, Response = http::Response<ResBody>> + Send,
168 S::Future: Send + 'static,
169 S::Error: Send + 'static,
170 B: Send + 'static,
171{
172 type Response = S::Response;
173 type Error = S::Error;
174 type Future = ObsHttpFuture<S::Future>;
175
176 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.inner.poll_ready(cx)
178 }
179
180 fn call(&mut self, req: Request<B>) -> Self::Future {
181 let started = Instant::now();
182 let cap: u16 = 256;
188 let route = obs_core::cap_external_string("route", (self.layer.route_extractor)(&req), cap);
189 let method =
190 obs_core::cap_external_string("method", req.method().as_str().to_string(), cap);
191 let propagator = Arc::clone(&self.layer.propagator);
192 let status_classifier = Arc::clone(&self.layer.status_classifier);
193 let emit_started = self.layer.emit_started;
194 let emit_metrics = self.layer.emit_metrics;
195 let observer_override = self
196 .layer
197 .per_request_observer
198 .as_ref()
199 .and_then(|f| f(&req));
200
201 let mut ctx = propagator
203 .extract(req.headers())
204 .unwrap_or_else(|| TraceContext {
205 trace_id: fresh_trace_id(),
206 span_id: fresh_span_id(),
207 flags: "01".to_string(),
208 tracestate: String::new(),
209 });
210 let parent_span = if !ctx.span_id.is_empty() && propagator.extract(req.headers()).is_some()
213 {
214 ctx.span_id.clone()
215 } else {
216 String::new()
217 };
218 ctx.span_id = fresh_span_id();
219 let trace_id = ctx.trace_id.clone();
220 let span_id = ctx.span_id.clone();
221
222 if emit_started {
223 emit_request_started(
224 &route,
225 &method,
226 &trace_id,
227 &parent_span,
228 observer_override.as_ref(),
229 );
230 }
231
232 let inner_fut = self.inner.call(req);
233
234 let scope_seed = ScopeFrameBuilder::new()
239 .context()
240 .trace_id(trace_id.clone())
241 .span_id(span_id.clone())
242 .parent_span_id(parent_span.clone())
243 .into_frame();
244
245 ObsHttpFuture {
246 inner: inner_fut,
247 started: Some(started),
248 route,
249 method,
250 trace_id,
251 span_id,
252 parent_span,
253 status_classifier,
254 emit_metrics,
255 observer_override,
256 scope_seed: Some(scope_seed),
257 }
258 }
259}
260
261pin_project! {
262 pub struct ObsHttpFuture<F> {
264 #[pin]
265 inner: F,
266 started: Option<Instant>,
267 route: String,
268 method: String,
269 trace_id: String,
270 span_id: String,
271 parent_span: String,
272 status_classifier: StatusFn,
273 emit_metrics: bool,
274 observer_override: Option<Arc<dyn Observer>>,
275 scope_seed: Option<ScopeFrame>,
279 }
280}
281
282impl<F, ResBody, E> Future for ObsHttpFuture<F>
283where
284 F: Future<Output = Result<http::Response<ResBody>, E>>,
285{
286 type Output = F::Output;
287
288 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
289 let mut this = self.project();
290 let _scope_guard = this
294 .scope_seed
295 .as_ref()
296 .map(|seed| RequestScopeGuard::push(seed.clone()));
297 let result = if let Some(o) = this.observer_override.clone() {
300 with_observer_task_sync(o, || this.inner.as_mut().poll(cx))
301 } else {
302 this.inner.as_mut().poll(cx)
303 };
304 match result {
305 Poll::Pending => Poll::Pending,
306 Poll::Ready(out) => {
307 let started = this.started.take().unwrap_or_else(Instant::now);
308 let elapsed_ms = started.elapsed().as_millis() as u64;
309 match &out {
310 Ok(resp) => {
311 if *this.emit_metrics {
312 let status = resp.status().as_u16();
313 let class = (this.status_classifier)(status);
314 emit_request_completed(
315 this.route,
316 this.method,
317 class,
318 elapsed_ms,
319 this.trace_id,
320 this.span_id,
321 this.parent_span,
322 this.observer_override.as_ref(),
323 );
324 }
325 }
326 Err(_) => {
327 if *this.emit_metrics {
328 emit_request_completed(
329 this.route,
330 this.method,
331 "err",
332 elapsed_ms,
333 this.trace_id,
334 this.span_id,
335 this.parent_span,
336 this.observer_override.as_ref(),
337 );
338 }
339 }
340 }
341 Poll::Ready(out)
342 }
343 }
344 }
345}
346
347struct RequestScopeGuard;
353
354impl RequestScopeGuard {
355 fn push(frame: ScopeFrame) -> Self {
356 obs_core::scope::push_frame_pub(frame);
357 Self
358 }
359}
360
361impl Drop for RequestScopeGuard {
362 fn drop(&mut self) {
363 let _ = obs_core::scope::pop_frame_pub();
364 }
365}
366
367fn encode_into<M: ::buffa::Message>(msg: &M, out: &mut Vec<u8>) {
369 let mut cache = ::buffa::SizeCache::default();
370 let size = msg.compute_size(&mut cache);
371 let mut buf = BytesMut::with_capacity(size as usize);
372 msg.write_to(&mut cache, &mut buf);
373 out.clear();
374 out.extend_from_slice(&buf);
375}
376
377fn emit_request_started(
378 route: &str,
379 method: &str,
380 trace_id: &str,
381 parent_span: &str,
382 observer: Option<&Arc<dyn Observer>>,
383) {
384 let typed = ObsHttpRequestStarted {
388 method: method.to_string(),
389 route: route.to_string(),
390 __buffa_unknown_fields: Default::default(),
391 };
392 let mut env = ObsEnvelope {
393 full_name: "obs.v1.ObsHttpRequestStarted".to_string(),
394 tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
395 sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
396 trace_id: trace_id.to_string(),
397 parent_span_id: parent_span.to_string(),
398 ..Default::default()
399 };
400 encode_into(&typed, &mut env.payload);
401 env.labels.insert("route".to_string(), route.to_string());
402 env.labels.insert("method".to_string(), method.to_string());
403 if let Some(o) = observer {
404 o.emit_envelope(env);
405 } else {
406 obs_core::observer().emit_envelope(env);
407 }
408}
409
410#[allow(clippy::too_many_arguments)]
411fn emit_request_completed(
412 route: &str,
413 method: &str,
414 status_class: &str,
415 latency_ms: u64,
416 trace_id: &str,
417 span_id: &str,
418 parent_span: &str,
419 observer: Option<&Arc<dyn Observer>>,
420) {
421 let typed = ObsHttpRequestCompleted {
427 method: method.to_string(),
428 route: route.to_string(),
429 status_class: status_class.to_string(),
430 latency_ms,
431 bytes_out: 0,
432 __buffa_unknown_fields: Default::default(),
433 };
434 let mut env = ObsEnvelope {
435 full_name: "obs.v1.ObsHttpRequestCompleted".to_string(),
436 tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
437 sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
438 trace_id: trace_id.to_string(),
439 span_id: span_id.to_string(),
440 parent_span_id: parent_span.to_string(),
441 ..Default::default()
442 };
443 encode_into(&typed, &mut env.payload);
444 env.labels.insert("route".to_string(), route.to_string());
447 env.labels.insert("method".to_string(), method.to_string());
448 env.labels
449 .insert("status_class".to_string(), status_class.to_string());
450 if let Some(o) = observer {
451 o.emit_envelope(env);
452 } else {
453 obs_core::observer().emit_envelope(env);
454 }
455}