nomad_protocol/extensions/
batching.rs1use super::negotiation::{ext_type, Extension, NegotiationError};
22use std::time::Duration;
23
24pub const DEFAULT_MAX_BATCH_SIZE: u16 = 32;
26
27pub const DEFAULT_MAX_BATCH_BYTES: u16 = 16384;
29
30pub const DEFAULT_MAX_DELAY_MS: u16 = 50;
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct BatchingConfig {
36 pub max_batch_size: u16,
38 pub max_batch_bytes: u16,
40 pub max_delay_ms: u16,
42}
43
44impl Default for BatchingConfig {
45 fn default() -> Self {
46 Self {
47 max_batch_size: DEFAULT_MAX_BATCH_SIZE,
48 max_batch_bytes: DEFAULT_MAX_BATCH_BYTES,
49 max_delay_ms: DEFAULT_MAX_DELAY_MS,
50 }
51 }
52}
53
54impl BatchingConfig {
55 pub fn low_latency() -> Self {
57 Self {
58 max_batch_size: 8,
59 max_batch_bytes: 4096,
60 max_delay_ms: 10,
61 }
62 }
63
64 pub fn high_throughput() -> Self {
66 Self {
67 max_batch_size: 128,
68 max_batch_bytes: 65535,
69 max_delay_ms: 100,
70 }
71 }
72
73 pub fn max_delay(&self) -> Duration {
75 Duration::from_millis(self.max_delay_ms as u64)
76 }
77
78 pub const fn wire_size() -> usize {
80 6 }
82
83 pub fn to_extension(&self) -> Extension {
85 let mut data = Vec::with_capacity(Self::wire_size());
86 data.extend_from_slice(&self.max_batch_size.to_le_bytes());
87 data.extend_from_slice(&self.max_batch_bytes.to_le_bytes());
88 data.extend_from_slice(&self.max_delay_ms.to_le_bytes());
89 Extension::new(ext_type::BATCHING, data)
90 }
91
92 pub fn from_extension(ext: &Extension) -> Option<Self> {
94 if ext.ext_type != ext_type::BATCHING || ext.data.len() < Self::wire_size() {
95 return None;
96 }
97 Some(Self {
98 max_batch_size: u16::from_le_bytes([ext.data[0], ext.data[1]]),
99 max_batch_bytes: u16::from_le_bytes([ext.data[2], ext.data[3]]),
100 max_delay_ms: u16::from_le_bytes([ext.data[4], ext.data[5]]),
101 })
102 }
103
104 pub fn negotiate(client: &Self, server: &Self) -> Self {
108 Self {
109 max_batch_size: client.max_batch_size.min(server.max_batch_size),
110 max_batch_bytes: client.max_batch_bytes.min(server.max_batch_bytes),
111 max_delay_ms: client.max_delay_ms.min(server.max_delay_ms),
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct Batch {
119 updates: Vec<Vec<u8>>,
120 total_bytes: usize,
121}
122
123impl Default for Batch {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl Batch {
130 pub fn new() -> Self {
132 Self {
133 updates: Vec::new(),
134 total_bytes: 0,
135 }
136 }
137
138 pub fn with_capacity(capacity: usize) -> Self {
140 Self {
141 updates: Vec::with_capacity(capacity),
142 total_bytes: 0,
143 }
144 }
145
146 pub fn len(&self) -> usize {
148 self.updates.len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.updates.is_empty()
154 }
155
156 pub fn total_bytes(&self) -> usize {
158 self.total_bytes
159 }
160
161 pub fn try_add(&mut self, update: Vec<u8>, config: &BatchingConfig) -> bool {
165 let new_size = self.updates.len() + 1;
166 let new_bytes = self.total_bytes + update.len() + 2; if new_size > config.max_batch_size as usize
169 || new_bytes > config.max_batch_bytes as usize
170 {
171 return false;
172 }
173
174 self.total_bytes += update.len() + 2;
175 self.updates.push(update);
176 true
177 }
178
179 pub fn encode(&self) -> Vec<u8> {
181 let mut buf = Vec::with_capacity(2 + self.total_bytes);
182 buf.extend_from_slice(&(self.updates.len() as u16).to_le_bytes());
183
184 for update in &self.updates {
185 buf.extend_from_slice(&(update.len() as u16).to_le_bytes());
186 buf.extend_from_slice(update);
187 }
188
189 buf
190 }
191
192 pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
194 if data.len() < 2 {
195 return Err(NegotiationError::TooShort {
196 expected: 2,
197 actual: data.len(),
198 });
199 }
200
201 let count = u16::from_le_bytes([data[0], data[1]]) as usize;
202 let mut offset = 2;
203 let mut updates = Vec::with_capacity(count);
204 let mut total_bytes = 0;
205
206 for _ in 0..count {
207 if offset + 2 > data.len() {
208 return Err(NegotiationError::TooShort {
209 expected: offset + 2,
210 actual: data.len(),
211 });
212 }
213
214 let len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
215 offset += 2;
216
217 if offset + len > data.len() {
218 return Err(NegotiationError::TooShort {
219 expected: offset + len,
220 actual: data.len(),
221 });
222 }
223
224 updates.push(data[offset..offset + len].to_vec());
225 total_bytes += len + 2;
226 offset += len;
227 }
228
229 Ok(Self {
230 updates,
231 total_bytes,
232 })
233 }
234
235 pub fn into_updates(self) -> Vec<Vec<u8>> {
237 self.updates
238 }
239
240 pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
242 self.updates.iter().map(|v| v.as_slice())
243 }
244
245 pub fn clear(&mut self) {
247 self.updates.clear();
248 self.total_bytes = 0;
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_config_default() {
258 let config = BatchingConfig::default();
259 assert_eq!(config.max_batch_size, DEFAULT_MAX_BATCH_SIZE);
260 assert_eq!(config.max_batch_bytes, DEFAULT_MAX_BATCH_BYTES);
261 assert_eq!(config.max_delay_ms, DEFAULT_MAX_DELAY_MS);
262 }
263
264 #[test]
265 fn test_config_extension_roundtrip() {
266 let config = BatchingConfig {
267 max_batch_size: 100,
268 max_batch_bytes: 8192,
269 max_delay_ms: 25,
270 };
271
272 let ext = config.to_extension();
273 assert_eq!(ext.ext_type, ext_type::BATCHING);
274
275 let decoded = BatchingConfig::from_extension(&ext).unwrap();
276 assert_eq!(decoded, config);
277 }
278
279 #[test]
280 fn test_config_negotiate() {
281 let client = BatchingConfig {
282 max_batch_size: 64,
283 max_batch_bytes: 32768,
284 max_delay_ms: 100,
285 };
286 let server = BatchingConfig {
287 max_batch_size: 32,
288 max_batch_bytes: 16384,
289 max_delay_ms: 50,
290 };
291
292 let result = BatchingConfig::negotiate(&client, &server);
293 assert_eq!(result.max_batch_size, 32);
294 assert_eq!(result.max_batch_bytes, 16384);
295 assert_eq!(result.max_delay_ms, 50);
296 }
297
298 #[test]
299 fn test_batch_add_and_encode() {
300 let config = BatchingConfig::default();
301 let mut batch = Batch::new();
302
303 assert!(batch.try_add(vec![1, 2, 3], &config));
304 assert!(batch.try_add(vec![4, 5], &config));
305 assert_eq!(batch.len(), 2);
306
307 let encoded = batch.encode();
308 let decoded = Batch::decode(&encoded).unwrap();
309
310 assert_eq!(decoded.len(), 2);
311 let updates: Vec<_> = decoded.iter().collect();
312 assert_eq!(updates[0], &[1, 2, 3]);
313 assert_eq!(updates[1], &[4, 5]);
314 }
315
316 #[test]
317 fn test_batch_size_limit() {
318 let config = BatchingConfig {
319 max_batch_size: 2,
320 max_batch_bytes: 1000,
321 max_delay_ms: 50,
322 };
323 let mut batch = Batch::new();
324
325 assert!(batch.try_add(vec![1], &config));
326 assert!(batch.try_add(vec![2], &config));
327 assert!(!batch.try_add(vec![3], &config)); assert_eq!(batch.len(), 2);
329 }
330
331 #[test]
332 fn test_batch_bytes_limit() {
333 let config = BatchingConfig {
334 max_batch_size: 100,
335 max_batch_bytes: 10, max_delay_ms: 50,
337 };
338 let mut batch = Batch::new();
339
340 assert!(batch.try_add(vec![1, 2, 3], &config)); assert!(!batch.try_add(vec![1, 2, 3, 4, 5, 6], &config)); assert_eq!(batch.len(), 1);
343 }
344
345 #[test]
346 fn test_batch_empty() {
347 let batch = Batch::new();
348 assert!(batch.is_empty());
349 assert_eq!(batch.len(), 0);
350
351 let encoded = batch.encode();
352 let decoded = Batch::decode(&encoded).unwrap();
353 assert!(decoded.is_empty());
354 }
355
356 #[test]
357 fn test_batch_decode_truncated() {
358 assert!(matches!(
360 Batch::decode(&[0]),
361 Err(NegotiationError::TooShort { .. })
362 ));
363
364 assert!(matches!(
366 Batch::decode(&[1, 0]),
367 Err(NegotiationError::TooShort { .. })
368 ));
369
370 assert!(matches!(
372 Batch::decode(&[1, 0, 5, 0, 1, 2]),
373 Err(NegotiationError::TooShort { .. })
374 ));
375 }
376
377 #[test]
378 fn test_presets() {
379 let low_lat = BatchingConfig::low_latency();
380 let high_tp = BatchingConfig::high_throughput();
381
382 assert!(low_lat.max_batch_size < high_tp.max_batch_size);
383 assert!(low_lat.max_delay_ms < high_tp.max_delay_ms);
384 }
385}