use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::sync::Arc;
use parking_lot::RwLock;
use smallvec::SmallVec;
use tracing::warn;
use crate::{ChannelBuilder, ChannelId, McapWriteOptions, McapWriter, RawChannel, Sink, SinkId};
mod lazy_context;
mod subscriptions;
pub use lazy_context::LazyContext;
use subscriptions::Subscriptions;
#[derive(Default)]
struct ContextInner {
channels: HashMap<ChannelId, Arc<RawChannel>>,
channels_by_topic: HashMap<String, SmallVec<[Arc<RawChannel>; 1]>>,
sinks: HashMap<SinkId, Arc<dyn Sink>>,
subs: Subscriptions,
}
impl ContextInner {
fn get_channel_by_topic(&self, topic: &str) -> Option<&Arc<RawChannel>> {
self.channels_by_topic.get(topic)?.first()
}
fn add_channel(&mut self, channel: Arc<RawChannel>) -> Arc<RawChannel> {
let topic = channel.topic();
let topic_channels = self.channels_by_topic.entry(topic.to_string()).or_default();
if let Some(matching) = topic_channels.iter().find(|c| channel.matches(c)) {
return matching.clone();
}
if !topic_channels.is_empty() {
warn!(
"Channel with topic {topic} already exists in this context; \
use a unique topic for each channel"
);
}
self.channels.insert(channel.id(), channel.clone());
topic_channels.push(channel.clone());
for sink in self.sinks.values() {
if sink.add_channel(&channel) && !sink.auto_subscribe() {
self.subs.subscribe_channels(sink, &[channel.id()]);
}
}
let sinks = self.subs.get_subscribers(channel.id());
channel.update_sinks(sinks);
channel
}
fn remove_channel(&mut self, channel_id: ChannelId) -> bool {
let Some(channel) = self.channels.remove(&channel_id) else {
return false;
};
if let Some(topic_channels) = self.channels_by_topic.get_mut(channel.topic()) {
topic_channels.retain(|c| c.id() != channel_id);
if topic_channels.is_empty() {
self.channels_by_topic.remove(channel.topic());
}
}
self.subs.remove_channel_subscriptions(channel.id());
channel.remove_from_context();
for sink in self.sinks.values() {
sink.remove_channel(&channel);
}
true
}
fn add_sink(&mut self, sink: Arc<dyn Sink>) -> bool {
let sink_id = sink.id();
let Entry::Vacant(entry) = self.sinks.entry(sink_id) else {
return false;
};
entry.insert(sink.clone());
let channels: Vec<_> = self.channels.values().collect();
let ids = if !channels.is_empty() {
sink.add_channels(&channels)
} else {
None
};
if sink.auto_subscribe() {
if self.subs.subscribe_global(sink.clone()) {
self.update_channel_sinks(&channels);
}
} else if let Some(mut ids) = ids {
ids.retain(|id| self.channels.contains_key(id));
if !ids.is_empty() && self.subs.subscribe_channels(&sink, &ids) {
self.update_channel_sinks_by_ids(&ids);
}
}
true
}
fn remove_sink(&mut self, sink_id: SinkId) -> bool {
if self.subs.remove_subscriber(sink_id) {
self.update_channel_sinks(self.channels.values());
}
self.sinks.remove(&sink_id).is_some()
}
fn subscribe_channels(&mut self, sink_id: SinkId, channel_ids: &[ChannelId]) {
if let Some(sink) = self.sinks.get(&sink_id) {
if self.subs.subscribe_channels(sink, channel_ids) {
self.update_channel_sinks_by_ids(channel_ids);
}
}
}
fn unsubscribe_channels(&mut self, sink_id: SinkId, channel_ids: &[ChannelId]) {
if self.subs.unsubscribe_channels(sink_id, channel_ids) {
self.update_channel_sinks_by_ids(channel_ids);
}
}
fn update_channel_sinks_by_ids(&self, channel_ids: &[ChannelId]) {
let channels = channel_ids.iter().filter_map(|id| self.channels.get(id));
self.update_channel_sinks(channels);
}
fn update_channel_sinks(&self, channels: impl IntoIterator<Item = impl AsRef<RawChannel>>) {
for channel in channels {
let channel = channel.as_ref();
let sinks = self.subs.get_subscribers(channel.id());
channel.update_sinks(sinks);
}
}
fn clear(&mut self) {
for (_, channel) in self.channels.drain() {
channel.remove_from_context();
for sink in self.sinks.values() {
sink.remove_channel(&channel);
}
}
self.channels_by_topic.clear();
self.sinks.clear();
self.subs.clear();
}
}
pub struct Context(RwLock<ContextInner>);
impl Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Context").finish_non_exhaustive()
}
}
impl Context {
#[allow(clippy::new_without_default)] pub fn new() -> Arc<Self> {
Arc::new(Self(RwLock::default()))
}
pub fn get_default() -> Arc<Self> {
Arc::clone(LazyContext::get_default())
}
pub fn channel_builder(self: &Arc<Self>, topic: impl Into<String>) -> ChannelBuilder {
ChannelBuilder::new(topic).context(self)
}
pub fn mcap_writer(self: &Arc<Self>) -> McapWriter {
McapWriter::new().context(self)
}
pub fn mcap_writer_with_options(self: &Arc<Self>, options: McapWriteOptions) -> McapWriter {
McapWriter::with_options(options).context(self)
}
#[cfg(feature = "websocket")]
pub fn websocket_server(self: &Arc<Self>) -> crate::WebSocketServer {
crate::WebSocketServer::new().context(self)
}
pub fn get_channel_by_topic(&self, topic: &str) -> Option<Arc<RawChannel>> {
self.0.read().get_channel_by_topic(topic).cloned()
}
pub(crate) fn add_channel(&self, channel: Arc<RawChannel>) -> Arc<RawChannel> {
self.0.write().add_channel(channel)
}
pub(crate) fn remove_channel(&self, channel_id: ChannelId) -> bool {
self.0.write().remove_channel(channel_id)
}
#[doc(hidden)] pub fn add_sink(&self, sink: Arc<dyn Sink>) -> bool {
self.0.write().add_sink(sink)
}
#[doc(hidden)] pub fn remove_sink(&self, sink_id: SinkId) -> bool {
self.0.write().remove_sink(sink_id)
}
#[doc(hidden)] pub fn subscribe_channels(&self, sink_id: SinkId, channel_ids: &[ChannelId]) {
self.0.write().subscribe_channels(sink_id, channel_ids);
}
#[doc(hidden)] pub fn unsubscribe_channels(&self, sink_id: SinkId, channel_ids: &[ChannelId]) {
self.0.write().unsubscribe_channels(sink_id, channel_ids);
}
pub(crate) fn clear(&self) {
self.0.write().clear();
}
}
impl Drop for Context {
fn drop(&mut self) {
self.clear();
}
}
#[cfg(test)]
mod tests {
use crate::context::*;
use crate::log_sink_set::ERROR_LOGGING_MESSAGE;
use crate::testutil::{ErrorSink, MockSink, RecordingSink};
use crate::{ChannelBuilder, FoxgloveError};
use crate::{PartialMetadata, RawChannel, Schema, nanoseconds_since_epoch};
use std::sync::Arc;
use tracing_test::traced_test;
fn new_test_channel_builder(ctx: &Arc<Context>, topic: &str) -> ChannelBuilder {
ChannelBuilder::new(topic)
.context(ctx)
.message_encoding("message_encoding")
.schema(Schema::new(
"name",
"encoding",
br#"{
"type": "object",
"properties": {
"msg": {"type": "string"},
"count": {"type": "number"},
},
}"#,
))
.metadata(maplit::btreemap! {"key".to_string() => "value".to_string()})
}
fn new_test_channel(ctx: &Arc<Context>, topic: &str) -> Result<Arc<RawChannel>, FoxgloveError> {
new_test_channel_builder(ctx, topic).build_raw()
}
#[test]
fn test_add_and_remove_sink() {
let ctx = Context::new();
let sink = Arc::new(MockSink::default());
let sink2 = Arc::new(MockSink::default());
let sink3 = Arc::new(MockSink::default());
assert!(ctx.add_sink(sink.clone()));
assert!(!ctx.add_sink(sink.clone()));
assert!(ctx.add_sink(sink2.clone()));
assert!(ctx.remove_sink(sink.id()));
assert!(!ctx.remove_sink(sink3.id()));
assert!(ctx.remove_sink(sink2.id()));
}
#[traced_test]
#[test]
fn test_log_calls_sinks() {
let ctx = Context::new();
let sink1 = Arc::new(RecordingSink::new());
let sink2 = Arc::new(RecordingSink::new());
assert!(ctx.add_sink(sink1.clone()));
assert!(ctx.add_sink(sink2.clone()));
let channel = new_test_channel(&ctx, "topic").unwrap();
let msg = b"test_message";
let now = nanoseconds_since_epoch();
channel.log(msg);
assert!(!logs_contain(ERROR_LOGGING_MESSAGE));
let messages1 = sink1.take_messages();
let messages2 = sink2.take_messages();
assert_eq!(messages1.len(), 1);
assert_eq!(messages2.len(), 1);
assert_eq!(messages1[0].channel_id, channel.id());
assert_eq!(messages1[0].msg, msg.to_vec());
let metadata1 = &messages1[0].metadata;
assert!(metadata1.log_time >= now);
assert_eq!(messages2[0].channel_id, channel.id());
assert_eq!(messages2[0].msg, msg.to_vec());
let metadata2 = &messages2[0].metadata;
assert!(metadata2.log_time >= now);
}
#[traced_test]
#[test]
fn test_log_calls_other_sinks_after_error() {
let ctx = Context::new();
let error_sink = Arc::new(ErrorSink::default());
let recording_sink = Arc::new(RecordingSink::new());
assert!(ctx.add_sink(error_sink.clone()));
assert!(!ctx.add_sink(error_sink.clone()));
assert!(ctx.add_sink(recording_sink.clone()));
let channel = new_test_channel(&ctx, "topic").unwrap();
let msg = b"test_message";
let opts = PartialMetadata {
log_time: Some(nanoseconds_since_epoch()),
};
channel.log_with_meta(msg, opts);
assert!(logs_contain(ERROR_LOGGING_MESSAGE));
assert!(logs_contain("ErrorSink always fails"));
let messages = recording_sink.take_messages();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].channel_id, channel.id());
assert_eq!(messages[0].msg, msg.to_vec());
let metadata = &messages[0].metadata;
assert_eq!(metadata.log_time, opts.log_time.unwrap());
}
#[traced_test]
#[test]
fn test_log_msg_no_sinks() {
let ctx = Context::new();
let channel = new_test_channel(&ctx, "topic").unwrap();
let msg = b"test_message";
channel.log(msg);
assert!(!logs_contain(ERROR_LOGGING_MESSAGE));
}
#[test]
fn test_remove_channel() {
let ctx = Context::new();
let ch = new_test_channel(&ctx, "topic").unwrap();
assert!(ctx.remove_channel(ch.id()));
assert!(ctx.0.read().channels.is_empty());
}
#[test]
fn test_auto_subscribe() {
let ctx = Context::new();
let c1 = new_test_channel(&ctx, "t1").unwrap();
let c2 = new_test_channel(&ctx, "t2").unwrap();
let sink = Arc::new(RecordingSink::new().auto_subscribe(true));
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
ctx.add_sink(sink.clone());
assert!(c1.has_sinks());
assert!(c2.has_sinks());
assert!(ctx.remove_channel(c1.id()));
assert!(!c1.has_sinks());
assert!(c2.has_sinks());
ctx.add_channel(c1.clone());
assert!(c1.has_sinks());
assert!(c2.has_sinks());
ctx.remove_sink(sink.id());
assert!(!c1.has_sinks());
}
#[test]
fn test_no_auto_subscribe() {
let ctx = Context::new();
let c1 = new_test_channel(&ctx, "t1").unwrap();
let c2 = new_test_channel(&ctx, "t2").unwrap();
let sink = Arc::new(RecordingSink::new().auto_subscribe(false));
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
ctx.add_sink(sink.clone());
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
assert!(ctx.remove_channel(c1.id()));
ctx.add_channel(c1.clone());
assert!(!c1.has_sinks());
ctx.subscribe_channels(sink.id(), &[c1.id()]);
assert!(c1.has_sinks());
assert!(!c2.has_sinks());
ctx.subscribe_channels(sink.id(), &[c2.id()]);
assert!(c1.has_sinks());
assert!(c2.has_sinks());
assert!(ctx.remove_channel(c1.id()));
assert!(!c1.has_sinks());
assert!(c2.has_sinks());
ctx.add_channel(c1.clone());
assert!(!c1.has_sinks());
assert!(c2.has_sinks());
ctx.subscribe_channels(sink.id(), &[c1.id()]);
assert!(c1.has_sinks());
assert!(c2.has_sinks());
ctx.unsubscribe_channels(sink.id(), &[c1.id()]);
assert!(!c1.has_sinks());
assert!(c2.has_sinks());
ctx.subscribe_channels(sink.id(), &[c1.id(), c2.id()]);
assert!(c1.has_sinks());
assert!(c2.has_sinks());
ctx.remove_sink(sink.id());
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
}
#[test]
fn test_sink_subscribe_on_channel_add() {
let ctx = Context::new();
let s1 = Arc::new(
RecordingSink::new()
.auto_subscribe(false)
.add_channels(|channels| {
Some(
channels
.iter()
.filter_map(|c| {
if c.topic() == "t1" {
Some(c.id())
} else {
None
}
})
.collect(),
)
}),
);
ctx.add_sink(s1.clone());
let c1 = new_test_channel(&ctx, "t1").unwrap();
let c2 = new_test_channel(&ctx, "t2").unwrap();
assert!(c1.has_sinks());
assert!(!c2.has_sinks());
ctx.remove_sink(s1.id());
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
ctx.add_sink(s1.clone());
assert!(c1.has_sinks());
assert!(!c2.has_sinks());
ctx.remove_sink(s1.id());
assert!(!c1.has_sinks());
let s2 = Arc::new(
RecordingSink::new()
.auto_subscribe(false)
.add_channels(|_| None),
);
ctx.add_sink(s2.clone());
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
assert!(ctx.remove_channel(c1.id()));
assert!(ctx.remove_channel(c2.id()));
ctx.add_channel(c1.clone());
ctx.add_channel(c2.clone());
assert!(!c1.has_sinks());
assert!(!c2.has_sinks());
}
#[test]
fn test_no_add_channels_cb() {
let ctx = Context::new();
let s1 = Arc::new(RecordingSink::new().add_channels(|_| unreachable!("no channels!")));
ctx.add_sink(s1.clone());
}
#[test]
fn test_supports_multiple_channels_with_same_topic() {
let ctx = Context::new();
let c1 = new_test_channel(&ctx, "topic").unwrap();
let c2 = new_test_channel_builder(&ctx, "topic")
.schema(None)
.build_raw()
.unwrap();
assert_ne!(c1.id(), c2.id());
assert_eq!(c1.topic(), c2.topic());
}
#[test]
#[traced_test]
fn test_get_channel_by_topic_with_duplicate() {
let ctx = Context::new();
let c1 = new_test_channel(&ctx, "dupe").unwrap();
let c2 = new_test_channel_builder(&ctx, "dupe")
.message_encoding("different")
.build_raw()
.unwrap();
assert!(logs_contain(
"Channel with topic dupe already exists in this context"
));
let channel = ctx.get_channel_by_topic("dupe");
assert!(channel.is_some());
assert_eq!(channel.unwrap().id(), c1.id());
assert!(ctx.remove_channel(c1.id()));
let channel = ctx.get_channel_by_topic("dupe");
assert!(channel.is_some());
assert_eq!(channel.unwrap().id(), c2.id());
assert!(ctx.remove_channel(c2.id()));
let channel = ctx.get_channel_by_topic("dupe");
assert!(channel.is_none());
let c3 = new_test_channel(&ctx, "dupe").unwrap();
let channel = ctx.get_channel_by_topic("dupe");
assert!(channel.is_some());
assert_eq!(channel.unwrap().id(), c3.id());
}
#[test]
fn test_add_channel_or_return_matching_channel() {
let ctx = Context::new();
let _ = new_test_channel_builder(&ctx, "dupe")
.message_encoding("different")
.build_raw()
.unwrap();
let _ = new_test_channel_builder(&ctx, "dupe")
.schema(None)
.build_raw()
.unwrap();
let _ = new_test_channel_builder(&ctx, "dupe")
.metadata(maplit::btreemap! {"it's".into() => "different".into()})
.build_raw()
.unwrap();
let c1 = new_test_channel(&ctx, "dupe").unwrap();
assert_eq!(ctx.0.read().channels.len(), 4);
let c2 = new_test_channel(&ctx, "dupe").unwrap();
assert_eq!(c1.id(), c2.id());
assert_eq!(Arc::as_ptr(&c1), Arc::as_ptr(&c2));
assert_eq!(ctx.0.read().channels.len(), 4);
assert!(ctx.remove_channel(c1.id()));
assert_eq!(ctx.0.read().channels.len(), 3);
let _ = new_test_channel(&ctx, "dupe").unwrap();
assert_eq!(ctx.0.read().channels.len(), 4);
}
}