1use std::{
4 cell::{Cell, RefCell},
5 collections::{HashMap, HashSet},
6 rc::{Rc, Weak},
7};
8
9use derive_more::with_trait::{Display, From};
10use futures::{
11 FutureExt as _, StreamExt as _, future, future::LocalBoxFuture,
12 stream::LocalBoxStream,
13};
14use medea_client_api_proto::{
15 self as proto, ConnectionQualityScore, MemberId, PeerConnectionState,
16 TrackId,
17};
18use tracerr::Traced;
19
20use crate::{
21 api,
22 media::{MediaKind, MediaSourceKind, RecvConstraints, track::remote},
23 peer::{
24 MediaState, MediaStateControllable as _, ProhibitedStateError,
25 TransceiverSide as _, media_exchange_state, receiver,
26 },
27 platform,
28 utils::{Caused, TaskHandle},
29};
30
31#[derive(Caused, Clone, Copy, Debug, Display, From)]
36#[cause(error = platform::Error)]
37pub enum ChangeMediaStateError {
38 #[display("`ConnectionHandle` is in detached state")]
40 Detached,
41
42 #[display(
47 "`MediaState` transits to opposite ({_0}) of the requested \
48 `MediaExchangeState`"
49 )]
50 TransitionIntoOppositeState(MediaState),
51
52 ProhibitedState(ProhibitedStateError),
56}
57
58type ChangeMediaStateResult = Result<(), Traced<ChangeMediaStateError>>;
60
61#[derive(Debug)]
63pub struct Connections {
64 tracks_to_members: RefCell<HashMap<TrackId, HashSet<MemberId>>>,
66
67 members_to_tracks: RefCell<HashMap<MemberId, HashSet<TrackId>>>,
69
70 members_to_conns: RefCell<HashMap<MemberId, Connection>>,
72
73 room_recv_constraints: Rc<RecvConstraints>,
75
76 on_new_connection: platform::Callback<api::ConnectionHandle>,
78}
79
80impl Connections {
81 pub fn new(room_recv_constraints: Rc<RecvConstraints>) -> Self {
83 Self {
84 tracks_to_members: RefCell::default(),
85 members_to_tracks: RefCell::default(),
86 members_to_conns: RefCell::default(),
87 room_recv_constraints,
88 on_new_connection: platform::Callback::default(),
89 }
90 }
91
92 pub fn on_new_connection(
95 &self,
96 f: platform::Function<api::ConnectionHandle>,
97 ) {
98 self.on_new_connection.set_func(f);
99 }
100
101 #[must_use]
109 pub fn update_connections(
110 &self,
111 track_id: &TrackId,
112 partner_members: HashSet<MemberId>,
113 ) -> Vec<Connection> {
114 if let Some(partners) =
115 self.tracks_to_members.borrow_mut().get_mut(track_id)
116 {
117 let mut connections = self.members_to_conns.borrow_mut();
118 let mut members_to_tracks = self.members_to_tracks.borrow_mut();
119
120 if partners == &partner_members {
122 return partners
123 .iter()
124 .filter_map(|partner| {
125 _ = members_to_tracks
126 .get_mut(partner)
127 .map(|tracks| tracks.insert(*track_id));
128 connections.get(partner).cloned()
129 })
130 .collect();
131 }
132
133 let added: Vec<_> =
135 partner_members.difference(partners).cloned().collect();
136 for mid in added {
137 _ = members_to_tracks
138 .entry(mid.clone())
139 .or_default()
140 .insert(*track_id);
141
142 if !connections.contains_key(&mid) {
143 let connection = Connection::new(
144 mid.clone(),
145 &self.room_recv_constraints,
146 );
147 self.on_new_connection.call1(connection.new_handle());
148 drop(connections.insert(mid.clone(), connection));
149 }
150 _ = partners.insert(mid);
151 }
152
153 partners.retain(|partner| {
155 let to_remove = !partner_members.contains(partner);
156
157 if to_remove {
158 if let Some(tracks) = members_to_tracks.get_mut(partner) {
159 _ = tracks.remove(track_id);
160
161 if tracks.is_empty() {
162 _ = connections
163 .remove(partner)
164 .map(|conn| conn.0.on_close.call0());
165 }
166 }
167 }
168
169 !to_remove
170 });
171
172 return partner_members
173 .into_iter()
174 .filter_map(|partner| connections.get(&partner).cloned())
175 .collect();
176 }
177
178 self.add_connections(*track_id, &partner_members)
179 }
180
181 #[must_use]
188 fn add_connections(
189 &self,
190 track_id: TrackId,
191 partner_members: &HashSet<MemberId>,
192 ) -> Vec<Connection> {
193 let mut connections = self.members_to_conns.borrow_mut();
194
195 #[expect(clippy::iter_over_hash_type, reason = "order doesn't matter")]
196 for partner in partner_members {
197 _ = self
198 .members_to_tracks
199 .borrow_mut()
200 .entry(partner.clone())
201 .or_default()
202 .insert(track_id);
203 if !connections.contains_key(partner) {
204 let connection = Connection::new(
205 partner.clone(),
206 &self.room_recv_constraints,
207 );
208 self.on_new_connection.call1(connection.new_handle());
209 drop(connections.insert(partner.clone(), connection));
210 }
211 }
212
213 drop(
214 self.tracks_to_members.borrow_mut().insert(
215 track_id,
216 partner_members.clone().into_iter().collect(),
217 ),
218 );
219
220 partner_members
221 .iter()
222 .filter_map(|p| connections.get(p).cloned())
223 .collect()
224 }
225
226 pub fn remove_track(&self, track_id: &TrackId) {
231 let mut tracks = self.tracks_to_members.borrow_mut();
232
233 if let Some(partners) = tracks.remove(track_id) {
234 #[expect(clippy::iter_over_hash_type, reason = "doesn't matter")]
235 for p in partners {
236 if let Some(member_tracks) =
237 self.members_to_tracks.borrow_mut().get_mut(&p)
238 {
239 _ = member_tracks.remove(track_id);
240
241 if member_tracks.is_empty() {
242 _ = self
243 .members_to_conns
244 .borrow_mut()
245 .remove(&p)
246 .map(|conn| conn.0.on_close.call0());
247 }
248 }
249 }
250 }
251 }
252
253 #[must_use]
255 pub fn get(&self, remote_member_id: &MemberId) -> Option<Connection> {
256 self.members_to_conns.borrow().get(remote_member_id).cloned()
257 }
258
259 pub fn iter_by_track(
261 &self,
262 track_id: &TrackId,
263 ) -> impl Iterator<Item = Connection> + use<'_> {
264 self.tracks_to_members
265 .borrow()
266 .get(track_id)
267 .cloned()
268 .into_iter()
269 .flat_map(|member_ids| {
270 member_ids.into_iter().filter_map(|member_id| {
271 self.members_to_conns.borrow().get(&member_id).cloned()
272 })
273 })
274 }
275
276 pub fn apply(&self, new_state: &proto::state::Room) {
278 #[expect(clippy::iter_over_hash_type, reason = "order doesn't matter")]
279 for peer in new_state.peers.values() {
280 for (track_id, sender) in &peer.senders {
281 if let Some(partners) =
282 self.tracks_to_members.borrow().get(track_id)
283 {
284 for member in partners {
285 if let Some(member_tracks) =
286 self.members_to_tracks.borrow_mut().get_mut(member)
287 {
288 if !sender.receivers.contains(member) {
289 _ = member_tracks.remove(track_id);
290 }
291 }
292 }
293 }
294 }
295 }
296 }
297}
298
299#[derive(Caused, Clone, Copy, Debug, Display)]
301#[cause(error = platform::Error)]
302#[display("`ConnectionHandle` is in detached state")]
303pub struct HandleDetachedError;
304
305#[derive(Clone, Debug)]
309pub struct ConnectionHandle(Weak<InnerConnection>);
310
311#[derive(Clone, Copy, Debug, Display, Eq, From, Ord, PartialEq, PartialOrd)]
313pub enum ClientConnectionQualityScore {
314 Disconnected,
316
317 Connected(ConnectionQualityScore),
319}
320
321impl ClientConnectionQualityScore {
322 #[must_use]
324 pub const fn into_u8(self) -> u8 {
325 match self {
326 Self::Disconnected => 0,
327 #[expect(clippy::as_conversions, reason = "needs refactoring")]
329 Self::Connected(score) => score as u8,
330 }
331 }
332}
333
334#[derive(Debug)]
338struct InnerConnection {
339 remote_id: MemberId,
341
342 quality_score: Cell<Option<ConnectionQualityScore>>,
344
345 client_quality_score: Cell<Option<ClientConnectionQualityScore>>,
347
348 peer_state: Cell<Option<PeerConnectionState>>,
350
351 on_remote_track_added: platform::Callback<api::RemoteMediaTrack>,
353
354 recv_constraints: Rc<RecvConstraints>,
356
357 receivers: RefCell<Vec<Rc<receiver::State>>>,
359
360 on_quality_score_update: platform::Callback<u8>,
362
363 on_close: platform::Callback<()>,
365
366 _task_handles: Vec<TaskHandle>,
369}
370
371impl InnerConnection {
372 async fn change_media_state(
381 &self,
382 desired_state: MediaState,
383 kind: MediaKind,
384 source_kind: Option<MediaSourceKind>,
385 ) -> ChangeMediaStateResult {
386 let receivers = self.receivers.borrow().clone();
387 let mut change_tasks = Vec::new();
388 for r in receivers {
389 let source_filter =
390 source_kind.is_none_or(|skind| skind == r.source_kind().into());
391
392 if r.is_subscription_needed(desired_state)
393 && r.kind() == kind
394 && source_filter
395 {
396 r.media_state_transition_to(desired_state)
397 .map_err(tracerr::map_from_and_wrap!())?;
398 change_tasks.push(r.when_media_state_stable(desired_state));
399 }
400 }
401
402 drop(
403 future::try_join_all(change_tasks)
404 .await
405 .map_err(tracerr::from_and_wrap!())?,
406 );
407
408 if let MediaState::MediaExchange(desired_state) = desired_state {
409 self.recv_constraints.set_enabled(
410 desired_state == media_exchange_state::Stable::Enabled,
411 kind,
412 source_kind.map(Into::into),
413 );
414 }
415
416 Ok(())
417 }
418}
419
420impl ConnectionHandle {
421 pub fn on_close(
427 &self,
428 f: platform::Function<()>,
429 ) -> Result<(), Traced<HandleDetachedError>> {
430 self.0
431 .upgrade()
432 .ok_or_else(|| tracerr::new!(HandleDetachedError))
433 .map(|inner| inner.on_close.set_func(f))
434 }
435
436 pub fn get_remote_member_id(
442 &self,
443 ) -> Result<String, Traced<HandleDetachedError>> {
444 self.0
445 .upgrade()
446 .ok_or_else(|| tracerr::new!(HandleDetachedError))
447 .map(|inner| inner.remote_id.0.clone())
448 }
449
450 pub fn on_remote_track_added(
457 &self,
458 f: platform::Function<api::RemoteMediaTrack>,
459 ) -> Result<(), Traced<HandleDetachedError>> {
460 self.0
461 .upgrade()
462 .ok_or_else(|| tracerr::new!(HandleDetachedError))
463 .map(|inner| inner.on_remote_track_added.set_func(f))
464 }
465
466 pub fn on_quality_score_update(
473 &self,
474 f: platform::Function<u8>,
475 ) -> Result<(), Traced<HandleDetachedError>> {
476 self.0
477 .upgrade()
478 .ok_or_else(|| tracerr::new!(HandleDetachedError))
479 .map(|inner| inner.on_quality_score_update.set_func(f))
480 }
481
482 pub fn enable_remote_video(
493 &self,
494 source_kind: Option<MediaSourceKind>,
495 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
496 self.change_media_state(
497 media_exchange_state::Stable::Enabled.into(),
498 MediaKind::Video,
499 source_kind,
500 )
501 }
502
503 pub fn disable_remote_video(
514 &self,
515 source_kind: Option<MediaSourceKind>,
516 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
517 self.change_media_state(
518 media_exchange_state::Stable::Disabled.into(),
519 MediaKind::Video,
520 source_kind,
521 )
522 }
523
524 pub fn enable_remote_audio(
535 &self,
536 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
537 self.change_media_state(
538 media_exchange_state::Stable::Enabled.into(),
539 MediaKind::Audio,
540 None,
541 )
542 }
543
544 pub fn disable_remote_audio(
555 &self,
556 ) -> impl Future<Output = ChangeMediaStateResult> + 'static + use<> {
557 self.change_media_state(
558 media_exchange_state::Stable::Disabled.into(),
559 MediaKind::Audio,
560 None,
561 )
562 }
563
564 fn change_media_state(
570 &self,
571 desired_state: MediaState,
572 kind: MediaKind,
573 source_kind: Option<MediaSourceKind>,
574 ) -> LocalBoxFuture<'static, ChangeMediaStateResult> {
575 let inner = self
576 .0
577 .upgrade()
578 .ok_or_else(|| tracerr::new!(ChangeMediaStateError::Detached));
579 let inner = match inner {
580 Ok(inner) => inner,
581 Err(e) => return Box::pin(future::err(e)),
582 };
583
584 Box::pin(async move {
585 inner.change_media_state(desired_state, kind, source_kind).await
586 })
587 }
588}
589
590#[derive(Clone, Debug)]
592pub struct Connection(Rc<InnerConnection>);
593
594impl Connection {
595 #[must_use]
600 pub fn new(
601 remote_id: MemberId,
602 room_recv_constraints: &Rc<RecvConstraints>,
603 ) -> Self {
604 let recv_constraints = Rc::new(room_recv_constraints.as_ref().clone());
606
607 Self(Rc::new(InnerConnection {
608 _task_handles: vec![
609 Self::spawn_constraints_synchronizer(
610 Rc::clone(&recv_constraints),
611 room_recv_constraints.on_video_device_enabled_change(),
612 MediaKind::Video,
613 MediaSourceKind::Device,
614 ),
615 Self::spawn_constraints_synchronizer(
616 Rc::clone(&recv_constraints),
617 room_recv_constraints.on_video_display_enabled_change(),
618 MediaKind::Video,
619 MediaSourceKind::Display,
620 ),
621 Self::spawn_constraints_synchronizer(
622 Rc::clone(&recv_constraints),
623 room_recv_constraints.on_audio_enabled_change(),
624 MediaKind::Audio,
625 MediaSourceKind::Device,
626 ),
627 ],
628 remote_id,
629 quality_score: Cell::default(),
630 client_quality_score: Cell::default(),
631 peer_state: Cell::default(),
632 on_quality_score_update: platform::Callback::default(),
633 recv_constraints,
634 on_close: platform::Callback::default(),
635 on_remote_track_added: platform::Callback::default(),
636 receivers: RefCell::default(),
637 }))
638 }
639
640 fn spawn_constraints_synchronizer(
647 recv_constraints: Rc<RecvConstraints>,
648 mut changes_stream: LocalBoxStream<'static, bool>,
649 kind: MediaKind,
650 source_kind: MediaSourceKind,
651 ) -> TaskHandle {
652 let (fut, abort) = future::abortable(async move {
653 while let Some(is_enabled) = changes_stream.next().await {
654 recv_constraints.set_enabled(
655 is_enabled,
656 kind,
657 Some(source_kind.into()),
658 );
659 }
660 });
661 platform::spawn(fut.map(drop));
662
663 TaskHandle::from(abort)
664 }
665
666 pub fn add_receiver(&self, receiver: Rc<receiver::State>) {
673 let enabled_in_cons = match &receiver.kind() {
674 MediaKind::Audio => self.0.recv_constraints.is_audio_enabled(),
675 MediaKind::Video => {
676 self.0.recv_constraints.is_video_device_enabled()
677 || self.0.recv_constraints.is_video_display_enabled()
678 }
679 };
680 receiver
681 .media_exchange_state_controller()
682 .transition_to(enabled_in_cons.into());
683
684 self.0.receivers.borrow_mut().push(receiver);
685 }
686
687 pub fn add_remote_track(&self, track: remote::Track) {
690 self.0.on_remote_track_added.call1(track);
691 }
692
693 #[must_use]
695 pub fn new_handle(&self) -> ConnectionHandle {
696 ConnectionHandle(Rc::downgrade(&self.0))
697 }
698
699 pub fn update_quality_score(&self, score: ConnectionQualityScore) {
701 if self.0.quality_score.replace(Some(score)) == Some(score) {
702 return;
703 }
704
705 self.refresh_client_conn_quality_score();
706 }
707
708 pub fn update_peer_state(&self, state: PeerConnectionState) {
710 if self.0.peer_state.replace(Some(state)) == Some(state) {
711 return;
712 }
713
714 self.refresh_client_conn_quality_score();
715 }
716
717 fn refresh_client_conn_quality_score(&self) {
719 use PeerConnectionState as S;
720
721 let state = self.0.peer_state.get();
722 let quality_score = self.0.quality_score.get();
723 let score = match (state, quality_score) {
724 (Some(S::Connected), Some(quality_score)) => quality_score.into(),
725 (Some(S::Disconnected | S::Failed | S::Closed), _) => {
726 ClientConnectionQualityScore::Disconnected
727 }
728 (Some(S::Connecting | S::New) | None, _)
729 | (Some(S::Connected), None) => return,
730 };
731
732 let is_score_changed =
733 self.0.client_quality_score.replace(Some(score)) != Some(score);
734 if is_score_changed {
735 self.0.on_quality_score_update.call1(score.into_u8());
736 }
737 }
738}