1use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24 collections::HashMap,
25 io::Cursor,
26 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
27 time::{Duration, SystemTime, UNIX_EPOCH},
28};
29use tokio::{
30 net::UdpSocket,
31 select,
32 sync::{mpsc, oneshot},
33 time::interval,
34};
35use tosub::SubsystemHandle;
36use tracing::{debug, error, info};
37
38pub mod error;
39
40const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
41const DEFAULT_SAP_PORT: u16 = 9875;
42const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
43
44lazy_static! {
45 static ref HASH_SEED: u32 = SystemTime::now()
46 .duration_since(UNIX_EPOCH)
47 .expect("something is wrong with the system clock")
48 .as_secs() as u32;
49}
50
51#[derive(Debug, Clone)]
52pub struct SessionAnnouncement {
53 pub deletion: bool,
54 pub encrypted: bool,
55 pub compressed: bool,
56 pub msg_id_hash: u16,
57 pub auth_data: Option<String>,
58 pub originating_source: IpAddr,
59 pub payload_type: Option<String>,
60 pub sdp: SessionDescription,
61}
62
63impl SessionAnnouncement {
64 pub fn new(sdp: SessionDescription) -> SapResult<Self> {
65 Ok(Self {
66 deletion: false,
67 encrypted: false,
68 compressed: false,
69 msg_id_hash: sdp_hash(&sdp),
70 auth_data: None,
71 originating_source: sdp.origin.unicast_address.parse()?,
72 payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
73 sdp,
74 })
75 }
76
77 pub fn deletion(sdp: SessionDescription) -> SapResult<Self> {
78 Ok(Self {
79 deletion: true,
80 encrypted: false,
81 compressed: false,
82 msg_id_hash: sdp_hash(&sdp),
83 auth_data: None,
84 originating_source: sdp.origin.unicast_address.parse()?,
85 payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
86 sdp,
87 })
88 }
89}
90
91pub struct SapActor {
92 subsys: SubsystemHandle,
93 rx: mpsc::Receiver<Vec<u8>>,
94 multicast_addr: SocketAddr,
95 active_sessions: HashMap<u64, SessionAnnouncement>,
96 foreign_sessions: HashMap<u64, SessionAnnouncement>,
97 deletion_announcements: HashMap<u64, SubsystemHandle>,
98 event_tx: mpsc::Sender<Event>,
99 msg_rx: mpsc::Receiver<Message>,
100 announcement_sender: mpsc::Sender<SessionAnnouncement>,
101}
102
103pub enum Event {
104 SessionFound(SessionAnnouncement),
105 SessionLost(SessionAnnouncement),
106}
107
108enum Message {
109 AnnounceSession(Box<SessionAnnouncement>, oneshot::Sender<SapResult<()>>),
110 DeleteSession(u64, oneshot::Sender<SapResult<()>>),
111 DeleteAllSessions(oneshot::Sender<SapResult<()>>),
112}
113
114impl SapActor {
115 async fn run(mut self) -> SapResult<()> {
116 loop {
117 select! {
118 recv = self.msg_rx.recv() => if let Some(msg) = recv {
119 self.process_api_msg(msg).await?;
120 } else {
121 info!("Message channel closed, shutting down SAP actor.");
122 break;
123 },
124 recv = self.rx.recv() => if let Some(data) = recv {
125 self.forward_announcement(&data).await;
126 } else {
127 info!("Socket channel closed, shutting down SAP actor.");
128 break;
129 },
130 _ = self.subsys.shutdown_requested() => {
131 info!("Shutdown requested, shutting down SAP actor.");
132 break;
133 },
134 }
135 }
136
137 info!("SAP actor stopped.");
138
139 Ok(())
140 }
141
142 async fn process_api_msg(&mut self, msg: Message) -> SapResult<()> {
143 match msg {
144 Message::AnnounceSession(sa, tx) => {
145 tx.send(self.announce_session(*sa).await).ok();
146 }
147 Message::DeleteSession(id, tx) => {
148 tx.send(self.delete_session(id).await).ok();
149 }
150 Message::DeleteAllSessions(tx) => {
151 tx.send(self.delete_all_sessions().await).ok();
152 }
153 }
154
155 Ok(())
156 }
157
158 async fn forward_announcement(&self, buf: &[u8]) {
159 debug!("forwarding SAP message");
160 match decode_sap(buf) {
161 Ok(sap) => {
162 let event = if sap.deletion {
163 Event::SessionLost(sap)
164 } else {
165 Event::SessionFound(sap)
166 };
167 if let Err(e) = self.event_tx.send(event).await {
168 error!("Error forwarding SAP message error: {e}");
169 } else {
170 debug!("SAP message forwarded");
171 }
172 }
173 Err(e) => {
174 error!("error decoding SAP message: {e}");
175 }
176 }
177 }
178
179 async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
180 let session_id = announcement.sdp.origin.session_id;
181
182 info!(
183 "Announcing new session with hash {}.",
184 announcement.msg_id_hash
185 );
186
187 self.delete_session(announcement.sdp.origin.session_id)
188 .await?;
189
190 let mut deletion_announcement = announcement.clone();
191 deletion_announcement.deletion = true;
192
193 let tx = self.announcement_sender.clone();
194
195 let announcement = self.subsys.spawn(
196 format!("announcement/{}", announcement.msg_id_hash),
197 |s| async move {
198 let mut interval = interval(Duration::from_secs(5));
199
200 loop {
201 select! {
205 _ = interval.tick() => tx.send(announcement.clone()).await?,
206 _ = s.shutdown_requested() => break,
207 }
208 }
209
210 tx.send(deletion_announcement).await.ok();
211
212 Ok::<(), error::Error>(())
213 },
214 );
215
216 self.deletion_announcements.insert(session_id, announcement);
217
218 Ok(())
219 }
220
221 async fn delete_session(&mut self, session_id: u64) -> SapResult<()> {
222 if let Some(subsys) = self.deletion_announcements.remove(&session_id) {
223 info!("Deleting active session {session_id}.");
224 subsys.request_local_shutdown();
225 } else {
226 debug!("No session active, nothing to delete.");
227 }
228
229 Ok(())
230 }
231
232 async fn delete_all_sessions(&mut self) -> SapResult<()> {
233 let sessions = self.deletion_announcements.drain().collect::<Vec<_>>();
234
235 for (session_id, subsys) in sessions {
236 info!("Deleting active session {session_id}.");
237 subsys.request_local_shutdown();
238 }
239
240 Ok(())
241 }
242}
243
244async fn send_announcement(
245 socket: &UdpSocket,
246 multicast_addr: &SocketAddr,
247 announcement: &SessionAnnouncement,
248) -> SapResult<()> {
249 debug!(
250 "Broadcasting session description:\n{}\n",
251 announcement.sdp.marshal()
252 );
253 let msg = encode_sap(announcement);
254 socket.send_to(&msg, multicast_addr).await?;
255 Ok(())
256}
257
258#[derive(Clone)]
259pub struct Sap {
260 msg_tx: mpsc::Sender<Message>,
261}
262
263impl Sap {
264 pub async fn new(subsys: &SubsystemHandle) -> SapResult<(Self, mpsc::Receiver<Event>)> {
265 let multicast_addr = SocketAddr::new(
266 IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
267 DEFAULT_SAP_PORT,
268 );
269 let socket = create_socket().await?;
270
271 let active_sessions = HashMap::new();
272 let foreign_sessions = HashMap::new();
273 let deletion_announcements = HashMap::new();
274
275 let (event_tx, event_rx) = mpsc::channel(1);
276 let (msg_tx, msg_rx) = mpsc::channel(100);
277 let (socket_tx, socket_rx) = mpsc::channel(100);
278
279 subsys.spawn("sap", move |s| {
280 let (announce_tx, announce_rx) = mpsc::channel(1);
281
282 s.spawn("socket", move |s| {
283 IoLoop {
284 s,
285 socket,
286 multicast_addr,
287 socket_tx,
288 announce_rx,
289 }
290 .io_loop()
291 });
292
293 SapActor {
294 subsys: s,
295 multicast_addr,
296 active_sessions,
297 foreign_sessions,
298 deletion_announcements,
299 event_tx,
300 msg_rx,
301 announcement_sender: announce_tx,
302 rx: socket_rx,
303 }
304 .run()
305 });
306
307 Ok((Sap { msg_tx }, event_rx))
308 }
309
310 pub async fn announce_session(&self, sd: SessionDescription) -> SapResult<()> {
311 let sa = SessionAnnouncement::new(sd)?;
312 let (tx, rx) = oneshot::channel();
313 self.msg_tx
314 .send(Message::AnnounceSession(Box::new(sa), tx))
315 .await?;
316 rx.await?
317 }
318
319 pub async fn delete_session(&self, session_id: u64) -> SapResult<()> {
320 let (tx, rx) = oneshot::channel();
321 self.msg_tx
322 .send(Message::DeleteSession(session_id, tx))
323 .await?;
324 rx.await?
325 }
326
327 pub async fn delete_all_sessions(&self) -> SapResult<()> {
328 let (tx, rx) = oneshot::channel();
329 self.msg_tx.send(Message::DeleteAllSessions(tx)).await?;
330 rx.await?
331 }
332}
333
334struct IoLoop {
335 s: SubsystemHandle,
336 socket: UdpSocket,
337 multicast_addr: SocketAddr,
338 socket_tx: mpsc::Sender<Vec<u8>>,
339 announce_rx: mpsc::Receiver<SessionAnnouncement>,
340}
341impl IoLoop {
342 async fn io_loop(mut self) -> SapResult<()> {
343 let mut buf = [0; 1024];
344
345 loop {
346 select! {
347 len = self.socket.recv(&mut buf) => self.socket_tx.send(buf[..len?].to_vec()).await?,
348 recv = self.announce_rx.recv() => if let Some(announcement) = recv {
349 send_announcement(&self.socket, &self.multicast_addr, &announcement).await?
350 } else {
351 break;
352 },
353 }
354 }
355
356 self.s.request_local_shutdown();
357
358 info!("SAP socket closed.");
359
360 Ok(())
361 }
362}
363
364pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
365 let mut min_length = 4;
366
367 if msg.len() < min_length {
368 return Err(Error::MalformedPacket(msg.to_owned()));
369 }
370
371 let header = msg[0];
372 let auth_len = msg[1];
373 let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
374
375 let ipv6 = (header & 0b00001000) >> 3 == 1;
376 let deletion = (header & 0b00000100) >> 2 == 1;
377 let encrypted = (header & 0b00000010) >> 1 == 1;
378 let compressed = header & 0b00000001 == 1;
379
380 if encrypted {
382 return Err(Error::NotImplemented("encryption"));
383 }
384 if compressed {
386 return Err(Error::NotImplemented("encryption"));
387 }
388
389 if ipv6 {
390 min_length += 16;
391 } else {
392 min_length += 4;
393 }
394
395 if msg.len() < min_length {
396 return Err(Error::MalformedPacket(msg.to_owned()));
397 }
398
399 let originating_source = if ipv6 {
400 let bits = u128::from_be_bytes([
401 msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
402 msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
403 ]);
404 IpAddr::V6(Ipv6Addr::from_bits(bits))
405 } else {
406 let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
407 IpAddr::V4(Ipv4Addr::from_bits(bits))
408 };
409
410 let auth_data_start = min_length;
411
412 min_length += auth_len as usize;
413
414 if msg.len() <= min_length {
415 return Err(Error::MalformedPacket(msg.to_owned()));
416 }
417
418 let auth_data = if auth_len > 0 {
419 Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
420 } else {
421 None
422 };
423
424 let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
425 let split: Vec<&str> = payload.split('\0').collect();
426
427 let payload_type = if split.len() >= 2 {
428 Some(split[0].to_owned())
429 } else {
430 None
431 };
432
433 let payload = if split.len() == 1 {
434 split[0]
435 } else {
436 &split[1..].join("\0")
437 };
438
439 let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
440
441 Ok(SessionAnnouncement {
442 deletion,
443 encrypted,
444 compressed,
445 msg_id_hash,
446 auth_data,
447 originating_source,
448 payload_type,
449 sdp,
450 })
451}
452
453pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
454 let v = 1u8;
455 let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
456 IpAddr::V4(addr) => (0u8, &addr.octets()),
457 IpAddr::V6(addr) => (1u8, &addr.octets()),
458 };
459 let r = 0u8;
460 let t = if msg.deletion { 1u8 } else { 0u8 };
461 let e = if msg.encrypted { 1u8 } else { 0u8 };
462 let c = if msg.compressed { 1u8 } else { 0u8 };
463 let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
464 let auth_len = msg.auth_data.as_ref().map(|d| d.len()).unwrap_or(0) as u8;
465 let msg_id_hash = msg.msg_id_hash.to_be_bytes();
466
467 let mut data = Vec::new();
468 data.push(header);
469 data.push(auth_len);
470 data.extend_from_slice(&msg_id_hash);
471 data.extend_from_slice(originating_source);
472 if let Some(auth_data) = &msg.auth_data {
473 data.extend_from_slice(auth_data.as_bytes());
474 }
475 if let Some(payload_type) = &msg.payload_type {
476 data.extend_from_slice(payload_type.as_bytes());
477 data.push(b'\0');
478 }
479 debug!("marshalling sdp ...");
480 data.extend_from_slice(msg.sdp.marshal().as_bytes());
481 debug!("marshalling sdp done.");
482
483 data
484}
485
486fn sdp_hash(sdp: &SessionDescription) -> u16 {
487 info!("computing message hash ...");
488 let res = murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16;
489 info!("computing message hash done");
490 res
491}
492
493async fn create_socket() -> SapResult<UdpSocket> {
494 let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
495 let local_ip = Ipv4Addr::UNSPECIFIED;
496 let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
497
498 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
499 socket.set_reuse_address(true)?;
500 socket.set_nonblocking(true)?;
501 socket.bind(&SockAddr::from(local_addr))?;
502 socket.join_multicast_v4(&multicast_addr, &local_ip)?;
503
504 let socket = UdpSocket::from_std(socket.into())?;
505
506 Ok(socket)
507}
508
509#[cfg(test)]
510mod tests {
511
512 use super::*;
513
514 #[test]
515 fn sdp_gets_hashed_correctly() {
516 let sdp = SessionDescription::unmarshal(&mut Cursor::new(
517 "v=0
518o=- 123456 123458 IN IP4 10.0.1.2
519s=My sample flow
520i=4 channels: c1, c2, c3, c4
521t=0 0
522a=recvonly
523m=audio 5004 RTP/AVP 98
524c=IN IP4 239.69.11.44/32
525a=rtpmap:98 L24/48000/4
526a=ptime:1
527a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
528a=mediaclk:direct=0",
529 ))
530 .unwrap();
531 assert!(sdp_hash(&sdp) != 0);
532 }
533
534 #[test]
535 fn encode_decode_roundtrip_is_successful() {
536 let sdp = "v=0
537o=- 123456 123458 IN IP4 10.0.1.2
538s=My sample flow
539i=4 channels: c1, c2, c3, c4
540t=0 0
541a=recvonly
542m=audio 5004 RTP/AVP 98
543c=IN IP4 239.69.11.44/32
544a=rtpmap:98 L24/48000/4
545a=ptime:1
546a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
547a=mediaclk:direct=0
548";
549
550 let sa = SessionAnnouncement {
551 auth_data: None,
552 payload_type: None,
553 compressed: false,
554 deletion: true,
555 encrypted: false,
556 msg_id_hash: 1234,
557 originating_source: "127.0.0.1".parse().unwrap(),
558 sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
559 };
560
561 let sa_msg = encode_sap(&sa);
562
563 let decoded = decode_sap(&sa_msg).unwrap();
564
565 assert_eq!(sa.auth_data, decoded.auth_data);
566 assert_eq!(sa.compressed, decoded.compressed);
567 assert_eq!(sa.deletion, decoded.deletion);
568 assert_eq!(sa.encrypted, decoded.encrypted);
569 assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
570 assert_eq!(sa.originating_source, decoded.originating_source);
571 assert_eq!(sa.payload_type, decoded.payload_type);
572 assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
573 }
574}