rtcp_types/feedback/
fir.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use std::collections::HashMap;
4
5use crate::feedback::FciFeedbackPacketType;
6use crate::prelude::*;
7use crate::utils::u32_from_be_bytes;
8use crate::{RtcpParseError, RtcpWriteError};
9
10/// An entry in a Full Intra Refresh
11#[derive(Debug, PartialEq, Eq)]
12pub struct FirEntry {
13    ssrc: u32,
14    sequence: u8,
15}
16
17impl FirEntry {
18    fn new(ssrc: u32, sequence: u8) -> Self {
19        Self { ssrc, sequence }
20    }
21
22    /// The SSRC for this FIR
23    pub fn ssrc(&self) -> u32 {
24        self.ssrc
25    }
26
27    /// The sequence count of the request. Intended for deduplication.
28    pub fn sequence(&self) -> u8 {
29        self.sequence
30    }
31}
32
33pub struct FirParserEntryIter<'a> {
34    parser: &'a Fir<'a>,
35    i: usize,
36}
37
38impl<'a> FirParserEntryIter<'a> {
39    fn decode_entry(entry: &[u8]) -> FirEntry {
40        FirEntry::new(u32_from_be_bytes(&entry[0..4]), entry[4])
41    }
42}
43
44impl<'a> std::iter::Iterator for FirParserEntryIter<'a> {
45    type Item = FirEntry;
46
47    fn next(&mut self) -> Option<Self::Item> {
48        let idx = self.i * 8;
49        if idx + 7 >= self.parser.data.len() {
50            return None;
51        }
52        let entry = FirParserEntryIter::decode_entry(&self.parser.data[idx..]);
53        self.i += 1;
54        Some(entry)
55    }
56}
57
58/// FIR (Full Intra Refresh) information as specified in RFC 5104
59pub struct Fir<'a> {
60    data: &'a [u8],
61}
62
63impl<'a> Fir<'a> {
64    /// The list of RTP SSRCs that are requesting a Full Intra Refresh.
65    pub fn entries(&self) -> impl Iterator<Item = FirEntry> + '_ {
66        FirParserEntryIter { parser: self, i: 0 }
67    }
68
69    /// Create a new [`FirBuilder`]
70    pub fn builder() -> FirBuilder {
71        FirBuilder::default()
72    }
73}
74
75impl<'a> FciParser<'a> for Fir<'a> {
76    const PACKET_TYPE: FciFeedbackPacketType = FciFeedbackPacketType::PAYLOAD;
77    const FCI_FORMAT: u8 = 4;
78
79    fn parse(data: &'a [u8]) -> Result<Self, RtcpParseError> {
80        if data.len() < 8 {
81            return Err(RtcpParseError::Truncated {
82                expected: 8,
83                actual: data.len(),
84            });
85        }
86        Ok(Self { data })
87    }
88}
89
90/// Builder for a Full Intra Refresh packet
91#[derive(Debug, Default)]
92pub struct FirBuilder {
93    ssrc_seq: HashMap<u32, u8>,
94}
95
96impl FirBuilder {
97    /// Add an SSRC to this FIR packet.  An existing SSRC will have their sequence number updated.
98    pub fn add_ssrc(mut self, ssrc: u32, sequence: u8) -> Self {
99        self.ssrc_seq
100            .entry(ssrc)
101            .and_modify(|entry| {
102                *entry = sequence;
103            })
104            .or_insert(sequence);
105        self
106    }
107}
108
109impl<'a> FciBuilder<'a> for FirBuilder {
110    fn format(&self) -> u8 {
111        Fir::FCI_FORMAT
112    }
113
114    fn supports_feedback_type(&self) -> FciFeedbackPacketType {
115        Fir::PACKET_TYPE
116    }
117}
118
119impl RtcpPacketWriter for FirBuilder {
120    fn calculate_size(&self) -> Result<usize, RtcpWriteError> {
121        let entries = self.ssrc_seq.len();
122        if entries > u16::MAX as usize / 2 - 2 {
123            return Err(RtcpWriteError::TooManyFir);
124        }
125        Ok(entries * 2 * 4)
126    }
127
128    fn write_into_unchecked(&self, buf: &mut [u8]) -> usize {
129        let mut idx = 0;
130        let mut end = idx;
131
132        for (ssrc, sequence) in self.ssrc_seq.iter() {
133            end += 4;
134            buf[idx..end].copy_from_slice(&ssrc.to_be_bytes());
135            idx = end;
136            end += 4;
137            buf[idx..end].copy_from_slice(&[*sequence, 0, 0, 0]);
138            idx = end;
139        }
140        end
141    }
142
143    fn get_padding(&self) -> Option<u8> {
144        None
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::PayloadFeedback;
152
153    #[test]
154    fn fir_build_parse() {
155        const REQ_LEN: usize = PayloadFeedback::MIN_PACKET_LEN + 8;
156        let mut data = [0; REQ_LEN];
157        let fir = {
158            let fci = Fir::builder().add_ssrc(0xfedcba98, 0x30);
159            PayloadFeedback::builder_owned(fci)
160                .sender_ssrc(0x98765432)
161                .media_ssrc(0)
162        };
163        assert_eq!(fir.calculate_size().unwrap(), REQ_LEN);
164        let len = fir.write_into(&mut data).unwrap();
165        assert_eq!(len, REQ_LEN);
166        assert_eq!(
167            data,
168            [
169                0x84, 0xce, 0x00, 0x04, 0x98, 0x76, 0x54, 0x32, 0x00, 0x00, 0x00, 0x00, 0xfe, 0xdc,
170                0xba, 0x98, 0x30, 0x00, 0x00, 0x00
171            ]
172        );
173
174        let fb = PayloadFeedback::parse(&data).unwrap();
175
176        assert_eq!(fb.sender_ssrc(), 0x98765432);
177        assert_eq!(fb.media_ssrc(), 0);
178        let fir = fb.parse_fci::<Fir>().unwrap();
179        let mut entry_iter = fir.entries();
180        assert_eq!(entry_iter.next(), Some(FirEntry::new(0xfedcba98, 0x30)));
181        assert_eq!(entry_iter.next(), None);
182    }
183}