sap_rs/
lib.rs

1/*
2 *  Copyright (C) 2024 Michael Bachmann
3 *
4 *  This program is free software: you can redistribute it and/or modify
5 *  it under the terms of the GNU Affero General Public License as published by
6 *  the Free Software Foundation, either version 3 of the License, or
7 *  (at your option) any later version.
8 *
9 *  This program is distributed in the hope that it will be useful,
10 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 *  GNU Affero General Public License for more details.
13 *
14 *  You should have received a copy of the GNU Affero General Public License
15 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
16 */
17
18use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24    collections::HashMap,
25    io::Cursor,
26    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
27    time::{Duration, SystemTime, UNIX_EPOCH},
28};
29use tokio::{
30    net::UdpSocket,
31    select, spawn,
32    sync::{mpsc, oneshot},
33    time::interval,
34};
35
36pub mod error;
37
38const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
39const DEFAULT_SAP_PORT: u16 = 9875;
40const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
41
42lazy_static! {
43    static ref HASH_SEED: u32 = SystemTime::now()
44        .duration_since(UNIX_EPOCH)
45        .expect("something is wrong with the system clock")
46        .as_secs() as u32;
47}
48
49#[derive(Debug, Clone)]
50pub struct SessionAnnouncement {
51    pub deletion: bool,
52    pub encrypted: bool,
53    pub compressed: bool,
54    pub msg_id_hash: u16,
55    pub auth_data: Option<String>,
56    pub originating_source: IpAddr,
57    pub payload_type: Option<String>,
58    pub sdp: SessionDescription,
59}
60
61impl SessionAnnouncement {
62    pub fn new(sdp: SessionDescription) -> SapResult<Self> {
63        Ok(Self {
64            deletion: false,
65            encrypted: false,
66            compressed: false,
67            msg_id_hash: sdp_hash(&sdp),
68            auth_data: None,
69            originating_source: sdp.origin.unicast_address.parse()?,
70            payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
71            sdp,
72        })
73    }
74
75    pub fn deletion(sdp: SessionDescription) -> SapResult<Self> {
76        Ok(Self {
77            deletion: true,
78            encrypted: false,
79            compressed: false,
80            msg_id_hash: sdp_hash(&sdp),
81            auth_data: None,
82            originating_source: sdp.origin.unicast_address.parse()?,
83            payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
84            sdp,
85        })
86    }
87}
88
89pub struct SapActor {
90    socket: UdpSocket,
91    multicast_addr: SocketAddr,
92    active_sessions: HashMap<u16, SessionAnnouncement>,
93    foreign_sessions: HashMap<u16, SessionAnnouncement>,
94    deletion_announcements: HashMap<u16, SessionAnnouncement>,
95    event_tx: mpsc::Sender<Event>,
96    msg_rx: mpsc::Receiver<Message>,
97}
98
99pub enum Event {
100    SessionFound(SessionAnnouncement),
101    SessionLost(SessionAnnouncement),
102}
103
104enum Message {
105    AnnounceSession(Box<SessionAnnouncement>, oneshot::Sender<SapResult<()>>),
106    DeleteSession(u16, oneshot::Sender<SapResult<()>>),
107}
108
109impl SapActor {
110    async fn run(mut self) {
111        let mut buf = [0; 1024];
112
113        loop {
114            select! {
115                Some(msg) = self.msg_rx.recv() => {
116                    match msg {
117                        Message::AnnounceSession(sa, tx) => {
118                            tx.send(self.announce_session(*sa).await).ok();
119                        },
120                        Message::DeleteSession(hash, tx) => {
121                            tx.send(self.delete_session(hash).await).ok();
122                        },
123                    }
124                },
125                Ok(len) = async {
126                    log::debug!("receiving SAP broadcast message …");
127                    let recv = self.socket.recv(&mut buf).await;
128                    log::debug!("broadcast message received");
129                    recv
130                } => self.forward_announcement(&buf[0..len]).await,
131                else => break,
132            }
133        }
134    }
135
136    async fn forward_announcement(&self, buf: &[u8]) {
137        log::debug!("forwarding SAP message");
138        match decode_sap(buf) {
139            Ok(sap) => {
140                let event = if sap.deletion {
141                    Event::SessionLost(sap)
142                } else {
143                    Event::SessionFound(sap)
144                };
145                if let Err(e) = self.event_tx.send(event).await {
146                    log::error!("Error forwarding SAP message error: {e}");
147                } else {
148                    log::debug!("SAP message forwarded");
149                }
150            }
151            Err(e) => {
152                log::error!("error decoding SAP message: {e}");
153            }
154        }
155    }
156
157    async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
158        self.delete_session(announcement.msg_id_hash).await?;
159
160        let mut deletion_announcement = announcement.clone();
161        deletion_announcement.deletion = true;
162        self.deletion_announcements
163            .insert(deletion_announcement.msg_id_hash, deletion_announcement);
164
165        let mut interval = interval(Duration::from_secs(5));
166
167        loop {
168            // TODO receive other announcements and update delay
169            // TODO send announcement in according intervals
170            //
171            select! {
172                _ = interval.tick() => self.send_announcement(&announcement).await?,
173            }
174        }
175    }
176
177    async fn delete_session(&mut self, hash: u16) -> SapResult<()> {
178        if let Some(deletion_announcement) = self.deletion_announcements.remove(&hash) {
179            log::info!("Deleting active session {hash}.");
180            let msg = encode_sap(&deletion_announcement);
181            self.socket.send_to(&msg, &self.multicast_addr).await?;
182        } else {
183            log::debug!("No session active, nothing to delete.");
184        }
185
186        Ok(())
187    }
188
189    async fn send_announcement(&self, announcement: &SessionAnnouncement) -> SapResult<()> {
190        log::info!("Broadcasting session description.");
191        let msg = encode_sap(announcement);
192        self.socket.send_to(&msg, &self.multicast_addr).await?;
193        Ok(())
194    }
195}
196
197#[derive(Clone)]
198pub struct Sap {
199    msg_tx: mpsc::Sender<Message>,
200}
201
202impl Sap {
203    pub async fn new() -> SapResult<(Self, mpsc::Receiver<Event>)> {
204        let multicast_addr = SocketAddr::new(
205            IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
206            DEFAULT_SAP_PORT,
207        );
208        let socket = create_socket().await?;
209
210        let active_sessions = HashMap::new();
211        let foreign_sessions = HashMap::new();
212        let deletion_announcements = HashMap::new();
213
214        let (event_tx, event_rx) = mpsc::channel(1);
215        let (msg_tx, msg_rx) = mpsc::channel(100);
216
217        let actor = SapActor {
218            socket,
219            multicast_addr,
220            active_sessions,
221            foreign_sessions,
222            deletion_announcements,
223            event_tx,
224            msg_rx,
225        };
226
227        spawn(actor.run());
228
229        Ok((Sap { msg_tx }, event_rx))
230    }
231
232    pub async fn announce_session(&self, sd: SessionDescription) -> SapResult<()> {
233        let sa = SessionAnnouncement::new(sd)?;
234        let (tx, rx) = oneshot::channel();
235        self.msg_tx
236            .send(Message::AnnounceSession(Box::new(sa), tx))
237            .await?;
238        rx.await?
239    }
240
241    pub async fn delete_session(&self, hash: u16) -> SapResult<()> {
242        let (tx, rx) = oneshot::channel();
243        self.msg_tx.send(Message::DeleteSession(hash, tx)).await?;
244        rx.await?
245    }
246}
247
248pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
249    let mut min_length = 4;
250
251    if msg.len() < min_length {
252        return Err(Error::MalformedPacket(msg.to_owned()));
253    }
254
255    let header = msg[0];
256    let auth_len = msg[1];
257    let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
258
259    let ipv6 = (header & 0b00001000) >> 3 == 1;
260    let deletion = (header & 0b00000100) >> 2 == 1;
261    let encrypted = (header & 0b00000010) >> 1 == 1;
262    let compressed = header & 0b00000001 == 1;
263
264    // TODO implement decryption
265    if encrypted {
266        return Err(Error::NotImplemented("encryption"));
267    }
268    // TODO implement decompression
269    if compressed {
270        return Err(Error::NotImplemented("encryption"));
271    }
272
273    if ipv6 {
274        min_length += 16;
275    } else {
276        min_length += 4;
277    }
278
279    if msg.len() < min_length {
280        return Err(Error::MalformedPacket(msg.to_owned()));
281    }
282
283    let originating_source = if ipv6 {
284        let bits = u128::from_be_bytes([
285            msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
286            msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
287        ]);
288        IpAddr::V6(Ipv6Addr::from_bits(bits))
289    } else {
290        let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
291        IpAddr::V4(Ipv4Addr::from_bits(bits))
292    };
293
294    let auth_data_start = min_length;
295
296    min_length += auth_len as usize;
297
298    if msg.len() <= min_length {
299        return Err(Error::MalformedPacket(msg.to_owned()));
300    }
301
302    let auth_data = if auth_len > 0 {
303        Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
304    } else {
305        None
306    };
307
308    let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
309    let split: Vec<&str> = payload.split('\0').collect();
310
311    let payload_type = if split.len() >= 2 {
312        Some(split[0].to_owned())
313    } else {
314        None
315    };
316
317    let payload = if split.len() == 1 {
318        split[0]
319    } else {
320        &split[1..].join("\0")
321    };
322
323    let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
324
325    Ok(SessionAnnouncement {
326        deletion,
327        encrypted,
328        compressed,
329        msg_id_hash,
330        auth_data,
331        originating_source,
332        payload_type,
333        sdp,
334    })
335}
336
337pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
338    let v = 1u8;
339    let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
340        IpAddr::V4(addr) => (0u8, &addr.octets()),
341        IpAddr::V6(addr) => (1u8, &addr.octets()),
342    };
343    let r = 0u8;
344    let t = if msg.deletion { 1u8 } else { 0u8 };
345    let e = if msg.encrypted { 1u8 } else { 0u8 };
346    let c = if msg.compressed { 1u8 } else { 0u8 };
347    let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
348    let auth_len = msg.auth_data.as_ref().map(|d| d.len()).unwrap_or(0) as u8;
349    let msg_id_hash = msg.msg_id_hash.to_be_bytes();
350
351    let mut data = Vec::new();
352    data.push(header);
353    data.push(auth_len);
354    data.extend_from_slice(&msg_id_hash);
355    data.extend_from_slice(originating_source);
356    if let Some(auth_data) = &msg.auth_data {
357        data.extend_from_slice(auth_data.as_bytes());
358    }
359    if let Some(payload_type) = &msg.payload_type {
360        data.extend_from_slice(payload_type.as_bytes());
361        data.push(b'\0');
362    }
363    log::info!("marshalling sdp ...");
364    data.extend_from_slice(msg.sdp.marshal().as_bytes());
365    log::info!("marshalling sdp done.");
366
367    data
368}
369
370fn sdp_hash(sdp: &SessionDescription) -> u16 {
371    log::info!("computing message hash ...");
372    let res = murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16;
373    log::info!("computing message hash done");
374    res
375}
376
377async fn create_socket() -> SapResult<UdpSocket> {
378    let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
379    let local_ip = Ipv4Addr::UNSPECIFIED;
380    let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
381
382    let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
383    socket.set_reuse_address(true)?;
384    socket.set_nonblocking(true)?;
385    socket.bind(&SockAddr::from(local_addr))?;
386    socket.join_multicast_v4(&multicast_addr, &local_ip)?;
387
388    let socket = UdpSocket::from_std(socket.into())?;
389
390    Ok(socket)
391}
392
393#[cfg(test)]
394mod tests {
395
396    use super::*;
397
398    #[test]
399    fn sdp_gets_hashed_correctly() {
400        let sdp = SessionDescription::unmarshal(&mut Cursor::new(
401            "v=0
402o=- 123456 123458 IN IP4 10.0.1.2
403s=My sample flow
404i=4 channels: c1, c2, c3, c4
405t=0 0
406a=recvonly
407m=audio 5004 RTP/AVP 98
408c=IN IP4 239.69.11.44/32
409a=rtpmap:98 L24/48000/4
410a=ptime:1
411a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
412a=mediaclk:direct=0",
413        ))
414        .unwrap();
415        assert!(sdp_hash(&sdp) != 0);
416    }
417
418    #[test]
419    fn encode_decode_roundtrip_is_successful() {
420        let sdp = "v=0
421o=- 123456 123458 IN IP4 10.0.1.2
422s=My sample flow
423i=4 channels: c1, c2, c3, c4
424t=0 0
425a=recvonly
426m=audio 5004 RTP/AVP 98
427c=IN IP4 239.69.11.44/32
428a=rtpmap:98 L24/48000/4
429a=ptime:1
430a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
431a=mediaclk:direct=0
432";
433
434        let sa = SessionAnnouncement {
435            auth_data: None,
436            payload_type: None,
437            compressed: false,
438            deletion: true,
439            encrypted: false,
440            msg_id_hash: 1234,
441            originating_source: "127.0.0.1".parse().unwrap(),
442            sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
443        };
444
445        let sa_msg = encode_sap(&sa);
446
447        let decoded = decode_sap(&sa_msg).unwrap();
448
449        assert_eq!(sa.auth_data, decoded.auth_data);
450        assert_eq!(sa.compressed, decoded.compressed);
451        assert_eq!(sa.deletion, decoded.deletion);
452        assert_eq!(sa.encrypted, decoded.encrypted);
453        assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
454        assert_eq!(sa.originating_source, decoded.originating_source);
455        assert_eq!(sa.payload_type, decoded.payload_type);
456        assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
457    }
458}