nomad_protocol/extensions/
rate_hints.rs

1//! Rate Hints extension (0x0004)
2//!
3//! Allows the server to hint acceptable update frequencies to clients,
4//! enabling adaptive rate limiting without hard rejections.
5//!
6//! Wire format for extension negotiation:
7//! ```text
8//! +0  Flags (1 byte)
9//!     - bit 0: Dynamic hints supported (server may send hints mid-session)
10//!     - bit 1: Per-region hints supported
11//! +1  Initial target rate (2 bytes LE16) - updates per second * 10
12//! +3  Initial burst allowance (2 bytes LE16) - maximum burst size
13//! ```
14//!
15//! Wire format for dynamic hint (sent in-band):
16//! ```text
17//! +0  Hint type (1 byte)
18//!     - 0x00: Global rate hint
19//!     - 0x01: Region-specific hint (followed by region ID)
20//! +1  Target rate (2 bytes LE16) - updates per second * 10
21//! +3  Burst allowance (2 bytes LE16)
22//! +5  Duration hint (2 bytes LE16) - suggested duration in seconds (0 = indefinite)
23//! +7  [Optional] Region ID (4 bytes LE32) - only if hint type is 0x01
24//! ```
25
26use super::negotiation::{ext_type, Extension, NegotiationError};
27use std::time::Duration;
28
29/// Rate hint flags
30pub mod rate_hint_flags {
31    /// Server may send dynamic rate hints during the session
32    pub const DYNAMIC_HINTS: u8 = 0x01;
33    /// Server may send per-region rate hints
34    pub const PER_REGION_HINTS: u8 = 0x02;
35}
36
37/// Rate hints configuration for negotiation
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RateHintsConfig {
40    /// Feature flags
41    pub flags: u8,
42    /// Initial target rate (updates per second * 10, so 1.5/s = 15)
43    pub target_rate_x10: u16,
44    /// Initial burst allowance
45    pub burst_allowance: u16,
46}
47
48impl Default for RateHintsConfig {
49    fn default() -> Self {
50        Self {
51            flags: rate_hint_flags::DYNAMIC_HINTS,
52            target_rate_x10: 100, // 10 updates/second
53            burst_allowance: 20,
54        }
55    }
56}
57
58impl RateHintsConfig {
59    /// Create config with specific rate
60    pub fn with_rate(updates_per_second: f32) -> Self {
61        Self {
62            flags: rate_hint_flags::DYNAMIC_HINTS,
63            target_rate_x10: (updates_per_second * 10.0) as u16,
64            burst_allowance: (updates_per_second * 2.0) as u16,
65        }
66    }
67
68    /// Get target rate as updates per second
69    pub fn target_rate(&self) -> f32 {
70        self.target_rate_x10 as f32 / 10.0
71    }
72
73    /// Check if dynamic hints are supported
74    pub fn supports_dynamic(&self) -> bool {
75        (self.flags & rate_hint_flags::DYNAMIC_HINTS) != 0
76    }
77
78    /// Check if per-region hints are supported
79    pub fn supports_per_region(&self) -> bool {
80        (self.flags & rate_hint_flags::PER_REGION_HINTS) != 0
81    }
82
83    /// Wire size of config
84    pub const fn wire_size() -> usize {
85        5 // flags (1) + target_rate (2) + burst (2)
86    }
87
88    /// Encode to extension
89    pub fn to_extension(&self) -> Extension {
90        let mut data = Vec::with_capacity(Self::wire_size());
91        data.push(self.flags);
92        data.extend_from_slice(&self.target_rate_x10.to_le_bytes());
93        data.extend_from_slice(&self.burst_allowance.to_le_bytes());
94        Extension::new(ext_type::RATE_HINTS, data)
95    }
96
97    /// Decode from extension
98    pub fn from_extension(ext: &Extension) -> Option<Self> {
99        if ext.ext_type != ext_type::RATE_HINTS || ext.data.len() < Self::wire_size() {
100            return None;
101        }
102        Some(Self {
103            flags: ext.data[0],
104            target_rate_x10: u16::from_le_bytes([ext.data[1], ext.data[2]]),
105            burst_allowance: u16::from_le_bytes([ext.data[3], ext.data[4]]),
106        })
107    }
108
109    /// Negotiate between client and server configs
110    pub fn negotiate(client: &Self, server: &Self) -> Self {
111        Self {
112            // Only enable features both sides support
113            flags: client.flags & server.flags,
114            // Use the more restrictive rate
115            target_rate_x10: client.target_rate_x10.min(server.target_rate_x10),
116            burst_allowance: client.burst_allowance.min(server.burst_allowance),
117        }
118    }
119}
120
121/// Type of rate hint
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123#[repr(u8)]
124pub enum RateHintType {
125    /// Global rate hint affecting all updates
126    Global = 0x00,
127    /// Region-specific rate hint
128    Region = 0x01,
129}
130
131impl RateHintType {
132    /// Convert from byte
133    pub fn from_byte(b: u8) -> Option<Self> {
134        match b {
135            0x00 => Some(Self::Global),
136            0x01 => Some(Self::Region),
137            _ => None,
138        }
139    }
140}
141
142/// A dynamic rate hint sent during the session
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct RateHint {
145    /// Type of hint
146    pub hint_type: RateHintType,
147    /// Target rate (updates per second * 10)
148    pub target_rate_x10: u16,
149    /// Burst allowance
150    pub burst_allowance: u16,
151    /// Suggested duration (0 = indefinite)
152    pub duration_secs: u16,
153    /// Region ID (only for Region type hints)
154    pub region_id: Option<u32>,
155}
156
157impl RateHint {
158    /// Create a global rate hint
159    pub fn global(rate: f32, burst: u16, duration: Duration) -> Self {
160        Self {
161            hint_type: RateHintType::Global,
162            target_rate_x10: (rate * 10.0) as u16,
163            burst_allowance: burst,
164            duration_secs: duration.as_secs().min(u16::MAX as u64) as u16,
165            region_id: None,
166        }
167    }
168
169    /// Create a region-specific rate hint
170    pub fn region(region_id: u32, rate: f32, burst: u16, duration: Duration) -> Self {
171        Self {
172            hint_type: RateHintType::Region,
173            target_rate_x10: (rate * 10.0) as u16,
174            burst_allowance: burst,
175            duration_secs: duration.as_secs().min(u16::MAX as u64) as u16,
176            region_id: Some(region_id),
177        }
178    }
179
180    /// Get target rate as updates per second
181    pub fn target_rate(&self) -> f32 {
182        self.target_rate_x10 as f32 / 10.0
183    }
184
185    /// Get duration (None if indefinite)
186    pub fn duration(&self) -> Option<Duration> {
187        if self.duration_secs == 0 {
188            None
189        } else {
190            Some(Duration::from_secs(self.duration_secs as u64))
191        }
192    }
193
194    /// Wire size
195    pub fn wire_size(&self) -> usize {
196        match self.hint_type {
197            RateHintType::Global => 7,  // type(1) + rate(2) + burst(2) + duration(2)
198            RateHintType::Region => 11, // + region_id(4)
199        }
200    }
201
202    /// Encode to bytes
203    pub fn encode(&self) -> Vec<u8> {
204        let mut buf = Vec::with_capacity(self.wire_size());
205        buf.push(self.hint_type as u8);
206        buf.extend_from_slice(&self.target_rate_x10.to_le_bytes());
207        buf.extend_from_slice(&self.burst_allowance.to_le_bytes());
208        buf.extend_from_slice(&self.duration_secs.to_le_bytes());
209
210        if let Some(region_id) = self.region_id {
211            buf.extend_from_slice(&region_id.to_le_bytes());
212        }
213
214        buf
215    }
216
217    /// Decode from bytes
218    pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
219        if data.len() < 7 {
220            return Err(NegotiationError::TooShort {
221                expected: 7,
222                actual: data.len(),
223            });
224        }
225
226        let hint_type = RateHintType::from_byte(data[0]).ok_or(NegotiationError::InvalidData)?;
227        let target_rate_x10 = u16::from_le_bytes([data[1], data[2]]);
228        let burst_allowance = u16::from_le_bytes([data[3], data[4]]);
229        let duration_secs = u16::from_le_bytes([data[5], data[6]]);
230
231        let region_id = if hint_type == RateHintType::Region {
232            if data.len() < 11 {
233                return Err(NegotiationError::TooShort {
234                    expected: 11,
235                    actual: data.len(),
236                });
237            }
238            Some(u32::from_le_bytes([data[7], data[8], data[9], data[10]]))
239        } else {
240            None
241        };
242
243        Ok(Self {
244            hint_type,
245            target_rate_x10,
246            burst_allowance,
247            duration_secs,
248            region_id,
249        })
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_config_default() {
259        let config = RateHintsConfig::default();
260        assert_eq!(config.target_rate(), 10.0);
261        assert!(config.supports_dynamic());
262        assert!(!config.supports_per_region());
263    }
264
265    #[test]
266    fn test_config_with_rate() {
267        let config = RateHintsConfig::with_rate(5.5);
268        assert_eq!(config.target_rate_x10, 55);
269        assert_eq!(config.target_rate(), 5.5);
270    }
271
272    #[test]
273    fn test_config_extension_roundtrip() {
274        let config = RateHintsConfig {
275            flags: rate_hint_flags::DYNAMIC_HINTS | rate_hint_flags::PER_REGION_HINTS,
276            target_rate_x10: 150,
277            burst_allowance: 30,
278        };
279
280        let ext = config.to_extension();
281        let decoded = RateHintsConfig::from_extension(&ext).unwrap();
282        assert_eq!(decoded, config);
283    }
284
285    #[test]
286    fn test_config_negotiate() {
287        let client = RateHintsConfig {
288            flags: rate_hint_flags::DYNAMIC_HINTS | rate_hint_flags::PER_REGION_HINTS,
289            target_rate_x10: 200,
290            burst_allowance: 50,
291        };
292        let server = RateHintsConfig {
293            flags: rate_hint_flags::DYNAMIC_HINTS, // No per-region
294            target_rate_x10: 100,
295            burst_allowance: 20,
296        };
297
298        let result = RateHintsConfig::negotiate(&client, &server);
299        assert!(result.supports_dynamic());
300        assert!(!result.supports_per_region()); // Not supported by server
301        assert_eq!(result.target_rate_x10, 100);
302        assert_eq!(result.burst_allowance, 20);
303    }
304
305    #[test]
306    fn test_global_hint_roundtrip() {
307        let hint = RateHint::global(5.0, 10, Duration::from_secs(60));
308
309        let encoded = hint.encode();
310        assert_eq!(encoded.len(), 7);
311
312        let decoded = RateHint::decode(&encoded).unwrap();
313        assert_eq!(decoded.hint_type, RateHintType::Global);
314        assert_eq!(decoded.target_rate(), 5.0);
315        assert_eq!(decoded.burst_allowance, 10);
316        assert_eq!(decoded.duration(), Some(Duration::from_secs(60)));
317        assert!(decoded.region_id.is_none());
318    }
319
320    #[test]
321    fn test_region_hint_roundtrip() {
322        let hint = RateHint::region(42, 2.5, 5, Duration::from_secs(120));
323
324        let encoded = hint.encode();
325        assert_eq!(encoded.len(), 11);
326
327        let decoded = RateHint::decode(&encoded).unwrap();
328        assert_eq!(decoded.hint_type, RateHintType::Region);
329        assert_eq!(decoded.target_rate(), 2.5);
330        assert_eq!(decoded.region_id, Some(42));
331    }
332
333    #[test]
334    fn test_indefinite_duration() {
335        let hint = RateHint::global(10.0, 20, Duration::ZERO);
336        assert_eq!(hint.duration_secs, 0);
337        assert!(hint.duration().is_none());
338    }
339
340    #[test]
341    fn test_decode_truncated() {
342        assert!(matches!(
343            RateHint::decode(&[0, 1, 2, 3, 4, 5]),
344            Err(NegotiationError::TooShort { .. })
345        ));
346
347        // Region hint without region ID
348        assert!(matches!(
349            RateHint::decode(&[0x01, 1, 2, 3, 4, 5, 6]),
350            Err(NegotiationError::TooShort { .. })
351        ));
352    }
353
354    #[test]
355    fn test_invalid_hint_type() {
356        let data = [0xFF, 1, 2, 3, 4, 5, 6];
357        assert!(matches!(
358            RateHint::decode(&data),
359            Err(NegotiationError::InvalidData)
360        ));
361    }
362}