1use std::io;
2use std::marker::PhantomData;
3use std::sync::Mutex;
4use std::time::{Duration, Instant};
5
6use tracing::Subscriber;
7use tracing_subscriber::filter::EnvFilter;
8use tracing_subscriber::fmt::format::{DefaultFields, Format, Full};
9use tracing_subscriber::fmt::{self, FormatFields, MakeWriter};
10use tracing_subscriber::registry::LookupSpan;
11
12use crate::capture::CaptureMakeWriter;
13use crate::layer::{SamplingLayer, State, Stats};
14use crate::reservoir::Reservoir;
15
16pub struct SamplingLayerBuilder<S, N = DefaultFields, E = Format<Full>, W = fn() -> io::Stderr> {
20 budgets: Vec<(EnvFilter, u64)>,
21 bucket_duration: Duration,
22 writer: W,
23 fmt_layer: fmt::Layer<S, N, E, CaptureMakeWriter>,
24 _subscriber: PhantomData<fn(S)>,
25}
26
27impl<S> SamplingLayer<S> {
28 pub fn builder() -> SamplingLayerBuilder<S> {
30 SamplingLayerBuilder {
31 budgets: Vec::new(),
32 bucket_duration: Duration::from_millis(50),
33 writer: io::stderr as fn() -> io::Stderr,
34 fmt_layer: fmt::Layer::default().with_writer(CaptureMakeWriter::default()),
35 _subscriber: PhantomData,
36 }
37 }
38}
39
40impl<S, N, E, W> SamplingLayerBuilder<S, N, E, W> {
41 pub fn budget(mut self, filter: EnvFilter, limit_per_second: u64) -> Self {
45 self.budgets.push((filter, limit_per_second));
46 self
47 }
48
49 pub fn bucket_duration(mut self, duration: Duration) -> Self {
51 self.bucket_duration = duration;
52 self
53 }
54
55 pub fn writer<W2>(self, writer: W2) -> SamplingLayerBuilder<S, N, E, W2> {
57 SamplingLayerBuilder {
58 budgets: self.budgets,
59 bucket_duration: self.bucket_duration,
60 writer,
61 fmt_layer: self.fmt_layer,
62 _subscriber: PhantomData,
63 }
64 }
65}
66
67impl<S, N, E, W> SamplingLayerBuilder<S, N, E, W>
68where
69 S: Subscriber + for<'a> LookupSpan<'a>,
70 N: for<'writer> FormatFields<'writer> + 'static,
71 E: fmt::FormatEvent<S, N> + 'static,
72{
73 pub fn event_format<E2>(self, e: E2) -> SamplingLayerBuilder<S, N, E2, W>
75 where
76 E2: fmt::FormatEvent<S, N> + 'static,
77 {
78 SamplingLayerBuilder {
79 budgets: self.budgets,
80 bucket_duration: self.bucket_duration,
81 writer: self.writer,
82 fmt_layer: self.fmt_layer.event_format(e),
83 _subscriber: PhantomData,
84 }
85 }
86
87 pub fn map_event_format<E2>(self, f: impl FnOnce(E) -> E2) -> SamplingLayerBuilder<S, N, E2, W>
89 where
90 E2: fmt::FormatEvent<S, N> + 'static,
91 {
92 SamplingLayerBuilder {
93 budgets: self.budgets,
94 bucket_duration: self.bucket_duration,
95 writer: self.writer,
96 fmt_layer: self.fmt_layer.map_event_format(f),
97 _subscriber: PhantomData,
98 }
99 }
100
101 pub fn fmt_fields<N2>(self, fmt_fields: N2) -> SamplingLayerBuilder<S, N2, E, W>
103 where
104 N2: for<'writer> FormatFields<'writer> + 'static,
105 {
106 SamplingLayerBuilder {
107 budgets: self.budgets,
108 bucket_duration: self.bucket_duration,
109 writer: self.writer,
110 fmt_layer: self.fmt_layer.fmt_fields(fmt_fields),
111 _subscriber: PhantomData,
112 }
113 }
114}
115
116impl<S, N, L, T, W> SamplingLayerBuilder<S, N, Format<L, T>, W>
117where
118 N: for<'writer> FormatFields<'writer> + 'static,
119{
120 pub fn without_time(self) -> SamplingLayerBuilder<S, N, Format<L, ()>, W> {
122 SamplingLayerBuilder {
123 budgets: self.budgets,
124 bucket_duration: self.bucket_duration,
125 writer: self.writer,
126 fmt_layer: self.fmt_layer.without_time(),
127 _subscriber: PhantomData,
128 }
129 }
130
131 pub fn with_target(self, display_target: bool) -> Self {
133 SamplingLayerBuilder {
134 fmt_layer: self.fmt_layer.with_target(display_target),
135 ..self
136 }
137 }
138
139 pub fn with_level(self, display_level: bool) -> Self {
141 SamplingLayerBuilder {
142 fmt_layer: self.fmt_layer.with_level(display_level),
143 ..self
144 }
145 }
146
147 pub fn compact(
149 self,
150 ) -> SamplingLayerBuilder<S, N, Format<tracing_subscriber::fmt::format::Compact, T>, W>
151 where
152 N: for<'writer> FormatFields<'writer> + 'static,
153 {
154 SamplingLayerBuilder {
155 budgets: self.budgets,
156 bucket_duration: self.bucket_duration,
157 writer: self.writer,
158 fmt_layer: self.fmt_layer.compact(),
159 _subscriber: PhantomData,
160 }
161 }
162}
163
164impl<S, N, E, W> SamplingLayerBuilder<S, N, E, W>
165where
166 W: for<'a> MakeWriter<'a> + 'static,
167 S: Subscriber + for<'a> LookupSpan<'a>,
168 N: for<'writer> FormatFields<'writer> + 'static,
169 E: fmt::FormatEvent<S, N> + 'static,
170{
171 pub fn build(self) -> (SamplingLayer<S, N, E, W>, Stats) {
174 assert!(
175 !self.bucket_duration.is_zero(),
176 "bucket_duration must be > 0"
177 );
178
179 let bucket_secs = self.bucket_duration.as_secs_f64();
180 let mut filters = Vec::new();
181 let mut reservoirs = Vec::new();
182 for (filter, limit_per_second) in self.budgets {
183 let limit_per_bucket = (limit_per_second as f64 * bucket_secs).ceil() as usize;
184 if limit_per_bucket == 0 {
185 continue;
186 }
187 filters.push(filter);
188 reservoirs.push(Reservoir::new(limit_per_bucket));
189 }
190
191 let now = Instant::now();
192 let stats = Stats::new();
193 let layer = SamplingLayer {
194 filters,
195 state: Mutex::new(State {
196 bucket_start: now,
197 seq: 0,
198 reservoirs,
199 pending: Vec::new().into_iter(),
200 last_release: now,
201 }),
202 bucket_duration: self.bucket_duration,
203 writer: self.writer,
204 fmt_layer: self.fmt_layer,
205 stats: stats.clone(),
206 _subscriber: PhantomData,
207 };
208 (layer, stats)
209 }
210}