use std::collections::HashMap;
use std::sync::Arc;
use livekit::id::{ParticipantIdentity, TrackSid};
use smallvec::SmallVec;
use tracing::{debug, info};
use crate::protocol::v2::server::advertise;
use crate::remote_access::channel_subscription::ChannelSubscription;
use crate::remote_access::participant::Participant;
use crate::remote_access::session::{VideoInputSchema, VideoPublisher};
use crate::{ChannelId, RawChannel};
pub(crate) struct SessionState {
participants: HashMap<ParticipantIdentity, Arc<Participant>>,
channels: HashMap<ChannelId, Arc<RawChannel>>,
subscriptions: HashMap<ChannelId, ChannelSubscription>,
video_schemas: HashMap<ChannelId, VideoInputSchema>,
video_publishers: HashMap<ChannelId, Arc<VideoPublisher>>,
video_track_sids: HashMap<ChannelId, TrackSid>,
}
impl SessionState {
pub fn new() -> Self {
Self {
participants: HashMap::new(),
channels: HashMap::new(),
subscriptions: HashMap::new(),
video_schemas: HashMap::new(),
video_publishers: HashMap::new(),
video_track_sids: HashMap::new(),
}
}
pub fn insert_participant(
&mut self,
identity: ParticipantIdentity,
participant: Arc<Participant>,
) -> Arc<Participant> {
use std::collections::hash_map::Entry;
match self.participants.entry(identity) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(v) => {
v.insert(participant.clone());
participant
}
}
}
pub fn remove_participant(
&mut self,
identity: &ParticipantIdentity,
) -> SmallVec<[ChannelId; 4]> {
if self.participants.remove(identity).is_none() {
return SmallVec::new();
}
info!("removed participant {identity:?}");
let mut last_unsubscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
self.subscriptions.retain(|&channel_id, sub| {
sub.remove(identity);
if sub.is_empty() {
last_unsubscribed.push(channel_id);
false
} else {
true
}
});
last_unsubscribed
}
pub fn get_participant(&self, identity: &ParticipantIdentity) -> Option<Arc<Participant>> {
self.participants.get(identity).cloned()
}
pub fn collect_participants(&self) -> SmallVec<[Arc<Participant>; 8]> {
self.participants.values().cloned().collect()
}
pub fn insert_channel(&mut self, channel: &Arc<RawChannel>) {
self.channels.insert(channel.id(), channel.clone());
}
pub fn remove_channel(&mut self, channel_id: ChannelId) -> bool {
self.channels.remove(&channel_id).is_some()
}
pub fn with_channels<R>(
&self,
f: impl FnOnce(&HashMap<ChannelId, Arc<RawChannel>>) -> R,
) -> Option<R> {
if self.channels.is_empty() {
return None;
}
Some(f(&self.channels))
}
pub fn insert_video_schema(&mut self, channel_id: ChannelId, schema: VideoInputSchema) {
self.video_schemas.insert(channel_id, schema);
}
pub fn get_video_schema(&self, channel_id: &ChannelId) -> Option<VideoInputSchema> {
self.video_schemas.get(channel_id).copied()
}
pub fn remove_video_schema(&mut self, channel_id: &ChannelId) {
self.video_schemas.remove(channel_id);
}
pub fn insert_video_publisher(
&mut self,
channel_id: ChannelId,
publisher: Arc<VideoPublisher>,
) {
self.video_publishers.insert(channel_id, publisher);
}
pub fn get_video_publisher(&self, channel_id: &ChannelId) -> Option<Arc<VideoPublisher>> {
self.video_publishers.get(channel_id).cloned()
}
pub fn remove_video_publisher(&mut self, channel_id: &ChannelId) {
self.video_publishers.remove(channel_id);
}
pub fn insert_video_track_sid(&mut self, channel_id: ChannelId, sid: TrackSid) {
self.video_track_sids.insert(channel_id, sid);
}
pub fn remove_video_track_sid(&mut self, channel_id: &ChannelId) -> Option<TrackSid> {
self.video_track_sids.remove(channel_id)
}
pub fn inject_video_track_metadata(&self, advertise: &mut advertise::Advertise<'_>) {
for ch in &mut advertise.channels {
if self.video_schemas.contains_key(&ChannelId::new(ch.id)) {
ch.metadata
.insert("foxglove.hasVideoTrack".to_string(), "true".to_string());
}
}
}
pub fn subscribe(
&mut self,
participant: &Participant,
channel_ids: &[ChannelId],
) -> SmallVec<[ChannelId; 4]> {
let mut first_subscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
for &channel_id in channel_ids {
let sub = self
.subscriptions
.entry(channel_id)
.or_insert_with(ChannelSubscription::new);
if sub.subscribers().contains(participant.identity()) {
info!("{participant} is already subscribed to channel {channel_id:?}; ignoring",);
continue;
}
let is_first = sub.is_empty();
sub.add(participant.identity().clone());
debug!("{participant} subscribed to channel {channel_id:?}");
if is_first {
first_subscribed.push(channel_id);
}
}
first_subscribed
}
pub fn unsubscribe(
&mut self,
participant: &Participant,
channel_ids: &[ChannelId],
) -> SmallVec<[ChannelId; 4]> {
let mut last_unsubscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
for &channel_id in channel_ids {
let Some(sub) = self.subscriptions.get_mut(&channel_id) else {
info!("{participant} is not subscribed to channel {channel_id:?}; ignoring",);
continue;
};
if !sub.remove(participant.identity()) {
info!("{participant} is not subscribed to channel {channel_id:?}; ignoring",);
continue;
}
debug!("{participant} unsubscribed from channel {channel_id:?}");
if sub.is_empty() {
self.subscriptions.remove(&channel_id);
last_unsubscribed.push(channel_id);
}
}
last_unsubscribed
}
pub fn get_subscription(&self, channel_id: &ChannelId) -> Option<&ChannelSubscription> {
self.subscriptions.get(channel_id)
}
#[cfg(test)]
pub fn get_subscriber_count(&self, channel_id: &ChannelId) -> usize {
self.subscriptions
.get(channel_id)
.map_or(0, |s| s.subscribers().len())
}
#[cfg(test)]
pub fn channel_count(&self) -> usize {
self.channels.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::remote_access::participant::ParticipantWriter;
fn make_participant(name: &str) -> (ParticipantIdentity, Arc<Participant>) {
let identity = ParticipantIdentity(name.to_string());
let writer = Arc::new(crate::remote_access::participant::TestByteStreamWriter::default());
let participant = Arc::new(Participant::new(
identity.clone(),
ParticipantWriter::Test(writer),
));
(identity, participant)
}
fn make_channel(topic: &str) -> Arc<RawChannel> {
use crate::{ChannelBuilder, Context, Schema};
let ctx = Context::new();
ChannelBuilder::new(topic)
.context(&ctx)
.message_encoding("json")
.schema(Schema::new("S", "jsonschema", b"{}"))
.build_raw()
.unwrap()
}
#[test]
fn insert_new_participant() {
let mut state = SessionState::new();
let (id, p) = make_participant("alice");
let stored = state.insert_participant(id.clone(), p);
assert_eq!(stored.identity(), &id);
assert!(Arc::ptr_eq(&stored, &state.get_participant(&id).unwrap()));
}
#[test]
fn insert_duplicate_returns_existing() {
let mut state = SessionState::new();
let (id, p1) = make_participant("alice");
let stored1 = state.insert_participant(id.clone(), p1);
let (_id2, p2) = make_participant("bob");
let stored2 = state.insert_participant(id, p2);
assert!(Arc::ptr_eq(&stored1, &stored2));
}
#[test]
fn get_participant_returns_existing() {
let mut state = SessionState::new();
let (id, p) = make_participant("alice");
state.insert_participant(id.clone(), p);
assert!(state.get_participant(&id).is_some());
}
#[test]
fn get_participant_returns_none_for_missing() {
let state = SessionState::new();
let id = ParticipantIdentity("nobody".to_string());
assert!(state.get_participant(&id).is_none());
}
#[test]
fn remove_missing_participant_is_noop() {
let mut state = SessionState::new();
let id = ParticipantIdentity("nobody".to_string());
let last = state.remove_participant(&id);
assert!(last.is_empty());
}
#[test]
fn remove_participant_cleans_up_subscriptions() {
let mut state = SessionState::new();
let (id, p) = make_participant("alice");
state.insert_participant(id.clone(), p.clone());
let ch = ChannelId::new(1);
state.subscribe(&p, &[ch]);
let last = state.remove_participant(&id);
assert_eq!(last.as_slice(), &[ch]);
assert_eq!(state.get_subscriber_count(&ch), 0);
}
#[test]
fn remove_participant_reports_only_last_unsubscribed_channels() {
let mut state = SessionState::new();
let (id_a, pa) = make_participant("alice");
let (id_b, pb) = make_participant("bob");
state.insert_participant(id_a.clone(), pa.clone());
state.insert_participant(id_b.clone(), pb.clone());
let ch1 = ChannelId::new(10);
let ch2 = ChannelId::new(20);
state.subscribe(&pa, &[ch1, ch2]);
state.subscribe(&pb, &[ch1]);
let last = state.remove_participant(&id_a);
assert_eq!(last.as_slice(), &[ch2]);
assert_eq!(state.get_subscriber_count(&ch1), 1);
}
#[test]
fn insert_and_query_channel() {
let mut state = SessionState::new();
let ch = make_channel("/topic1");
state.insert_channel(&ch);
assert_eq!(state.channel_count(), 1);
}
#[test]
fn remove_channel_returns_true_when_present() {
let mut state = SessionState::new();
let ch = make_channel("/topic1");
state.insert_channel(&ch);
assert!(state.remove_channel(ch.id()));
}
#[test]
fn remove_channel_returns_false_when_absent() {
let mut state = SessionState::new();
assert!(!state.remove_channel(ChannelId::new(999)));
}
#[test]
fn first_subscriber_is_reported() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch = ChannelId::new(1);
let first = state.subscribe(&p, &[ch]);
assert_eq!(first.as_slice(), &[ch]);
}
#[test]
fn second_subscriber_is_not_reported_as_first() {
let mut state = SessionState::new();
let (_id_a, pa) = make_participant("alice");
let (_id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.subscribe(&pa, &[ch]);
let first = state.subscribe(&pb, &[ch]);
assert!(first.is_empty());
}
#[test]
fn duplicate_subscribe_is_idempotent() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch = ChannelId::new(1);
state.subscribe(&p, &[ch]);
let first = state.subscribe(&p, &[ch]);
assert!(first.is_empty());
assert_eq!(state.get_subscriber_count(&ch), 1);
}
#[test]
fn subscribe_multiple_channels_at_once() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch1 = ChannelId::new(1);
let ch2 = ChannelId::new(2);
let first = state.subscribe(&p, &[ch1, ch2]);
assert_eq!(first.len(), 2);
assert!(first.contains(&ch1));
assert!(first.contains(&ch2));
}
#[test]
fn last_unsubscriber_is_reported() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch = ChannelId::new(1);
state.subscribe(&p, &[ch]);
let last = state.unsubscribe(&p, &[ch]);
assert_eq!(last.as_slice(), &[ch]);
}
#[test]
fn unsubscribe_with_remaining_subscribers_is_not_reported() {
let mut state = SessionState::new();
let (_id_a, pa) = make_participant("alice");
let (_id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.subscribe(&pa, &[ch]);
state.subscribe(&pb, &[ch]);
let last = state.unsubscribe(&pa, &[ch]);
assert!(last.is_empty());
assert_eq!(state.get_subscriber_count(&ch), 1);
}
#[test]
fn unsubscribe_when_not_subscribed_is_noop() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch = ChannelId::new(1);
let last = state.unsubscribe(&p, &[ch]);
assert!(last.is_empty());
}
#[test]
fn unsubscribe_multiple_channels_at_once() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch1 = ChannelId::new(1);
let ch2 = ChannelId::new(2);
state.subscribe(&p, &[ch1, ch2]);
let last = state.unsubscribe(&p, &[ch1, ch2]);
assert_eq!(last.len(), 2);
assert!(last.contains(&ch1));
assert!(last.contains(&ch2));
}
#[test]
fn get_subscription_returns_none_for_no_subscriptions() {
let state = SessionState::new();
assert!(state.get_subscription(&ChannelId::new(1)).is_none());
}
#[test]
fn get_subscription_returns_subscriber_identities() {
let mut state = SessionState::new();
let (id_a, pa) = make_participant("alice");
let (id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.subscribe(&pa, &[ch]);
state.subscribe(&pb, &[ch]);
let sub = state.get_subscription(&ch).unwrap();
assert_eq!(sub.subscribers().len(), 2);
assert!(sub.subscribers().contains(&id_a));
assert!(sub.subscribers().contains(&id_b));
}
#[test]
fn subscription_version_increments_on_subscribe() {
let mut state = SessionState::new();
let (_id_a, pa) = make_participant("alice");
let (_id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.subscribe(&pa, &[ch]);
let v1 = state.get_subscription(&ch).unwrap().version();
state.subscribe(&pb, &[ch]);
let v2 = state.get_subscription(&ch).unwrap().version();
assert_ne!(v1, v2);
}
#[test]
fn subscription_version_does_not_increment_on_duplicate_subscribe() {
let mut state = SessionState::new();
let (_id, p) = make_participant("alice");
let ch = ChannelId::new(1);
state.subscribe(&p, &[ch]);
let v1 = state.get_subscription(&ch).unwrap().version();
state.subscribe(&p, &[ch]);
let v2 = state.get_subscription(&ch).unwrap().version();
assert_eq!(v1, v2);
}
#[test]
fn subscription_version_increments_on_unsubscribe() {
let mut state = SessionState::new();
let (_id_a, pa) = make_participant("alice");
let (_id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.subscribe(&pa, &[ch]);
state.subscribe(&pb, &[ch]);
let v1 = state.get_subscription(&ch).unwrap().version();
state.unsubscribe(&pa, &[ch]);
let v2 = state.get_subscription(&ch).unwrap().version();
assert_ne!(v1, v2);
}
#[test]
fn subscription_version_increments_on_remove_participant() {
let mut state = SessionState::new();
let (id_a, pa) = make_participant("alice");
let (id_b, pb) = make_participant("bob");
let ch = ChannelId::new(1);
state.insert_participant(id_a.clone(), pa.clone());
state.insert_participant(id_b, pb.clone());
state.subscribe(&pa, &[ch]);
state.subscribe(&pb, &[ch]);
let v1 = state.get_subscription(&ch).unwrap().version();
state.remove_participant(&id_a);
let v2 = state.get_subscription(&ch).unwrap().version();
assert_ne!(v1, v2);
}
#[test]
fn collect_participants_yields_all() {
let mut state = SessionState::new();
let (id_a, pa) = make_participant("alice");
let (id_b, pb) = make_participant("bob");
state.insert_participant(id_a, pa);
state.insert_participant(id_b, pb);
assert_eq!(state.collect_participants().len(), 2);
}
}