Skip to main content

sap_rs/
lib.rs

1/*
2 *  Copyright (C) 2024 Michael Bachmann
3 *
4 *  This program is free software: you can redistribute it and/or modify
5 *  it under the terms of the GNU Affero General Public License as published by
6 *  the Free Software Foundation, either version 3 of the License, or
7 *  (at your option) any later version.
8 *
9 *  This program is distributed in the hope that it will be useful,
10 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 *  GNU Affero General Public License for more details.
13 *
14 *  You should have received a copy of the GNU Affero General Public License
15 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
16 */
17
18use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24    collections::HashMap,
25    io::Cursor,
26    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
27    time::{Duration, SystemTime, UNIX_EPOCH},
28};
29use tokio::{
30    net::UdpSocket,
31    select,
32    sync::{mpsc, oneshot},
33    time::interval,
34};
35use tosub::SubsystemHandle;
36use tracing::{debug, error, info};
37
38pub mod error;
39
40const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
41const DEFAULT_SAP_PORT: u16 = 9875;
42const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
43
44lazy_static! {
45    static ref HASH_SEED: u32 = SystemTime::now()
46        .duration_since(UNIX_EPOCH)
47        .expect("something is wrong with the system clock")
48        .as_secs() as u32;
49}
50
51#[derive(Debug, Clone)]
52pub struct SessionAnnouncement {
53    pub deletion: bool,
54    pub encrypted: bool,
55    pub compressed: bool,
56    pub msg_id_hash: u16,
57    pub auth_data: Option<String>,
58    pub originating_source: IpAddr,
59    pub payload_type: Option<String>,
60    pub sdp: SessionDescription,
61}
62
63impl SessionAnnouncement {
64    pub fn new(sdp: SessionDescription) -> SapResult<Self> {
65        Ok(Self {
66            deletion: false,
67            encrypted: false,
68            compressed: false,
69            msg_id_hash: sdp_hash(&sdp),
70            auth_data: None,
71            originating_source: sdp.origin.unicast_address.parse()?,
72            payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
73            sdp,
74        })
75    }
76
77    pub fn deletion(sdp: SessionDescription) -> SapResult<Self> {
78        Ok(Self {
79            deletion: true,
80            encrypted: false,
81            compressed: false,
82            msg_id_hash: sdp_hash(&sdp),
83            auth_data: None,
84            originating_source: sdp.origin.unicast_address.parse()?,
85            payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
86            sdp,
87        })
88    }
89}
90
91pub struct SapActor {
92    subsys: SubsystemHandle,
93    rx: mpsc::Receiver<Vec<u8>>,
94    multicast_addr: SocketAddr,
95    active_sessions: HashMap<u64, SessionAnnouncement>,
96    foreign_sessions: HashMap<u64, SessionAnnouncement>,
97    deletion_announcements: HashMap<u64, SubsystemHandle>,
98    event_tx: mpsc::Sender<Event>,
99    msg_rx: mpsc::Receiver<Message>,
100    announcement_sender: mpsc::Sender<SessionAnnouncement>,
101}
102
103pub enum Event {
104    SessionFound(SessionAnnouncement),
105    SessionLost(SessionAnnouncement),
106}
107
108enum Message {
109    AnnounceSession(Box<SessionAnnouncement>, oneshot::Sender<SapResult<()>>),
110    DeleteSession(u64, oneshot::Sender<SapResult<()>>),
111    DeleteAllSessions(oneshot::Sender<SapResult<()>>),
112}
113
114impl SapActor {
115    async fn run(mut self) -> SapResult<()> {
116        loop {
117            select! {
118                recv = self.msg_rx.recv() => if let Some(msg) = recv {
119                    self.process_api_msg(msg).await?;
120                } else {
121                    info!("Message channel closed, shutting down SAP actor.");
122                    break;
123                },
124                recv = self.rx.recv() => if let Some(data) = recv {
125                    self.forward_announcement(&data).await;
126                } else {
127                    info!("Socket channel closed, shutting down SAP actor.");
128                    break;
129                },
130                _ = self.subsys.shutdown_requested() => {
131                    info!("Shutdown requested, shutting down SAP actor.");
132                    break;
133                },
134            }
135        }
136
137        info!("SAP actor stopped.");
138
139        Ok(())
140    }
141
142    async fn process_api_msg(&mut self, msg: Message) -> SapResult<()> {
143        match msg {
144            Message::AnnounceSession(sa, tx) => {
145                tx.send(self.announce_session(*sa).await).ok();
146            }
147            Message::DeleteSession(id, tx) => {
148                tx.send(self.delete_session(id).await).ok();
149            }
150            Message::DeleteAllSessions(tx) => {
151                tx.send(self.delete_all_sessions().await).ok();
152            }
153        }
154
155        Ok(())
156    }
157
158    async fn forward_announcement(&self, buf: &[u8]) {
159        debug!("forwarding SAP message");
160        match decode_sap(buf) {
161            Ok(sap) => {
162                let event = if sap.deletion {
163                    Event::SessionLost(sap)
164                } else {
165                    Event::SessionFound(sap)
166                };
167                if let Err(e) = self.event_tx.send(event).await {
168                    error!("Error forwarding SAP message error: {e}");
169                } else {
170                    debug!("SAP message forwarded");
171                }
172            }
173            Err(e) => {
174                error!("error decoding SAP message: {e}");
175            }
176        }
177    }
178
179    async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
180        let session_id = announcement.sdp.origin.session_id;
181
182        info!(
183            "Announcing new session with hash {}.",
184            announcement.msg_id_hash
185        );
186
187        self.delete_session(announcement.sdp.origin.session_id)
188            .await?;
189
190        let mut deletion_announcement = announcement.clone();
191        deletion_announcement.deletion = true;
192
193        let tx = self.announcement_sender.clone();
194
195        let announcement = self.subsys.spawn(
196            format!("announcement/{}", announcement.msg_id_hash),
197            |s| async move {
198                let mut interval = interval(Duration::from_secs(5));
199
200                loop {
201                    // TODO receive other announcements and update delay
202                    // TODO send announcement in according intervals
203                    //
204                    select! {
205                        _ = interval.tick() => tx.send(announcement.clone()).await?,
206                        _ = s.shutdown_requested() => break,
207                    }
208                }
209
210                tx.send(deletion_announcement).await.ok();
211
212                Ok::<(), error::Error>(())
213            },
214        );
215
216        self.deletion_announcements.insert(session_id, announcement);
217
218        Ok(())
219    }
220
221    async fn delete_session(&mut self, session_id: u64) -> SapResult<()> {
222        if let Some(subsys) = self.deletion_announcements.remove(&session_id) {
223            info!("Deleting active session {session_id}.");
224            subsys.request_local_shutdown();
225        } else {
226            debug!("No session active, nothing to delete.");
227        }
228
229        Ok(())
230    }
231
232    async fn delete_all_sessions(&mut self) -> SapResult<()> {
233        let sessions = self.deletion_announcements.drain().collect::<Vec<_>>();
234
235        for (session_id, subsys) in sessions {
236            info!("Deleting active session {session_id}.");
237            subsys.request_local_shutdown();
238        }
239
240        Ok(())
241    }
242}
243
244async fn send_announcement(
245    socket: &UdpSocket,
246    multicast_addr: &SocketAddr,
247    announcement: &SessionAnnouncement,
248) -> SapResult<()> {
249    debug!(
250        "Broadcasting session description:\n{}\n",
251        announcement.sdp.marshal()
252    );
253    let msg = encode_sap(announcement);
254    socket.send_to(&msg, multicast_addr).await?;
255    Ok(())
256}
257
258#[derive(Clone)]
259pub struct Sap {
260    msg_tx: mpsc::Sender<Message>,
261}
262
263impl Sap {
264    pub async fn new(subsys: &SubsystemHandle) -> SapResult<(Self, mpsc::Receiver<Event>)> {
265        let multicast_addr = SocketAddr::new(
266            IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
267            DEFAULT_SAP_PORT,
268        );
269        let socket = create_socket().await?;
270
271        let active_sessions = HashMap::new();
272        let foreign_sessions = HashMap::new();
273        let deletion_announcements = HashMap::new();
274
275        let (event_tx, event_rx) = mpsc::channel(1);
276        let (msg_tx, msg_rx) = mpsc::channel(100);
277        let (socket_tx, socket_rx) = mpsc::channel(100);
278
279        subsys.spawn("sap", move |s| {
280            let (announce_tx, announce_rx) = mpsc::channel(1);
281
282            s.spawn("socket", move |s| {
283                IoLoop {
284                    s,
285                    socket,
286                    multicast_addr,
287                    socket_tx,
288                    announce_rx,
289                }
290                .io_loop()
291            });
292
293            SapActor {
294                subsys: s,
295                multicast_addr,
296                active_sessions,
297                foreign_sessions,
298                deletion_announcements,
299                event_tx,
300                msg_rx,
301                announcement_sender: announce_tx,
302                rx: socket_rx,
303            }
304            .run()
305        });
306
307        Ok((Sap { msg_tx }, event_rx))
308    }
309
310    pub async fn announce_session(&self, sd: SessionDescription) -> SapResult<()> {
311        let sa = SessionAnnouncement::new(sd)?;
312        let (tx, rx) = oneshot::channel();
313        self.msg_tx
314            .send(Message::AnnounceSession(Box::new(sa), tx))
315            .await?;
316        rx.await?
317    }
318
319    pub async fn delete_session(&self, session_id: u64) -> SapResult<()> {
320        let (tx, rx) = oneshot::channel();
321        self.msg_tx
322            .send(Message::DeleteSession(session_id, tx))
323            .await?;
324        rx.await?
325    }
326
327    pub async fn delete_all_sessions(&self) -> SapResult<()> {
328        let (tx, rx) = oneshot::channel();
329        self.msg_tx.send(Message::DeleteAllSessions(tx)).await?;
330        rx.await?
331    }
332}
333
334struct IoLoop {
335    s: SubsystemHandle,
336    socket: UdpSocket,
337    multicast_addr: SocketAddr,
338    socket_tx: mpsc::Sender<Vec<u8>>,
339    announce_rx: mpsc::Receiver<SessionAnnouncement>,
340}
341impl IoLoop {
342    async fn io_loop(mut self) -> SapResult<()> {
343        let mut buf = [0; 1024];
344
345        loop {
346            select! {
347                len = self.socket.recv(&mut buf) => self.socket_tx.send(buf[..len?].to_vec()).await?,
348                recv = self.announce_rx.recv() => if let Some(announcement) = recv {
349                    send_announcement(&self.socket, &self.multicast_addr, &announcement).await?
350                } else {
351                    break;
352                },
353            }
354        }
355
356        self.s.request_local_shutdown();
357
358        info!("SAP socket closed.");
359
360        Ok(())
361    }
362}
363
364pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
365    let mut min_length = 4;
366
367    if msg.len() < min_length {
368        return Err(Error::MalformedPacket(msg.to_owned()));
369    }
370
371    let header = msg[0];
372    let auth_len = msg[1];
373    let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
374
375    let ipv6 = (header & 0b00001000) >> 3 == 1;
376    let deletion = (header & 0b00000100) >> 2 == 1;
377    let encrypted = (header & 0b00000010) >> 1 == 1;
378    let compressed = header & 0b00000001 == 1;
379
380    // TODO implement decryption
381    if encrypted {
382        return Err(Error::NotImplemented("encryption"));
383    }
384    // TODO implement decompression
385    if compressed {
386        return Err(Error::NotImplemented("encryption"));
387    }
388
389    if ipv6 {
390        min_length += 16;
391    } else {
392        min_length += 4;
393    }
394
395    if msg.len() < min_length {
396        return Err(Error::MalformedPacket(msg.to_owned()));
397    }
398
399    let originating_source = if ipv6 {
400        let bits = u128::from_be_bytes([
401            msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
402            msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
403        ]);
404        IpAddr::V6(Ipv6Addr::from_bits(bits))
405    } else {
406        let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
407        IpAddr::V4(Ipv4Addr::from_bits(bits))
408    };
409
410    let auth_data_start = min_length;
411
412    min_length += auth_len as usize;
413
414    if msg.len() <= min_length {
415        return Err(Error::MalformedPacket(msg.to_owned()));
416    }
417
418    let auth_data = if auth_len > 0 {
419        Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
420    } else {
421        None
422    };
423
424    let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
425    let split: Vec<&str> = payload.split('\0').collect();
426
427    let payload_type = if split.len() >= 2 {
428        Some(split[0].to_owned())
429    } else {
430        None
431    };
432
433    let payload = if split.len() == 1 {
434        split[0]
435    } else {
436        &split[1..].join("\0")
437    };
438
439    let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
440
441    Ok(SessionAnnouncement {
442        deletion,
443        encrypted,
444        compressed,
445        msg_id_hash,
446        auth_data,
447        originating_source,
448        payload_type,
449        sdp,
450    })
451}
452
453pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
454    let v = 1u8;
455    let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
456        IpAddr::V4(addr) => (0u8, &addr.octets()),
457        IpAddr::V6(addr) => (1u8, &addr.octets()),
458    };
459    let r = 0u8;
460    let t = if msg.deletion { 1u8 } else { 0u8 };
461    let e = if msg.encrypted { 1u8 } else { 0u8 };
462    let c = if msg.compressed { 1u8 } else { 0u8 };
463    let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
464    let auth_len = msg.auth_data.as_ref().map(|d| d.len()).unwrap_or(0) as u8;
465    let msg_id_hash = msg.msg_id_hash.to_be_bytes();
466
467    let mut data = Vec::new();
468    data.push(header);
469    data.push(auth_len);
470    data.extend_from_slice(&msg_id_hash);
471    data.extend_from_slice(originating_source);
472    if let Some(auth_data) = &msg.auth_data {
473        data.extend_from_slice(auth_data.as_bytes());
474    }
475    if let Some(payload_type) = &msg.payload_type {
476        data.extend_from_slice(payload_type.as_bytes());
477        data.push(b'\0');
478    }
479    debug!("marshalling sdp ...");
480    data.extend_from_slice(msg.sdp.marshal().as_bytes());
481    debug!("marshalling sdp done.");
482
483    data
484}
485
486fn sdp_hash(sdp: &SessionDescription) -> u16 {
487    info!("computing message hash ...");
488    let res = murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16;
489    info!("computing message hash done");
490    res
491}
492
493async fn create_socket() -> SapResult<UdpSocket> {
494    let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
495    let local_ip = Ipv4Addr::UNSPECIFIED;
496    let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
497
498    let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
499    socket.set_reuse_address(true)?;
500    socket.set_nonblocking(true)?;
501    socket.bind(&SockAddr::from(local_addr))?;
502    socket.join_multicast_v4(&multicast_addr, &local_ip)?;
503
504    let socket = UdpSocket::from_std(socket.into())?;
505
506    Ok(socket)
507}
508
509#[cfg(test)]
510mod tests {
511
512    use super::*;
513
514    #[test]
515    fn sdp_gets_hashed_correctly() {
516        let sdp = SessionDescription::unmarshal(&mut Cursor::new(
517            "v=0
518o=- 123456 123458 IN IP4 10.0.1.2
519s=My sample flow
520i=4 channels: c1, c2, c3, c4
521t=0 0
522a=recvonly
523m=audio 5004 RTP/AVP 98
524c=IN IP4 239.69.11.44/32
525a=rtpmap:98 L24/48000/4
526a=ptime:1
527a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
528a=mediaclk:direct=0",
529        ))
530        .unwrap();
531        assert!(sdp_hash(&sdp) != 0);
532    }
533
534    #[test]
535    fn encode_decode_roundtrip_is_successful() {
536        let sdp = "v=0
537o=- 123456 123458 IN IP4 10.0.1.2
538s=My sample flow
539i=4 channels: c1, c2, c3, c4
540t=0 0
541a=recvonly
542m=audio 5004 RTP/AVP 98
543c=IN IP4 239.69.11.44/32
544a=rtpmap:98 L24/48000/4
545a=ptime:1
546a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
547a=mediaclk:direct=0
548";
549
550        let sa = SessionAnnouncement {
551            auth_data: None,
552            payload_type: None,
553            compressed: false,
554            deletion: true,
555            encrypted: false,
556            msg_id_hash: 1234,
557            originating_source: "127.0.0.1".parse().unwrap(),
558            sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
559        };
560
561        let sa_msg = encode_sap(&sa);
562
563        let decoded = decode_sap(&sa_msg).unwrap();
564
565        assert_eq!(sa.auth_data, decoded.auth_data);
566        assert_eq!(sa.compressed, decoded.compressed);
567        assert_eq!(sa.deletion, decoded.deletion);
568        assert_eq!(sa.encrypted, decoded.encrypted);
569        assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
570        assert_eq!(sa.originating_source, decoded.originating_source);
571        assert_eq!(sa.payload_type, decoded.payload_type);
572        assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
573    }
574}