use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::{Mutex, RwLock};
use crate::media::Packetizer;
pub const DEFAULT_MOUNT_PATH: &str = "/stream";
pub struct Mount {
path: String,
packetizer: Mutex<Box<dyn Packetizer>>,
session_ids: RwLock<Vec<String>>,
}
impl Mount {
pub fn new(path: &str, packetizer: Box<dyn Packetizer>) -> Self {
Self {
path: path.to_string(),
packetizer: Mutex::new(packetizer),
session_ids: RwLock::new(Vec::new()),
}
}
pub fn path(&self) -> &str {
&self.path
}
pub fn packetize(&self, data: &[u8], timestamp_increment: u32) -> Vec<Vec<u8>> {
self.packetizer.lock().packetize(data, timestamp_increment)
}
pub fn payload_type(&self) -> u8 {
self.packetizer.lock().payload_type()
}
pub fn sdp_attributes(&self) -> Vec<String> {
self.packetizer.lock().sdp_attributes()
}
pub fn clock_rate(&self) -> u32 {
self.packetizer.lock().clock_rate()
}
pub fn next_sequence(&self) -> u16 {
self.packetizer.lock().next_sequence()
}
pub fn next_rtp_timestamp(&self) -> u32 {
self.packetizer.lock().next_rtp_timestamp()
}
pub fn subscribe(&self, session_id: &str) {
let mut ids = self.session_ids.write();
if !ids.iter().any(|id| id == session_id) {
ids.push(session_id.to_string());
tracing::debug!(mount = %self.path, session_id, "session subscribed");
}
}
pub fn unsubscribe(&self, session_id: &str) {
let mut ids = self.session_ids.write();
if let Some(pos) = ids.iter().position(|id| id == session_id) {
ids.swap_remove(pos);
tracing::debug!(mount = %self.path, session_id, "session unsubscribed");
}
}
pub fn subscribed_session_ids(&self) -> Vec<String> {
self.session_ids.read().clone()
}
}
#[derive(Clone)]
pub struct MountRegistry {
mounts: Arc<RwLock<HashMap<String, Arc<Mount>>>>,
default_path: Arc<RwLock<Option<String>>>,
}
impl MountRegistry {
pub fn new() -> Self {
Self {
mounts: Arc::new(RwLock::new(HashMap::new())),
default_path: Arc::new(RwLock::new(None)),
}
}
pub fn add(&self, path: &str, packetizer: Box<dyn Packetizer>) -> Arc<Mount> {
let mount = Arc::new(Mount::new(path, packetizer));
self.mounts.write().insert(path.to_string(), mount.clone());
tracing::info!(path, "mount registered");
mount
}
pub fn set_default(&self, path: &str) {
*self.default_path.write() = Some(path.to_string());
}
pub fn get(&self, path: &str) -> Option<Arc<Mount>> {
self.mounts.read().get(path).cloned()
}
pub fn resolve_from_uri(&self, uri: &str) -> Option<Arc<Mount>> {
let path = extract_mount_path(uri);
self.get(path).or_else(|| {
let default = self.default_path.read();
default.as_ref().and_then(|p| self.get(p))
})
}
pub fn unsubscribe_all(&self, session_id: &str) {
let mounts = self.mounts.read();
for mount in mounts.values() {
mount.unsubscribe(session_id);
}
}
}
impl Default for MountRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn extract_mount_path(uri: &str) -> &str {
let path = if let Some(after) = uri
.strip_prefix("rtsp://")
.or_else(|| uri.strip_prefix("rtsps://"))
{
match after.find('/') {
Some(slash) => &after[slash..],
None => DEFAULT_MOUNT_PATH,
}
} else if uri.starts_with('/') {
uri
} else {
DEFAULT_MOUNT_PATH
};
if let Some(pos) = path.rfind("/track") {
&path[..pos]
} else {
path
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_path_full_uri() {
assert_eq!(
extract_mount_path("rtsp://localhost:8554/stream"),
"/stream"
);
}
#[test]
fn extract_path_with_track() {
assert_eq!(
extract_mount_path("rtsp://localhost:8554/stream/track1"),
"/stream"
);
}
#[test]
fn extract_path_no_path() {
assert_eq!(
extract_mount_path("rtsp://localhost:8554"),
DEFAULT_MOUNT_PATH
);
}
#[test]
fn extract_path_star() {
assert_eq!(extract_mount_path("*"), DEFAULT_MOUNT_PATH);
}
#[test]
fn extract_path_bare_path() {
assert_eq!(extract_mount_path("/camera1"), "/camera1");
}
#[test]
fn extract_path_with_camera_track() {
assert_eq!(
extract_mount_path("rtsp://10.0.0.1:8554/camera1/track1"),
"/camera1"
);
}
#[test]
fn subscribe_unsubscribe() {
let mount = Mount::new(
"/test",
Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234)),
);
mount.subscribe("session1");
mount.subscribe("session2");
assert_eq!(mount.subscribed_session_ids().len(), 2);
mount.unsubscribe("session1");
assert_eq!(mount.subscribed_session_ids(), vec!["session2"]);
}
#[test]
fn subscribe_idempotent() {
let mount = Mount::new(
"/test",
Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234)),
);
mount.subscribe("session1");
mount.subscribe("session1");
assert_eq!(mount.subscribed_session_ids().len(), 1);
}
#[test]
fn registry_add_and_get() {
let registry = MountRegistry::new();
let p = Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234));
registry.add("/stream", p);
assert!(registry.get("/stream").is_some());
assert!(registry.get("/other").is_none());
}
#[test]
fn registry_resolve_from_uri() {
let registry = MountRegistry::new();
let p = Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234));
registry.add("/stream", p);
assert!(
registry
.resolve_from_uri("rtsp://localhost:8554/stream")
.is_some()
);
assert!(
registry
.resolve_from_uri("rtsp://localhost:8554/stream/track1")
.is_some()
);
assert!(
registry
.resolve_from_uri("rtsp://localhost:8554/other")
.is_none()
);
}
#[test]
fn registry_resolve_fallback_to_default() {
let registry = MountRegistry::new();
let p = Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234));
registry.add("/stream", p);
registry.set_default("/stream");
let mount = registry
.resolve_from_uri("rtsp://localhost:8554/stream")
.unwrap();
assert_eq!(mount.path(), "/stream");
let mount = registry
.resolve_from_uri("rtsp://localhost:8554/test")
.unwrap();
assert_eq!(mount.path(), "/stream");
let mount = registry
.resolve_from_uri("rtsp://localhost:8554/anything")
.unwrap();
assert_eq!(mount.path(), "/stream");
}
#[test]
fn registry_unsubscribe_all() {
let registry = MountRegistry::new();
let p1 = Box::new(crate::media::h264::H264Packetizer::new(96, 0x1234));
let p2 = Box::new(crate::media::h264::H264Packetizer::new(96, 0x5678));
registry.add("/stream1", p1);
registry.add("/stream2", p2);
registry.get("/stream1").unwrap().subscribe("sess1");
registry.get("/stream2").unwrap().subscribe("sess1");
registry.unsubscribe_all("sess1");
assert!(
registry
.get("/stream1")
.unwrap()
.subscribed_session_ids()
.is_empty()
);
assert!(
registry
.get("/stream2")
.unwrap()
.subscribed_session_ids()
.is_empty()
);
}
}