use crate::Update;
use std::{
collections::BTreeSet,
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tracing_core::span::Id;
pub(crate) fn bounded(id: Id, capacity: usize) -> (Sender, Updates) {
let (sender, receiver) = flume::bounded(capacity);
let overflow_flag = OverflowFlag::default();
let sender = Sender {
id,
sender,
overflow_flag: overflow_flag.clone(),
};
let updates = Updates {
receiver,
overflow_flag,
};
(sender, updates)
}
#[derive(Default, Clone, Debug)]
struct OverflowFlag {
flag: Arc<AtomicBool>,
}
impl OverflowFlag {
fn set(&self) {
self.flag.store(true, Ordering::Release);
}
fn check(&self) -> bool {
self.flag.load(Ordering::Acquire)
}
}
#[cfg(test)]
mod test_overflow_flag {
use super::OverflowFlag;
#[test]
fn set_and_check() {
let flag = OverflowFlag::default();
assert_eq!(flag.check(), false);
flag.set();
assert_eq!(flag.check(), true);
flag.set();
assert_eq!(flag.check(), true);
}
}
pub struct Updates {
receiver: flume::Receiver<Update>,
overflow_flag: OverflowFlag,
}
impl Updates {
pub fn is_empty(&self) -> bool {
self.receiver.is_empty()
}
pub fn is_disconnected(&self) -> bool {
self.receiver.is_disconnected()
}
pub fn next(&self) -> Option<Update> {
self.receiver.try_recv().ok()
}
pub fn into_iter(self) -> impl Iterator<Item = Update> {
self.receiver
.into_iter()
.take_while(move |_| !self.overflow_flag.check())
}
pub fn into_stream(self) -> impl futures_core::stream::Stream {
use futures::stream::StreamExt;
self.receiver
.into_stream()
.take_while(move |_| std::future::ready(!self.overflow_flag.check()))
}
pub fn drain(&self) -> impl ExactSizeIterator<Item = Update> + '_ {
self.receiver.drain()
}
}
impl Default for Updates {
fn default() -> Self {
let (_, receiver) = flume::bounded(0);
let overflow_flag = OverflowFlag::default();
Self {
receiver,
overflow_flag,
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct Sender {
id: Id,
sender: flume::Sender<Update>,
overflow_flag: OverflowFlag,
}
impl Sender {
fn try_send(&self, update: Update) -> Result<(), ()> {
use flume::TrySendError::{Disconnected, Full};
self.sender
.try_send(update)
.map_err(|err| match err {
Full(_) => {
self.overflow_flag.set();
}
Disconnected(_) => {
}
})
.map(|_| {})
}
pub(crate) fn broadcast(listeners: &mut BTreeSet<Self>, update: Update) {
listeners.retain(|listener| listener.try_send(update.clone()).is_ok());
}
}
impl std::hash::Hash for Sender {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl Eq for Sender {}
impl Ord for Sender {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.id.into_u64().cmp(&other.id.into_u64())
}
}
impl PartialOrd for Sender {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.id.into_u64().partial_cmp(&other.id.into_u64())
}
}
impl PartialEq for Sender {
fn eq(&self, other: &Self) -> bool {
self.id.eq(&other.id)
}
}
#[cfg(test)]
mod test_sender {
use super::*;
use tracing_core::span::Id;
use tracing_subscriber::{prelude::*, registry::Registry};
use crate as tracing_causality;
#[test]
fn should_disconnect_if_sender_dropped() {
let _guard = Registry::default().set_default();
let (sender, updates) = bounded(Id::from_u64(1), 1);
assert!(!updates.is_disconnected());
let cause = tracing::trace_span!("cause");
let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
let cause_id_and_metadata = tracing_causality::Span {
id: cause.id().unwrap(),
metadata: cause.metadata().unwrap(),
};
let consequence_id_and_metadata = tracing_causality::Span {
id: consequence.id().unwrap(),
metadata: consequence.metadata().unwrap(),
};
let update = Update::OpenDirect {
cause: cause_id_and_metadata,
consequence: consequence_id_and_metadata,
};
sender
.try_send(update.clone())
.expect("sending should succeed");
assert!(!updates.is_disconnected());
drop(sender);
assert!(updates.is_disconnected());
assert_eq!(updates.next(), Some(update.clone()));
}
#[test]
fn try_send_success() {
let _guard = Registry::default().set_default();
let (sender, updates) = bounded(Id::from_u64(1), 1);
let cause = tracing::trace_span!("cause");
let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
let cause_id_and_metadata = tracing_causality::Span {
id: cause.id().unwrap(),
metadata: cause.metadata().unwrap(),
};
let consequence_id_and_metadata = tracing_causality::Span {
id: consequence.id().unwrap(),
metadata: consequence.metadata().unwrap(),
};
let update = Update::OpenDirect {
cause: cause_id_and_metadata,
consequence: consequence_id_and_metadata,
};
let send_result = sender.try_send(update.clone());
assert!(send_result.is_ok());
assert_eq!(updates.next(), Some(update.clone()));
assert!(updates.is_empty());
}
#[test]
fn try_send_err_disconnected() {
let _guard = Registry::default().set_default();
let (sender, _) = bounded(Id::from_u64(1), 1);
let cause = tracing::trace_span!("cause");
let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
let cause_id_and_metadata = tracing_causality::Span {
id: cause.id().unwrap(),
metadata: cause.metadata().unwrap(),
};
let consequence_id_and_metadata = tracing_causality::Span {
id: consequence.id().unwrap(),
metadata: consequence.metadata().unwrap(),
};
let update = Update::OpenDirect {
cause: cause_id_and_metadata,
consequence: consequence_id_and_metadata,
};
let send_result = sender.try_send(update);
assert!(send_result.is_err());
}
#[test]
fn try_send_err_full() {
let _guard = Registry::default().set_default();
let (sender, updates) = bounded(Id::from_u64(1), 0);
let cause = tracing::trace_span!("cause");
let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
let cause_id_and_metadata = tracing_causality::Span {
id: cause.id().unwrap(),
metadata: cause.metadata().unwrap(),
};
let consequence_id_and_metadata = tracing_causality::Span {
id: consequence.id().unwrap(),
metadata: consequence.metadata().unwrap(),
};
let update = Update::OpenDirect {
cause: cause_id_and_metadata,
consequence: consequence_id_and_metadata,
};
assert_eq!(updates.overflow_flag.check(), false);
let send_result = sender.try_send(update);
assert!(send_result.is_err());
assert_eq!(updates.overflow_flag.check(), true);
assert_eq!(updates.next(), None,);
}
}