1#![deny(rust_2018_idioms, warnings)]
3
4use std::marker::PhantomData;
5use std::task::{Context, Poll};
6
7use actix_service::{
8 apply, dev::ApplyTransform, IntoServiceFactory, Service, ServiceFactory, Transform,
9};
10use futures_util::future::{ok, Either, Ready};
11use tracing_futures::{Instrument, Instrumented};
12
13#[derive(Clone)]
16pub struct TracingService<S, F> {
17 inner: S,
18 make_span: F,
19}
20
21impl<S, F> TracingService<S, F> {
22 pub fn new(inner: S, make_span: F) -> Self {
23 TracingService { inner, make_span }
24 }
25}
26
27impl<S, F> Service for TracingService<S, F>
28where
29 S: Service,
30 F: Fn(&S::Request) -> Option<tracing::Span>,
31{
32 type Request = S::Request;
33 type Response = S::Response;
34 type Error = S::Error;
35 type Future = Either<S::Future, Instrumented<S::Future>>;
36
37 fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38 self.inner.poll_ready(ctx)
39 }
40
41 fn call(&mut self, req: Self::Request) -> Self::Future {
42 let span = (self.make_span)(&req);
43 let _enter = span.as_ref().map(|s| s.enter());
44
45 let fut = self.inner.call(req);
46
47 if let Some(span) = span
49 .clone()
50 .map(|span| tracing::span!(parent: &span, tracing::Level::INFO, "future"))
51 {
52 Either::Right(fut.instrument(span))
53 } else {
54 Either::Left(fut)
55 }
56 }
57}
58
59pub struct TracingTransform<S, U, F> {
63 make_span: F,
64 _p: PhantomData<fn(S, U)>,
65}
66
67impl<S, U, F> TracingTransform<S, U, F> {
68 pub fn new(make_span: F) -> Self {
69 TracingTransform {
70 make_span,
71 _p: PhantomData,
72 }
73 }
74}
75
76impl<S, U, F> Transform<S> for TracingTransform<S, U, F>
77where
78 S: Service,
79 U: ServiceFactory<
80 Request = S::Request,
81 Response = S::Response,
82 Error = S::Error,
83 Service = S,
84 >,
85 F: Fn(&S::Request) -> Option<tracing::Span> + Clone,
86{
87 type Request = S::Request;
88 type Response = S::Response;
89 type Error = S::Error;
90 type Transform = TracingService<S, F>;
91 type InitError = U::InitError;
92 type Future = Ready<Result<Self::Transform, Self::InitError>>;
93
94 fn new_transform(&self, service: S) -> Self::Future {
95 ok(TracingService::new(service, self.make_span.clone()))
96 }
97}
98
99pub fn trace<S, U, F>(
113 service_factory: U,
114 make_span: F,
115) -> ApplyTransform<TracingTransform<S::Service, S, F>, S>
116where
117 S: ServiceFactory,
118 F: Fn(&S::Request) -> Option<tracing::Span> + Clone,
119 U: IntoServiceFactory<S>,
120{
121 apply(
122 TracingTransform::new(make_span),
123 service_factory.into_factory(),
124 )
125}
126
127#[cfg(test)]
128mod test {
129 use super::*;
130
131 use std::cell::RefCell;
132 use std::collections::{BTreeMap, BTreeSet};
133 use std::sync::{Arc, RwLock};
134
135 use actix_service::{fn_factory, fn_service};
136 use slab::Slab;
137 use tracing::{span, Event, Level, Metadata, Subscriber};
138
139 thread_local! {
140 static SPAN: RefCell<Vec<span::Id>> = RefCell::new(Vec::new());
141 }
142
143 #[derive(Default)]
144 struct Stats {
145 entered_spans: BTreeSet<u64>,
146 exited_spans: BTreeSet<u64>,
147 events_count: BTreeMap<u64, usize>,
148 }
149
150 #[derive(Default)]
151 struct Inner {
152 spans: Slab<&'static Metadata<'static>>,
153 stats: Stats,
154 }
155
156 #[derive(Clone, Default)]
157 struct TestSubscriber {
158 inner: Arc<RwLock<Inner>>,
159 }
160
161 impl Subscriber for TestSubscriber {
162 fn enabled(&self, _metadata: &Metadata<'_>) -> bool {
163 true
164 }
165
166 fn new_span(&self, span: &span::Attributes<'_>) -> span::Id {
167 let id = self.inner.write().unwrap().spans.insert(span.metadata());
168 span::Id::from_u64(id as u64 + 1)
169 }
170
171 fn record(&self, _span: &span::Id, _values: &span::Record<'_>) {}
172
173 fn record_follows_from(&self, _span: &span::Id, _follows: &span::Id) {}
174
175 fn event(&self, event: &Event<'_>) {
176 let id = event
177 .parent()
178 .cloned()
179 .or_else(|| SPAN.with(|current_span| current_span.borrow().last().cloned()))
180 .unwrap();
181
182 *self
183 .inner
184 .write()
185 .unwrap()
186 .stats
187 .events_count
188 .entry(id.into_u64())
189 .or_insert(0) += 1;
190 }
191
192 fn enter(&self, span: &span::Id) {
193 self.inner
194 .write()
195 .unwrap()
196 .stats
197 .entered_spans
198 .insert(span.into_u64());
199
200 SPAN.with(|current_span| {
201 current_span.borrow_mut().push(span.clone());
202 });
203 }
204
205 fn exit(&self, span: &span::Id) {
206 self.inner
207 .write()
208 .unwrap()
209 .stats
210 .exited_spans
211 .insert(span.into_u64());
212
213 SPAN.with(|current_span| {
215 let leaving = current_span
216 .borrow_mut()
217 .pop()
218 .expect("told to exit span when not in span");
219 assert_eq!(
220 &leaving, span,
221 "told to exit span that was not most recently entered"
222 );
223 });
224 }
225 }
226
227 #[actix_rt::test]
228 async fn service_call() {
229 let service_factory = fn_factory(|| {
230 ok::<_, ()>(fn_service(|req: &'static str| {
231 tracing::event!(Level::TRACE, "It's happening - {}!", req);
232 ok::<_, ()>(())
233 }))
234 });
235
236 let subscriber = TestSubscriber::default();
237 let _guard = tracing::subscriber::set_default(subscriber.clone());
238
239 let span_svc = span!(Level::TRACE, "span_svc");
240 let trace_service_factory = trace(service_factory, |_: &&str| Some(span_svc.clone()));
241 let mut service = trace_service_factory.new_service(()).await.unwrap();
242 service.call("boo").await.unwrap();
243
244 let id = span_svc.id().unwrap().into_u64();
245 assert!(subscriber
246 .inner
247 .read()
248 .unwrap()
249 .stats
250 .entered_spans
251 .contains(&id));
252 assert!(subscriber
253 .inner
254 .read()
255 .unwrap()
256 .stats
257 .exited_spans
258 .contains(&id));
259 assert_eq!(subscriber.inner.read().unwrap().stats.events_count[&id], 1);
260 }
261}