Skip to main content

ntp_proto/
extension.rs

1// Copyright 2026 U.S. Federal Government (in countries where recognized)
2// SPDX-License-Identifier: Apache-2.0
3
4//! NTP extension field parsing and NTS (Network Time Security) extension types.
5//!
6//! Extension fields follow the NTPv4 extension field format defined in RFC 7822,
7//! appended after the 48-byte NTP packet header. NTS (RFC 8915) defines specific
8//! extension field types for authenticated NTP.
9//!
10//! # Extension Field Format (RFC 7822)
11//!
12//! ```text
13//!  0                   1                   2                   3
14//!  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
15//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
16//! |          Field Type           |        Field Length           |
17//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
18//! .                                                               .
19//! .                       Field Value (variable)                  .
20//! .                                                               .
21//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
22//! ```
23
24#[cfg(all(feature = "alloc", not(feature = "std")))]
25use alloc::vec;
26#[cfg(all(feature = "alloc", not(feature = "std")))]
27use alloc::vec::Vec;
28#[cfg(feature = "std")]
29use std::io;
30
31use crate::error::ParseError;
32
33/// A borrowed view of an extension field (no allocation).
34///
35/// This type references data within the original byte buffer, avoiding
36/// the heap allocation required by [`ExtensionField`].
37#[derive(Clone, Debug, Eq, Hash, PartialEq)]
38pub struct ExtensionFieldRef<'a> {
39    /// The extension field type code.
40    pub field_type: u16,
41    /// The extension field value (variable length, excluding the 4-byte header).
42    pub value: &'a [u8],
43}
44
45/// Iterator over extension fields in a byte buffer.
46///
47/// Yields [`ExtensionFieldRef`] values without heap allocation.
48/// Created by [`iter_extension_fields`].
49pub struct ExtensionFieldIter<'a> {
50    data: &'a [u8],
51    offset: usize,
52}
53
54impl<'a> Iterator for ExtensionFieldIter<'a> {
55    type Item = Result<ExtensionFieldRef<'a>, ParseError>;
56
57    fn next(&mut self) -> Option<Self::Item> {
58        let remaining = &self.data[self.offset..];
59        if remaining.len() < 4 {
60            return None;
61        }
62
63        let field_type = u16::from_be_bytes([remaining[0], remaining[1]]);
64        let field_length = u16::from_be_bytes([remaining[2], remaining[3]]);
65
66        if field_length < 4 {
67            // Advance past this header to avoid infinite iteration.
68            self.offset = self.data.len();
69            return Some(Err(ParseError::InvalidExtensionLength {
70                declared: field_length,
71            }));
72        }
73
74        let value_length = (field_length - 4) as usize;
75        let value_start = self.offset + 4;
76
77        if value_start + value_length > self.data.len() {
78            self.offset = self.data.len();
79            return Some(Err(ParseError::ExtensionOverflow));
80        }
81
82        let value = &self.data[value_start..value_start + value_length];
83
84        // Advance past value and padding to 4-byte boundary.
85        let padded = (field_length as usize + 3) & !3;
86        let next_offset = self.offset + padded;
87        self.offset = next_offset.min(self.data.len());
88
89        Some(Ok(ExtensionFieldRef { field_type, value }))
90    }
91}
92
93/// Create an iterator over extension fields without allocating.
94///
95/// This is the zero-allocation alternative to [`parse_extension_fields`].
96/// Each item yields a borrowed view of the extension field data.
97pub fn iter_extension_fields(data: &[u8]) -> ExtensionFieldIter<'_> {
98    ExtensionFieldIter { data, offset: 0 }
99}
100
101/// Minimum extension field length per RFC 7822.
102pub const MIN_EXTENSION_FIELD_LENGTH: u16 = 16;
103
104// NTS extension field type codes (RFC 8915 Section 5.7).
105
106/// Unique Identifier extension field type.
107pub const UNIQUE_IDENTIFIER: u16 = 0x0104;
108
109/// NTS Cookie extension field type.
110pub const NTS_COOKIE: u16 = 0x0204;
111
112/// NTS Cookie Placeholder extension field type.
113pub const NTS_COOKIE_PLACEHOLDER: u16 = 0x0304;
114
115/// NTS Authenticator and Encrypted Extensions extension field type.
116pub const NTS_AUTHENTICATOR: u16 = 0x0404;
117
118/// A generic NTP extension field.
119#[cfg(any(feature = "alloc", feature = "std"))]
120#[derive(Clone, Debug, Eq, Hash, PartialEq)]
121pub struct ExtensionField {
122    /// The extension field type code.
123    pub field_type: u16,
124    /// The extension field value (variable length, excluding the 4-byte header).
125    pub value: Vec<u8>,
126}
127
128/// Parse extension fields from a byte buffer without using `std::io`.
129///
130/// Returns a vector of parsed extension fields. Stops when the remaining
131/// data is too short for another extension field header.
132#[cfg(any(feature = "alloc", feature = "std"))]
133pub fn parse_extension_fields_buf(data: &[u8]) -> Result<Vec<ExtensionField>, ParseError> {
134    iter_extension_fields(data)
135        .map(|r| {
136            r.map(|ef_ref| ExtensionField {
137                field_type: ef_ref.field_type,
138                value: ef_ref.value.to_vec(),
139            })
140        })
141        .collect()
142}
143
144/// Serialize extension fields into a byte buffer without using `std::io`.
145///
146/// Each field is padded to a 4-byte boundary with zero bytes.
147/// Returns the number of bytes written.
148#[cfg(any(feature = "alloc", feature = "std"))]
149pub fn write_extension_fields_buf(
150    fields: &[ExtensionField],
151    buf: &mut [u8],
152) -> Result<usize, ParseError> {
153    let mut offset = 0;
154
155    for field in fields {
156        let field_length = 4 + field.value.len();
157        let padded = (field_length + 3) & !3;
158
159        if offset + padded > buf.len() {
160            return Err(ParseError::BufferTooShort {
161                needed: offset + padded,
162                available: buf.len(),
163            });
164        }
165
166        let fl = field_length as u16;
167        buf[offset..offset + 2].copy_from_slice(&field.field_type.to_be_bytes());
168        buf[offset + 2..offset + 4].copy_from_slice(&fl.to_be_bytes());
169        buf[offset + 4..offset + 4 + field.value.len()].copy_from_slice(&field.value);
170
171        // Zero-fill padding.
172        for b in &mut buf[offset + field_length..offset + padded] {
173            *b = 0;
174        }
175
176        offset += padded;
177    }
178
179    Ok(offset)
180}
181
182/// Parse extension fields from data following the 48-byte NTP header.
183///
184/// Returns a vector of parsed extension fields. Stops when the remaining
185/// data is too short for another extension field header.
186#[cfg(feature = "std")]
187pub fn parse_extension_fields(data: &[u8]) -> io::Result<Vec<ExtensionField>> {
188    parse_extension_fields_buf(data).map_err(io::Error::from)
189}
190
191/// Serialize extension fields to a byte vector.
192///
193/// Each field is padded to a 4-byte boundary with zero bytes.
194#[cfg(feature = "std")]
195pub fn write_extension_fields(fields: &[ExtensionField]) -> io::Result<Vec<u8>> {
196    // Calculate total size needed.
197    let total: usize = fields.iter().map(|f| ((4 + f.value.len()) + 3) & !3).sum();
198    let mut buf = vec![0u8; total];
199    write_extension_fields_buf(fields, &mut buf)?;
200    Ok(buf)
201}
202
203/// NTS Unique Identifier extension field (RFC 8915 Section 5.3).
204///
205/// Contains random data for replay protection at the NTS level.
206/// The client generates this value and the server echoes it back.
207#[cfg(any(feature = "alloc", feature = "std"))]
208#[derive(Clone, Debug, Eq, PartialEq)]
209pub struct UniqueIdentifier(pub Vec<u8>);
210
211#[cfg(any(feature = "alloc", feature = "std"))]
212impl UniqueIdentifier {
213    /// Create a Unique Identifier from raw bytes.
214    pub fn new(data: Vec<u8>) -> Self {
215        UniqueIdentifier(data)
216    }
217
218    /// Convert to a generic extension field.
219    pub fn to_extension_field(&self) -> ExtensionField {
220        ExtensionField {
221            field_type: UNIQUE_IDENTIFIER,
222            value: self.0.clone(),
223        }
224    }
225
226    /// Try to extract from a generic extension field.
227    pub fn from_extension_field(ef: &ExtensionField) -> Option<Self> {
228        if ef.field_type == UNIQUE_IDENTIFIER {
229            Some(UniqueIdentifier(ef.value.clone()))
230        } else {
231            None
232        }
233    }
234}
235
236/// NTS Cookie extension field (RFC 8915 Section 5.4).
237///
238/// Contains an opaque cookie provided by the NTS-KE server.
239/// Each cookie is used exactly once per NTP request.
240#[cfg(any(feature = "alloc", feature = "std"))]
241#[derive(Clone, Debug, Eq, PartialEq)]
242pub struct NtsCookie(pub Vec<u8>);
243
244#[cfg(any(feature = "alloc", feature = "std"))]
245impl NtsCookie {
246    /// Create an NTS Cookie from raw bytes.
247    pub fn new(data: Vec<u8>) -> Self {
248        NtsCookie(data)
249    }
250
251    /// Convert to a generic extension field.
252    pub fn to_extension_field(&self) -> ExtensionField {
253        ExtensionField {
254            field_type: NTS_COOKIE,
255            value: self.0.clone(),
256        }
257    }
258
259    /// Try to extract from a generic extension field.
260    pub fn from_extension_field(ef: &ExtensionField) -> Option<Self> {
261        if ef.field_type == NTS_COOKIE {
262            Some(NtsCookie(ef.value.clone()))
263        } else {
264            None
265        }
266    }
267}
268
269/// NTS Cookie Placeholder extension field (RFC 8915 Section 5.5).
270///
271/// Signals to the server that the client wants to receive additional cookies.
272/// The placeholder size should match the expected cookie size.
273#[cfg(any(feature = "alloc", feature = "std"))]
274#[derive(Clone, Debug, Eq, PartialEq)]
275pub struct NtsCookiePlaceholder {
276    /// Size of the placeholder body in bytes.
277    pub size: usize,
278}
279
280#[cfg(any(feature = "alloc", feature = "std"))]
281impl NtsCookiePlaceholder {
282    /// Create a cookie placeholder of the given size.
283    pub fn new(size: usize) -> Self {
284        NtsCookiePlaceholder { size }
285    }
286
287    /// Convert to a generic extension field.
288    pub fn to_extension_field(&self) -> ExtensionField {
289        ExtensionField {
290            field_type: NTS_COOKIE_PLACEHOLDER,
291            value: vec![0u8; self.size],
292        }
293    }
294}
295
296/// NTS Authenticator and Encrypted Extensions extension field (RFC 8915 Section 5.6).
297///
298/// Contains the AEAD nonce and ciphertext. The ciphertext includes any
299/// encrypted extension fields plus the AEAD authentication tag.
300#[cfg(any(feature = "alloc", feature = "std"))]
301#[derive(Clone, Debug, Eq, PartialEq)]
302pub struct NtsAuthenticator {
303    /// The AEAD nonce.
304    pub nonce: Vec<u8>,
305    /// The AEAD ciphertext (encrypted extensions + authentication tag).
306    pub ciphertext: Vec<u8>,
307}
308
309#[cfg(any(feature = "alloc", feature = "std"))]
310impl NtsAuthenticator {
311    /// Create an NTS Authenticator.
312    pub fn new(nonce: Vec<u8>, ciphertext: Vec<u8>) -> Self {
313        NtsAuthenticator { nonce, ciphertext }
314    }
315
316    /// Convert to a generic extension field.
317    ///
318    /// The value format is: nonce_length (u16) + nonce + ciphertext_length (u16) + ciphertext.
319    pub fn to_extension_field(&self) -> ExtensionField {
320        let mut value = Vec::new();
321        // Nonce length (u16 BE) + nonce.
322        value.extend_from_slice(&(self.nonce.len() as u16).to_be_bytes());
323        value.extend_from_slice(&self.nonce);
324        // Pad nonce to 4-byte boundary.
325        let nonce_padded = (2 + self.nonce.len() + 3) & !3;
326        let nonce_pad = nonce_padded - (2 + self.nonce.len());
327        value.extend(core::iter::repeat_n(0u8, nonce_pad));
328        // Ciphertext length (u16 BE) + ciphertext.
329        value.extend_from_slice(&(self.ciphertext.len() as u16).to_be_bytes());
330        value.extend_from_slice(&self.ciphertext);
331
332        ExtensionField {
333            field_type: NTS_AUTHENTICATOR,
334            value,
335        }
336    }
337
338    /// Try to extract from a generic extension field.
339    #[cfg(feature = "std")]
340    pub fn from_extension_field(ef: &ExtensionField) -> io::Result<Option<Self>> {
341        Self::from_extension_field_buf(ef).map_err(io::Error::from)
342    }
343
344    /// Try to extract from a generic extension field without using `std::io`.
345    pub fn from_extension_field_buf(ef: &ExtensionField) -> Result<Option<Self>, ParseError> {
346        if ef.field_type != NTS_AUTHENTICATOR {
347            return Ok(None);
348        }
349
350        let data = &ef.value;
351        if data.len() < 2 {
352            return Err(ParseError::BufferTooShort {
353                needed: 2,
354                available: data.len(),
355            });
356        }
357
358        let nonce_len = u16::from_be_bytes([data[0], data[1]]) as usize;
359        let nonce_start = 2;
360
361        if nonce_start + nonce_len > data.len() {
362            return Err(ParseError::ExtensionOverflow);
363        }
364        let nonce = data[nonce_start..nonce_start + nonce_len].to_vec();
365
366        // Skip to padded boundary.
367        let nonce_padded = (2 + nonce_len + 3) & !3;
368        let ct_offset = nonce_padded;
369        if ct_offset + 2 > data.len() {
370            return Err(ParseError::BufferTooShort {
371                needed: ct_offset + 2,
372                available: data.len(),
373            });
374        }
375
376        let ct_len = u16::from_be_bytes([data[ct_offset], data[ct_offset + 1]]) as usize;
377        let ct_start = ct_offset + 2;
378
379        if ct_start + ct_len > data.len() {
380            return Err(ParseError::ExtensionOverflow);
381        }
382        let ciphertext = data[ct_start..ct_start + ct_len].to_vec();
383
384        Ok(Some(NtsAuthenticator { nonce, ciphertext }))
385    }
386}
387
388// ============================================================================
389// Generic extension field registry and dispatch (RFC 7822)
390// ============================================================================
391
392/// A handler trait for processing extension fields.
393///
394/// Implement this trait to create custom extension field handlers that can
395/// be registered in an [`ExtensionRegistry`].
396pub trait ExtensionHandler: Send + Sync {
397    /// Return the extension field type code this handler processes.
398    fn field_type(&self) -> u16;
399
400    /// Process an extension field value.
401    ///
402    /// Returns `Ok(())` if the field was successfully processed, or an error
403    /// describing why processing failed.
404    #[cfg(feature = "std")]
405    fn handle(&self, value: &[u8]) -> io::Result<()>;
406}
407
408/// A registry for extension field handlers per RFC 7822.
409///
410/// This allows applications to register custom handlers for non-NTS extension
411/// field types. The registry dispatches incoming extension fields to the
412/// appropriate handler based on field type.
413///
414/// # Examples
415///
416/// ```
417/// use ntp_proto::extension::{ExtensionRegistry, ExtensionHandler, ExtensionField};
418/// # #[cfg(feature = "std")] {
419/// use std::io;
420///
421/// // Define a custom handler
422/// struct MyHandler;
423///
424/// impl ExtensionHandler for MyHandler {
425///     fn field_type(&self) -> u16 {
426///         0x4000  // Custom field type
427///     }
428///
429///     fn handle(&self, value: &[u8]) -> io::Result<()> {
430///         println!("Received extension field with {} bytes", value.len());
431///         Ok(())
432///     }
433/// }
434///
435/// // Create registry and register handler
436/// let mut registry = ExtensionRegistry::new();
437/// registry.register(Box::new(MyHandler));
438///
439/// // Dispatch an extension field
440/// let field = ExtensionField {
441///     field_type: 0x4000,
442///     value: vec![1, 2, 3, 4],
443/// };
444/// registry.dispatch(&field).unwrap();
445/// # }
446/// ```
447#[cfg(feature = "std")]
448#[derive(Default)]
449pub struct ExtensionRegistry {
450    handlers: Vec<Box<dyn ExtensionHandler>>,
451}
452
453#[cfg(feature = "std")]
454impl ExtensionRegistry {
455    /// Create a new empty extension field registry.
456    pub fn new() -> Self {
457        ExtensionRegistry {
458            handlers: Vec::new(),
459        }
460    }
461
462    /// Register a handler for a specific extension field type.
463    ///
464    /// If a handler for this field type is already registered, it will be
465    /// replaced.
466    pub fn register(&mut self, handler: Box<dyn ExtensionHandler>) {
467        let field_type = handler.field_type();
468        // Remove any existing handler for this type
469        self.handlers.retain(|h| h.field_type() != field_type);
470        self.handlers.push(handler);
471    }
472
473    /// Dispatch an extension field to the registered handler.
474    ///
475    /// Returns `Ok(())` if a handler was found and successfully processed the
476    /// field. Returns an error if no handler is registered for this field type
477    /// or if the handler returns an error.
478    pub fn dispatch(&self, field: &ExtensionField) -> io::Result<()> {
479        for handler in &self.handlers {
480            if handler.field_type() == field.field_type {
481                return handler.handle(&field.value);
482            }
483        }
484        Err(io::Error::new(
485            io::ErrorKind::Unsupported,
486            format!(
487                "no handler registered for extension field type 0x{:04X}",
488                field.field_type
489            ),
490        ))
491    }
492
493    /// Dispatch all extension fields in a list.
494    ///
495    /// Processes each field in sequence. Stops and returns an error on the
496    /// first failure. Ignores fields with no registered handler unless
497    /// `require_handlers` is true.
498    pub fn dispatch_all(
499        &self,
500        fields: &[ExtensionField],
501        require_handlers: bool,
502    ) -> io::Result<()> {
503        for field in fields {
504            if let Err(e) = self.dispatch(field)
505                && (require_handlers || e.kind() != io::ErrorKind::Unsupported)
506            {
507                return Err(e);
508            }
509            // Ignore "no handler" errors if require_handlers is false
510        }
511        Ok(())
512    }
513
514    /// Check if a handler is registered for the given field type.
515    pub fn has_handler(&self, field_type: u16) -> bool {
516        self.handlers.iter().any(|h| h.field_type() == field_type)
517    }
518
519    /// Return the number of registered handlers.
520    pub fn len(&self) -> usize {
521        self.handlers.len()
522    }
523
524    /// Return true if no handlers are registered.
525    pub fn is_empty(&self) -> bool {
526        self.handlers.is_empty()
527    }
528}
529
530#[cfg(all(test, feature = "std"))]
531mod tests {
532    use super::*;
533
534    #[test]
535    fn test_parse_empty() {
536        let fields = parse_extension_fields(&[]).unwrap();
537        assert!(fields.is_empty());
538    }
539
540    #[test]
541    fn test_roundtrip_single_field() {
542        let field = ExtensionField {
543            field_type: UNIQUE_IDENTIFIER,
544            value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
545        };
546        let buf = write_extension_fields(std::slice::from_ref(&field)).unwrap();
547        let parsed = parse_extension_fields(&buf).unwrap();
548        assert_eq!(parsed.len(), 1);
549        assert_eq!(parsed[0], field);
550    }
551
552    #[test]
553    fn test_roundtrip_multiple_fields() {
554        let fields = vec![
555            ExtensionField {
556                field_type: UNIQUE_IDENTIFIER,
557                value: vec![0xAA; 32],
558            },
559            ExtensionField {
560                field_type: NTS_COOKIE,
561                value: vec![0xBB; 64],
562            },
563        ];
564        let buf = write_extension_fields(&fields).unwrap();
565        let parsed = parse_extension_fields(&buf).unwrap();
566        assert_eq!(parsed.len(), 2);
567        assert_eq!(parsed[0], fields[0]);
568        assert_eq!(parsed[1], fields[1]);
569    }
570
571    #[test]
572    fn test_padding() {
573        // Value of 5 bytes: 4 header + 5 value = 9 bytes, padded to 12.
574        let field = ExtensionField {
575            field_type: 0x1234,
576            value: vec![1, 2, 3, 4, 5],
577        };
578        let buf = write_extension_fields(&[field]).unwrap();
579        assert_eq!(buf.len(), 12); // 4 header + 5 value + 3 padding
580    }
581
582    #[test]
583    fn test_unique_identifier_conversion() {
584        let uid = UniqueIdentifier::new(vec![0x42; 32]);
585        let ef = uid.to_extension_field();
586        assert_eq!(ef.field_type, UNIQUE_IDENTIFIER);
587        let back = UniqueIdentifier::from_extension_field(&ef).unwrap();
588        assert_eq!(back.0, vec![0x42; 32]);
589    }
590
591    #[test]
592    fn test_nts_cookie_conversion() {
593        let cookie = NtsCookie::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
594        let ef = cookie.to_extension_field();
595        assert_eq!(ef.field_type, NTS_COOKIE);
596        let back = NtsCookie::from_extension_field(&ef).unwrap();
597        assert_eq!(back.0, vec![0xDE, 0xAD, 0xBE, 0xEF]);
598    }
599
600    // ── ExtensionRegistry tests ───────────────────────────────────
601
602    struct TestHandler {
603        field_type: u16,
604        call_count: std::sync::Arc<std::sync::Mutex<usize>>,
605    }
606
607    impl ExtensionHandler for TestHandler {
608        fn field_type(&self) -> u16 {
609            self.field_type
610        }
611
612        fn handle(&self, _value: &[u8]) -> io::Result<()> {
613            *self.call_count.lock().unwrap() += 1;
614            Ok(())
615        }
616    }
617
618    #[test]
619    fn test_registry_register_and_dispatch() {
620        let call_count = std::sync::Arc::new(std::sync::Mutex::new(0));
621        let mut registry = ExtensionRegistry::new();
622
623        registry.register(Box::new(TestHandler {
624            field_type: 0x1234,
625            call_count: call_count.clone(),
626        }));
627
628        let field = ExtensionField {
629            field_type: 0x1234,
630            value: vec![1, 2, 3],
631        };
632
633        registry.dispatch(&field).unwrap();
634        assert_eq!(*call_count.lock().unwrap(), 1);
635    }
636
637    #[test]
638    fn test_registry_no_handler() {
639        let registry = ExtensionRegistry::new();
640        let field = ExtensionField {
641            field_type: 0x9999,
642            value: vec![],
643        };
644
645        let result = registry.dispatch(&field);
646        assert!(result.is_err());
647        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::Unsupported);
648    }
649
650    #[test]
651    fn test_registry_has_handler() {
652        let mut registry = ExtensionRegistry::new();
653        assert!(!registry.has_handler(0x4000));
654
655        registry.register(Box::new(TestHandler {
656            field_type: 0x4000,
657            call_count: std::sync::Arc::new(std::sync::Mutex::new(0)),
658        }));
659
660        assert!(registry.has_handler(0x4000));
661        assert!(!registry.has_handler(0x4001));
662    }
663
664    #[test]
665    fn test_registry_replace_handler() {
666        let count1 = std::sync::Arc::new(std::sync::Mutex::new(0));
667        let count2 = std::sync::Arc::new(std::sync::Mutex::new(0));
668
669        let mut registry = ExtensionRegistry::new();
670        registry.register(Box::new(TestHandler {
671            field_type: 0x5000,
672            call_count: count1.clone(),
673        }));
674        registry.register(Box::new(TestHandler {
675            field_type: 0x5000,
676            call_count: count2.clone(),
677        }));
678
679        let field = ExtensionField {
680            field_type: 0x5000,
681            value: vec![],
682        };
683
684        registry.dispatch(&field).unwrap();
685
686        // Only the second handler should be called
687        assert_eq!(*count1.lock().unwrap(), 0);
688        assert_eq!(*count2.lock().unwrap(), 1);
689    }
690
691    #[test]
692    fn test_registry_dispatch_all() {
693        let count1 = std::sync::Arc::new(std::sync::Mutex::new(0));
694        let count2 = std::sync::Arc::new(std::sync::Mutex::new(0));
695
696        let mut registry = ExtensionRegistry::new();
697        registry.register(Box::new(TestHandler {
698            field_type: 0x6000,
699            call_count: count1.clone(),
700        }));
701        registry.register(Box::new(TestHandler {
702            field_type: 0x6001,
703            call_count: count2.clone(),
704        }));
705
706        let fields = vec![
707            ExtensionField {
708                field_type: 0x6000,
709                value: vec![1],
710            },
711            ExtensionField {
712                field_type: 0x6001,
713                value: vec![2],
714            },
715        ];
716
717        registry.dispatch_all(&fields, false).unwrap();
718        assert_eq!(*count1.lock().unwrap(), 1);
719        assert_eq!(*count2.lock().unwrap(), 1);
720    }
721
722    #[test]
723    fn test_registry_dispatch_all_ignores_unknown() {
724        let count = std::sync::Arc::new(std::sync::Mutex::new(0));
725        let mut registry = ExtensionRegistry::new();
726        registry.register(Box::new(TestHandler {
727            field_type: 0x7000,
728            call_count: count.clone(),
729        }));
730
731        let fields = vec![
732            ExtensionField {
733                field_type: 0x7000,
734                value: vec![1],
735            },
736            ExtensionField {
737                field_type: 0x9999, // Unknown field
738                value: vec![2],
739            },
740        ];
741
742        // Should succeed (ignore unknown fields when require_handlers = false)
743        registry.dispatch_all(&fields, false).unwrap();
744        assert_eq!(*count.lock().unwrap(), 1);
745    }
746
747    #[test]
748    fn test_registry_dispatch_all_requires_handlers() {
749        let count = std::sync::Arc::new(std::sync::Mutex::new(0));
750        let mut registry = ExtensionRegistry::new();
751        registry.register(Box::new(TestHandler {
752            field_type: 0x8000,
753            call_count: count.clone(),
754        }));
755
756        let fields = vec![
757            ExtensionField {
758                field_type: 0x8000,
759                value: vec![1],
760            },
761            ExtensionField {
762                field_type: 0x9999, // Unknown field
763                value: vec![2],
764            },
765        ];
766
767        // Should fail (unknown field when require_handlers = true)
768        let result = registry.dispatch_all(&fields, true);
769        assert!(result.is_err());
770        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::Unsupported);
771    }
772
773    #[test]
774    fn test_registry_len_and_is_empty() {
775        let mut registry = ExtensionRegistry::new();
776        assert!(registry.is_empty());
777        assert_eq!(registry.len(), 0);
778
779        registry.register(Box::new(TestHandler {
780            field_type: 0xA000,
781            call_count: std::sync::Arc::new(std::sync::Mutex::new(0)),
782        }));
783
784        assert!(!registry.is_empty());
785        assert_eq!(registry.len(), 1);
786    }
787
788    // ──────────────────────────────────────────────────────────────
789
790    #[test]
791    fn test_nts_authenticator_roundtrip() {
792        let auth = NtsAuthenticator::new(vec![0x11; 16], vec![0x22; 48]);
793        let ef = auth.to_extension_field();
794        assert_eq!(ef.field_type, NTS_AUTHENTICATOR);
795        let back = NtsAuthenticator::from_extension_field(&ef)
796            .unwrap()
797            .unwrap();
798        assert_eq!(back.nonce, vec![0x11; 16]);
799        assert_eq!(back.ciphertext, vec![0x22; 48]);
800    }
801
802    #[test]
803    fn test_cookie_placeholder() {
804        let placeholder = NtsCookiePlaceholder::new(100);
805        let ef = placeholder.to_extension_field();
806        assert_eq!(ef.field_type, NTS_COOKIE_PLACEHOLDER);
807        assert_eq!(ef.value.len(), 100);
808        assert!(ef.value.iter().all(|&b| b == 0));
809    }
810
811    #[test]
812    fn test_parse_truncated_field() {
813        // Only 3 bytes: not enough for the 4-byte header.
814        let data = [0x01, 0x04, 0x00];
815        let fields = parse_extension_fields(&data).unwrap();
816        assert!(fields.is_empty()); // Silently stops, not enough for header
817    }
818
819    #[test]
820    fn test_parse_invalid_length() {
821        // field_length=2 (less than 4).
822        let data = [0x01, 0x04, 0x00, 0x02];
823        let result = parse_extension_fields(&data);
824        assert!(result.is_err());
825    }
826
827    // Buffer-based API tests.
828
829    #[test]
830    fn test_buf_parse_empty() {
831        let fields = parse_extension_fields_buf(&[]).unwrap();
832        assert!(fields.is_empty());
833    }
834
835    #[test]
836    fn test_buf_roundtrip_single_field() {
837        let field = ExtensionField {
838            field_type: UNIQUE_IDENTIFIER,
839            value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
840        };
841
842        // Write to Vec via io API, then to fixed buffer via buf API.
843        let io_buf = write_extension_fields(std::slice::from_ref(&field)).unwrap();
844        let mut buf = vec![0u8; 256];
845        let written = write_extension_fields_buf(std::slice::from_ref(&field), &mut buf).unwrap();
846        assert_eq!(&io_buf[..], &buf[..written]);
847
848        // Parse with buf API.
849        let parsed = parse_extension_fields_buf(&buf[..written]).unwrap();
850        assert_eq!(parsed.len(), 1);
851        assert_eq!(parsed[0], field);
852    }
853
854    #[test]
855    fn test_buf_equivalence_with_io_api() {
856        let fields = vec![
857            ExtensionField {
858                field_type: UNIQUE_IDENTIFIER,
859                value: vec![0xAA; 32],
860            },
861            ExtensionField {
862                field_type: NTS_COOKIE,
863                value: vec![0xBB; 64],
864            },
865        ];
866
867        let io_buf = write_extension_fields(&fields).unwrap();
868        let mut raw_buf = vec![0u8; 512];
869        let written = write_extension_fields_buf(&fields, &mut raw_buf).unwrap();
870
871        // Same output.
872        assert_eq!(&io_buf[..], &raw_buf[..written]);
873
874        // Same parse result.
875        let io_parsed = parse_extension_fields(&io_buf).unwrap();
876        let buf_parsed = parse_extension_fields_buf(&raw_buf[..written]).unwrap();
877        assert_eq!(io_parsed, buf_parsed);
878    }
879
880    #[test]
881    fn test_buf_write_buffer_too_short() {
882        let field = ExtensionField {
883            field_type: UNIQUE_IDENTIFIER,
884            value: vec![0xAA; 32],
885        };
886        let mut tiny_buf = [0u8; 4]; // Too small for 4 header + 32 value.
887        let result = write_extension_fields_buf(&[field], &mut tiny_buf);
888        assert!(result.is_err());
889    }
890
891    #[test]
892    fn test_buf_parse_invalid_length() {
893        let data = [0x01, 0x04, 0x00, 0x02]; // field_length=2 (< 4).
894        let result = parse_extension_fields_buf(&data);
895        assert!(matches!(
896            result,
897            Err(ParseError::InvalidExtensionLength { declared: 2 })
898        ));
899    }
900
901    #[test]
902    fn test_iter_extension_fields() {
903        let fields = vec![
904            ExtensionField {
905                field_type: UNIQUE_IDENTIFIER,
906                value: vec![0xAA; 32],
907            },
908            ExtensionField {
909                field_type: NTS_COOKIE,
910                value: vec![0xBB; 64],
911            },
912        ];
913        let io_buf = write_extension_fields(&fields).unwrap();
914
915        let mut iter = iter_extension_fields(&io_buf);
916
917        let first = iter.next().unwrap().unwrap();
918        assert_eq!(first.field_type, UNIQUE_IDENTIFIER);
919        assert_eq!(first.value, &[0xAA; 32][..]);
920
921        let second = iter.next().unwrap().unwrap();
922        assert_eq!(second.field_type, NTS_COOKIE);
923        assert_eq!(second.value, &[0xBB; 64][..]);
924
925        assert!(iter.next().is_none());
926    }
927
928    #[test]
929    fn test_iter_extension_fields_empty() {
930        let mut iter = iter_extension_fields(&[]);
931        assert!(iter.next().is_none());
932    }
933
934    #[test]
935    fn test_nts_authenticator_buf_roundtrip() {
936        let auth = NtsAuthenticator::new(vec![0x11; 16], vec![0x22; 48]);
937        let ef = auth.to_extension_field();
938        let back = NtsAuthenticator::from_extension_field_buf(&ef)
939            .unwrap()
940            .unwrap();
941        assert_eq!(back.nonce, vec![0x11; 16]);
942        assert_eq!(back.ciphertext, vec![0x22; 48]);
943    }
944}