mod collection;
mod registry;
pub(super) use registry::ParticipantRegistry;
use std::collections::HashSet;
use std::sync::Arc;
use bytes::Bytes;
use livekit::{
ByteStreamWriter, StreamWriter,
id::{ParticipantIdentity, ParticipantSid},
};
use tokio_util::sync::CancellationToken;
use crate::protocol::v2::server::FetchAssetResponse;
use crate::remote_access::RemoteAccessError;
use crate::remote_access::session::encode_binary_message;
use crate::remote_common::ClientId;
use crate::remote_common::semaphore::Semaphore;
type Result<T> = std::result::Result<T, Box<RemoteAccessError>>;
const DEFAULT_SERVICE_CALLS_PER_PARTICIPANT: usize = 32;
const DEFAULT_FETCH_ASSET_PER_PARTICIPANT: usize = 32;
pub(super) struct Participant {
client_id: ClientId,
participant_id: ParticipantIdentity,
participant_sid: ParticipantSid,
joined_at: i64,
control_tx: flume::Sender<Bytes>,
pending_resets: Arc<parking_lot::Mutex<HashSet<ParticipantSid>>>,
reset_notify: Arc<tokio::sync::Notify>,
cancel: CancellationToken,
service_call_sem: Semaphore,
fetch_asset_sem: Semaphore,
}
impl Participant {
#[allow(clippy::too_many_arguments)]
pub fn spawn(
identity: ParticipantIdentity,
participant_sid: ParticipantSid,
joined_at: i64,
writer: ParticipantWriter,
queue_size: usize,
pending_resets: Arc<parking_lot::Mutex<HashSet<ParticipantSid>>>,
reset_notify: Arc<tokio::sync::Notify>,
session_cancel: &CancellationToken,
) -> (Arc<Self>, tokio::task::JoinHandle<()>) {
let (control_tx, control_rx) = flume::bounded::<Bytes>(queue_size);
let cancel = session_cancel.child_token();
let cancel_for_task = cancel.clone();
let client_id = ClientId::next();
let identity_for_task = identity.clone();
let sid_for_task = participant_sid.clone();
let pending_resets_for_task = pending_resets.clone();
let reset_notify_for_task = reset_notify.clone();
let flush_handle = tokio::spawn(async move {
loop {
let data = tokio::select! {
biased;
() = cancel_for_task.cancelled() => break,
msg = control_rx.recv_async() => match msg {
Ok(data) => data,
Err(_) => break,
},
};
let write_result = tokio::select! {
biased;
() = cancel_for_task.cancelled() => break,
result = writer.write(&data) => result,
};
if let Err(e) = write_result {
tracing::warn!(
"control write failed for {:?}, requesting reset: {e:?}",
identity_for_task,
);
pending_resets_for_task.lock().insert(sid_for_task);
reset_notify_for_task.notify_one();
break;
}
}
});
let participant = Arc::new(Self {
client_id,
participant_id: identity,
participant_sid,
joined_at,
control_tx,
pending_resets,
reset_notify,
cancel,
service_call_sem: Semaphore::new(DEFAULT_SERVICE_CALLS_PER_PARTICIPANT),
fetch_asset_sem: Semaphore::new(DEFAULT_FETCH_ASSET_PER_PARTICIPANT),
});
(participant, flush_handle)
}
#[cfg(test)]
pub fn new(
identity: ParticipantIdentity,
participant_sid: ParticipantSid,
control_tx: flume::Sender<Bytes>,
pending_resets: Arc<parking_lot::Mutex<HashSet<ParticipantSid>>>,
reset_notify: Arc<tokio::sync::Notify>,
cancel: CancellationToken,
) -> Self {
Self {
client_id: ClientId::next(),
participant_id: identity,
participant_sid,
joined_at: 0,
control_tx,
pending_resets,
reset_notify,
cancel,
service_call_sem: Semaphore::new(DEFAULT_SERVICE_CALLS_PER_PARTICIPANT),
fetch_asset_sem: Semaphore::new(DEFAULT_FETCH_ASSET_PER_PARTICIPANT),
}
}
pub fn client_id(&self) -> ClientId {
self.client_id
}
pub fn service_call_sem(&self) -> &Semaphore {
&self.service_call_sem
}
pub fn fetch_asset_sem(&self) -> &Semaphore {
&self.fetch_asset_sem
}
pub(super) fn cancel(&self) {
self.cancel.cancel();
}
pub fn participant_id(&self) -> &ParticipantIdentity {
&self.participant_id
}
pub(super) fn participant_sid(&self) -> &ParticipantSid {
&self.participant_sid
}
pub(super) fn joined_at(&self) -> i64 {
self.joined_at
}
#[must_use]
pub(super) fn try_queue_control(&self, data: Bytes) -> bool {
match self.control_tx.try_send(data) {
Ok(()) => true,
Err(flume::TrySendError::Full(_)) => {
tracing::warn!("control queue full for {}", self.participant_id);
false
}
Err(flume::TrySendError::Disconnected(_)) => {
tracing::debug!(
"control queue disconnected for {}, dropping message",
self.participant_id
);
true
}
}
}
pub(super) fn send_control(&self, data: Bytes) {
if !self.try_queue_control(data) {
self.cancel.cancel();
self.pending_resets
.lock()
.insert(self.participant_sid.clone());
self.reset_notify.notify_one();
}
}
pub(super) fn send_asset_response(&self, data: &[u8], request_id: u32) {
self.send_control(encode_binary_message(&FetchAssetResponse::asset_data(
request_id, data,
)));
}
pub(super) fn send_asset_error(&self, error: &str, request_id: u32) {
self.send_control(encode_binary_message(&FetchAssetResponse::error_message(
request_id, error,
)));
}
}
impl std::fmt::Debug for Participant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Participant")
.field("identity", &self.participant_id)
.finish()
}
}
impl std::fmt::Display for Participant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Participant({})", self.participant_id)
}
}
pub(super) enum ParticipantWriter {
Livekit(ByteStreamWriter),
#[allow(dead_code)]
#[cfg(test)]
Test(Arc<TestByteStreamWriter>),
}
impl ParticipantWriter {
async fn write(&self, bytes: &[u8]) -> Result<()> {
match self {
ParticipantWriter::Livekit(stream) => stream.write(bytes).await.map_err(|e| e.into()),
#[cfg(test)]
ParticipantWriter::Test(writer) => {
writer.record(bytes);
Ok(())
}
}
}
}
#[cfg(test)]
pub(super) fn test_sid(label: &str) -> ParticipantSid {
ParticipantSid::try_from(format!("PA_{label}"))
.expect("test_sid label should form a valid ParticipantSid")
}
#[cfg(test)]
#[derive(Default)]
pub(super) struct TestByteStreamWriter {
writes: parking_lot::Mutex<Vec<Bytes>>,
}
#[cfg(test)]
impl TestByteStreamWriter {
fn record(&self, data: &[u8]) {
self.writes.lock().push(Bytes::copy_from_slice(data));
}
#[allow(dead_code)]
pub(super) fn writes(&self) -> Vec<Bytes> {
std::mem::take(&mut self.writes.lock())
}
}