rtcp_types/feedback/
fir.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use std::collections::HashMap;
4
5use crate::feedback::FciFeedbackPacketType;
6use crate::prelude::*;
7use crate::utils::u32_from_be_bytes;
8use crate::{RtcpParseError, RtcpWriteError};
9
10/// An entry in a Full Intra Refresh
11#[derive(Debug, PartialEq, Eq)]
12pub struct FirEntry {
13    ssrc: u32,
14    sequence: u8,
15}
16
17impl FirEntry {
18    fn new(ssrc: u32, sequence: u8) -> Self {
19        Self { ssrc, sequence }
20    }
21
22    /// The SSRC for this FIR
23    pub fn ssrc(&self) -> u32 {
24        self.ssrc
25    }
26
27    /// The sequence count of the request. Intended for deduplication.
28    pub fn sequence(&self) -> u8 {
29        self.sequence
30    }
31}
32
33pub struct FirParserEntryIter<'a> {
34    parser: &'a Fir<'a>,
35    i: usize,
36}
37
38impl FirParserEntryIter<'_> {
39    fn decode_entry(entry: &[u8]) -> FirEntry {
40        FirEntry::new(u32_from_be_bytes(&entry[0..4]), entry[4])
41    }
42}
43
44impl std::iter::Iterator for FirParserEntryIter<'_> {
45    type Item = FirEntry;
46
47    fn next(&mut self) -> Option<Self::Item> {
48        let idx = self.i * 8;
49        if idx + 7 >= self.parser.data.len() {
50            return None;
51        }
52        let entry = FirParserEntryIter::decode_entry(&self.parser.data[idx..]);
53        self.i += 1;
54        Some(entry)
55    }
56}
57
58/// FIR (Full Intra Refresh) information as specified in RFC 5104
59#[derive(Debug)]
60pub struct Fir<'a> {
61    data: &'a [u8],
62}
63
64impl Fir<'_> {
65    /// The list of RTP SSRCs that are requesting a Full Intra Refresh.
66    pub fn entries(&self) -> impl Iterator<Item = FirEntry> + '_ {
67        FirParserEntryIter { parser: self, i: 0 }
68    }
69
70    /// Create a new [`FirBuilder`]
71    pub fn builder() -> FirBuilder {
72        FirBuilder::default()
73    }
74}
75
76impl<'a> FciParser<'a> for Fir<'a> {
77    const PACKET_TYPE: FciFeedbackPacketType = FciFeedbackPacketType::PAYLOAD;
78    const FCI_FORMAT: u8 = 4;
79
80    fn parse(data: &'a [u8]) -> Result<Self, RtcpParseError> {
81        if data.len() < 8 {
82            return Err(RtcpParseError::Truncated {
83                expected: 8,
84                actual: data.len(),
85            });
86        }
87        Ok(Self { data })
88    }
89}
90
91/// Builder for a Full Intra Refresh packet
92#[derive(Debug, Default)]
93pub struct FirBuilder {
94    ssrc_seq: HashMap<u32, u8>,
95}
96
97impl FirBuilder {
98    /// Add an SSRC to this FIR packet.  An existing SSRC will have their sequence number updated.
99    pub fn add_ssrc(mut self, ssrc: u32, sequence: u8) -> Self {
100        self.ssrc_seq
101            .entry(ssrc)
102            .and_modify(|entry| {
103                *entry = sequence;
104            })
105            .or_insert(sequence);
106        self
107    }
108}
109
110impl FciBuilder<'_> for FirBuilder {
111    fn format(&self) -> u8 {
112        Fir::FCI_FORMAT
113    }
114
115    fn supports_feedback_type(&self) -> FciFeedbackPacketType {
116        Fir::PACKET_TYPE
117    }
118}
119
120impl RtcpPacketWriter for FirBuilder {
121    fn calculate_size(&self) -> Result<usize, RtcpWriteError> {
122        let entries = self.ssrc_seq.len();
123        if entries > u16::MAX as usize / 2 - 2 {
124            return Err(RtcpWriteError::TooManyFir);
125        }
126        Ok(entries * 2 * 4)
127    }
128
129    fn write_into_unchecked(&self, buf: &mut [u8]) -> usize {
130        let mut idx = 0;
131        let mut end = idx;
132
133        for (ssrc, sequence) in self.ssrc_seq.iter() {
134            end += 4;
135            buf[idx..end].copy_from_slice(&ssrc.to_be_bytes());
136            idx = end;
137            end += 4;
138            buf[idx..end].copy_from_slice(&[*sequence, 0, 0, 0]);
139            idx = end;
140        }
141        end
142    }
143
144    fn get_padding(&self) -> Option<u8> {
145        None
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::PayloadFeedback;
153
154    #[test]
155    fn fir_build_parse() {
156        const REQ_LEN: usize = PayloadFeedback::MIN_PACKET_LEN + 8;
157        let mut data = [0; REQ_LEN];
158        let fir = {
159            let fci = Fir::builder().add_ssrc(0xfedcba98, 0x30);
160            PayloadFeedback::builder_owned(fci)
161                .sender_ssrc(0x98765432)
162                .media_ssrc(0)
163        };
164        assert_eq!(fir.calculate_size().unwrap(), REQ_LEN);
165        let len = fir.write_into(&mut data).unwrap();
166        assert_eq!(len, REQ_LEN);
167        assert_eq!(
168            data,
169            [
170                0x84, 0xce, 0x00, 0x04, 0x98, 0x76, 0x54, 0x32, 0x00, 0x00, 0x00, 0x00, 0xfe, 0xdc,
171                0xba, 0x98, 0x30, 0x00, 0x00, 0x00
172            ]
173        );
174
175        let fb = PayloadFeedback::parse(&data).unwrap();
176
177        assert_eq!(fb.sender_ssrc(), 0x98765432);
178        assert_eq!(fb.media_ssrc(), 0);
179        let fir = fb.parse_fci::<Fir>().unwrap();
180        let mut entry_iter = fir.entries();
181        assert_eq!(entry_iter.next(), Some(FirEntry::new(0xfedcba98, 0x30)));
182        assert_eq!(entry_iter.next(), None);
183    }
184}