nomad_protocol/extensions/
rate_hints.rs1use super::negotiation::{ext_type, Extension, NegotiationError};
27use std::time::Duration;
28
29pub mod rate_hint_flags {
31 pub const DYNAMIC_HINTS: u8 = 0x01;
33 pub const PER_REGION_HINTS: u8 = 0x02;
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RateHintsConfig {
40 pub flags: u8,
42 pub target_rate_x10: u16,
44 pub burst_allowance: u16,
46}
47
48impl Default for RateHintsConfig {
49 fn default() -> Self {
50 Self {
51 flags: rate_hint_flags::DYNAMIC_HINTS,
52 target_rate_x10: 100, burst_allowance: 20,
54 }
55 }
56}
57
58impl RateHintsConfig {
59 pub fn with_rate(updates_per_second: f32) -> Self {
61 Self {
62 flags: rate_hint_flags::DYNAMIC_HINTS,
63 target_rate_x10: (updates_per_second * 10.0) as u16,
64 burst_allowance: (updates_per_second * 2.0) as u16,
65 }
66 }
67
68 pub fn target_rate(&self) -> f32 {
70 self.target_rate_x10 as f32 / 10.0
71 }
72
73 pub fn supports_dynamic(&self) -> bool {
75 (self.flags & rate_hint_flags::DYNAMIC_HINTS) != 0
76 }
77
78 pub fn supports_per_region(&self) -> bool {
80 (self.flags & rate_hint_flags::PER_REGION_HINTS) != 0
81 }
82
83 pub const fn wire_size() -> usize {
85 5 }
87
88 pub fn to_extension(&self) -> Extension {
90 let mut data = Vec::with_capacity(Self::wire_size());
91 data.push(self.flags);
92 data.extend_from_slice(&self.target_rate_x10.to_le_bytes());
93 data.extend_from_slice(&self.burst_allowance.to_le_bytes());
94 Extension::new(ext_type::RATE_HINTS, data)
95 }
96
97 pub fn from_extension(ext: &Extension) -> Option<Self> {
99 if ext.ext_type != ext_type::RATE_HINTS || ext.data.len() < Self::wire_size() {
100 return None;
101 }
102 Some(Self {
103 flags: ext.data[0],
104 target_rate_x10: u16::from_le_bytes([ext.data[1], ext.data[2]]),
105 burst_allowance: u16::from_le_bytes([ext.data[3], ext.data[4]]),
106 })
107 }
108
109 pub fn negotiate(client: &Self, server: &Self) -> Self {
111 Self {
112 flags: client.flags & server.flags,
114 target_rate_x10: client.target_rate_x10.min(server.target_rate_x10),
116 burst_allowance: client.burst_allowance.min(server.burst_allowance),
117 }
118 }
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123#[repr(u8)]
124pub enum RateHintType {
125 Global = 0x00,
127 Region = 0x01,
129}
130
131impl RateHintType {
132 pub fn from_byte(b: u8) -> Option<Self> {
134 match b {
135 0x00 => Some(Self::Global),
136 0x01 => Some(Self::Region),
137 _ => None,
138 }
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct RateHint {
145 pub hint_type: RateHintType,
147 pub target_rate_x10: u16,
149 pub burst_allowance: u16,
151 pub duration_secs: u16,
153 pub region_id: Option<u32>,
155}
156
157impl RateHint {
158 pub fn global(rate: f32, burst: u16, duration: Duration) -> Self {
160 Self {
161 hint_type: RateHintType::Global,
162 target_rate_x10: (rate * 10.0) as u16,
163 burst_allowance: burst,
164 duration_secs: duration.as_secs().min(u16::MAX as u64) as u16,
165 region_id: None,
166 }
167 }
168
169 pub fn region(region_id: u32, rate: f32, burst: u16, duration: Duration) -> Self {
171 Self {
172 hint_type: RateHintType::Region,
173 target_rate_x10: (rate * 10.0) as u16,
174 burst_allowance: burst,
175 duration_secs: duration.as_secs().min(u16::MAX as u64) as u16,
176 region_id: Some(region_id),
177 }
178 }
179
180 pub fn target_rate(&self) -> f32 {
182 self.target_rate_x10 as f32 / 10.0
183 }
184
185 pub fn duration(&self) -> Option<Duration> {
187 if self.duration_secs == 0 {
188 None
189 } else {
190 Some(Duration::from_secs(self.duration_secs as u64))
191 }
192 }
193
194 pub fn wire_size(&self) -> usize {
196 match self.hint_type {
197 RateHintType::Global => 7, RateHintType::Region => 11, }
200 }
201
202 pub fn encode(&self) -> Vec<u8> {
204 let mut buf = Vec::with_capacity(self.wire_size());
205 buf.push(self.hint_type as u8);
206 buf.extend_from_slice(&self.target_rate_x10.to_le_bytes());
207 buf.extend_from_slice(&self.burst_allowance.to_le_bytes());
208 buf.extend_from_slice(&self.duration_secs.to_le_bytes());
209
210 if let Some(region_id) = self.region_id {
211 buf.extend_from_slice(®ion_id.to_le_bytes());
212 }
213
214 buf
215 }
216
217 pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
219 if data.len() < 7 {
220 return Err(NegotiationError::TooShort {
221 expected: 7,
222 actual: data.len(),
223 });
224 }
225
226 let hint_type = RateHintType::from_byte(data[0]).ok_or(NegotiationError::InvalidData)?;
227 let target_rate_x10 = u16::from_le_bytes([data[1], data[2]]);
228 let burst_allowance = u16::from_le_bytes([data[3], data[4]]);
229 let duration_secs = u16::from_le_bytes([data[5], data[6]]);
230
231 let region_id = if hint_type == RateHintType::Region {
232 if data.len() < 11 {
233 return Err(NegotiationError::TooShort {
234 expected: 11,
235 actual: data.len(),
236 });
237 }
238 Some(u32::from_le_bytes([data[7], data[8], data[9], data[10]]))
239 } else {
240 None
241 };
242
243 Ok(Self {
244 hint_type,
245 target_rate_x10,
246 burst_allowance,
247 duration_secs,
248 region_id,
249 })
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_config_default() {
259 let config = RateHintsConfig::default();
260 assert_eq!(config.target_rate(), 10.0);
261 assert!(config.supports_dynamic());
262 assert!(!config.supports_per_region());
263 }
264
265 #[test]
266 fn test_config_with_rate() {
267 let config = RateHintsConfig::with_rate(5.5);
268 assert_eq!(config.target_rate_x10, 55);
269 assert_eq!(config.target_rate(), 5.5);
270 }
271
272 #[test]
273 fn test_config_extension_roundtrip() {
274 let config = RateHintsConfig {
275 flags: rate_hint_flags::DYNAMIC_HINTS | rate_hint_flags::PER_REGION_HINTS,
276 target_rate_x10: 150,
277 burst_allowance: 30,
278 };
279
280 let ext = config.to_extension();
281 let decoded = RateHintsConfig::from_extension(&ext).unwrap();
282 assert_eq!(decoded, config);
283 }
284
285 #[test]
286 fn test_config_negotiate() {
287 let client = RateHintsConfig {
288 flags: rate_hint_flags::DYNAMIC_HINTS | rate_hint_flags::PER_REGION_HINTS,
289 target_rate_x10: 200,
290 burst_allowance: 50,
291 };
292 let server = RateHintsConfig {
293 flags: rate_hint_flags::DYNAMIC_HINTS, target_rate_x10: 100,
295 burst_allowance: 20,
296 };
297
298 let result = RateHintsConfig::negotiate(&client, &server);
299 assert!(result.supports_dynamic());
300 assert!(!result.supports_per_region()); assert_eq!(result.target_rate_x10, 100);
302 assert_eq!(result.burst_allowance, 20);
303 }
304
305 #[test]
306 fn test_global_hint_roundtrip() {
307 let hint = RateHint::global(5.0, 10, Duration::from_secs(60));
308
309 let encoded = hint.encode();
310 assert_eq!(encoded.len(), 7);
311
312 let decoded = RateHint::decode(&encoded).unwrap();
313 assert_eq!(decoded.hint_type, RateHintType::Global);
314 assert_eq!(decoded.target_rate(), 5.0);
315 assert_eq!(decoded.burst_allowance, 10);
316 assert_eq!(decoded.duration(), Some(Duration::from_secs(60)));
317 assert!(decoded.region_id.is_none());
318 }
319
320 #[test]
321 fn test_region_hint_roundtrip() {
322 let hint = RateHint::region(42, 2.5, 5, Duration::from_secs(120));
323
324 let encoded = hint.encode();
325 assert_eq!(encoded.len(), 11);
326
327 let decoded = RateHint::decode(&encoded).unwrap();
328 assert_eq!(decoded.hint_type, RateHintType::Region);
329 assert_eq!(decoded.target_rate(), 2.5);
330 assert_eq!(decoded.region_id, Some(42));
331 }
332
333 #[test]
334 fn test_indefinite_duration() {
335 let hint = RateHint::global(10.0, 20, Duration::ZERO);
336 assert_eq!(hint.duration_secs, 0);
337 assert!(hint.duration().is_none());
338 }
339
340 #[test]
341 fn test_decode_truncated() {
342 assert!(matches!(
343 RateHint::decode(&[0, 1, 2, 3, 4, 5]),
344 Err(NegotiationError::TooShort { .. })
345 ));
346
347 assert!(matches!(
349 RateHint::decode(&[0x01, 1, 2, 3, 4, 5, 6]),
350 Err(NegotiationError::TooShort { .. })
351 ));
352 }
353
354 #[test]
355 fn test_invalid_hint_type() {
356 let data = [0xFF, 1, 2, 3, 4, 5, 6];
357 assert!(matches!(
358 RateHint::decode(&data),
359 Err(NegotiationError::InvalidData)
360 ));
361 }
362}