Skip to main content

obs_tower/
server.rs

1//! Server-side `tower::Layer`. Spec 40 § 1.
2
3use std::{
4    future::Future,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Instant,
9};
10
11use bytes::BytesMut;
12use http::Request;
13use obs_core::{Observer, ScopeFrame, ScopeFrameBuilder, with_observer_task_sync};
14use obs_proto::obs::v1::{ObsEnvelope, ObsHttpRequestCompleted, ObsHttpRequestStarted};
15use pin_project_lite::pin_project;
16use tower::Service;
17
18use crate::propagator::{TraceContext, W3cPropagator, fresh_span_id, fresh_trace_id, status_class};
19
20type RouteFn<B> = Arc<dyn Fn(&Request<B>) -> String + Send + Sync>;
21type ObserverFn<B> = Arc<dyn Fn(&Request<B>) -> Option<Arc<dyn Observer>> + Send + Sync>;
22type StatusFn = Arc<dyn Fn(u16) -> &'static str + Send + Sync>;
23
24/// HTTP server-side layer. Spec 40 § 1.
25pub struct ObsHttpLayer<B = ()> {
26    route_extractor: RouteFn<B>,
27    propagator: Arc<W3cPropagator>,
28    emit_started: bool,
29    emit_metrics: bool,
30    status_classifier: StatusFn,
31    per_request_observer: Option<ObserverFn<B>>,
32}
33
34impl<B> std::fmt::Debug for ObsHttpLayer<B> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ObsHttpLayer")
37            .field("emit_started", &self.emit_started)
38            .field("emit_metrics", &self.emit_metrics)
39            .finish_non_exhaustive()
40    }
41}
42
43impl<B> Clone for ObsHttpLayer<B> {
44    fn clone(&self) -> Self {
45        Self {
46            route_extractor: Arc::clone(&self.route_extractor),
47            propagator: Arc::clone(&self.propagator),
48            emit_started: self.emit_started,
49            emit_metrics: self.emit_metrics,
50            status_classifier: Arc::clone(&self.status_classifier),
51            per_request_observer: self.per_request_observer.clone(),
52        }
53    }
54}
55
56impl<B> ObsHttpLayer<B> {
57    /// Construct a server-side layer with sensible defaults.
58    /// `emit_started` is off; `emit_metrics` is on.
59    #[must_use]
60    pub fn server() -> Self {
61        Self {
62            route_extractor: Arc::new(|req: &Request<B>| req.uri().path().to_string()),
63            propagator: Arc::new(W3cPropagator::new()),
64            emit_started: false,
65            emit_metrics: true,
66            status_classifier: Arc::new(|s| status_class(s)),
67            per_request_observer: None,
68        }
69    }
70
71    /// Override the route extractor.
72    #[must_use]
73    pub fn with_route_extractor<F>(mut self, f: F) -> Self
74    where
75        F: Fn(&Request<B>) -> String + Send + Sync + 'static,
76    {
77        self.route_extractor = Arc::new(f);
78        self
79    }
80
81    /// Toggle emission of `ObsHttpRequestStarted`. Default off.
82    #[must_use]
83    pub fn with_emit_started(mut self, on: bool) -> Self {
84        self.emit_started = on;
85        self
86    }
87
88    /// Toggle emission of `ObsHttpRequestCompleted` metrics fields.
89    /// Default on.
90    #[must_use]
91    pub fn with_emit_metrics(mut self, on: bool) -> Self {
92        self.emit_metrics = on;
93        self
94    }
95
96    /// Override the W3C propagator.
97    #[must_use]
98    pub fn with_propagator(mut self, p: W3cPropagator) -> Self {
99        self.propagator = Arc::new(p);
100        self
101    }
102
103    /// Override the status classifier.
104    #[must_use]
105    pub fn with_status_classifier<F>(mut self, f: F) -> Self
106    where
107        F: Fn(u16) -> &'static str + Send + Sync + 'static,
108    {
109        self.status_classifier = Arc::new(f);
110        self
111    }
112
113    /// Per-request observer hook. Spec 40 § 3.1.
114    #[must_use]
115    pub fn with_per_request_observer<F>(mut self, f: F) -> Self
116    where
117        F: Fn(&Request<B>) -> Option<Arc<dyn Observer>> + Send + Sync + 'static,
118    {
119        self.per_request_observer = Some(Arc::new(f));
120        self
121    }
122}
123
124impl<S, B> tower::Layer<S> for ObsHttpLayer<B>
125where
126    S: Service<Request<B>>,
127    S::Future: Send,
128    B: 'static,
129{
130    type Service = ObsHttpService<S, B>;
131    fn layer(&self, inner: S) -> Self::Service {
132        ObsHttpService {
133            inner,
134            layer: self.clone(),
135        }
136    }
137}
138
139/// The wrapped service.
140pub struct ObsHttpService<S, B> {
141    inner: S,
142    layer: ObsHttpLayer<B>,
143}
144
145impl<S, B> std::fmt::Debug for ObsHttpService<S, B> {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("ObsHttpService")
148            .field("layer", &self.layer)
149            .finish_non_exhaustive()
150    }
151}
152
153impl<S, B> Clone for ObsHttpService<S, B>
154where
155    S: Clone,
156{
157    fn clone(&self) -> Self {
158        Self {
159            inner: self.inner.clone(),
160            layer: self.layer.clone(),
161        }
162    }
163}
164
165impl<S, B, ResBody> Service<Request<B>> for ObsHttpService<S, B>
166where
167    S: Service<Request<B>, Response = http::Response<ResBody>> + Send,
168    S::Future: Send + 'static,
169    S::Error: Send + 'static,
170    B: Send + 'static,
171{
172    type Response = S::Response;
173    type Error = S::Error;
174    type Future = ObsHttpFuture<S::Future>;
175
176    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.inner.poll_ready(cx)
178    }
179
180    fn call(&mut self, req: Request<B>) -> Self::Future {
181        let started = Instant::now();
182        // Spec 95 § 3.10 / P2-AH: cap externally-supplied strings
183        // (route + method) at `max_external_string_bytes` (default 256
184        // per CLAUDE.md) so a hostile caller cannot blow up
185        // `env.labels` and downstream consumers. The aggregate
186        // `max_payload_bytes` still applies as a backstop.
187        let cap: u16 = 256;
188        let route = obs_core::cap_external_string("route", (self.layer.route_extractor)(&req), cap);
189        let method =
190            obs_core::cap_external_string("method", req.method().as_str().to_string(), cap);
191        let propagator = Arc::clone(&self.layer.propagator);
192        let status_classifier = Arc::clone(&self.layer.status_classifier);
193        let emit_started = self.layer.emit_started;
194        let emit_metrics = self.layer.emit_metrics;
195        let observer_override = self
196            .layer
197            .per_request_observer
198            .as_ref()
199            .and_then(|f| f(&req));
200
201        // Extract or generate trace context.
202        let mut ctx = propagator
203            .extract(req.headers())
204            .unwrap_or_else(|| TraceContext {
205                trace_id: fresh_trace_id(),
206                span_id: fresh_span_id(),
207                flags: "01".to_string(),
208                tracestate: String::new(),
209            });
210        // Always assign a fresh `span_id` at the boundary (the
211        // extracted span becomes the parent if present).
212        let parent_span = if !ctx.span_id.is_empty() && propagator.extract(req.headers()).is_some()
213        {
214            ctx.span_id.clone()
215        } else {
216            String::new()
217        };
218        ctx.span_id = fresh_span_id();
219        let trace_id = ctx.trace_id.clone();
220        let span_id = ctx.span_id.clone();
221
222        if emit_started {
223            emit_request_started(
224                &route,
225                &method,
226                &trace_id,
227                &parent_span,
228                observer_override.as_ref(),
229            );
230        }
231
232        let inner_fut = self.inner.call(req);
233
234        // Spec 94 § 2.1 / P0-A: build an `obs::scope!` frame so handler
235        // emits inherit `trace_id`/`span_id`/`parent_span_id`. The
236        // frame is re-entered on every poll via `Instrumented<F>`-style
237        // push/pop in `Future::poll`.
238        let scope_seed = ScopeFrameBuilder::new()
239            .context()
240            .trace_id(trace_id.clone())
241            .span_id(span_id.clone())
242            .parent_span_id(parent_span.clone())
243            .into_frame();
244
245        ObsHttpFuture {
246            inner: inner_fut,
247            started: Some(started),
248            route,
249            method,
250            trace_id,
251            span_id,
252            parent_span,
253            status_classifier,
254            emit_metrics,
255            observer_override,
256            scope_seed: Some(scope_seed),
257        }
258    }
259}
260
261pin_project! {
262    /// Future returned by [`ObsHttpService::call`].
263    pub struct ObsHttpFuture<F> {
264        #[pin]
265        inner: F,
266        started: Option<Instant>,
267        route: String,
268        method: String,
269        trace_id: String,
270        span_id: String,
271        parent_span: String,
272        status_classifier: StatusFn,
273        emit_metrics: bool,
274        observer_override: Option<Arc<dyn Observer>>,
275        // Cloned per poll into a fresh `ScopeFrame`; the frame is
276        // pushed at poll-start and popped at poll-end so handler emits
277        // inherit the request's trace context (spec 94 P0-A).
278        scope_seed: Option<ScopeFrame>,
279    }
280}
281
282impl<F, ResBody, E> Future for ObsHttpFuture<F>
283where
284    F: Future<Output = Result<http::Response<ResBody>, E>>,
285{
286    type Output = F::Output;
287
288    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
289        let mut this = self.project();
290        // Spec 94 § 2.1 / P0-A: push a fresh `obs::scope!` frame per
291        // poll so handler emits inherit trace context across `.await`
292        // and thread migration. The guard pops the frame on drop.
293        let _scope_guard = this
294            .scope_seed
295            .as_ref()
296            .map(|seed| RequestScopeGuard::push(seed.clone()));
297        // If a per-request observer override is present, install it
298        // for this poll. Otherwise just poll directly.
299        let result = if let Some(o) = this.observer_override.clone() {
300            with_observer_task_sync(o, || this.inner.as_mut().poll(cx))
301        } else {
302            this.inner.as_mut().poll(cx)
303        };
304        match result {
305            Poll::Pending => Poll::Pending,
306            Poll::Ready(out) => {
307                let started = this.started.take().unwrap_or_else(Instant::now);
308                let elapsed_ms = started.elapsed().as_millis() as u64;
309                match &out {
310                    Ok(resp) => {
311                        if *this.emit_metrics {
312                            let status = resp.status().as_u16();
313                            let class = (this.status_classifier)(status);
314                            emit_request_completed(
315                                this.route,
316                                this.method,
317                                class,
318                                elapsed_ms,
319                                this.trace_id,
320                                this.span_id,
321                                this.parent_span,
322                                this.observer_override.as_ref(),
323                            );
324                        }
325                    }
326                    Err(_) => {
327                        if *this.emit_metrics {
328                            emit_request_completed(
329                                this.route,
330                                this.method,
331                                "err",
332                                elapsed_ms,
333                                this.trace_id,
334                                this.span_id,
335                                this.parent_span,
336                                this.observer_override.as_ref(),
337                            );
338                        }
339                    }
340                }
341                Poll::Ready(out)
342            }
343        }
344    }
345}
346
347/// Per-poll RAII guard that pushes a request-scope frame at poll-start
348/// and pops it at poll-end. Mirrors the `Instrumented<F>` pattern in
349/// `obs-core`'s instrumented module so handler emits inherit
350/// `trace_id`/`span_id`/`parent_span_id` across thread migration.
351/// Spec 94 § 2.1.
352struct RequestScopeGuard;
353
354impl RequestScopeGuard {
355    fn push(frame: ScopeFrame) -> Self {
356        obs_core::scope::push_frame_pub(frame);
357        Self
358    }
359}
360
361impl Drop for RequestScopeGuard {
362    fn drop(&mut self) {
363        let _ = obs_core::scope::pop_frame_pub();
364    }
365}
366
367/// Encode a buffa message into a `Vec<u8>` payload. Spec 94 P1-B / P1-G.
368fn encode_into<M: ::buffa::Message>(msg: &M, out: &mut Vec<u8>) {
369    let mut cache = ::buffa::SizeCache::default();
370    let size = msg.compute_size(&mut cache);
371    let mut buf = BytesMut::with_capacity(size as usize);
372    msg.write_to(&mut cache, &mut buf);
373    out.clear();
374    out.extend_from_slice(&buf);
375}
376
377fn emit_request_started(
378    route: &str,
379    method: &str,
380    trace_id: &str,
381    parent_span: &str,
382    observer: Option<&Arc<dyn Observer>>,
383) {
384    // Spec 94 P1-G: encode typed `ObsHttpRequestStarted` via buffa
385    // rather than overloading `env.labels`. Mirror `route`/`method`
386    // onto labels for downstream filter operators (D7-4).
387    let typed = ObsHttpRequestStarted {
388        method: method.to_string(),
389        route: route.to_string(),
390        __buffa_unknown_fields: Default::default(),
391    };
392    let mut env = ObsEnvelope {
393        full_name: "obs.v1.ObsHttpRequestStarted".to_string(),
394        tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
395        sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
396        trace_id: trace_id.to_string(),
397        parent_span_id: parent_span.to_string(),
398        ..Default::default()
399    };
400    encode_into(&typed, &mut env.payload);
401    env.labels.insert("route".to_string(), route.to_string());
402    env.labels.insert("method".to_string(), method.to_string());
403    if let Some(o) = observer {
404        o.emit_envelope(env);
405    } else {
406        obs_core::observer().emit_envelope(env);
407    }
408}
409
410#[allow(clippy::too_many_arguments)]
411fn emit_request_completed(
412    route: &str,
413    method: &str,
414    status_class: &str,
415    latency_ms: u64,
416    trace_id: &str,
417    span_id: &str,
418    parent_span: &str,
419    observer: Option<&Arc<dyn Observer>>,
420) {
421    // Spec 94 § 3.7 / P1-G: encode typed `ObsHttpRequestCompleted` so
422    // the MEASUREMENT fields (`latency_ms`, `bytes_out`) live in the
423    // typed payload — `project_metrics` can then dispatch them. The
424    // bytes_out counter is currently unknown at this layer (we'd need
425    // a wrapping body), so it ships as 0 until that plumbing lands.
426    let typed = ObsHttpRequestCompleted {
427        method: method.to_string(),
428        route: route.to_string(),
429        status_class: status_class.to_string(),
430        latency_ms,
431        bytes_out: 0,
432        __buffa_unknown_fields: Default::default(),
433    };
434    let mut env = ObsEnvelope {
435        full_name: "obs.v1.ObsHttpRequestCompleted".to_string(),
436        tier: ::buffa::EnumValue::Known(obs_proto::obs::v1::Tier::TIER_LOG),
437        sev: ::buffa::EnumValue::Known(obs_proto::obs::v1::Severity::SEVERITY_INFO),
438        trace_id: trace_id.to_string(),
439        span_id: span_id.to_string(),
440        parent_span_id: parent_span.to_string(),
441        ..Default::default()
442    };
443    encode_into(&typed, &mut env.payload);
444    // Mirror low-cardinality labels for filter operators (D7-4).
445    // `latency_ms` and `bytes_out` live only in the typed payload now.
446    env.labels.insert("route".to_string(), route.to_string());
447    env.labels.insert("method".to_string(), method.to_string());
448    env.labels
449        .insert("status_class".to_string(), status_class.to_string());
450    if let Some(o) = observer {
451        o.emit_envelope(env);
452    } else {
453        obs_core::observer().emit_envelope(env);
454    }
455}