1use std::{
4 cell::{Cell, RefCell},
5 collections::{HashMap, HashSet},
6 rc::{Rc, Weak},
7};
8
9use derive_more::with_trait::{Display, From};
10use futures::{
11 FutureExt as _, StreamExt as _, future, future::LocalBoxFuture,
12 stream::LocalBoxStream,
13};
14use medea_client_api_proto::{
15 self as proto, ConnectionMode, ConnectionQualityScore, MemberId,
16 PeerConnectionState, PeerId, TrackId,
17};
18use tracerr::Traced;
19
20use crate::{
21 api,
22 media::{MediaKind, MediaSourceKind, RecvConstraints, track::remote},
23 peer::{
24 MediaState, MediaStateControllable as _, ProhibitedStateError,
25 TransceiverSide as _, media_exchange_state, receiver,
26 },
27 platform,
28 utils::{Caused, TaskHandle},
29};
30
31#[derive(Caused, Clone, Copy, Debug, Display, From)]
36#[cause(error = platform::Error)]
37pub enum ChangeMediaStateError {
38 #[display("`ConnectionHandle` is in detached state")]
40 Detached,
41
42 #[display(
47 "`MediaState` transits to opposite ({_0}) of the requested \
48 `MediaExchangeState`"
49 )]
50 TransitionIntoOppositeState(MediaState),
51
52 ProhibitedState(ProhibitedStateError),
56}
57
58type ChangeMediaStateResult = Result<(), Traced<ChangeMediaStateError>>;
60
61#[derive(Debug)]
63pub struct Connections {
64 tracks_to_members: RefCell<HashMap<TrackId, HashSet<MemberId>>>,
66
67 members_to_tracks: RefCell<HashMap<MemberId, HashSet<TrackId>>>,
69
70 members_to_conns: RefCell<HashMap<MemberId, Connection>>,
72
73 room_recv_constraints: Rc<RecvConstraints>,
75
76 on_new_connection: platform::Callback<api::ConnectionHandle>,
78}
79
80impl Connections {
81 pub fn new(room_recv_constraints: Rc<RecvConstraints>) -> Self {
83 Self {
84 tracks_to_members: RefCell::default(),
85 members_to_tracks: RefCell::default(),
86 members_to_conns: RefCell::default(),
87 room_recv_constraints,
88 on_new_connection: platform::Callback::default(),
89 }
90 }
91
92 pub fn on_new_connection(
95 &self,
96 f: platform::Function<api::ConnectionHandle>,
97 ) {
98 self.on_new_connection.set_func(f);
99 }
100
101 #[must_use]
109 pub fn update_connections(
110 &self,
111 track_id: &TrackId,
112 partner_members: HashSet<MemberId>,
113 connection_mode: ConnectionMode,
114 ) -> Vec<Connection> {
115 if let Some(partners) =
116 self.tracks_to_members.borrow_mut().get_mut(track_id)
117 {
118 let mut connections = self.members_to_conns.borrow_mut();
119 let mut members_to_tracks = self.members_to_tracks.borrow_mut();
120
121 if partners == &partner_members {
123 return partners
124 .iter()
125 .filter_map(|partner| {
126 _ = members_to_tracks
127 .get_mut(partner)
128 .map(|tracks| tracks.insert(*track_id));
129 connections.get(partner).cloned()
130 })
131 .collect();
132 }
133
134 let added: Vec<_> =
136 partner_members.difference(partners).cloned().collect();
137 for mid in added {
138 _ = members_to_tracks
139 .entry(mid.clone())
140 .or_default()
141 .insert(*track_id);
142
143 if !connections.contains_key(&mid) {
144 let connection = Connection::new(
145 mid.clone(),
146 &self.room_recv_constraints,
147 connection_mode,
148 );
149 self.on_new_connection.call1(connection.new_handle());
150 drop(connections.insert(mid.clone(), connection));
151 }
152 _ = partners.insert(mid);
153 }
154
155 partners.retain(|partner| {
157 let to_remove = !partner_members.contains(partner);
158
159 if to_remove {
160 if let Some(tracks) = members_to_tracks.get_mut(partner) {
161 _ = tracks.remove(track_id);
162
163 if tracks.is_empty() {
164 _ = connections
165 .remove(partner)
166 .map(|conn| conn.0.on_close.call0());
167 }
168 }
169 }
170
171 !to_remove
172 });
173
174 return partner_members
175 .into_iter()
176 .filter_map(|partner| connections.get(&partner).cloned())
177 .collect();
178 }
179
180 self.add_connections(*track_id, &partner_members, connection_mode)
181 }
182
183 #[must_use]
190 fn add_connections(
191 &self,
192 track_id: TrackId,
193 partner_members: &HashSet<MemberId>,
194 connection_mode: ConnectionMode,
195 ) -> Vec<Connection> {
196 let mut connections = self.members_to_conns.borrow_mut();
197
198 #[expect(clippy::iter_over_hash_type, reason = "order doesn't matter")]
199 for partner in partner_members {
200 _ = self
201 .members_to_tracks
202 .borrow_mut()
203 .entry(partner.clone())
204 .or_default()
205 .insert(track_id);
206 if !connections.contains_key(partner) {
207 let connection = Connection::new(
208 partner.clone(),
209 &self.room_recv_constraints,
210 connection_mode,
211 );
212 self.on_new_connection.call1(connection.new_handle());
213 drop(connections.insert(partner.clone(), connection));
214 }
215 }
216
217 drop(
218 self.tracks_to_members.borrow_mut().insert(
219 track_id,
220 partner_members.clone().into_iter().collect(),
221 ),
222 );
223
224 partner_members
225 .iter()
226 .filter_map(|p| connections.get(p).cloned())
227 .collect()
228 }
229
230 pub fn remove_track(&self, track_id: &TrackId) {
235 let mut tracks = self.tracks_to_members.borrow_mut();
236
237 if let Some(partners) = tracks.remove(track_id) {
238 #[expect(clippy::iter_over_hash_type, reason = "doesn't matter")]
239 for p in partners {
240 if let Some(member_tracks) =
241 self.members_to_tracks.borrow_mut().get_mut(&p)
242 {
243 _ = member_tracks.remove(track_id);
244
245 if member_tracks.is_empty() {
246 _ = self
247 .members_to_conns
248 .borrow_mut()
249 .remove(&p)
250 .map(|conn| conn.0.on_close.call0());
251 }
252 }
253 }
254 }
255 }
256
257 #[must_use]
259 pub fn get(&self, remote_member_id: &MemberId) -> Option<Connection> {
260 self.members_to_conns.borrow().get(remote_member_id).cloned()
261 }
262
263 pub fn iter_by_track(
265 &self,
266 track_id: &TrackId,
267 ) -> impl Iterator<Item = Connection> + use<'_> {
268 self.tracks_to_members
269 .borrow()
270 .get(track_id)
271 .cloned()
272 .into_iter()
273 .flat_map(|member_ids| {
274 member_ids.into_iter().filter_map(|member_id| {
275 self.members_to_conns.borrow().get(&member_id).cloned()
276 })
277 })
278 }
279
280 pub fn apply(&self, new_state: &proto::state::Room) {
282 #[expect(clippy::iter_over_hash_type, reason = "order doesn't matter")]
283 for peer in new_state.peers.values() {
284 for (track_id, sender) in &peer.senders {
285 if let Some(partners) =
286 self.tracks_to_members.borrow().get(track_id)
287 {
288 for member in partners {
289 if let Some(member_tracks) =
290 self.members_to_tracks.borrow_mut().get_mut(member)
291 {
292 if !sender.receivers.contains(member) {
293 _ = member_tracks.remove(track_id);
294 }
295 }
296 }
297 }
298 }
299 }
300 }
301}
302
303#[derive(Caused, Clone, Copy, Debug, Display)]
305#[cause(error = platform::Error)]
306#[display("`ConnectionHandle` is in detached state")]
307pub struct HandleDetachedError;
308
309#[derive(Clone, Debug)]
313pub struct ConnectionHandleImpl(Weak<InnerConnection>);
314
315#[derive(Clone, Copy, Debug, Display, Eq, From, Ord, PartialEq, PartialOrd)]
317pub enum ClientConnectionQualityScore {
318 Disconnected,
320
321 Connected(ConnectionQualityScore),
323}
324
325impl ClientConnectionQualityScore {
326 #[must_use]
328 pub const fn into_u8(self) -> u8 {
329 match self {
330 Self::Disconnected => 0,
331 #[expect(clippy::as_conversions, reason = "needs refactoring")]
333 Self::Connected(score) => score as u8,
334 }
335 }
336}
337
338#[derive(Clone, Copy, Debug, Eq, From, PartialEq)]
340pub enum MemberConnectionState {
341 P2P(PeerConnectionState),
345}
346
347#[derive(Debug)]
352struct InnerConnection {
353 remote_id: MemberId,
355
356 quality_score: Cell<Option<ConnectionQualityScore>>,
358
359 client_quality_score: Cell<Option<ClientConnectionQualityScore>>,
361
362 peer_states: RefCell<HashMap<PeerId, PeerConnectionState>>,
367
368 on_remote_track_added: platform::Callback<api::RemoteMediaTrack>,
370
371 recv_constraints: Rc<RecvConstraints>,
373
374 receivers: RefCell<Vec<Rc<receiver::State>>>,
376
377 on_quality_score_update: platform::Callback<u8>,
379
380 on_state_change: platform::Callback<api::MemberConnectionState>,
382
383 connection_mode: ConnectionMode,
389
390 on_close: platform::Callback<()>,
392
393 _task_handles: Vec<TaskHandle>,
396}
397
398impl InnerConnection {
399 async fn change_media_state(
408 &self,
409 desired_state: MediaState,
410 kind: MediaKind,
411 source_kind: Option<MediaSourceKind>,
412 ) -> ChangeMediaStateResult {
413 let receivers = self.receivers.borrow().clone();
414 let mut change_tasks = Vec::new();
415 for r in receivers {
416 let source_filter =
417 source_kind.is_none_or(|skind| skind == r.source_kind().into());
418
419 if r.is_subscription_needed(desired_state)
420 && r.kind() == kind
421 && source_filter
422 {
423 r.media_state_transition_to(desired_state)
424 .map_err(tracerr::map_from_and_wrap!())?;
425 change_tasks.push(r.when_media_state_stable(desired_state));
426 }
427 }
428
429 drop(
430 future::try_join_all(change_tasks)
431 .await
432 .map_err(tracerr::from_and_wrap!())?,
433 );
434
435 if let MediaState::MediaExchange(desired_state) = desired_state {
436 self.recv_constraints.set_enabled(
437 desired_state == media_exchange_state::Stable::Enabled,
438 kind,
439 source_kind.map(Into::into),
440 );
441 }
442
443 Ok(())
444 }
445}
446
447impl ConnectionHandleImpl {
448 pub fn on_close(
454 &self,
455 f: platform::Function<()>,
456 ) -> Result<(), Traced<HandleDetachedError>> {
457 self.0
458 .upgrade()
459 .ok_or_else(|| tracerr::new!(HandleDetachedError))
460 .map(|inner| inner.on_close.set_func(f))
461 }
462
463 pub fn get_remote_member_id(
469 &self,
470 ) -> Result<String, Traced<HandleDetachedError>> {
471 self.0
472 .upgrade()
473 .ok_or_else(|| tracerr::new!(HandleDetachedError))
474 .map(|inner| inner.remote_id.0.clone())
475 }
476
477 pub fn get_state(
483 &self,
484 ) -> Result<Option<MemberConnectionState>, Traced<HandleDetachedError>>
485 {
486 self.0.upgrade().ok_or_else(|| tracerr::new!(HandleDetachedError)).map(
490 |inner| {
491 (inner.connection_mode == ConnectionMode::Mesh)
492 .then(|| {
493 inner
494 .peer_states
495 .borrow()
496 .values()
497 .next()
498 .map(|&s| MemberConnectionState::P2P(s))
499 })
500 .flatten()
501 },
502 )
503 }
504
505 pub fn on_state_change(
512 &self,
513 f: platform::Function<api::MemberConnectionState>,
514 ) -> Result<(), Traced<HandleDetachedError>> {
515 self.0
516 .upgrade()
517 .ok_or_else(|| tracerr::new!(HandleDetachedError))
518 .map(|inner| inner.on_state_change.set_func(f))
519 }
520
521 pub fn on_remote_track_added(
528 &self,
529 f: platform::Function<api::RemoteMediaTrack>,
530 ) -> Result<(), Traced<HandleDetachedError>> {
531 self.0
532 .upgrade()
533 .ok_or_else(|| tracerr::new!(HandleDetachedError))
534 .map(|inner| inner.on_remote_track_added.set_func(f))
535 }
536
537 pub fn on_quality_score_update(
544 &self,
545 f: platform::Function<u8>,
546 ) -> Result<(), Traced<HandleDetachedError>> {
547 self.0
548 .upgrade()
549 .ok_or_else(|| tracerr::new!(HandleDetachedError))
550 .map(|inner| inner.on_quality_score_update.set_func(f))
551 }
552
553 pub fn enable_remote_video(
564 &self,
565 source_kind: Option<MediaSourceKind>,
566 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
567 self.change_media_state(
568 media_exchange_state::Stable::Enabled.into(),
569 MediaKind::Video,
570 source_kind,
571 )
572 }
573
574 pub fn disable_remote_video(
585 &self,
586 source_kind: Option<MediaSourceKind>,
587 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
588 self.change_media_state(
589 media_exchange_state::Stable::Disabled.into(),
590 MediaKind::Video,
591 source_kind,
592 )
593 }
594
595 pub fn enable_remote_audio(
606 &self,
607 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
608 self.change_media_state(
609 media_exchange_state::Stable::Enabled.into(),
610 MediaKind::Audio,
611 None,
612 )
613 }
614
615 pub fn disable_remote_audio(
626 &self,
627 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
628 self.change_media_state(
629 media_exchange_state::Stable::Disabled.into(),
630 MediaKind::Audio,
631 None,
632 )
633 }
634
635 fn change_media_state(
641 &self,
642 desired_state: MediaState,
643 kind: MediaKind,
644 source_kind: Option<MediaSourceKind>,
645 ) -> LocalBoxFuture<'static, ChangeMediaStateResult> {
646 let inner = self
647 .0
648 .upgrade()
649 .ok_or_else(|| tracerr::new!(ChangeMediaStateError::Detached));
650 let inner = match inner {
651 Ok(inner) => inner,
652 Err(e) => return Box::pin(future::err(e)),
653 };
654
655 Box::pin(async move {
656 inner.change_media_state(desired_state, kind, source_kind).await
657 })
658 }
659}
660
661#[derive(Clone, Debug)]
663pub struct Connection(Rc<InnerConnection>);
664
665impl Connection {
666 #[must_use]
671 pub fn new(
672 remote_id: MemberId,
673 room_recv_constraints: &Rc<RecvConstraints>,
674 connection_mode: ConnectionMode,
675 ) -> Self {
676 let recv_constraints = Rc::new(room_recv_constraints.as_ref().clone());
678
679 Self(Rc::new(InnerConnection {
680 _task_handles: vec![
681 Self::spawn_constraints_synchronizer(
682 Rc::clone(&recv_constraints),
683 room_recv_constraints.on_video_device_enabled_change(),
684 MediaKind::Video,
685 MediaSourceKind::Device,
686 ),
687 Self::spawn_constraints_synchronizer(
688 Rc::clone(&recv_constraints),
689 room_recv_constraints.on_video_display_enabled_change(),
690 MediaKind::Video,
691 MediaSourceKind::Display,
692 ),
693 Self::spawn_constraints_synchronizer(
694 Rc::clone(&recv_constraints),
695 room_recv_constraints.on_audio_device_enabled_change(),
696 MediaKind::Audio,
697 MediaSourceKind::Device,
698 ),
699 Self::spawn_constraints_synchronizer(
700 Rc::clone(&recv_constraints),
701 room_recv_constraints.on_audio_display_enabled_change(),
702 MediaKind::Audio,
703 MediaSourceKind::Display,
704 ),
705 ],
706 remote_id,
707 quality_score: Cell::default(),
708 client_quality_score: Cell::default(),
709 peer_states: RefCell::default(),
710 on_quality_score_update: platform::Callback::default(),
711 on_state_change: platform::Callback::default(),
712 recv_constraints,
713 connection_mode,
714 on_close: platform::Callback::default(),
715 on_remote_track_added: platform::Callback::default(),
716 receivers: RefCell::default(),
717 }))
718 }
719
720 fn spawn_constraints_synchronizer(
727 recv_constraints: Rc<RecvConstraints>,
728 mut changes_stream: LocalBoxStream<'static, bool>,
729 kind: MediaKind,
730 source_kind: MediaSourceKind,
731 ) -> TaskHandle {
732 let (fut, abort) = future::abortable(async move {
733 while let Some(is_enabled) = changes_stream.next().await {
734 recv_constraints.set_enabled(
735 is_enabled,
736 kind,
737 Some(source_kind.into()),
738 );
739 }
740 });
741 platform::spawn(fut.map(drop));
742
743 TaskHandle::from(abort)
744 }
745
746 pub fn add_receiver(&self, receiver: Rc<receiver::State>) {
753 let enabled_in_cons = match &receiver.kind() {
754 MediaKind::Audio => {
755 self.0.recv_constraints.is_audio_device_enabled()
756 || self.0.recv_constraints.is_audio_display_enabled()
757 }
758 MediaKind::Video => {
759 self.0.recv_constraints.is_video_device_enabled()
760 || self.0.recv_constraints.is_video_display_enabled()
761 }
762 };
763 receiver
764 .media_exchange_state_controller()
765 .transition_to(enabled_in_cons.into());
766
767 self.0.receivers.borrow_mut().push(receiver);
768 }
769
770 pub fn add_remote_track(&self, track: remote::Track) {
773 self.0.on_remote_track_added.call1(track);
774 }
775
776 #[must_use]
778 pub fn new_handle(&self) -> ConnectionHandleImpl {
779 ConnectionHandleImpl(Rc::downgrade(&self.0))
780 }
781
782 pub fn update_quality_score(&self, score: ConnectionQualityScore) {
784 if self.0.quality_score.replace(Some(score)) == Some(score) {
785 return;
786 }
787
788 self.refresh_client_conn_quality_score();
789 }
790
791 pub fn update_peer_state(
793 &self,
794 peer_id: PeerId,
795 state: PeerConnectionState,
796 ) {
797 let old = self.0.peer_states.borrow_mut().insert(peer_id, state);
798 if old == Some(state) {
799 return;
800 }
801
802 self.refresh_client_conn_quality_score();
803 if self.0.connection_mode == ConnectionMode::Mesh {
804 self.0.on_state_change.call1::<api::MemberConnectionState>(
805 MemberConnectionState::P2P(state).into(),
806 );
807 } else {
808 }
812 }
813
814 fn refresh_client_conn_quality_score(&self) {
816 use PeerConnectionState as S;
817
818 let peer_states = self.0.peer_states.borrow();
819 let quality_score = self.0.quality_score.get();
820 let score = if peer_states.is_empty() {
821 return;
822 } else if peer_states
823 .values()
824 .any(|s| matches!(s, S::Disconnected | S::Failed | S::Closed))
825 {
826 ClientConnectionQualityScore::Disconnected
827 } else if peer_states.values().all(|s| matches!(s, S::Connected)) {
828 match quality_score {
829 Some(qs) => qs.into(),
830 None => return,
831 }
832 } else {
833 return;
834 };
835
836 let is_score_changed =
837 self.0.client_quality_score.replace(Some(score)) != Some(score);
838 if is_score_changed {
839 self.0.on_quality_score_update.call1(score.into_u8());
840 }
841 }
842}