tracing_causality/
channel.rs1use crate::Update;
2use std::{
3 collections::BTreeSet,
4 fmt::Debug,
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 Arc,
8 },
9};
10use tracing_core::span::Id;
11
12pub(crate) fn bounded(id: Id, capacity: usize) -> (Sender, Updates) {
14 let (sender, receiver) = flume::bounded(capacity);
15 let overflow_flag = OverflowFlag::default();
16 let sender = Sender {
17 id,
18 sender,
19 overflow_flag: overflow_flag.clone(),
20 };
21 let updates = Updates {
22 receiver,
23 overflow_flag,
24 };
25 (sender, updates)
26}
27
28#[derive(Default, Clone, Debug)]
29struct OverflowFlag {
30 flag: Arc<AtomicBool>,
31}
32
33impl OverflowFlag {
34 fn set(&self) {
35 self.flag.store(true, Ordering::Release);
36 }
37
38 fn check(&self) -> bool {
39 self.flag.load(Ordering::Acquire)
40 }
41}
42
43#[cfg(test)]
44mod test_overflow_flag {
45 use super::OverflowFlag;
46
47 #[test]
48 fn set_and_check() {
49 let flag = OverflowFlag::default();
50 assert_eq!(flag.check(), false);
52 flag.set();
54 assert_eq!(flag.check(), true);
55 flag.set();
57 assert_eq!(flag.check(), true);
58 }
59}
60
61pub struct Updates {
63 receiver: flume::Receiver<Update>,
64 overflow_flag: OverflowFlag,
65}
66
67impl Updates {
68 pub fn is_empty(&self) -> bool {
69 self.receiver.is_empty()
70 }
71
72 pub fn is_disconnected(&self) -> bool {
76 self.receiver.is_disconnected()
77 }
78
79 pub fn next(&self) -> Option<Update> {
80 self.receiver.try_recv().ok()
81 }
82
83 pub fn into_iter(self) -> impl Iterator<Item = Update> {
84 self.receiver
85 .into_iter()
86 .take_while(move |_| !self.overflow_flag.check())
87 }
88
89 pub fn into_stream(self) -> impl futures_core::stream::Stream {
90 use futures::stream::StreamExt;
91 self.receiver
92 .into_stream()
93 .take_while(move |_| std::future::ready(!self.overflow_flag.check()))
94 }
95
96 pub fn drain(&self) -> impl ExactSizeIterator<Item = Update> + '_ {
99 self.receiver.drain()
100 }
101}
102
103impl Default for Updates {
104 fn default() -> Self {
105 let (_, receiver) = flume::bounded(0);
106 let overflow_flag = OverflowFlag::default();
107
108 Self {
109 receiver,
110 overflow_flag,
111 }
112 }
113}
114
115#[derive(Clone, Debug)]
116pub(crate) struct Sender {
117 id: Id,
118 sender: flume::Sender<Update>,
119 overflow_flag: OverflowFlag,
120}
121
122impl Sender {
123 fn try_send(&self, update: Update) -> Result<(), ()> {
124 use flume::TrySendError::{Disconnected, Full};
125
126 self.sender
127 .try_send(update)
128 .map_err(|err| match err {
129 Full(_) => {
130 self.overflow_flag.set();
132 }
133 Disconnected(_) => {
134 }
136 })
137 .map(|_| {})
138 }
139
140 pub(crate) fn broadcast(listeners: &mut BTreeSet<Self>, update: Update) {
141 listeners.retain(|listener| listener.try_send(update.clone()).is_ok());
142 }
143}
144
145impl std::hash::Hash for Sender {
146 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
147 self.id.hash(state);
148 }
149}
150
151impl Eq for Sender {}
152
153impl Ord for Sender {
154 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
155 self.id.into_u64().cmp(&other.id.into_u64())
156 }
157}
158
159impl PartialOrd for Sender {
160 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161 self.id.into_u64().partial_cmp(&other.id.into_u64())
162 }
163}
164
165impl PartialEq for Sender {
166 fn eq(&self, other: &Self) -> bool {
167 self.id.eq(&other.id)
168 }
169}
170
171#[cfg(test)]
172mod test_sender {
173 use super::*;
174 use tracing_core::span::Id;
175 use tracing_subscriber::{prelude::*, registry::Registry};
176
177 use crate as tracing_causality;
178
179 #[test]
182 fn should_disconnect_if_sender_dropped() {
183 let _guard = Registry::default().set_default();
184
185 let (sender, updates) = bounded(Id::from_u64(1), 1);
186
187 assert!(!updates.is_disconnected());
188
189 let cause = tracing::trace_span!("cause");
190 let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
191
192 let cause_id_and_metadata = tracing_causality::Span {
193 id: cause.id().unwrap(),
194 metadata: cause.metadata().unwrap(),
195 };
196
197 let consequence_id_and_metadata = tracing_causality::Span {
198 id: consequence.id().unwrap(),
199 metadata: consequence.metadata().unwrap(),
200 };
201
202 let update = Update::OpenDirect {
203 cause: cause_id_and_metadata,
204 consequence: consequence_id_and_metadata,
205 };
206
207 sender
208 .try_send(update.clone())
209 .expect("sending should succeed");
210
211 assert!(!updates.is_disconnected());
212
213 drop(sender);
214
215 assert!(updates.is_disconnected());
216
217 assert_eq!(updates.next(), Some(update.clone()));
218 }
219
220 #[test]
221 fn try_send_success() {
222 let _guard = Registry::default().set_default();
223
224 let (sender, updates) = bounded(Id::from_u64(1), 1);
225
226 let cause = tracing::trace_span!("cause");
227 let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
228
229 let cause_id_and_metadata = tracing_causality::Span {
230 id: cause.id().unwrap(),
231 metadata: cause.metadata().unwrap(),
232 };
233
234 let consequence_id_and_metadata = tracing_causality::Span {
235 id: consequence.id().unwrap(),
236 metadata: consequence.metadata().unwrap(),
237 };
238
239 let update = Update::OpenDirect {
240 cause: cause_id_and_metadata,
241 consequence: consequence_id_and_metadata,
242 };
243 let send_result = sender.try_send(update.clone());
244 assert!(send_result.is_ok());
245 assert_eq!(updates.next(), Some(update.clone()));
246 assert!(updates.is_empty());
247 }
248
249 #[test]
250 fn try_send_err_disconnected() {
251 let _guard = Registry::default().set_default();
252
253 let (sender, _) = bounded(Id::from_u64(1), 1);
255
256 let cause = tracing::trace_span!("cause");
257 let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
258
259 let cause_id_and_metadata = tracing_causality::Span {
260 id: cause.id().unwrap(),
261 metadata: cause.metadata().unwrap(),
262 };
263
264 let consequence_id_and_metadata = tracing_causality::Span {
265 id: consequence.id().unwrap(),
266 metadata: consequence.metadata().unwrap(),
267 };
268
269 let update = Update::OpenDirect {
270 cause: cause_id_and_metadata,
271 consequence: consequence_id_and_metadata,
272 };
273 let send_result = sender.try_send(update);
274 assert!(send_result.is_err());
275 }
276
277 #[test]
278 fn try_send_err_full() {
279 let _guard = Registry::default().set_default();
280
281 let (sender, updates) = bounded(Id::from_u64(1), 0);
283 let cause = tracing::trace_span!("cause");
284 let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
285
286 let cause_id_and_metadata = tracing_causality::Span {
287 id: cause.id().unwrap(),
288 metadata: cause.metadata().unwrap(),
289 };
290
291 let consequence_id_and_metadata = tracing_causality::Span {
292 id: consequence.id().unwrap(),
293 metadata: consequence.metadata().unwrap(),
294 };
295
296 let update = Update::OpenDirect {
297 cause: cause_id_and_metadata,
298 consequence: consequence_id_and_metadata,
299 };
300 assert_eq!(updates.overflow_flag.check(), false);
301 let send_result = sender.try_send(update);
302 assert!(send_result.is_err());
303 assert_eq!(updates.overflow_flag.check(), true);
304 assert_eq!(updates.next(), None,);
305 }
306}