1use crate::error::{Error, Result};
8use alloc::vec::Vec;
9use dvb_common::{Parse, Serialize};
10
11pub const TABLE_ID: u8 = 0x00;
13pub const PID: u16 = 0x0000;
15pub const PROGRAM_NUMBER_NIT: u16 = 0x0000;
17
18const MIN_HEADER_LEN: usize = 3;
19const EXTENSION_HEADER_LEN: usize = 5;
20const CRC_LEN: usize = 4;
21const MIN_SECTION_LEN: usize = MIN_HEADER_LEN + EXTENSION_HEADER_LEN + CRC_LEN;
22const ENTRY_LEN: usize = 4;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize))]
27pub struct PatEntry {
28 pub program_number: u16,
30 pub pid: u16,
32}
33
34#[derive(Debug, Clone, Default, PartialEq, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize))]
37pub struct PatSection {
38 pub transport_stream_id: u16,
40 pub version_number: u8,
42 pub current_next_indicator: bool,
44 pub section_number: u8,
46 pub last_section_number: u8,
48 pub entries: Vec<PatEntry>,
50}
51
52impl<'a> Parse<'a> for PatSection {
53 type Error = crate::error::Error;
54 fn parse(bytes: &'a [u8]) -> Result<Self> {
55 if bytes.len() < MIN_HEADER_LEN + EXTENSION_HEADER_LEN + CRC_LEN {
56 return Err(Error::BufferTooShort {
57 need: MIN_HEADER_LEN + EXTENSION_HEADER_LEN + CRC_LEN,
58 have: bytes.len(),
59 what: "PatSection",
60 });
61 }
62
63 if bytes[0] != TABLE_ID {
64 return Err(Error::UnexpectedTableId {
65 table_id: bytes[0],
66 what: "PatSection",
67 expected: &[TABLE_ID],
68 });
69 }
70
71 let section_length = ((bytes[1] & 0x0F) as u16) << 8 | bytes[2] as u16;
72 let total = super::check_section_length(
73 bytes.len(),
74 MIN_HEADER_LEN,
75 section_length as usize,
76 MIN_SECTION_LEN,
77 )?;
78
79 let transport_stream_id = u16::from_be_bytes([bytes[3], bytes[4]]);
80 let version_number = (bytes[5] >> 1) & 0x1F;
81 let current_next_indicator = (bytes[5] & 0x01) != 0;
82 let section_number = bytes[6];
83 let last_section_number = bytes[7];
84
85 let end = total - CRC_LEN;
86 let mut entries = Vec::new();
87 let mut pos = 8;
88 while pos < end {
89 if pos + ENTRY_LEN > end {
90 return Err(Error::BufferTooShort {
91 need: ENTRY_LEN,
92 have: end - pos,
93 what: "PatSection trailing entry bytes",
94 });
95 }
96 let chunk = &bytes[pos..pos + ENTRY_LEN];
97 let program_number = u16::from_be_bytes([chunk[0], chunk[1]]);
98 let pid = (((chunk[2] & 0x1F) as u16) << 8) | chunk[3] as u16;
99 entries.push(PatEntry {
100 program_number,
101 pid,
102 });
103 pos += ENTRY_LEN;
104 }
105
106 Ok(PatSection {
107 transport_stream_id,
108 version_number,
109 current_next_indicator,
110 section_number,
111 last_section_number,
112 entries,
113 })
114 }
115}
116
117impl Serialize for PatSection {
118 type Error = crate::error::Error;
119 fn serialized_len(&self) -> usize {
120 MIN_HEADER_LEN + EXTENSION_HEADER_LEN + self.entries.len() * ENTRY_LEN + CRC_LEN
121 }
122
123 fn serialize_into(&self, buf: &mut [u8]) -> Result<usize> {
124 let len = self.serialized_len();
125 if buf.len() < len {
126 return Err(Error::OutputBufferTooSmall {
127 need: len,
128 have: buf.len(),
129 });
130 }
131
132 let section_length: u16 =
133 (EXTENSION_HEADER_LEN + self.entries.len() * ENTRY_LEN + CRC_LEN) as u16;
134
135 buf[0] = TABLE_ID;
136 buf[1] = super::SECTION_B1_FLAGS_PSI | ((section_length >> 8) as u8 & 0x0F);
137 buf[2] = (section_length & 0xFF) as u8;
138 buf[3..5].copy_from_slice(&self.transport_stream_id.to_be_bytes());
139 buf[5] = 0xC0 | ((self.version_number & 0x1F) << 1) | u8::from(self.current_next_indicator);
140 buf[6] = self.section_number;
141 buf[7] = self.last_section_number;
142
143 let mut pos = 8;
144 for entry in &self.entries {
145 buf[pos..pos + 2].copy_from_slice(&entry.program_number.to_be_bytes());
146 buf[pos + 2] = 0xE0 | ((entry.pid >> 8) as u8 & 0x1F);
147 buf[pos + 3] = (entry.pid & 0xFF) as u8;
148 pos += ENTRY_LEN;
149 }
150
151 let crc_pos = len - CRC_LEN;
152 let crc = dvb_common::crc32_mpeg2::compute(&buf[..crc_pos]);
153 buf[crc_pos..len].copy_from_slice(&crc.to_be_bytes());
154
155 Ok(len)
156 }
157}
158impl<'a> crate::traits::TableDef<'a> for PatSection {
159 const TABLE_ID_RANGES: &'static [(u8, u8)] = &[(TABLE_ID, TABLE_ID)];
160 const NAME: &'static str = "PROGRAM_ASSOCIATION";
161}
162
163impl PatSection {
164 pub fn programmes(&self) -> impl Iterator<Item = &PatEntry> {
166 self.entries
167 .iter()
168 .filter(|e| e.program_number != PROGRAM_NUMBER_NIT)
169 }
170
171 pub fn nit_pid(&self) -> Option<u16> {
173 self.entries
174 .iter()
175 .find(|e| e.program_number == PROGRAM_NUMBER_NIT)
176 .map(|e| e.pid)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 fn build_pat(tsid: u16, version: u8, entries: &[(u16, u16)]) -> Vec<u8> {
186 let section_length: u16 =
187 (EXTENSION_HEADER_LEN + entries.len() * ENTRY_LEN + CRC_LEN) as u16;
188 let mut v = Vec::new();
189 v.push(TABLE_ID);
190 v.push(super::super::SECTION_B1_FLAGS_PSI | ((section_length >> 8) as u8 & 0x0F));
191 v.push((section_length & 0xFF) as u8);
192 v.extend_from_slice(&tsid.to_be_bytes());
193 v.push(0xC0 | ((version & 0x1F) << 1) | 0x01); v.push(0x00); v.push(0x00); for &(pn, pid) in entries {
197 v.extend_from_slice(&pn.to_be_bytes());
198 v.push(0xE0 | ((pid >> 8) as u8 & 0x1F));
199 v.push((pid & 0xFF) as u8);
200 }
201 v.extend_from_slice(&[0, 0, 0, 0]); v
203 }
204
205 #[test]
206 fn parse_empty_pat_zero_programs() {
207 let bytes = build_pat(0x1234, 5, &[]);
208 let pat = PatSection::parse(&bytes).expect("parse");
209 assert_eq!(pat.transport_stream_id, 0x1234);
210 assert_eq!(pat.version_number, 5);
211 assert!(pat.current_next_indicator);
212 assert_eq!(pat.section_number, 0);
213 assert_eq!(pat.last_section_number, 0);
214 assert_eq!(pat.entries.len(), 0);
215 }
216
217 #[test]
218 fn parse_single_program_extracts_pmt_pid() {
219 let bytes = build_pat(1, 0, &[(42, 0x1234)]);
220 let pat = PatSection::parse(&bytes).unwrap();
221 assert_eq!(pat.entries.len(), 1);
222 assert_eq!(pat.entries[0].program_number, 42);
223 assert_eq!(pat.entries[0].pid, 0x1234);
224 }
225
226 #[test]
227 fn parse_many_programs_preserves_order() {
228 let entries: Vec<(u16, u16)> = (1..=10).map(|i| (i, 0x1000 + i)).collect();
229 let bytes = build_pat(1, 0, &entries);
230 let pat = PatSection::parse(&bytes).unwrap();
231 assert_eq!(pat.entries.len(), 10);
232 for (i, e) in pat.entries.iter().enumerate() {
233 assert_eq!(e.program_number, (i + 1) as u16);
234 assert_eq!(e.pid, 0x1000 + (i + 1) as u16);
235 }
236 }
237
238 #[test]
239 fn parse_rejects_wrong_table_id() {
240 let mut bytes = build_pat(1, 0, &[]);
241 bytes[0] = 0x02; let err = PatSection::parse(&bytes).unwrap_err();
243 assert!(matches!(
244 err,
245 Error::UnexpectedTableId { table_id: 0x02, .. }
246 ));
247 }
248
249 #[test]
250 fn parse_rejects_short_buffer() {
251 let err = PatSection::parse(&[0x00, 0x00]).unwrap_err();
252 assert!(matches!(err, Error::BufferTooShort { .. }));
253 }
254
255 #[test]
256 fn serialize_round_trip_empty() {
257 let pat = PatSection {
258 transport_stream_id: 0x0001,
259 version_number: 0,
260 current_next_indicator: true,
261 section_number: 0,
262 last_section_number: 0,
263 entries: vec![],
264 };
265 let mut buf = vec![0u8; pat.serialized_len()];
266 pat.serialize_into(&mut buf).expect("serialize");
267 let reparsed = PatSection::parse(&buf).expect("reparse");
268 assert_eq!(pat, reparsed);
269 }
270
271 #[test]
272 fn serialize_round_trip_many_programs() {
273 let entries: Vec<PatEntry> = (1..=5)
274 .map(|i| PatEntry {
275 program_number: i,
276 pid: 0x1000 + i,
277 })
278 .collect();
279 let pat = PatSection {
280 transport_stream_id: 0xABCD,
281 version_number: 3,
282 current_next_indicator: true,
283 section_number: 0,
284 last_section_number: 0,
285 entries,
286 };
287 let mut buf = vec![0u8; pat.serialized_len()];
288 pat.serialize_into(&mut buf).unwrap();
289 let reparsed = PatSection::parse(&buf).unwrap();
290 assert_eq!(pat, reparsed);
291 }
292
293 #[test]
294 fn parse_rejects_zero_section_length() {
295 let mut buf = vec![0u8; 64];
296 buf[0] = TABLE_ID;
297 buf[1] = 0xB0;
298 buf[2] = 0x00;
299 for b in &mut buf[3..] {
300 *b = 0xFF;
301 }
302 assert!(matches!(
303 PatSection::parse(&buf).unwrap_err(),
304 Error::SectionLengthOverflow { .. }
305 ));
306 }
307
308 #[test]
309 fn network_pid_entry_identified_by_program_number_0() {
310 let bytes = build_pat(1, 0, &[(0, 0x0010), (1, 0x0100)]);
311 let pat = PatSection::parse(&bytes).unwrap();
312 assert_eq!(pat.nit_pid(), Some(0x0010));
313 assert_eq!(pat.programmes().count(), 1);
314 assert_eq!(pat.programmes().next().unwrap().program_number, 1);
315 }
316
317 #[test]
318 fn parse_rejects_trailing_slack_bytes() {
319 let mut bytes = build_pat(1, 0, &[(1, 0x100)]);
320 let sl = (bytes.len() - MIN_HEADER_LEN) as u16;
321 bytes[1] = (bytes[1] & 0xF0) | (((sl + 2) >> 8) as u8 & 0x0F);
322 bytes[2] = ((sl + 2) & 0xFF) as u8;
323 bytes.extend_from_slice(&[0xFF, 0xFF]);
324 let crc_pos = bytes.len();
325 let crc = dvb_common::crc32_mpeg2::compute(&bytes[..crc_pos - CRC_LEN]);
326 bytes[crc_pos - CRC_LEN..crc_pos].copy_from_slice(&crc.to_be_bytes());
327 assert!(matches!(
328 PatSection::parse(&bytes).unwrap_err(),
329 Error::BufferTooShort {
330 what: "PatSection trailing entry bytes",
331 ..
332 }
333 ));
334 }
335}