Skip to main content

ntp_proto/
extension.rs

1// Copyright 2026 U.S. Federal Government (in countries where recognized)
2// SPDX-License-Identifier: Apache-2.0
3
4//! NTP extension field parsing and NTS (Network Time Security) extension types.
5//!
6//! Extension fields follow the NTPv4 extension field format defined in RFC 7822,
7//! appended after the 48-byte NTP packet header. NTS (RFC 8915) defines specific
8//! extension field types for authenticated NTP.
9//!
10//! # Extension Field Format (RFC 7822)
11//!
12//! ```text
13//!  0                   1                   2                   3
14//!  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
15//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
16//! |          Field Type           |        Field Length           |
17//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
18//! .                                                               .
19//! .                       Field Value (variable)                  .
20//! .                                                               .
21//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
22//! ```
23
24#[cfg(all(feature = "alloc", not(feature = "std")))]
25use alloc::vec;
26#[cfg(all(feature = "alloc", not(feature = "std")))]
27use alloc::vec::Vec;
28#[cfg(feature = "std")]
29use std::io;
30
31use crate::error::ParseError;
32
33/// A borrowed view of an extension field (no allocation).
34///
35/// This type references data within the original byte buffer, avoiding
36/// the heap allocation required by [`ExtensionField`].
37#[derive(Clone, Debug, Eq, PartialEq)]
38pub struct ExtensionFieldRef<'a> {
39    /// The extension field type code.
40    pub field_type: u16,
41    /// The extension field value (variable length, excluding the 4-byte header).
42    pub value: &'a [u8],
43}
44
45/// Iterator over extension fields in a byte buffer.
46///
47/// Yields [`ExtensionFieldRef`] values without heap allocation.
48/// Created by [`iter_extension_fields`].
49pub struct ExtensionFieldIter<'a> {
50    data: &'a [u8],
51    offset: usize,
52}
53
54impl<'a> Iterator for ExtensionFieldIter<'a> {
55    type Item = Result<ExtensionFieldRef<'a>, ParseError>;
56
57    fn next(&mut self) -> Option<Self::Item> {
58        let remaining = &self.data[self.offset..];
59        if remaining.len() < 4 {
60            return None;
61        }
62
63        let field_type = u16::from_be_bytes([remaining[0], remaining[1]]);
64        let field_length = u16::from_be_bytes([remaining[2], remaining[3]]);
65
66        if field_length < 4 {
67            return Some(Err(ParseError::InvalidExtensionLength {
68                declared: field_length,
69            }));
70        }
71
72        let value_length = (field_length - 4) as usize;
73        let value_start = self.offset + 4;
74
75        if value_start + value_length > self.data.len() {
76            return Some(Err(ParseError::ExtensionOverflow));
77        }
78
79        let value = &self.data[value_start..value_start + value_length];
80
81        // Advance past value and padding to 4-byte boundary.
82        let padded = (field_length as usize + 3) & !3;
83        let next_offset = self.offset + padded;
84        self.offset = next_offset.min(self.data.len());
85
86        Some(Ok(ExtensionFieldRef { field_type, value }))
87    }
88}
89
90/// Create an iterator over extension fields without allocating.
91///
92/// This is the zero-allocation alternative to [`parse_extension_fields`].
93/// Each item yields a borrowed view of the extension field data.
94pub fn iter_extension_fields(data: &[u8]) -> ExtensionFieldIter<'_> {
95    ExtensionFieldIter { data, offset: 0 }
96}
97
98/// Minimum extension field length per RFC 7822.
99pub const MIN_EXTENSION_FIELD_LENGTH: u16 = 16;
100
101// NTS extension field type codes (RFC 8915 Section 5.7).
102
103/// Unique Identifier extension field type.
104pub const UNIQUE_IDENTIFIER: u16 = 0x0104;
105
106/// NTS Cookie extension field type.
107pub const NTS_COOKIE: u16 = 0x0204;
108
109/// NTS Cookie Placeholder extension field type.
110pub const NTS_COOKIE_PLACEHOLDER: u16 = 0x0304;
111
112/// NTS Authenticator and Encrypted Extensions extension field type.
113pub const NTS_AUTHENTICATOR: u16 = 0x0404;
114
115/// A generic NTP extension field.
116#[cfg(any(feature = "alloc", feature = "std"))]
117#[derive(Clone, Debug, Eq, PartialEq)]
118pub struct ExtensionField {
119    /// The extension field type code.
120    pub field_type: u16,
121    /// The extension field value (variable length, excluding the 4-byte header).
122    pub value: Vec<u8>,
123}
124
125/// Parse extension fields from a byte buffer without using `std::io`.
126///
127/// Returns a vector of parsed extension fields. Stops when the remaining
128/// data is too short for another extension field header.
129#[cfg(any(feature = "alloc", feature = "std"))]
130pub fn parse_extension_fields_buf(data: &[u8]) -> Result<Vec<ExtensionField>, ParseError> {
131    iter_extension_fields(data)
132        .map(|r| {
133            r.map(|ef_ref| ExtensionField {
134                field_type: ef_ref.field_type,
135                value: ef_ref.value.to_vec(),
136            })
137        })
138        .collect()
139}
140
141/// Serialize extension fields into a byte buffer without using `std::io`.
142///
143/// Each field is padded to a 4-byte boundary with zero bytes.
144/// Returns the number of bytes written.
145#[cfg(any(feature = "alloc", feature = "std"))]
146pub fn write_extension_fields_buf(
147    fields: &[ExtensionField],
148    buf: &mut [u8],
149) -> Result<usize, ParseError> {
150    let mut offset = 0;
151
152    for field in fields {
153        let field_length = 4 + field.value.len();
154        let padded = (field_length + 3) & !3;
155
156        if offset + padded > buf.len() {
157            return Err(ParseError::BufferTooShort {
158                needed: offset + padded,
159                available: buf.len(),
160            });
161        }
162
163        let fl = field_length as u16;
164        buf[offset..offset + 2].copy_from_slice(&field.field_type.to_be_bytes());
165        buf[offset + 2..offset + 4].copy_from_slice(&fl.to_be_bytes());
166        buf[offset + 4..offset + 4 + field.value.len()].copy_from_slice(&field.value);
167
168        // Zero-fill padding.
169        for b in &mut buf[offset + field_length..offset + padded] {
170            *b = 0;
171        }
172
173        offset += padded;
174    }
175
176    Ok(offset)
177}
178
179/// Parse extension fields from data following the 48-byte NTP header.
180///
181/// Returns a vector of parsed extension fields. Stops when the remaining
182/// data is too short for another extension field header.
183#[cfg(feature = "std")]
184pub fn parse_extension_fields(data: &[u8]) -> io::Result<Vec<ExtensionField>> {
185    parse_extension_fields_buf(data).map_err(io::Error::from)
186}
187
188/// Serialize extension fields to a byte vector.
189///
190/// Each field is padded to a 4-byte boundary with zero bytes.
191#[cfg(feature = "std")]
192pub fn write_extension_fields(fields: &[ExtensionField]) -> io::Result<Vec<u8>> {
193    // Calculate total size needed.
194    let total: usize = fields.iter().map(|f| ((4 + f.value.len()) + 3) & !3).sum();
195    let mut buf = vec![0u8; total];
196    write_extension_fields_buf(fields, &mut buf)?;
197    Ok(buf)
198}
199
200/// NTS Unique Identifier extension field (RFC 8915 Section 5.3).
201///
202/// Contains random data for replay protection at the NTS level.
203/// The client generates this value and the server echoes it back.
204#[cfg(any(feature = "alloc", feature = "std"))]
205#[derive(Clone, Debug, Eq, PartialEq)]
206pub struct UniqueIdentifier(pub Vec<u8>);
207
208#[cfg(any(feature = "alloc", feature = "std"))]
209impl UniqueIdentifier {
210    /// Create a Unique Identifier from raw bytes.
211    pub fn new(data: Vec<u8>) -> Self {
212        UniqueIdentifier(data)
213    }
214
215    /// Convert to a generic extension field.
216    pub fn to_extension_field(&self) -> ExtensionField {
217        ExtensionField {
218            field_type: UNIQUE_IDENTIFIER,
219            value: self.0.clone(),
220        }
221    }
222
223    /// Try to extract from a generic extension field.
224    pub fn from_extension_field(ef: &ExtensionField) -> Option<Self> {
225        if ef.field_type == UNIQUE_IDENTIFIER {
226            Some(UniqueIdentifier(ef.value.clone()))
227        } else {
228            None
229        }
230    }
231}
232
233/// NTS Cookie extension field (RFC 8915 Section 5.4).
234///
235/// Contains an opaque cookie provided by the NTS-KE server.
236/// Each cookie is used exactly once per NTP request.
237#[cfg(any(feature = "alloc", feature = "std"))]
238#[derive(Clone, Debug, Eq, PartialEq)]
239pub struct NtsCookie(pub Vec<u8>);
240
241#[cfg(any(feature = "alloc", feature = "std"))]
242impl NtsCookie {
243    /// Create an NTS Cookie from raw bytes.
244    pub fn new(data: Vec<u8>) -> Self {
245        NtsCookie(data)
246    }
247
248    /// Convert to a generic extension field.
249    pub fn to_extension_field(&self) -> ExtensionField {
250        ExtensionField {
251            field_type: NTS_COOKIE,
252            value: self.0.clone(),
253        }
254    }
255
256    /// Try to extract from a generic extension field.
257    pub fn from_extension_field(ef: &ExtensionField) -> Option<Self> {
258        if ef.field_type == NTS_COOKIE {
259            Some(NtsCookie(ef.value.clone()))
260        } else {
261            None
262        }
263    }
264}
265
266/// NTS Cookie Placeholder extension field (RFC 8915 Section 5.5).
267///
268/// Signals to the server that the client wants to receive additional cookies.
269/// The placeholder size should match the expected cookie size.
270#[cfg(any(feature = "alloc", feature = "std"))]
271#[derive(Clone, Debug, Eq, PartialEq)]
272pub struct NtsCookiePlaceholder {
273    /// Size of the placeholder body in bytes.
274    pub size: usize,
275}
276
277#[cfg(any(feature = "alloc", feature = "std"))]
278impl NtsCookiePlaceholder {
279    /// Create a cookie placeholder of the given size.
280    pub fn new(size: usize) -> Self {
281        NtsCookiePlaceholder { size }
282    }
283
284    /// Convert to a generic extension field.
285    pub fn to_extension_field(&self) -> ExtensionField {
286        ExtensionField {
287            field_type: NTS_COOKIE_PLACEHOLDER,
288            value: vec![0u8; self.size],
289        }
290    }
291}
292
293/// NTS Authenticator and Encrypted Extensions extension field (RFC 8915 Section 5.6).
294///
295/// Contains the AEAD nonce and ciphertext. The ciphertext includes any
296/// encrypted extension fields plus the AEAD authentication tag.
297#[cfg(any(feature = "alloc", feature = "std"))]
298#[derive(Clone, Debug, Eq, PartialEq)]
299pub struct NtsAuthenticator {
300    /// The AEAD nonce.
301    pub nonce: Vec<u8>,
302    /// The AEAD ciphertext (encrypted extensions + authentication tag).
303    pub ciphertext: Vec<u8>,
304}
305
306#[cfg(any(feature = "alloc", feature = "std"))]
307impl NtsAuthenticator {
308    /// Create an NTS Authenticator.
309    pub fn new(nonce: Vec<u8>, ciphertext: Vec<u8>) -> Self {
310        NtsAuthenticator { nonce, ciphertext }
311    }
312
313    /// Convert to a generic extension field.
314    ///
315    /// The value format is: nonce_length (u16) + nonce + ciphertext_length (u16) + ciphertext.
316    pub fn to_extension_field(&self) -> ExtensionField {
317        let mut value = Vec::new();
318        // Nonce length (u16 BE) + nonce.
319        value.extend_from_slice(&(self.nonce.len() as u16).to_be_bytes());
320        value.extend_from_slice(&self.nonce);
321        // Pad nonce to 4-byte boundary.
322        let nonce_padded = (2 + self.nonce.len() + 3) & !3;
323        let nonce_pad = nonce_padded - (2 + self.nonce.len());
324        value.extend(core::iter::repeat_n(0u8, nonce_pad));
325        // Ciphertext length (u16 BE) + ciphertext.
326        value.extend_from_slice(&(self.ciphertext.len() as u16).to_be_bytes());
327        value.extend_from_slice(&self.ciphertext);
328
329        ExtensionField {
330            field_type: NTS_AUTHENTICATOR,
331            value,
332        }
333    }
334
335    /// Try to extract from a generic extension field.
336    #[cfg(feature = "std")]
337    pub fn from_extension_field(ef: &ExtensionField) -> io::Result<Option<Self>> {
338        Self::from_extension_field_buf(ef).map_err(io::Error::from)
339    }
340
341    /// Try to extract from a generic extension field without using `std::io`.
342    pub fn from_extension_field_buf(ef: &ExtensionField) -> Result<Option<Self>, ParseError> {
343        if ef.field_type != NTS_AUTHENTICATOR {
344            return Ok(None);
345        }
346
347        let data = &ef.value;
348        if data.len() < 2 {
349            return Err(ParseError::BufferTooShort {
350                needed: 2,
351                available: data.len(),
352            });
353        }
354
355        let nonce_len = u16::from_be_bytes([data[0], data[1]]) as usize;
356        let nonce_start = 2;
357
358        if nonce_start + nonce_len > data.len() {
359            return Err(ParseError::ExtensionOverflow);
360        }
361        let nonce = data[nonce_start..nonce_start + nonce_len].to_vec();
362
363        // Skip to padded boundary.
364        let nonce_padded = (2 + nonce_len + 3) & !3;
365        let ct_offset = nonce_padded;
366        if ct_offset + 2 > data.len() {
367            return Err(ParseError::BufferTooShort {
368                needed: ct_offset + 2,
369                available: data.len(),
370            });
371        }
372
373        let ct_len = u16::from_be_bytes([data[ct_offset], data[ct_offset + 1]]) as usize;
374        let ct_start = ct_offset + 2;
375
376        if ct_start + ct_len > data.len() {
377            return Err(ParseError::ExtensionOverflow);
378        }
379        let ciphertext = data[ct_start..ct_start + ct_len].to_vec();
380
381        Ok(Some(NtsAuthenticator { nonce, ciphertext }))
382    }
383}
384
385// ============================================================================
386// Generic extension field registry and dispatch (RFC 7822)
387// ============================================================================
388
389/// A handler trait for processing extension fields.
390///
391/// Implement this trait to create custom extension field handlers that can
392/// be registered in an [`ExtensionRegistry`].
393pub trait ExtensionHandler: Send + Sync {
394    /// Return the extension field type code this handler processes.
395    fn field_type(&self) -> u16;
396
397    /// Process an extension field value.
398    ///
399    /// Returns `Ok(())` if the field was successfully processed, or an error
400    /// describing why processing failed.
401    #[cfg(feature = "std")]
402    fn handle(&self, value: &[u8]) -> io::Result<()>;
403}
404
405/// A registry for extension field handlers per RFC 7822.
406///
407/// This allows applications to register custom handlers for non-NTS extension
408/// field types. The registry dispatches incoming extension fields to the
409/// appropriate handler based on field type.
410///
411/// # Examples
412///
413/// ```
414/// use ntp_proto::extension::{ExtensionRegistry, ExtensionHandler, ExtensionField};
415/// # #[cfg(feature = "std")] {
416/// use std::io;
417///
418/// // Define a custom handler
419/// struct MyHandler;
420///
421/// impl ExtensionHandler for MyHandler {
422///     fn field_type(&self) -> u16 {
423///         0x4000  // Custom field type
424///     }
425///
426///     fn handle(&self, value: &[u8]) -> io::Result<()> {
427///         println!("Received extension field with {} bytes", value.len());
428///         Ok(())
429///     }
430/// }
431///
432/// // Create registry and register handler
433/// let mut registry = ExtensionRegistry::new();
434/// registry.register(Box::new(MyHandler));
435///
436/// // Dispatch an extension field
437/// let field = ExtensionField {
438///     field_type: 0x4000,
439///     value: vec![1, 2, 3, 4],
440/// };
441/// registry.dispatch(&field).unwrap();
442/// # }
443/// ```
444#[cfg(feature = "std")]
445#[derive(Default)]
446pub struct ExtensionRegistry {
447    handlers: Vec<Box<dyn ExtensionHandler>>,
448}
449
450#[cfg(feature = "std")]
451impl ExtensionRegistry {
452    /// Create a new empty extension field registry.
453    pub fn new() -> Self {
454        ExtensionRegistry {
455            handlers: Vec::new(),
456        }
457    }
458
459    /// Register a handler for a specific extension field type.
460    ///
461    /// If a handler for this field type is already registered, it will be
462    /// replaced.
463    pub fn register(&mut self, handler: Box<dyn ExtensionHandler>) {
464        let field_type = handler.field_type();
465        // Remove any existing handler for this type
466        self.handlers.retain(|h| h.field_type() != field_type);
467        self.handlers.push(handler);
468    }
469
470    /// Dispatch an extension field to the registered handler.
471    ///
472    /// Returns `Ok(())` if a handler was found and successfully processed the
473    /// field. Returns an error if no handler is registered for this field type
474    /// or if the handler returns an error.
475    pub fn dispatch(&self, field: &ExtensionField) -> io::Result<()> {
476        for handler in &self.handlers {
477            if handler.field_type() == field.field_type {
478                return handler.handle(&field.value);
479            }
480        }
481        Err(io::Error::new(
482            io::ErrorKind::Unsupported,
483            format!(
484                "no handler registered for extension field type 0x{:04X}",
485                field.field_type
486            ),
487        ))
488    }
489
490    /// Dispatch all extension fields in a list.
491    ///
492    /// Processes each field in sequence. Stops and returns an error on the
493    /// first failure. Ignores fields with no registered handler unless
494    /// `require_handlers` is true.
495    pub fn dispatch_all(
496        &self,
497        fields: &[ExtensionField],
498        require_handlers: bool,
499    ) -> io::Result<()> {
500        for field in fields {
501            if let Err(e) = self.dispatch(field)
502                && (require_handlers || e.kind() != io::ErrorKind::Unsupported)
503            {
504                return Err(e);
505            }
506            // Ignore "no handler" errors if require_handlers is false
507        }
508        Ok(())
509    }
510
511    /// Check if a handler is registered for the given field type.
512    pub fn has_handler(&self, field_type: u16) -> bool {
513        self.handlers.iter().any(|h| h.field_type() == field_type)
514    }
515
516    /// Return the number of registered handlers.
517    pub fn len(&self) -> usize {
518        self.handlers.len()
519    }
520
521    /// Return true if no handlers are registered.
522    pub fn is_empty(&self) -> bool {
523        self.handlers.is_empty()
524    }
525}
526
527#[cfg(all(test, feature = "std"))]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_parse_empty() {
533        let fields = parse_extension_fields(&[]).unwrap();
534        assert!(fields.is_empty());
535    }
536
537    #[test]
538    fn test_roundtrip_single_field() {
539        let field = ExtensionField {
540            field_type: UNIQUE_IDENTIFIER,
541            value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
542        };
543        let buf = write_extension_fields(std::slice::from_ref(&field)).unwrap();
544        let parsed = parse_extension_fields(&buf).unwrap();
545        assert_eq!(parsed.len(), 1);
546        assert_eq!(parsed[0], field);
547    }
548
549    #[test]
550    fn test_roundtrip_multiple_fields() {
551        let fields = vec![
552            ExtensionField {
553                field_type: UNIQUE_IDENTIFIER,
554                value: vec![0xAA; 32],
555            },
556            ExtensionField {
557                field_type: NTS_COOKIE,
558                value: vec![0xBB; 64],
559            },
560        ];
561        let buf = write_extension_fields(&fields).unwrap();
562        let parsed = parse_extension_fields(&buf).unwrap();
563        assert_eq!(parsed.len(), 2);
564        assert_eq!(parsed[0], fields[0]);
565        assert_eq!(parsed[1], fields[1]);
566    }
567
568    #[test]
569    fn test_padding() {
570        // Value of 5 bytes: 4 header + 5 value = 9 bytes, padded to 12.
571        let field = ExtensionField {
572            field_type: 0x1234,
573            value: vec![1, 2, 3, 4, 5],
574        };
575        let buf = write_extension_fields(&[field]).unwrap();
576        assert_eq!(buf.len(), 12); // 4 header + 5 value + 3 padding
577    }
578
579    #[test]
580    fn test_unique_identifier_conversion() {
581        let uid = UniqueIdentifier::new(vec![0x42; 32]);
582        let ef = uid.to_extension_field();
583        assert_eq!(ef.field_type, UNIQUE_IDENTIFIER);
584        let back = UniqueIdentifier::from_extension_field(&ef).unwrap();
585        assert_eq!(back.0, vec![0x42; 32]);
586    }
587
588    #[test]
589    fn test_nts_cookie_conversion() {
590        let cookie = NtsCookie::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
591        let ef = cookie.to_extension_field();
592        assert_eq!(ef.field_type, NTS_COOKIE);
593        let back = NtsCookie::from_extension_field(&ef).unwrap();
594        assert_eq!(back.0, vec![0xDE, 0xAD, 0xBE, 0xEF]);
595    }
596
597    // ── ExtensionRegistry tests ───────────────────────────────────
598
599    struct TestHandler {
600        field_type: u16,
601        call_count: std::sync::Arc<std::sync::Mutex<usize>>,
602    }
603
604    impl ExtensionHandler for TestHandler {
605        fn field_type(&self) -> u16 {
606            self.field_type
607        }
608
609        fn handle(&self, _value: &[u8]) -> io::Result<()> {
610            *self.call_count.lock().unwrap() += 1;
611            Ok(())
612        }
613    }
614
615    #[test]
616    fn test_registry_register_and_dispatch() {
617        let call_count = std::sync::Arc::new(std::sync::Mutex::new(0));
618        let mut registry = ExtensionRegistry::new();
619
620        registry.register(Box::new(TestHandler {
621            field_type: 0x1234,
622            call_count: call_count.clone(),
623        }));
624
625        let field = ExtensionField {
626            field_type: 0x1234,
627            value: vec![1, 2, 3],
628        };
629
630        registry.dispatch(&field).unwrap();
631        assert_eq!(*call_count.lock().unwrap(), 1);
632    }
633
634    #[test]
635    fn test_registry_no_handler() {
636        let registry = ExtensionRegistry::new();
637        let field = ExtensionField {
638            field_type: 0x9999,
639            value: vec![],
640        };
641
642        let result = registry.dispatch(&field);
643        assert!(result.is_err());
644        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::Unsupported);
645    }
646
647    #[test]
648    fn test_registry_has_handler() {
649        let mut registry = ExtensionRegistry::new();
650        assert!(!registry.has_handler(0x4000));
651
652        registry.register(Box::new(TestHandler {
653            field_type: 0x4000,
654            call_count: std::sync::Arc::new(std::sync::Mutex::new(0)),
655        }));
656
657        assert!(registry.has_handler(0x4000));
658        assert!(!registry.has_handler(0x4001));
659    }
660
661    #[test]
662    fn test_registry_replace_handler() {
663        let count1 = std::sync::Arc::new(std::sync::Mutex::new(0));
664        let count2 = std::sync::Arc::new(std::sync::Mutex::new(0));
665
666        let mut registry = ExtensionRegistry::new();
667        registry.register(Box::new(TestHandler {
668            field_type: 0x5000,
669            call_count: count1.clone(),
670        }));
671        registry.register(Box::new(TestHandler {
672            field_type: 0x5000,
673            call_count: count2.clone(),
674        }));
675
676        let field = ExtensionField {
677            field_type: 0x5000,
678            value: vec![],
679        };
680
681        registry.dispatch(&field).unwrap();
682
683        // Only the second handler should be called
684        assert_eq!(*count1.lock().unwrap(), 0);
685        assert_eq!(*count2.lock().unwrap(), 1);
686    }
687
688    #[test]
689    fn test_registry_dispatch_all() {
690        let count1 = std::sync::Arc::new(std::sync::Mutex::new(0));
691        let count2 = std::sync::Arc::new(std::sync::Mutex::new(0));
692
693        let mut registry = ExtensionRegistry::new();
694        registry.register(Box::new(TestHandler {
695            field_type: 0x6000,
696            call_count: count1.clone(),
697        }));
698        registry.register(Box::new(TestHandler {
699            field_type: 0x6001,
700            call_count: count2.clone(),
701        }));
702
703        let fields = vec![
704            ExtensionField {
705                field_type: 0x6000,
706                value: vec![1],
707            },
708            ExtensionField {
709                field_type: 0x6001,
710                value: vec![2],
711            },
712        ];
713
714        registry.dispatch_all(&fields, false).unwrap();
715        assert_eq!(*count1.lock().unwrap(), 1);
716        assert_eq!(*count2.lock().unwrap(), 1);
717    }
718
719    #[test]
720    fn test_registry_dispatch_all_ignores_unknown() {
721        let count = std::sync::Arc::new(std::sync::Mutex::new(0));
722        let mut registry = ExtensionRegistry::new();
723        registry.register(Box::new(TestHandler {
724            field_type: 0x7000,
725            call_count: count.clone(),
726        }));
727
728        let fields = vec![
729            ExtensionField {
730                field_type: 0x7000,
731                value: vec![1],
732            },
733            ExtensionField {
734                field_type: 0x9999, // Unknown field
735                value: vec![2],
736            },
737        ];
738
739        // Should succeed (ignore unknown fields when require_handlers = false)
740        registry.dispatch_all(&fields, false).unwrap();
741        assert_eq!(*count.lock().unwrap(), 1);
742    }
743
744    #[test]
745    fn test_registry_dispatch_all_requires_handlers() {
746        let count = std::sync::Arc::new(std::sync::Mutex::new(0));
747        let mut registry = ExtensionRegistry::new();
748        registry.register(Box::new(TestHandler {
749            field_type: 0x8000,
750            call_count: count.clone(),
751        }));
752
753        let fields = vec![
754            ExtensionField {
755                field_type: 0x8000,
756                value: vec![1],
757            },
758            ExtensionField {
759                field_type: 0x9999, // Unknown field
760                value: vec![2],
761            },
762        ];
763
764        // Should fail (unknown field when require_handlers = true)
765        let result = registry.dispatch_all(&fields, true);
766        assert!(result.is_err());
767        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::Unsupported);
768    }
769
770    #[test]
771    fn test_registry_len_and_is_empty() {
772        let mut registry = ExtensionRegistry::new();
773        assert!(registry.is_empty());
774        assert_eq!(registry.len(), 0);
775
776        registry.register(Box::new(TestHandler {
777            field_type: 0xA000,
778            call_count: std::sync::Arc::new(std::sync::Mutex::new(0)),
779        }));
780
781        assert!(!registry.is_empty());
782        assert_eq!(registry.len(), 1);
783    }
784
785    // ──────────────────────────────────────────────────────────────
786
787    #[test]
788    fn test_nts_authenticator_roundtrip() {
789        let auth = NtsAuthenticator::new(vec![0x11; 16], vec![0x22; 48]);
790        let ef = auth.to_extension_field();
791        assert_eq!(ef.field_type, NTS_AUTHENTICATOR);
792        let back = NtsAuthenticator::from_extension_field(&ef)
793            .unwrap()
794            .unwrap();
795        assert_eq!(back.nonce, vec![0x11; 16]);
796        assert_eq!(back.ciphertext, vec![0x22; 48]);
797    }
798
799    #[test]
800    fn test_cookie_placeholder() {
801        let placeholder = NtsCookiePlaceholder::new(100);
802        let ef = placeholder.to_extension_field();
803        assert_eq!(ef.field_type, NTS_COOKIE_PLACEHOLDER);
804        assert_eq!(ef.value.len(), 100);
805        assert!(ef.value.iter().all(|&b| b == 0));
806    }
807
808    #[test]
809    fn test_parse_truncated_field() {
810        // Only 3 bytes: not enough for the 4-byte header.
811        let data = [0x01, 0x04, 0x00];
812        let fields = parse_extension_fields(&data).unwrap();
813        assert!(fields.is_empty()); // Silently stops, not enough for header
814    }
815
816    #[test]
817    fn test_parse_invalid_length() {
818        // field_length=2 (less than 4).
819        let data = [0x01, 0x04, 0x00, 0x02];
820        let result = parse_extension_fields(&data);
821        assert!(result.is_err());
822    }
823
824    // Buffer-based API tests.
825
826    #[test]
827    fn test_buf_parse_empty() {
828        let fields = parse_extension_fields_buf(&[]).unwrap();
829        assert!(fields.is_empty());
830    }
831
832    #[test]
833    fn test_buf_roundtrip_single_field() {
834        let field = ExtensionField {
835            field_type: UNIQUE_IDENTIFIER,
836            value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
837        };
838
839        // Write to Vec via io API, then to fixed buffer via buf API.
840        let io_buf = write_extension_fields(std::slice::from_ref(&field)).unwrap();
841        let mut buf = vec![0u8; 256];
842        let written = write_extension_fields_buf(std::slice::from_ref(&field), &mut buf).unwrap();
843        assert_eq!(&io_buf[..], &buf[..written]);
844
845        // Parse with buf API.
846        let parsed = parse_extension_fields_buf(&buf[..written]).unwrap();
847        assert_eq!(parsed.len(), 1);
848        assert_eq!(parsed[0], field);
849    }
850
851    #[test]
852    fn test_buf_equivalence_with_io_api() {
853        let fields = vec![
854            ExtensionField {
855                field_type: UNIQUE_IDENTIFIER,
856                value: vec![0xAA; 32],
857            },
858            ExtensionField {
859                field_type: NTS_COOKIE,
860                value: vec![0xBB; 64],
861            },
862        ];
863
864        let io_buf = write_extension_fields(&fields).unwrap();
865        let mut raw_buf = vec![0u8; 512];
866        let written = write_extension_fields_buf(&fields, &mut raw_buf).unwrap();
867
868        // Same output.
869        assert_eq!(&io_buf[..], &raw_buf[..written]);
870
871        // Same parse result.
872        let io_parsed = parse_extension_fields(&io_buf).unwrap();
873        let buf_parsed = parse_extension_fields_buf(&raw_buf[..written]).unwrap();
874        assert_eq!(io_parsed, buf_parsed);
875    }
876
877    #[test]
878    fn test_buf_write_buffer_too_short() {
879        let field = ExtensionField {
880            field_type: UNIQUE_IDENTIFIER,
881            value: vec![0xAA; 32],
882        };
883        let mut tiny_buf = [0u8; 4]; // Too small for 4 header + 32 value.
884        let result = write_extension_fields_buf(&[field], &mut tiny_buf);
885        assert!(result.is_err());
886    }
887
888    #[test]
889    fn test_buf_parse_invalid_length() {
890        let data = [0x01, 0x04, 0x00, 0x02]; // field_length=2 (< 4).
891        let result = parse_extension_fields_buf(&data);
892        assert!(matches!(
893            result,
894            Err(ParseError::InvalidExtensionLength { declared: 2 })
895        ));
896    }
897
898    #[test]
899    fn test_iter_extension_fields() {
900        let fields = vec![
901            ExtensionField {
902                field_type: UNIQUE_IDENTIFIER,
903                value: vec![0xAA; 32],
904            },
905            ExtensionField {
906                field_type: NTS_COOKIE,
907                value: vec![0xBB; 64],
908            },
909        ];
910        let io_buf = write_extension_fields(&fields).unwrap();
911
912        let mut iter = iter_extension_fields(&io_buf);
913
914        let first = iter.next().unwrap().unwrap();
915        assert_eq!(first.field_type, UNIQUE_IDENTIFIER);
916        assert_eq!(first.value, &[0xAA; 32][..]);
917
918        let second = iter.next().unwrap().unwrap();
919        assert_eq!(second.field_type, NTS_COOKIE);
920        assert_eq!(second.value, &[0xBB; 64][..]);
921
922        assert!(iter.next().is_none());
923    }
924
925    #[test]
926    fn test_iter_extension_fields_empty() {
927        let mut iter = iter_extension_fields(&[]);
928        assert!(iter.next().is_none());
929    }
930
931    #[test]
932    fn test_nts_authenticator_buf_roundtrip() {
933        let auth = NtsAuthenticator::new(vec![0x11; 16], vec![0x22; 48]);
934        let ef = auth.to_extension_field();
935        let back = NtsAuthenticator::from_extension_field_buf(&ef)
936            .unwrap()
937            .unwrap();
938        assert_eq!(back.nonce, vec![0x11; 16]);
939        assert_eq!(back.ciphertext, vec![0x22; 48]);
940    }
941}