use std::{
cell::RefCell,
collections::{HashMap, HashSet},
rc::Rc,
};
use derive_more::with_trait::From;
use futures::{
FutureExt as _, TryFutureExt as _, future, future::LocalBoxFuture,
stream::LocalBoxStream,
};
use medea_client_api_proto::TrackId;
use medea_reactive::{AllProcessed, Guarded, ProgressableHashMap};
use tracerr::Traced;
use super::sender;
use crate::{
media::LocalTracksConstraints,
peer::UpdateLocalStreamError,
utils::{AsProtoState, SynchronizableState, Updatable},
};
#[derive(Debug, From)]
pub struct TracksRepository<S: 'static>(
RefCell<ProgressableHashMap<TrackId, Rc<S>>>,
);
impl<S> TracksRepository<S> {
#[must_use]
pub fn new() -> Self {
Self(RefCell::new(ProgressableHashMap::new()))
}
pub fn when_all_processed(&self) -> AllProcessed<'static> {
self.0.borrow().when_all_processed()
}
pub fn insert(&self, id: TrackId, track: Rc<S>) {
drop(self.0.borrow_mut().insert(id, track));
}
#[must_use]
pub fn get(&self, id: TrackId) -> Option<Rc<S>> {
self.0.borrow().get(&id).cloned()
}
pub fn ids(&self) -> Vec<TrackId> {
self.0.borrow().iter().map(|(id, _)| *id).collect()
}
pub fn on_insert(
&self,
) -> LocalBoxStream<'static, Guarded<(TrackId, Rc<S>)>> {
self.0.borrow().on_insert_with_replay()
}
pub fn on_remove(
&self,
) -> LocalBoxStream<'static, Guarded<(TrackId, Rc<S>)>> {
self.0.borrow().on_remove()
}
pub fn remove(&self, id: TrackId) -> bool {
self.0.borrow_mut().remove(&id).is_some()
}
}
impl TracksRepository<sender::State> {
#[must_use]
pub fn get_outdated(&self) -> Vec<Rc<sender::State>> {
self.0
.borrow()
.values()
.filter(|s| s.is_local_stream_update_needed())
.cloned()
.collect()
}
pub fn local_stream_update_result(
&self,
tracks_ids: HashSet<TrackId>,
) -> LocalBoxFuture<'static, Result<(), Traced<UpdateLocalStreamError>>>
{
let senders = self.0.borrow();
Box::pin(
future::try_join_all(tracks_ids.into_iter().filter_map(|id| {
Some(
senders
.get(&id)?
.local_stream_update_result()
.map_err(tracerr::map_from_and_wrap!()),
)
}))
.map(|r| r.map(drop)),
)
}
}
impl<S> SynchronizableState for TracksRepository<S>
where
S: SynchronizableState,
{
type Input = HashMap<TrackId, S::Input>;
fn from_proto(
input: Self::Input,
send_constraints: &LocalTracksConstraints,
) -> Self {
Self(RefCell::new(
input
.into_iter()
.map(|(id, t)| {
(id, Rc::new(S::from_proto(t, send_constraints)))
})
.collect(),
))
}
fn apply(&self, input: Self::Input, send_cons: &LocalTracksConstraints) {
self.0.borrow_mut().remove_not_present(&input);
#[expect(clippy::iter_over_hash_type, reason = "order doesn't matter")]
for (id, track) in input {
if let Some(sync_track) = self.0.borrow().get(&id) {
sync_track.apply(track, send_cons);
} else {
drop(
self.0
.borrow_mut()
.insert(id, Rc::new(S::from_proto(track, send_cons))),
);
}
}
}
}
impl<S> Updatable for TracksRepository<S>
where
S: Updatable,
{
fn when_stabilized(&self) -> AllProcessed<'static> {
let when_futs: Vec<_> = self
.0
.borrow()
.values()
.map(|s| s.when_stabilized().into())
.collect();
medea_reactive::when_all_processed(when_futs)
}
fn when_updated(&self) -> AllProcessed<'static> {
let when_futs: Vec<_> =
self.0.borrow().values().map(|s| s.when_updated().into()).collect();
medea_reactive::when_all_processed(when_futs)
}
fn connection_lost(&self) {
self.0.borrow().values().for_each(|s| s.connection_lost());
}
fn connection_recovered(&self) {
self.0.borrow().values().for_each(|s| s.connection_recovered());
}
}
impl<S> AsProtoState for TracksRepository<S>
where
S: AsProtoState,
{
type Output = HashMap<TrackId, S::Output>;
fn as_proto(&self) -> Self::Output {
self.0.borrow().iter().map(|(id, s)| (*id, s.as_proto())).collect()
}
}
#[cfg(feature = "mockable")]
impl<S> TracksRepository<S> {
pub fn when_insert_processed(&self) -> medea_reactive::Processed<'static> {
self.0.borrow().when_insert_processed()
}
}
#[cfg(feature = "mockable")]
#[expect(clippy::multiple_inherent_impl, reason = "feature gated")]
impl TracksRepository<sender::State> {
pub fn synced(&self) {
self.0.borrow().values().for_each(|s| s.synced());
}
}
#[cfg(feature = "mockable")]
impl TracksRepository<super::receiver::State> {
pub fn stabilize_all(&self) {
self.0.borrow().values().for_each(|r| r.stabilize());
}
pub fn synced(&self) {
self.0.borrow().values().for_each(|r| r.synced());
}
}