use std::{
collections::HashMap,
sync::{Arc, Weak},
time::Duration,
};
use bytes::Bytes;
use futures_util::StreamExt;
use livekit::{id::ParticipantIdentity, ByteStreamReader, Room, StreamByteOptions};
use parking_lot::RwLock;
use smallvec::SmallVec;
use tokio::io::AsyncReadExt;
use tokio_util::{io::StreamReader, sync::CancellationToken};
use tracing::{debug, error, info, warn};
use crate::{
protocol::v2::{
client::{self, ClientMessage},
server::{advertise, MessageData as ServerMessageData, ServerInfo, Unadvertise},
BinaryMessage, JsonMessage,
},
remote_access::{participant::Participant, RemoteAccessError},
ChannelId, Context, FoxgloveError, Metadata, RawChannel, Sink, SinkChannelFilter, SinkId,
};
const WS_PROTOCOL_TOPIC: &str = "ws-protocol";
const MESSAGE_FRAME_SIZE: usize = 5; const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; const MAX_SEND_RETRIES: usize = 3;
struct ChannelMessage {
channel_id: ChannelId,
data: Bytes,
}
struct ControlPlaneMessage {
participant: Arc<Participant>,
data: Bytes,
}
#[derive(Clone, Copy, Debug)]
#[repr(u8)]
enum OpCode {
Text = 1,
Binary = 2,
}
fn frame_text_message(payload: &[u8]) -> Bytes {
let mut buf = Vec::with_capacity(MESSAGE_FRAME_SIZE + payload.len());
buf.push(OpCode::Text as u8);
let len = u32::try_from(payload.len()).expect("message too large");
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(payload);
Bytes::from(buf)
}
fn encode_binary_message<'a>(message: &impl BinaryMessage<'a>) -> Bytes {
let msg_len = message.encoded_len();
let mut buf = Vec::with_capacity(MESSAGE_FRAME_SIZE + msg_len);
buf.push(OpCode::Binary as u8);
buf.extend_from_slice(
&u32::try_from(msg_len)
.expect("message too large")
.to_le_bytes(),
);
message.encode(&mut buf);
Bytes::from(buf)
}
pub(crate) struct RemoteAccessSession {
sink_id: SinkId,
room: Room,
context: Weak<Context>,
participants: RwLock<HashMap<ParticipantIdentity, Arc<Participant>>>,
channels: RwLock<HashMap<ChannelId, Arc<RawChannel>>>,
subscriptions: RwLock<HashMap<ChannelId, SmallVec<[ParticipantIdentity; 1]>>>,
channel_filter: Option<Arc<dyn SinkChannelFilter>>,
cancellation_token: CancellationToken,
data_plane_tx: flume::Sender<ChannelMessage>,
data_plane_rx: flume::Receiver<ChannelMessage>,
control_plane_tx: flume::Sender<ControlPlaneMessage>,
control_plane_rx: flume::Receiver<ControlPlaneMessage>,
}
impl Sink for RemoteAccessSession {
fn id(&self) -> SinkId {
self.sink_id
}
fn log(
&self,
channel: &RawChannel,
msg: &[u8],
metadata: &Metadata,
) -> std::result::Result<(), FoxgloveError> {
let channel_id = channel.id();
let message = ServerMessageData::new(u64::from(channel_id), metadata.log_time, msg);
let data = encode_binary_message(&message);
self.send_data_lossy(ChannelMessage { channel_id, data });
Ok(())
}
fn add_channels(&self, channels: &[&Arc<RawChannel>]) -> Option<Vec<ChannelId>> {
let filtered: Vec<_> = channels
.iter()
.filter(|ch| {
let Some(filter) = self.channel_filter.as_ref() else {
return true;
};
filter.should_subscribe(ch.descriptor())
})
.copied()
.collect();
if filtered.is_empty() {
return None;
}
let advertise_msg = advertise::advertise_channels(filtered.iter().copied());
if advertise_msg.channels.is_empty() {
return None;
}
let advertised_ids: std::collections::HashSet<u64> =
advertise_msg.channels.iter().map(|ch| ch.id).collect();
{
let mut advertised_channels = self.channels.write();
for &ch in &filtered {
if advertised_ids.contains(&u64::from(ch.id())) {
advertised_channels.insert(ch.id(), ch.clone());
}
}
}
let framed = frame_text_message(advertise_msg.to_string().as_bytes());
self.broadcast_control(framed);
None
}
fn remove_channel(&self, channel: &RawChannel) {
let channel_id = channel.id();
if self.channels.write().remove(&channel_id).is_none() {
return;
}
let unadvertise = Unadvertise::new([u64::from(channel_id)]);
let framed = frame_text_message(unadvertise.to_string().as_bytes());
self.broadcast_control(framed);
}
fn auto_subscribe(&self) -> bool {
false
}
}
impl RemoteAccessSession {
pub(crate) fn new(
room: Room,
context: Weak<Context>,
channel_filter: Option<Arc<dyn SinkChannelFilter>>,
cancellation_token: CancellationToken,
message_backlog_size: usize,
) -> Self {
let (data_plane_tx, data_plane_rx) = flume::bounded(message_backlog_size);
let (control_plane_tx, control_plane_rx) = flume::bounded(message_backlog_size);
Self {
sink_id: SinkId::next(),
room,
context,
participants: RwLock::new(HashMap::new()),
channels: RwLock::new(HashMap::new()),
subscriptions: RwLock::new(HashMap::new()),
channel_filter,
cancellation_token,
data_plane_tx,
data_plane_rx,
control_plane_tx,
control_plane_rx,
}
}
pub(crate) fn sink_id(&self) -> SinkId {
self.sink_id
}
pub(crate) fn room(&self) -> &Room {
&self.room
}
fn send_data_lossy(&self, mut msg: ChannelMessage) {
static THROTTLER: parking_lot::Mutex<crate::throttler::Throttler> =
parking_lot::Mutex::new(crate::throttler::Throttler::new(Duration::from_secs(30)));
let mut dropped = 0;
loop {
match self.data_plane_tx.try_send(msg) {
Ok(_) => {
if dropped > 0 && THROTTLER.lock().try_acquire() {
info!("data plane queue full, dropped {dropped} message(s)");
}
return;
}
Err(flume::TrySendError::Disconnected(_)) => return,
Err(flume::TrySendError::Full(rejected)) => {
if dropped >= MAX_SEND_RETRIES {
if THROTTLER.lock().try_acquire() {
info!("data plane queue full, dropped message");
}
return;
}
msg = rejected;
let _ = self.data_plane_rx.try_recv();
dropped += 1;
}
}
}
}
fn send_control(&self, participant: Arc<Participant>, data: Bytes) {
let msg = ControlPlaneMessage { participant, data };
if let Err(e) = self.control_plane_tx.send(msg) {
warn!("control plane queue disconnected, dropping message: {e}");
}
}
fn broadcast_control(&self, data: Bytes) {
let participants = self.participants.read();
for participant in participants.values() {
self.send_control(participant.clone(), data.clone());
}
}
pub(crate) async fn run_sender(session: Arc<Self>) {
loop {
tokio::select! {
biased;
() = session.cancellation_token.cancelled() => break,
msg = session.control_plane_rx.recv_async() => {
let Ok(msg) = msg else { break };
if let Err(e) = msg.participant.send(&msg.data).await {
error!("failed to send control message to {:?}: {e:?}", msg.participant);
}
}
msg = session.data_plane_rx.recv_async() => {
let Ok(msg) = msg else { break };
let subscriber_ids: SmallVec<[ParticipantIdentity; 8]> = {
let subscriptions = session.subscriptions.read();
match subscriptions.get(&msg.channel_id) {
Some(ids) => ids.iter().cloned().collect(),
None => continue,
}
};
let participants: SmallVec<[Arc<Participant>; 8]> = {
let participants = session.participants.read();
subscriber_ids
.iter()
.filter_map(|id| participants.get(id).cloned())
.collect()
};
for participant in &participants {
if let Err(e) = participant.send(&msg.data).await {
error!("failed to send message data to {participant:?}: {e:?}");
}
}
}
}
}
}
pub(crate) async fn handle_byte_stream_from_client(
self: &Arc<Self>,
participant_identity: ParticipantIdentity,
reader: ByteStreamReader,
) {
let stream = reader.map(|result| result.map_err(std::io::Error::other));
let mut reader = StreamReader::new(stream);
loop {
let mut header = [0u8; MESSAGE_FRAME_SIZE];
let read_result = tokio::select! {
() = self.cancellation_token.cancelled() => break,
result = reader.read_exact(&mut header) => result,
};
match read_result {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => {
error!(
"Error reading from byte stream for client {:?}: {:?}",
participant_identity, e
);
break;
}
}
let opcode = header[0];
let length =
u32::from_le_bytes(header[1..MESSAGE_FRAME_SIZE].try_into().unwrap()) as usize;
if length > MAX_MESSAGE_SIZE {
error!(
"message too large ({length} bytes) from client {:?}, disconnecting",
participant_identity
);
return;
}
let mut payload = vec![0u8; length];
let read_result = tokio::select! {
() = self.cancellation_token.cancelled() => break,
result = reader.read_exact(&mut payload) => result,
};
match read_result {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => {
error!(
"Error reading from byte stream for client {:?}: {:?}",
participant_identity, e
);
break;
}
}
if !self.handle_client_message(&participant_identity, opcode, Bytes::from(payload)) {
return;
}
}
}
fn handle_client_message(
self: &Arc<Self>,
participant_identity: &ParticipantIdentity,
opcode: u8,
payload: Bytes,
) -> bool {
const TEXT: u8 = OpCode::Text as u8;
const BINARY: u8 = OpCode::Binary as u8;
let client_msg = match opcode {
TEXT => match std::str::from_utf8(&payload) {
Ok(text) => ClientMessage::parse_json(text),
Err(e) => {
error!("Invalid UTF-8 in text message: {e:?}");
return true;
}
},
BINARY => ClientMessage::parse_binary(&payload[..]),
_ => {
error!("Unrecognized message opcode ({opcode}) received, you likely need to upgrade to a newer version of the Foxglove SDK");
return false;
}
};
let client_msg = match client_msg {
Ok(msg) => msg,
Err(e) => {
error!("failed to parse client message: {e:?}");
return true;
}
};
let Some(participant) = ({
let participants = self.participants.read();
participants.get(participant_identity).cloned()
}) else {
error!("Unknown participant identity: {:?}", participant_identity);
return false;
};
match client_msg {
ClientMessage::Subscribe(msg) => {
self.handle_client_subscribe(&participant, msg);
}
ClientMessage::Unsubscribe(msg) => {
self.handle_client_unsubscribe(&participant, msg);
}
_ => {
warn!("Unhandled client message: {client_msg:?}");
}
}
true
}
fn handle_client_subscribe(&self, participant: &Participant, msg: client::Subscribe) {
let channel_ids: Vec<ChannelId> = msg
.channel_ids
.iter()
.map(|&id| ChannelId::new(id))
.collect();
let mut first_subscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
{
let mut subscriptions = self.subscriptions.write();
for &channel_id in &channel_ids {
let subscribers = subscriptions.entry(channel_id).or_default();
if subscribers.contains(participant.identity()) {
info!(
"{participant} is already subscribed to channel {channel_id:?}; ignoring",
);
continue;
}
let is_first = subscribers.is_empty();
subscribers.push(participant.identity().clone());
debug!("{participant} subscribed to channel {channel_id:?}",);
if is_first {
first_subscribed.push(channel_id);
}
}
}
if !first_subscribed.is_empty() {
if let Some(context) = self.context.upgrade() {
context.subscribe_channels(self.sink_id, &first_subscribed);
}
}
}
fn handle_client_unsubscribe(&self, participant: &Participant, msg: client::Unsubscribe) {
let channel_ids: Vec<ChannelId> = msg
.channel_ids
.iter()
.map(|&id| ChannelId::new(id))
.collect();
let mut last_unsubscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
{
let mut subscriptions = self.subscriptions.write();
for &channel_id in &channel_ids {
let Some(subscribers) = subscriptions.get_mut(&channel_id) else {
info!("{participant} is not subscribed to channel {channel_id:?}; ignoring",);
continue;
};
let Some(pos) = subscribers
.iter()
.position(|id| id == participant.identity())
else {
info!("{participant} is not subscribed to channel {channel_id:?}; ignoring",);
continue;
};
subscribers.swap_remove(pos);
debug!("{participant} unsubscribed from channel {channel_id:?}",);
if subscribers.is_empty() {
subscriptions.remove(&channel_id);
last_unsubscribed.push(channel_id);
}
}
}
if !last_unsubscribed.is_empty() {
if let Some(context) = self.context.upgrade() {
context.unsubscribe_channels(self.sink_id, &last_unsubscribed);
}
}
}
pub(crate) async fn add_participant(
&self,
participant_id: ParticipantIdentity,
) -> Result<Arc<Participant>, RemoteAccessError> {
use crate::remote_access::participant::ParticipantWriter;
{
if let Some(existing_participant) = self.participants.read().get(&participant_id) {
return Ok(existing_participant.clone());
}
}
let stream = match self
.room
.local_participant()
.stream_bytes(StreamByteOptions {
topic: WS_PROTOCOL_TOPIC.to_string(),
destination_identities: vec![participant_id.clone()],
..StreamByteOptions::default()
})
.await
{
Ok(stream) => stream,
Err(e) => {
error!("failed to create stream for participant {participant_id}: {e:?}");
return Err(e.into());
}
};
let participant = Arc::new(Participant::new(
participant_id.clone(),
ParticipantWriter::Livekit(stream),
));
self.participants
.write()
.insert(participant_id, participant.clone());
Ok(participant)
}
pub(crate) fn remove_participant(&self, participant_id: &ParticipantIdentity) {
if self.participants.write().remove(participant_id).is_none() {
return;
}
info!("removed participant {participant_id:?}");
let mut last_unsubscribed: SmallVec<[ChannelId; 4]> = SmallVec::new();
{
let mut subscriptions = self.subscriptions.write();
subscriptions.retain(|&channel_id, subscribers| {
subscribers.retain(|id| id != participant_id);
if subscribers.is_empty() {
last_unsubscribed.push(channel_id);
false
} else {
true
}
});
}
if !last_unsubscribed.is_empty() {
if let Some(context) = self.context.upgrade() {
context.unsubscribe_channels(self.sink_id, &last_unsubscribed);
}
}
}
pub(crate) fn send_info_and_advertisements(
&self,
participant: Arc<Participant>,
server_info: ServerInfo,
) {
info!("sending server info and advertisements to participant {participant:?}");
let framed = frame_text_message(server_info.to_string().as_bytes());
self.send_control(participant.clone(), framed);
self.send_channel_advertisements(participant);
}
fn send_channel_advertisements(&self, participant: Arc<Participant>) {
let framed = {
let channels = self.channels.read();
if channels.is_empty() {
return;
}
let advertise_msg = advertise::advertise_channels(channels.values());
if advertise_msg.channels.is_empty() {
return;
}
frame_text_message(advertise_msg.to_string().as_bytes())
};
self.send_control(participant, framed);
}
}