Skip to main content

ant_quic/constrained/
types.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Core types for the constrained protocol engine
9//!
10//! This module defines fundamental types used throughout the constrained protocol:
11//! - [`ConnectionId`] - Unique identifier for connections
12//! - [`SequenceNumber`] - Packet sequence tracking
13//! - [`PacketType`] - Distinguishes control vs data packets
14//! - [`ConstrainedError`] - Error handling
15
16use std::fmt;
17use std::net::SocketAddr;
18use thiserror::Error;
19
20use crate::transport::TransportAddr;
21
22/// Connection identifier for the constrained protocol
23///
24/// A 16-bit identifier that uniquely identifies a connection between two peers.
25/// Connection IDs are locally generated and do not need to be globally unique.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct ConnectionId(pub u16);
28
29impl ConnectionId {
30    /// Create a new connection ID from raw value
31    pub const fn new(value: u16) -> Self {
32        Self(value)
33    }
34
35    /// Get the raw u16 value
36    pub const fn value(self) -> u16 {
37        self.0
38    }
39
40    /// Serialize to bytes (big-endian)
41    pub const fn to_bytes(self) -> [u8; 2] {
42        self.0.to_be_bytes()
43    }
44
45    /// Deserialize from bytes (big-endian)
46    pub const fn from_bytes(bytes: [u8; 2]) -> Self {
47        Self(u16::from_be_bytes(bytes))
48    }
49
50    /// Generate a random connection ID
51    pub fn random() -> Self {
52        use std::time::{SystemTime, UNIX_EPOCH};
53        let seed = SystemTime::now()
54            .duration_since(UNIX_EPOCH)
55            .unwrap_or_default()
56            .as_nanos() as u16;
57        Self(seed ^ 0x5A5A) // XOR with pattern for better distribution
58    }
59}
60
61impl fmt::Display for ConnectionId {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        write!(f, "CID:{:04X}", self.0)
64    }
65}
66
67/// Sequence number for packet ordering and acknowledgment
68///
69/// An 8-bit sequence number that wraps around at 255. The constrained protocol
70/// uses a sliding window to handle wrap-around correctly.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub struct SequenceNumber(pub u8);
73
74impl SequenceNumber {
75    /// Create a new sequence number
76    pub const fn new(value: u8) -> Self {
77        Self(value)
78    }
79
80    /// Get the raw u8 value
81    pub const fn value(self) -> u8 {
82        self.0
83    }
84
85    /// Increment the sequence number (wrapping at 255)
86    pub const fn next(self) -> Self {
87        Self(self.0.wrapping_add(1))
88    }
89
90    /// Calculate distance from self to other (considering wrap-around)
91    ///
92    /// Returns positive if other is ahead, negative if behind.
93    /// Assumes window size is less than 128.
94    pub fn distance_to(self, other: Self) -> i16 {
95        let diff = other.0.wrapping_sub(self.0) as i8;
96        diff as i16
97    }
98
99    /// Check if other is within the valid window ahead of self
100    pub fn is_in_window(self, other: Self, window_size: u8) -> bool {
101        let dist = self.distance_to(other);
102        dist >= 0 && dist <= window_size as i16
103    }
104}
105
106impl fmt::Display for SequenceNumber {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "SEQ:{}", self.0)
109    }
110}
111
112/// Packet type flags for the constrained protocol
113///
114/// These flags are combined in a single byte in the packet header.
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116#[repr(u8)]
117pub enum PacketType {
118    /// Connection request (SYN)
119    Syn = 0x01,
120    /// Acknowledgment (ACK)
121    Ack = 0x02,
122    /// Connection close (FIN)
123    Fin = 0x04,
124    /// Connection reset (RST)
125    Reset = 0x08,
126    /// Data packet
127    Data = 0x10,
128    /// Keep-alive ping
129    Ping = 0x20,
130    /// Pong response to ping
131    Pong = 0x40,
132}
133
134impl PacketType {
135    /// Get the flag value for this packet type
136    pub const fn flag(self) -> u8 {
137        self as u8
138    }
139}
140
141/// Packet flags combining multiple packet types
142///
143/// A packet can have multiple flags set (e.g., SYN+ACK).
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
145pub struct PacketFlags(pub u8);
146
147impl PacketFlags {
148    /// No flags set
149    pub const NONE: Self = Self(0);
150
151    /// SYN flag
152    pub const SYN: Self = Self(0x01);
153    /// ACK flag
154    pub const ACK: Self = Self(0x02);
155    /// FIN flag
156    pub const FIN: Self = Self(0x04);
157    /// RST flag
158    pub const RST: Self = Self(0x08);
159    /// DATA flag
160    pub const DATA: Self = Self(0x10);
161    /// PING flag
162    pub const PING: Self = Self(0x20);
163    /// PONG flag
164    pub const PONG: Self = Self(0x40);
165
166    /// SYN+ACK combination
167    pub const SYN_ACK: Self = Self(0x03);
168
169    /// Create flags from raw value
170    pub const fn new(value: u8) -> Self {
171        Self(value)
172    }
173
174    /// Get raw value
175    pub const fn value(self) -> u8 {
176        self.0
177    }
178
179    /// Check if a specific flag is set
180    pub const fn has(self, flag: PacketType) -> bool {
181        self.0 & (flag as u8) != 0
182    }
183
184    /// Check if SYN flag is set
185    pub const fn is_syn(self) -> bool {
186        self.0 & 0x01 != 0
187    }
188
189    /// Check if ACK flag is set
190    pub const fn is_ack(self) -> bool {
191        self.0 & 0x02 != 0
192    }
193
194    /// Check if FIN flag is set
195    pub const fn is_fin(self) -> bool {
196        self.0 & 0x04 != 0
197    }
198
199    /// Check if RST flag is set
200    pub const fn is_rst(self) -> bool {
201        self.0 & 0x08 != 0
202    }
203
204    /// Check if DATA flag is set
205    pub const fn is_data(self) -> bool {
206        self.0 & 0x10 != 0
207    }
208
209    /// Check if PING flag is set
210    pub const fn is_ping(self) -> bool {
211        self.0 & 0x20 != 0
212    }
213
214    /// Check if PONG flag is set
215    pub const fn is_pong(self) -> bool {
216        self.0 & 0x40 != 0
217    }
218
219    /// Combine with another flag
220    pub const fn with(self, flag: PacketType) -> Self {
221        Self(self.0 | flag as u8)
222    }
223
224    /// Combine two flag sets
225    pub const fn union(self, other: Self) -> Self {
226        Self(self.0 | other.0)
227    }
228}
229
230impl fmt::Display for PacketFlags {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        let mut flags = Vec::new();
233        if self.is_syn() {
234            flags.push("SYN");
235        }
236        if self.is_ack() {
237            flags.push("ACK");
238        }
239        if self.is_fin() {
240            flags.push("FIN");
241        }
242        if self.is_rst() {
243            flags.push("RST");
244        }
245        if self.is_data() {
246            flags.push("DATA");
247        }
248        if self.is_ping() {
249            flags.push("PING");
250        }
251        if self.is_pong() {
252            flags.push("PONG");
253        }
254        if flags.is_empty() {
255            write!(f, "NONE")
256        } else {
257            write!(f, "{}", flags.join("|"))
258        }
259    }
260}
261
262/// Address wrapper for constrained protocol connections
263///
264/// This wraps `TransportAddr` to provide constrained-specific functionality
265/// while maintaining compatibility with the transport system.
266#[derive(Debug, Clone, PartialEq, Eq, Hash)]
267pub struct ConstrainedAddr(TransportAddr);
268
269impl ConstrainedAddr {
270    /// Create a new constrained address from a transport address
271    pub fn new(addr: TransportAddr) -> Self {
272        Self(addr)
273    }
274
275    /// Get the underlying transport address
276    pub fn transport_addr(&self) -> &TransportAddr {
277        &self.0
278    }
279
280    /// Consume self and return the underlying transport address
281    pub fn into_transport_addr(self) -> TransportAddr {
282        self.0
283    }
284
285    /// Check if this address supports the constrained protocol
286    ///
287    /// Constrained protocol is used for bandwidth-limited transports like BLE and LoRa.
288    pub fn is_constrained_transport(&self) -> bool {
289        matches!(
290            self.0,
291            TransportAddr::Ble { .. }
292                | TransportAddr::LoRa { .. }
293                | TransportAddr::Serial { .. }
294                | TransportAddr::Ax25 { .. }
295        )
296    }
297}
298
299impl From<TransportAddr> for ConstrainedAddr {
300    fn from(addr: TransportAddr) -> Self {
301        Self(addr)
302    }
303}
304
305impl From<ConstrainedAddr> for TransportAddr {
306    fn from(addr: ConstrainedAddr) -> Self {
307        addr.0
308    }
309}
310
311impl From<SocketAddr> for ConstrainedAddr {
312    fn from(addr: SocketAddr) -> Self {
313        Self(TransportAddr::Udp(addr))
314    }
315}
316
317impl fmt::Display for ConstrainedAddr {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        write!(f, "{}", self.0)
320    }
321}
322
323/// Errors that can occur in the constrained protocol
324#[derive(Debug, Clone, Error)]
325pub enum ConstrainedError {
326    /// Packet too small to contain header
327    #[error("packet too small: expected at least {expected} bytes, got {actual}")]
328    PacketTooSmall {
329        /// Minimum expected size in bytes
330        expected: usize,
331        /// Actual size received
332        actual: usize,
333    },
334
335    /// Invalid header format
336    #[error("invalid header: {0}")]
337    InvalidHeader(String),
338
339    /// Connection not found
340    #[error("connection not found: {0}")]
341    ConnectionNotFound(ConnectionId),
342
343    /// Connection already exists
344    #[error("connection already exists: {0}")]
345    ConnectionExists(ConnectionId),
346
347    /// Invalid state transition
348    #[error("invalid state transition from {from} to {to}")]
349    InvalidStateTransition {
350        /// Current state name
351        from: String,
352        /// Attempted target state
353        to: String,
354    },
355
356    /// Connection reset by peer
357    #[error("connection reset by peer")]
358    ConnectionReset,
359
360    /// Connection timed out
361    #[error("connection timed out")]
362    Timeout,
363
364    /// Maximum retransmissions exceeded
365    #[error("maximum retransmissions exceeded ({count})")]
366    MaxRetransmissions {
367        /// Number of retransmissions attempted
368        count: u32,
369    },
370
371    /// Send buffer full
372    #[error("send buffer full")]
373    SendBufferFull,
374
375    /// Receive buffer full
376    #[error("receive buffer full")]
377    ReceiveBufferFull,
378
379    /// Transport error
380    #[error("transport error: {0}")]
381    Transport(String),
382
383    /// Sequence number out of window
384    #[error("sequence number {seq} out of window (expected {expected_min}-{expected_max})")]
385    SequenceOutOfWindow {
386        /// Received sequence number
387        seq: u8,
388        /// Minimum expected sequence number
389        expected_min: u8,
390        /// Maximum expected sequence number
391        expected_max: u8,
392    },
393
394    /// Connection closed
395    #[error("connection closed")]
396    ConnectionClosed,
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_constrained_addr_from_transport() {
405        let ble_addr = TransportAddr::Ble {
406            device_id: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF],
407            service_uuid: None,
408        };
409        let constrained = ConstrainedAddr::from(ble_addr.clone());
410        assert!(constrained.is_constrained_transport());
411        assert_eq!(*constrained.transport_addr(), ble_addr);
412    }
413
414    #[test]
415    fn test_constrained_addr_from_socket() {
416        let socket: SocketAddr = "127.0.0.1:8080".parse().unwrap();
417        let constrained = ConstrainedAddr::from(socket);
418        assert!(!constrained.is_constrained_transport());
419        assert_eq!(
420            *constrained.transport_addr(),
421            TransportAddr::Udp("127.0.0.1:8080".parse().unwrap())
422        );
423    }
424
425    #[test]
426    fn test_constrained_addr_into_transport() {
427        let ble_addr = TransportAddr::Ble {
428            device_id: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
429            service_uuid: None,
430        };
431        let constrained = ConstrainedAddr::new(ble_addr.clone());
432        let back: TransportAddr = constrained.into();
433        assert_eq!(back, ble_addr);
434    }
435
436    #[test]
437    fn test_constrained_addr_transport_detection() {
438        // BLE is constrained
439        let ble = ConstrainedAddr::new(TransportAddr::Ble {
440            device_id: [0; 6],
441            service_uuid: None,
442        });
443        assert!(ble.is_constrained_transport());
444
445        // LoRa is constrained
446        let lora = ConstrainedAddr::new(TransportAddr::LoRa {
447            device_addr: [0; 4],
448            params: crate::transport::LoRaParams::default(),
449        });
450        assert!(lora.is_constrained_transport());
451
452        // UDP is not constrained (uses QUIC)
453        let udp = ConstrainedAddr::new(TransportAddr::Udp("0.0.0.0:0".parse().unwrap()));
454        assert!(!udp.is_constrained_transport());
455    }
456
457    #[test]
458    fn test_connection_id() {
459        let cid = ConnectionId::new(0x1234);
460        assert_eq!(cid.value(), 0x1234);
461        assert_eq!(cid.to_bytes(), [0x12, 0x34]);
462        assert_eq!(ConnectionId::from_bytes([0x12, 0x34]), cid);
463    }
464
465    #[test]
466    fn test_connection_id_display() {
467        let cid = ConnectionId::new(0xABCD);
468        assert_eq!(format!("{}", cid), "CID:ABCD");
469    }
470
471    #[test]
472    fn test_connection_id_random() {
473        let cid1 = ConnectionId::random();
474        let cid2 = ConnectionId::random();
475        // Random IDs should be different (with very high probability)
476        // But we can't guarantee it in a test, so just verify they're valid
477        assert!(cid1.value() != 0 || cid2.value() != 0);
478    }
479
480    #[test]
481    fn test_sequence_number_next() {
482        assert_eq!(SequenceNumber::new(0).next(), SequenceNumber::new(1));
483        assert_eq!(SequenceNumber::new(254).next(), SequenceNumber::new(255));
484        assert_eq!(SequenceNumber::new(255).next(), SequenceNumber::new(0));
485    }
486
487    #[test]
488    fn test_sequence_number_distance() {
489        let a = SequenceNumber::new(10);
490        let b = SequenceNumber::new(15);
491        assert_eq!(a.distance_to(b), 5);
492        assert_eq!(b.distance_to(a), -5);
493
494        // Wrap-around case
495        let x = SequenceNumber::new(250);
496        let y = SequenceNumber::new(5);
497        assert_eq!(x.distance_to(y), 11); // 5 is 11 ahead of 250 (wrapping)
498    }
499
500    #[test]
501    fn test_sequence_number_in_window() {
502        let base = SequenceNumber::new(100);
503        assert!(base.is_in_window(SequenceNumber::new(100), 16));
504        assert!(base.is_in_window(SequenceNumber::new(110), 16));
505        assert!(base.is_in_window(SequenceNumber::new(116), 16));
506        assert!(!base.is_in_window(SequenceNumber::new(117), 16));
507        assert!(!base.is_in_window(SequenceNumber::new(99), 16));
508    }
509
510    #[test]
511    fn test_packet_flags() {
512        let flags = PacketFlags::SYN;
513        assert!(flags.is_syn());
514        assert!(!flags.is_ack());
515
516        let syn_ack = flags.with(PacketType::Ack);
517        assert!(syn_ack.is_syn());
518        assert!(syn_ack.is_ack());
519        assert_eq!(syn_ack, PacketFlags::SYN_ACK);
520    }
521
522    #[test]
523    fn test_packet_flags_display() {
524        assert_eq!(format!("{}", PacketFlags::NONE), "NONE");
525        assert_eq!(format!("{}", PacketFlags::SYN), "SYN");
526        assert_eq!(format!("{}", PacketFlags::SYN_ACK), "SYN|ACK");
527        assert_eq!(
528            format!("{}", PacketFlags::DATA.with(PacketType::Ack)),
529            "ACK|DATA"
530        );
531    }
532
533    #[test]
534    fn test_packet_flags_union() {
535        let a = PacketFlags::SYN;
536        let b = PacketFlags::DATA;
537        let combined = a.union(b);
538        assert!(combined.is_syn());
539        assert!(combined.is_data());
540        assert!(!combined.is_ack());
541    }
542
543    #[test]
544    fn test_constrained_error_display() {
545        let err = ConstrainedError::PacketTooSmall {
546            expected: 5,
547            actual: 3,
548        };
549        assert!(format!("{}", err).contains("expected at least 5 bytes"));
550
551        let err = ConstrainedError::ConnectionNotFound(ConnectionId::new(0x1234));
552        assert!(format!("{}", err).contains("CID:1234"));
553    }
554}