1use super::{
16 events::*,
17 pipeline::{Pipeline, PipelineOptions},
18 RemoteDataTrack, RemoteTrackInner,
19};
20use crate::{
21 api::{DataTrackFrame, DataTrackInfo, DataTrackSid, DataTrackSubscribeError, InternalError},
22 e2ee::DecryptionProvider,
23 packet::{Handle, Packet},
24};
25use anyhow::{anyhow, Context};
26use bytes::Bytes;
27use std::{
28 collections::{HashMap, HashSet},
29 mem,
30 sync::Arc,
31};
32use tokio::sync::{broadcast, mpsc, oneshot, watch};
33use tokio_stream::{wrappers::ReceiverStream, Stream};
34
35#[derive(Debug)]
37pub struct ManagerOptions {
38 pub decryption_provider: Option<Arc<dyn DecryptionProvider>>,
44}
45
46pub struct Manager {
48 decryption_provider: Option<Arc<dyn DecryptionProvider>>,
49 event_in_tx: mpsc::Sender<InputEvent>,
50 event_in_rx: mpsc::Receiver<InputEvent>,
51 event_out_tx: mpsc::Sender<OutputEvent>,
52
53 descriptors: HashMap<DataTrackSid, Descriptor>,
55
56 sub_handles: HashMap<Handle, DataTrackSid>,
62}
63
64impl Manager {
65 pub fn new(options: ManagerOptions) -> (Self, ManagerInput, impl Stream<Item = OutputEvent>) {
74 let (event_in_tx, event_in_rx) = mpsc::channel(Self::EVENT_BUFFER_COUNT);
75 let (event_out_tx, event_out_rx) = mpsc::channel(Self::EVENT_BUFFER_COUNT);
76
77 let event_in = ManagerInput::new(event_in_tx.clone());
78 let manager = Manager {
79 decryption_provider: options.decryption_provider,
80 event_in_tx,
81 event_in_rx,
82 event_out_tx,
83 descriptors: HashMap::default(),
84 sub_handles: HashMap::default(),
85 };
86
87 let event_out = ReceiverStream::new(event_out_rx);
88 (manager, event_in, event_out)
89 }
90
91 pub async fn run(mut self) {
96 log::debug!("Task started");
97 while let Some(event) = self.event_in_rx.recv().await {
98 match event {
99 InputEvent::SubscribeRequest(event) => self.on_subscribe_request(event).await,
100 InputEvent::UnsubscribeRequest(event) => self.on_unsubscribe_request(event).await,
101 InputEvent::SfuPublicationUpdates(event) => {
102 self.on_sfu_publication_updates(event).await
103 }
104 InputEvent::SfuSubscriberHandles(event) => self.on_sfu_subscriber_handles(event),
105 InputEvent::PacketReceived(bytes) => self.on_packet_received(bytes),
106 InputEvent::ResendSubscriptionUpdates => {
107 self.on_resend_subscription_updates().await
108 }
109 InputEvent::Shutdown => break,
110 }
111 }
112 self.shutdown().await;
113 log::debug!("Task ended");
114 }
115
116 async fn on_subscribe_request(&mut self, event: SubscribeRequest) {
117 let Some(descriptor) = self.descriptors.get_mut(&event.sid) else {
118 let error = DataTrackSubscribeError::Internal(
119 anyhow!("Cannot subscribe to unknown track").into(),
120 );
121 _ = event.result_tx.send(Err(error));
122 return;
123 };
124 match &mut descriptor.subscription {
125 SubscriptionState::None => {
126 let update_event = SfuUpdateSubscription { sid: event.sid, subscribe: true };
127 _ = self.event_out_tx.send(update_event.into()).await;
128 descriptor.subscription = SubscriptionState::Pending {
129 result_txs: vec![event.result_tx],
130 buffer_size: event.options.buffer_size,
131 };
132 }
134 SubscriptionState::Pending { result_txs, .. } => {
135 result_txs.push(event.result_tx);
136 }
137 SubscriptionState::Active { frame_tx, .. } => {
138 let frame_rx = frame_tx.subscribe();
139 _ = event.result_tx.send(Ok(frame_rx))
140 }
141 }
142 }
143
144 async fn on_unsubscribe_request(&mut self, event: UnsubscribeRequest) {
145 let Some(descriptor) = self.descriptors.get_mut(&event.sid) else {
146 return;
147 };
148
149 let SubscriptionState::Active { sub_handle, .. } = descriptor.subscription else {
150 log::warn!("Unexpected state");
151 return;
152 };
153 descriptor.subscription = SubscriptionState::None;
154 self.sub_handles.remove(&sub_handle);
155
156 let event = SfuUpdateSubscription { sid: event.sid, subscribe: false };
157 _ = self.event_out_tx.send(event.into()).await;
158 }
159
160 async fn on_sfu_publication_updates(&mut self, event: SfuPublicationUpdates) {
161 if event.updates.is_empty() {
162 return;
163 }
164 let mut participant_to_sids: HashMap<String, HashSet<DataTrackSid>> = HashMap::new();
165
166 for (publisher_identity, tracks) in event.updates {
168 let sids_in_update = participant_to_sids.entry(publisher_identity.clone()).or_default();
169 for info in tracks {
170 let sid = info.sid();
171 sids_in_update.insert(sid.clone());
172 if self.descriptors.contains_key(&sid) {
173 continue;
174 }
175 self.handle_track_published(publisher_identity.clone(), info).await;
176 }
177 }
178
179 for (publisher_identity, sids_in_update) in &participant_to_sids {
181 let unpublished_sids: Vec<_> = self
182 .descriptors
183 .iter()
184 .filter(|(_, desc)| desc.publisher_identity.as_ref() == publisher_identity)
185 .filter(|(sid, _)| !sids_in_update.contains(*sid))
186 .map(|(sid, _)| sid.clone())
187 .collect();
188 for sid in unpublished_sids {
189 self.handle_track_unpublished(sid).await;
190 }
191 }
192 }
193
194 async fn handle_track_published(&mut self, publisher_identity: String, info: DataTrackInfo) {
195 let sid = info.sid();
196 if self.descriptors.contains_key(&sid) {
197 log::error!("Existing descriptor for track {}", sid);
198 return;
199 }
200 let info = Arc::new(info);
201 let publisher_identity: Arc<str> = publisher_identity.into();
202
203 let (published_tx, published_rx) = watch::channel(true);
204
205 let descriptor = Descriptor {
206 info: info.clone(),
207 publisher_identity: publisher_identity.clone(),
208 published_tx,
209 subscription: SubscriptionState::None,
210 };
211 self.descriptors.insert(sid, descriptor);
212
213 let inner = RemoteTrackInner {
214 published_rx,
215 event_in_tx: self.event_in_tx.downgrade(), publisher_identity,
217 };
218 let track = RemoteDataTrack::new(info, inner);
219 _ = self.event_out_tx.send(TrackPublished { track }.into()).await;
220 }
221
222 async fn handle_track_unpublished(&mut self, sid: DataTrackSid) {
223 let Some(descriptor) = self.descriptors.remove(&sid) else {
224 log::error!("Unknown track {}", sid);
225 return;
226 };
227 if let SubscriptionState::Active { sub_handle, .. } = descriptor.subscription {
228 self.sub_handles.remove(&sub_handle);
229 };
230 _ = descriptor.published_tx.send(false);
231 _ = self.event_out_tx.send(TrackUnpublished { sid }.into()).await;
232 }
233
234 fn on_sfu_subscriber_handles(&mut self, event: SfuSubscriberHandles) {
235 for (handle, sid) in event.mapping {
236 self.register_subscriber_handle(handle, sid);
237 }
238 }
239
240 fn register_subscriber_handle(&mut self, assigned_handle: Handle, sid: DataTrackSid) {
241 let Some(descriptor) = self.descriptors.get_mut(&sid) else {
242 log::warn!("Unknown track: {}", sid);
243 return;
244 };
245 let (result_txs, buffer_size) = match &mut descriptor.subscription {
246 SubscriptionState::None => {
247 log::warn!("No subscription for {}", sid);
249 return;
250 }
251 SubscriptionState::Active { sub_handle, .. } => {
252 self.sub_handles.remove(sub_handle);
254 *sub_handle = assigned_handle;
255 self.sub_handles.insert(assigned_handle, sid);
256 return;
257 }
258 SubscriptionState::Pending { result_txs, buffer_size } => {
259 (mem::take(result_txs), *buffer_size)
261 }
262 };
263
264 let (packet_tx, packet_rx) = mpsc::channel(Self::PACKET_BUFFER_COUNT);
265 let (frame_tx, frame_rx) = broadcast::channel(buffer_size);
266
267 let decryption_provider = if descriptor.info.uses_e2ee() {
268 self.decryption_provider.as_ref().map(Arc::clone)
269 } else {
270 None
271 };
272
273 let pipeline_opts = PipelineOptions {
274 info: descriptor.info.clone(),
275 publisher_identity: descriptor.publisher_identity.clone(),
276 decryption_provider,
277 };
278 let pipeline = Pipeline::new(pipeline_opts);
279
280 let track_task = TrackTask {
281 info: descriptor.info.clone(),
282 pipeline,
283 published_rx: descriptor.published_tx.subscribe(),
284 packet_rx,
285 frame_tx: frame_tx.clone(),
286 event_in_tx: self.event_in_tx.clone(),
287 };
288 let task_handle = livekit_runtime::spawn(track_task.run());
289
290 descriptor.subscription = SubscriptionState::Active {
291 sub_handle: assigned_handle,
292 packet_tx,
293 frame_tx,
294 task_handle,
295 };
296 self.sub_handles.insert(assigned_handle, sid);
297
298 for result_tx in result_txs {
299 _ = result_tx.send(Ok(frame_rx.resubscribe()));
300 }
301 }
302
303 fn on_packet_received(&mut self, bytes: Bytes) {
304 let packet = match Packet::deserialize(bytes) {
305 Ok(packet) => packet,
306 Err(err) => {
307 log::error!("Failed to deserialize packet: {}", err);
308 return;
309 }
310 };
311 let Some(sid) = self.sub_handles.get(&packet.header.track_handle) else {
312 log::warn!("Unknown subscriber handle {}", packet.header.track_handle);
313 return;
314 };
315 let Some(descriptor) = self.descriptors.get(sid) else {
316 log::warn!("Missing descriptor for track {}", sid);
317 return;
318 };
319 let SubscriptionState::Active { packet_tx, .. } = &descriptor.subscription else {
320 log::warn!("Received packet for track {} without subscription", sid);
321 return;
322 };
323 _ = packet_tx
324 .try_send(packet)
325 .inspect_err(|err| log::debug!("Cannot send packet to track pipeline: {}", err));
326 }
327
328 async fn on_resend_subscription_updates(&self) {
329 let update_events =
330 self.descriptors.iter().filter_map(|(sid, descriptor)| match descriptor.subscription {
331 SubscriptionState::None => None,
332 SubscriptionState::Pending { .. } | SubscriptionState::Active { .. } => {
333 Some(SfuUpdateSubscription { sid: sid.clone(), subscribe: true })
334 }
335 });
336 for event in update_events {
337 _ = self.event_out_tx.send(event.into()).await;
338 }
339 }
340
341 async fn shutdown(self) {
343 for (_, descriptor) in self.descriptors {
344 _ = descriptor.published_tx.send(false);
345 match descriptor.subscription {
346 SubscriptionState::None => {}
347 SubscriptionState::Pending { result_txs, .. } => {
348 for result_tx in result_txs {
349 _ = result_tx.send(Err(DataTrackSubscribeError::Disconnected));
350 }
351 }
352 SubscriptionState::Active { task_handle, .. } => task_handle.await,
353 }
354 }
355 }
356
357 const PACKET_BUFFER_COUNT: usize = 16;
360
361 const EVENT_BUFFER_COUNT: usize = 16;
363}
364
365#[derive(Debug)]
367struct Descriptor {
368 info: Arc<DataTrackInfo>,
369 publisher_identity: Arc<str>,
370 published_tx: watch::Sender<bool>,
371 subscription: SubscriptionState,
372}
373
374#[derive(Debug)]
375enum SubscriptionState {
376 None,
378 Pending {
380 result_txs: Vec<oneshot::Sender<SubscribeResult>>,
382 buffer_size: usize,
384 },
385 Active {
387 sub_handle: Handle,
388 packet_tx: mpsc::Sender<Packet>,
389 frame_tx: broadcast::Sender<DataTrackFrame>,
390 task_handle: livekit_runtime::JoinHandle<()>,
391 },
392}
393
394struct TrackTask {
396 info: Arc<DataTrackInfo>,
397 pipeline: Pipeline,
398 published_rx: watch::Receiver<bool>,
399 packet_rx: mpsc::Receiver<Packet>,
400 frame_tx: broadcast::Sender<DataTrackFrame>,
401 event_in_tx: mpsc::Sender<InputEvent>,
402}
403
404impl TrackTask {
405 async fn run(mut self) {
406 log::debug!("Track task started: name={}", self.info.name);
407
408 let mut is_published = *self.published_rx.borrow();
409 while is_published {
410 tokio::select! {
411 biased; _ = self.published_rx.changed() => {
413 is_published = *self.published_rx.borrow();
414 },
415 _ = self.frame_tx.closed() => {
416 let event = UnsubscribeRequest { sid: self.info.sid() };
417 _ = self.event_in_tx.send(event.into()).await;
418 break; },
420 Some(packet) = self.packet_rx.recv() => {
421 self.receive(packet);
422 },
423 else => break
424 }
425 }
426
427 log::debug!("Track task ended: name={}", self.info.name);
428 }
429
430 fn receive(&mut self, packet: Packet) {
431 let Some(frame) = self.pipeline.process_packet(packet) else { return };
432 _ = self
433 .frame_tx
434 .send(frame)
435 .inspect_err(|err| log::debug!("Cannot send frame to subscribers: {}", err));
436 }
437}
438
439#[derive(Debug, Clone)]
441pub struct ManagerInput {
442 event_in_tx: mpsc::Sender<InputEvent>,
443 _drop_guard: Arc<DropGuard>,
444}
445
446#[derive(Debug)]
448struct DropGuard {
449 event_in_tx: mpsc::Sender<InputEvent>,
450}
451
452impl Drop for DropGuard {
453 fn drop(&mut self) {
454 _ = self.event_in_tx.try_send(InputEvent::Shutdown);
455 }
456}
457
458impl ManagerInput {
459 fn new(event_in_tx: mpsc::Sender<InputEvent>) -> Self {
460 Self { event_in_tx: event_in_tx.clone(), _drop_guard: DropGuard { event_in_tx }.into() }
461 }
462
463 pub fn send(&self, event: InputEvent) -> Result<(), InternalError> {
465 Ok(self.event_in_tx.try_send(event).context("Failed to send input event")?)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::{
473 api::DataTrackSubscribeOptions,
474 e2ee::{DecryptionError, DecryptionProvider, EncryptedPayload},
475 packet::{E2eeExt, Extensions, FrameMarker, Header, Timestamp},
476 utils::testing::expect_event,
477 };
478 use fake::{Fake, Faker};
479 use futures_util::{future::join, StreamExt};
480 use std::{collections::HashMap, sync::RwLock, time::Duration};
481 use test_case::test_case;
482 use tokio::time;
483
484 #[derive(Debug)]
485 struct PrefixStrippingDecryptor;
486
487 impl DecryptionProvider for PrefixStrippingDecryptor {
488 fn decrypt(
489 &self,
490 payload: EncryptedPayload,
491 _sender_identity: &str,
492 ) -> Result<Bytes, DecryptionError> {
493 Ok(payload.payload.slice(4..))
494 }
495 }
496
497 #[tokio::test]
498 async fn test_manager_task_shutdown() {
499 let options = ManagerOptions { decryption_provider: None };
500 let (manager, input, _) = Manager::new(options);
501
502 let join_handle = livekit_runtime::spawn(manager.run());
503 _ = input.send(InputEvent::Shutdown);
504
505 time::timeout(Duration::from_secs(1), join_handle).await.unwrap();
506 }
507
508 #[test_case(true; "via_unpublish")]
509 #[test_case(false; "via_unsubscribe")]
510 #[tokio::test]
511 async fn test_track_task_shutdown(via_unpublish: bool) {
512 let mut info: DataTrackInfo = Faker.fake();
513 info.uses_e2ee = false;
514
515 let info = Arc::new(info);
516 let sid = info.sid();
517 let publisher_identity: Arc<str> = Faker.fake::<String>().into();
518
519 let pipeline_opts =
520 PipelineOptions { info: info.clone(), publisher_identity, decryption_provider: None };
521 let pipeline = Pipeline::new(pipeline_opts);
522
523 let (published_tx, published_rx) = watch::channel(true);
524 let (_packet_tx, packet_rx) = mpsc::channel(4);
525 let (frame_tx, frame_rx) = broadcast::channel(4);
526 let (event_in_tx, mut event_in_rx) = mpsc::channel(4);
527
528 let task =
529 TrackTask { info: info, pipeline, published_rx, packet_rx, frame_tx, event_in_tx };
530 let task_handle = livekit_runtime::spawn(task.run());
531
532 let trigger_shutdown = async {
533 if via_unpublish {
534 published_tx.send(false).unwrap();
536 return;
537 }
538 mem::drop(frame_rx);
540
541 while let Some(event) = event_in_rx.recv().await {
542 let InputEvent::UnsubscribeRequest(event) = event else {
543 panic!("Unexpected event type");
544 };
545 assert_eq!(event.sid, sid);
546 return;
547 }
548 panic!("Did not receive unsubscribe");
549 };
550 time::timeout(Duration::from_secs(1), join(task_handle, trigger_shutdown)).await.unwrap();
551 }
552
553 #[tokio::test]
554 async fn test_subscribe() {
555 let publisher_identity: String = Faker.fake();
556 let track_name: String = Faker.fake();
557 let track_sid: DataTrackSid = Faker.fake();
558 let sub_handle: Handle = Faker.fake();
559
560 let options = ManagerOptions { decryption_provider: None };
561 let (manager, input, mut output) = Manager::new(options);
562 livekit_runtime::spawn(manager.run());
563
564 let event = SfuPublicationUpdates {
566 updates: HashMap::from([(
567 publisher_identity.clone(),
568 vec![DataTrackInfo {
569 sid: RwLock::new(track_sid.clone()).into(),
570 pub_handle: Faker.fake(), name: track_name.clone(),
572 uses_e2ee: false,
573 }],
574 )]),
575 };
576 _ = input.send(event.into());
577
578 let wait_for_track = async {
579 while let Some(event) = output.next().await {
580 match event {
581 OutputEvent::TrackPublished(track) => return track,
582 _ => continue,
583 }
584 }
585 panic!("No track received");
586 };
587
588 let track = wait_for_track.await.track;
589 assert!(track.is_published());
590 assert_eq!(track.info().name, track_name);
591 assert_eq!(track.info().sid(), track_sid);
592 assert_eq!(track.publisher_identity(), publisher_identity);
593
594 let simulate_subscriber_handles = async {
595 while let Some(event) = output.next().await {
596 match event {
597 OutputEvent::SfuUpdateSubscription(event) => {
598 assert!(event.subscribe);
599 assert_eq!(event.sid, track_sid);
600 time::sleep(Duration::from_millis(20)).await;
601
602 let event = SfuSubscriberHandles {
604 mapping: HashMap::from([(sub_handle, track_sid.clone())]),
605 };
606 _ = input.send(event.into());
607 }
608 _ => {}
609 }
610 }
611 };
612
613 time::timeout(Duration::from_secs(1), async {
614 tokio::select! {
615 _ = simulate_subscriber_handles => {}
616 _ = track.subscribe() => {}
617 }
618 })
619 .await
620 .unwrap();
621 }
622
623 #[tokio::test]
624 async fn test_track_publication_add_and_remove() {
625 let options = ManagerOptions { decryption_provider: None };
626 let (manager, input, mut output) = Manager::new(options);
627 livekit_runtime::spawn(manager.run());
628
629 let track_sid: DataTrackSid = Faker.fake();
630 let info = DataTrackInfo {
631 sid: RwLock::new(track_sid.clone()).into(),
632 pub_handle: Faker.fake(),
633 name: "test".into(),
634 uses_e2ee: false,
635 };
636
637 let event =
639 SfuPublicationUpdates { updates: HashMap::from([("identity1".into(), vec![info])]) };
640 input.send(event.into()).unwrap();
641
642 let track = expect_event!(output, OutputEvent::TrackPublished).track;
643 assert_eq!(track.info().sid(), track_sid);
644 assert_eq!(track.info().name, "test");
645 assert!(track.is_published());
646
647 let event =
649 SfuPublicationUpdates { updates: HashMap::from([("identity1".into(), vec![])]) };
650 input.send(event.into()).unwrap();
651
652 time::timeout(Duration::from_secs(1), track.wait_for_unpublish()).await.unwrap();
653 assert!(!track.is_published());
654
655 let event = expect_event!(output, OutputEvent::TrackUnpublished);
656 assert_eq!(event.sid, track_sid);
657 }
658
659 #[tokio::test]
660 async fn test_sfu_publication_updates_idempotent() {
661 let options = ManagerOptions { decryption_provider: None };
662 let (manager, input, mut output) = Manager::new(options);
663 livekit_runtime::spawn(manager.run());
664
665 let track_sid: DataTrackSid = Faker.fake();
666 let info = DataTrackInfo {
667 sid: RwLock::new(track_sid.clone()).into(),
668 pub_handle: Faker.fake(),
669 name: "test".into(),
670 uses_e2ee: false,
671 };
672
673 for _ in 0..3 {
675 let event = SfuPublicationUpdates {
676 updates: HashMap::from([("identity1".into(), vec![info.clone()])]),
677 };
678 input.send(event.into()).unwrap();
679 }
680
681 expect_event!(output, OutputEvent::TrackPublished);
682
683 input.send(InputEvent::Shutdown).unwrap();
685 while let Some(event) = output.next().await {
686 assert!(!matches!(event, OutputEvent::TrackPublished(_)));
687 }
688 }
689
690 #[tokio::test]
691 async fn test_subscribe_receives_frame() {
692 let options = ManagerOptions { decryption_provider: None };
693 let (manager, input, mut output) = Manager::new(options);
694 livekit_runtime::spawn(manager.run());
695
696 let track_sid: DataTrackSid = Faker.fake();
697 let sub_handle: Handle = Faker.fake();
698 let info = DataTrackInfo {
699 sid: RwLock::new(track_sid.clone()).into(),
700 pub_handle: Faker.fake(),
701 name: "test".into(),
702 uses_e2ee: false,
703 };
704
705 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
707 input.send(event.into()).unwrap();
708 expect_event!(output, OutputEvent::TrackPublished);
709
710 let (result_tx, result_rx) = oneshot::channel();
712 let event = SubscribeRequest {
713 sid: track_sid.clone(),
714 options: DataTrackSubscribeOptions::default(),
715 result_tx,
716 };
717 input.send(event.into()).unwrap();
718
719 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
720 assert!(event.subscribe);
721 assert_eq!(event.sid, track_sid);
722
723 let event = SfuSubscriberHandles { mapping: HashMap::from([(sub_handle, track_sid)]) };
725 input.send(event.into()).unwrap();
726
727 let mut frame_rx =
728 time::timeout(Duration::from_secs(1), result_rx).await.unwrap().unwrap().unwrap();
729
730 let packet = Packet {
732 header: Header {
733 marker: FrameMarker::Single,
734 track_handle: sub_handle,
735 sequence: 0,
736 frame_number: 0,
737 timestamp: Timestamp::from_ticks(0),
738 extensions: Extensions::default(),
739 },
740 payload: Bytes::from_static(&[1, 2, 3, 4, 5]),
741 };
742 input.send(InputEvent::PacketReceived(packet.serialize())).unwrap();
743
744 let frame = time::timeout(Duration::from_secs(1), frame_rx.recv()).await.unwrap().unwrap();
745 assert_eq!(frame.payload.as_ref(), &[1, 2, 3, 4, 5]);
746 }
747
748 #[tokio::test]
749 async fn test_subscribe_with_e2ee() {
750 let options =
751 ManagerOptions { decryption_provider: Some(Arc::new(PrefixStrippingDecryptor)) };
752 let (manager, input, mut output) = Manager::new(options);
753 livekit_runtime::spawn(manager.run());
754
755 let track_sid: DataTrackSid = Faker.fake();
756 let sub_handle: Handle = Faker.fake();
757 let info = DataTrackInfo {
758 sid: RwLock::new(track_sid.clone()).into(),
759 pub_handle: Faker.fake(),
760 name: "test".into(),
761 uses_e2ee: true,
762 };
763
764 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
766 input.send(event.into()).unwrap();
767 expect_event!(output, OutputEvent::TrackPublished);
768
769 let (result_tx, result_rx) = oneshot::channel();
771 let event = SubscribeRequest {
772 sid: track_sid.clone(),
773 options: DataTrackSubscribeOptions::default(),
774 result_tx,
775 };
776 input.send(event.into()).unwrap();
777
778 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
779 assert!(event.subscribe);
780
781 let event = SfuSubscriberHandles { mapping: HashMap::from([(sub_handle, track_sid)]) };
783 input.send(event.into()).unwrap();
784
785 let mut frame_rx =
786 time::timeout(Duration::from_secs(1), result_rx).await.unwrap().unwrap().unwrap();
787
788 let packet = Packet {
790 header: Header {
791 marker: FrameMarker::Single,
792 track_handle: sub_handle,
793 sequence: 0,
794 frame_number: 0,
795 timestamp: Timestamp::from_ticks(0),
796 extensions: Extensions {
797 e2ee: Some(E2eeExt { key_index: 0, iv: [0; 12] }),
798 ..Default::default()
799 },
800 },
801 payload: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF, 1, 2, 3, 4, 5]),
802 };
803 input.send(InputEvent::PacketReceived(packet.serialize())).unwrap();
804
805 let frame = time::timeout(Duration::from_secs(1), frame_rx.recv()).await.unwrap().unwrap();
807 assert_eq!(frame.payload.as_ref(), &[1, 2, 3, 4, 5]);
808 }
809
810 #[tokio::test]
811 async fn test_subscribe_fan_out_to_multiple_subscribers() {
812 let options = ManagerOptions { decryption_provider: None };
813 let (manager, input, mut output) = Manager::new(options);
814 livekit_runtime::spawn(manager.run());
815
816 let track_sid: DataTrackSid = Faker.fake();
817 let sub_handle: Handle = Faker.fake();
818 let info = DataTrackInfo {
819 sid: RwLock::new(track_sid.clone()).into(),
820 pub_handle: Faker.fake(),
821 name: "test".into(),
822 uses_e2ee: false,
823 };
824
825 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
827 input.send(event.into()).unwrap();
828 expect_event!(output, OutputEvent::TrackPublished);
829
830 let (result_tx1, result_rx1) = oneshot::channel();
832 let event = SubscribeRequest {
833 sid: track_sid.clone(),
834 options: DataTrackSubscribeOptions::default(),
835 result_tx: result_tx1,
836 };
837 input.send(event.into()).unwrap();
838
839 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
840 assert!(event.subscribe);
841
842 let event =
844 SfuSubscriberHandles { mapping: HashMap::from([(sub_handle, track_sid.clone())]) };
845 input.send(event.into()).unwrap();
846
847 let mut rx1 =
848 time::timeout(Duration::from_secs(1), result_rx1).await.unwrap().unwrap().unwrap();
849
850 let (result_tx2, result_rx2) = oneshot::channel();
852 let event = SubscribeRequest {
853 sid: track_sid.clone(),
854 options: DataTrackSubscribeOptions::default(),
855 result_tx: result_tx2,
856 };
857 input.send(event.into()).unwrap();
858 let mut rx2 = result_rx2.await.unwrap().unwrap();
859
860 let (result_tx3, result_rx3) = oneshot::channel();
861 let event = SubscribeRequest {
862 sid: track_sid.clone(),
863 options: DataTrackSubscribeOptions::default(),
864 result_tx: result_tx3,
865 };
866 input.send(event.into()).unwrap();
867 let mut rx3 = result_rx3.await.unwrap().unwrap();
868
869 let packet = Packet {
871 header: Header {
872 marker: FrameMarker::Single,
873 track_handle: sub_handle,
874 sequence: 0,
875 frame_number: 0,
876 timestamp: Timestamp::from_ticks(0),
877 extensions: Extensions::default(),
878 },
879 payload: Bytes::from_static(&[1, 2, 3, 4, 5]),
880 };
881 input.send(InputEvent::PacketReceived(packet.serialize())).unwrap();
882
883 for rx in [&mut rx1, &mut rx2, &mut rx3] {
885 let frame = time::timeout(Duration::from_secs(1), rx.recv()).await.unwrap().unwrap();
886 assert_eq!(frame.payload.as_ref(), &[1, 2, 3, 4, 5]);
887 }
888 }
889
890 #[tokio::test]
891 async fn test_subscribe_unknown_track_fails() {
892 let options = ManagerOptions { decryption_provider: None };
893 let (manager, input, _) = Manager::new(options);
894 livekit_runtime::spawn(manager.run());
895
896 let (result_tx, result_rx) = oneshot::channel();
898 let event = SubscribeRequest {
899 sid: Faker.fake(),
900 options: DataTrackSubscribeOptions::default(),
901 result_tx,
902 };
903 input.send(event.into()).unwrap();
904
905 let result = result_rx.await.unwrap();
906 assert!(result.is_err());
907 }
908
909 #[tokio::test]
910 async fn test_unpublish_terminates_pending_subscription() {
911 let options = ManagerOptions { decryption_provider: None };
912 let (manager, input, mut output) = Manager::new(options);
913 livekit_runtime::spawn(manager.run());
914
915 let track_sid: DataTrackSid = Faker.fake();
916 let info = DataTrackInfo {
917 sid: RwLock::new(track_sid.clone()).into(),
918 pub_handle: Faker.fake(),
919 name: "test".into(),
920 uses_e2ee: false,
921 };
922
923 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
925 input.send(event.into()).unwrap();
926 expect_event!(output, OutputEvent::TrackPublished);
927
928 let (result_tx, result_rx) = oneshot::channel();
930 let event = SubscribeRequest {
931 sid: track_sid.clone(),
932 options: DataTrackSubscribeOptions::default(),
933 result_tx,
934 };
935 input.send(event.into()).unwrap();
936
937 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
938 assert!(event.subscribe);
939
940 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![])]) };
942 input.send(event.into()).unwrap();
943
944 let result = time::timeout(Duration::from_secs(1), result_rx).await.unwrap();
945 assert!(result.is_err());
946
947 let event = expect_event!(output, OutputEvent::TrackUnpublished);
948 assert_eq!(event.sid, track_sid);
949 }
950
951 #[tokio::test]
952 async fn test_unpublish_terminates_active_subscription() {
953 let options = ManagerOptions { decryption_provider: None };
954 let (manager, input, mut output) = Manager::new(options);
955 livekit_runtime::spawn(manager.run());
956
957 let track_sid: DataTrackSid = Faker.fake();
958 let sub_handle: Handle = Faker.fake();
959 let info = DataTrackInfo {
960 sid: RwLock::new(track_sid.clone()).into(),
961 pub_handle: Faker.fake(),
962 name: "test".into(),
963 uses_e2ee: false,
964 };
965
966 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
968 input.send(event.into()).unwrap();
969 expect_event!(output, OutputEvent::TrackPublished);
970
971 let (result_tx, result_rx) = oneshot::channel();
973 let event = SubscribeRequest {
974 sid: track_sid.clone(),
975 options: DataTrackSubscribeOptions::default(),
976 result_tx,
977 };
978 input.send(event.into()).unwrap();
979
980 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
981 assert!(event.subscribe);
982
983 let event =
985 SfuSubscriberHandles { mapping: HashMap::from([(sub_handle, track_sid.clone())]) };
986 input.send(event.into()).unwrap();
987
988 let mut frame_rx =
989 time::timeout(Duration::from_secs(1), result_rx).await.unwrap().unwrap().unwrap();
990
991 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![])]) };
993 input.send(event.into()).unwrap();
994
995 let result = time::timeout(Duration::from_secs(1), frame_rx.recv()).await.unwrap();
996 assert!(result.is_err());
997
998 let event = expect_event!(output, OutputEvent::TrackUnpublished);
999 assert_eq!(event.sid, track_sid);
1000 }
1001
1002 #[tokio::test]
1003 async fn test_all_subscribers_dropped_terminates_sfu_subscription() {
1004 let options = ManagerOptions { decryption_provider: None };
1005 let (manager, input, mut output) = Manager::new(options);
1006 livekit_runtime::spawn(manager.run());
1007
1008 let track_sid: DataTrackSid = Faker.fake();
1009 let sub_handle: Handle = Faker.fake();
1010 let info = DataTrackInfo {
1011 sid: RwLock::new(track_sid.clone()).into(),
1012 pub_handle: Faker.fake(),
1013 name: "test".into(),
1014 uses_e2ee: false,
1015 };
1016
1017 let event = SfuPublicationUpdates { updates: HashMap::from([("id".into(), vec![info])]) };
1019 input.send(event.into()).unwrap();
1020 expect_event!(output, OutputEvent::TrackPublished);
1021
1022 let (result_tx, result_rx) = oneshot::channel();
1024 let event = SubscribeRequest {
1025 sid: track_sid.clone(),
1026 options: DataTrackSubscribeOptions::default(),
1027 result_tx,
1028 };
1029 input.send(event.into()).unwrap();
1030
1031 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
1032 assert!(event.subscribe);
1033
1034 let event =
1036 SfuSubscriberHandles { mapping: HashMap::from([(sub_handle, track_sid.clone())]) };
1037 input.send(event.into()).unwrap();
1038
1039 let frame_rx =
1040 time::timeout(Duration::from_secs(1), result_rx).await.unwrap().unwrap().unwrap();
1041
1042 drop(frame_rx);
1044
1045 let event = expect_event!(output, OutputEvent::SfuUpdateSubscription);
1047 assert!(!event.subscribe);
1048 assert_eq!(event.sid, track_sid);
1049 }
1050}