network_protocol/service/
multiplex.rs1use crate::error::{ProtocolError, Result};
23use dashmap::DashMap;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot, Semaphore};
29use tracing::{debug, error, warn};
30
31pub type RequestId = u64;
33
34#[derive(Debug, Clone)]
36pub struct MultiplexFrame {
37 pub request_id: RequestId,
39 pub payload: Vec<u8>,
41}
42
43#[derive(Debug, Clone)]
45pub struct MultiplexConfig {
46 pub max_in_flight: usize,
48 pub request_timeout: Duration,
50 pub send_buffer_size: usize,
52}
53
54impl MultiplexConfig {
55 pub fn validate(&self) -> Result<()> {
57 let mut errors = Vec::new();
58
59 if self.max_in_flight == 0 {
61 errors.push("max_in_flight must be greater than 0".to_string());
62 }
63
64 if self.max_in_flight > 1_000_000 {
65 errors.push(format!(
66 "max_in_flight ({}) exceeds recommended limit (1,000,000)",
67 self.max_in_flight
68 ));
69 }
70
71 if self.request_timeout.is_zero() {
73 errors.push("request_timeout must be greater than 0".to_string());
74 }
75
76 if self.request_timeout.as_millis() < 100 {
77 errors.push(format!(
78 "request_timeout ({} ms) is too short (minimum: 100ms)",
79 self.request_timeout.as_millis()
80 ));
81 }
82
83 if self.request_timeout.as_secs() > 300 {
84 errors.push(format!(
85 "request_timeout ({} seconds) is unusually long (recommended: < 5 minutes)",
86 self.request_timeout.as_secs()
87 ));
88 }
89
90 if self.send_buffer_size == 0 {
92 errors.push("send_buffer_size must be greater than 0".to_string());
93 }
94
95 if self.send_buffer_size > 10_000 {
96 errors.push(format!(
97 "send_buffer_size ({}) is unusually large (recommended: < 10,000)",
98 self.send_buffer_size
99 ));
100 }
101
102 if errors.is_empty() {
104 Ok(())
105 } else {
106 Err(ProtocolError::ConfigError(format!(
107 "Multiplex configuration validation failed:\n - {}",
108 errors.join("\n - ")
109 )))
110 }
111 }
112}
113
114impl Default for MultiplexConfig {
115 fn default() -> Self {
116 Self {
117 max_in_flight: 10_000, request_timeout: Duration::from_secs(30),
119 send_buffer_size: 100,
120 }
121 }
122}
123
124struct PendingRequest {
126 response_tx: oneshot::Sender<Vec<u8>>,
127 created_at: Instant,
128}
129
130#[derive(Debug, Default)]
132pub struct MultiplexMetrics {
133 pub requests_sent: AtomicU64,
135 pub responses_received: AtomicU64,
137 pub timeouts: AtomicU64,
139 pub errors: AtomicU64,
141 pub in_flight: AtomicU64,
143}
144
145pub struct Multiplexer<R, W>
147where
148 R: AsyncReadExt + Send + Unpin + 'static,
149 W: AsyncWriteExt + Send + Unpin + 'static,
150{
151 config: MultiplexConfig,
152 next_request_id: Arc<AtomicU64>,
153 pending: Arc<DashMap<RequestId, PendingRequest>>,
154 send_tx: mpsc::Sender<MultiplexFrame>,
155 backpressure: Arc<Semaphore>,
156 metrics: Arc<MultiplexMetrics>,
157 reader: Option<R>,
158 writer: Option<W>,
159}
160
161impl<R, W> Multiplexer<R, W>
162where
163 R: AsyncReadExt + Send + Unpin + 'static,
164 W: AsyncWriteExt + Send + Unpin + 'static,
165{
166 pub fn new(reader: R, writer: W, config: MultiplexConfig) -> Self {
168 let (send_tx, send_rx) = mpsc::channel(config.send_buffer_size);
169
170 let pending = Arc::new(DashMap::new());
171 let metrics = Arc::new(MultiplexMetrics::default());
172 let backpressure = Arc::new(Semaphore::new(config.max_in_flight));
173
174 let mut multiplexer = Self {
175 config: config.clone(),
176 next_request_id: Arc::new(AtomicU64::new(1)),
177 pending: pending.clone(),
178 send_tx,
179 backpressure,
180 metrics: metrics.clone(),
181 reader: Some(reader),
182 writer: Some(writer),
183 };
184
185 #[allow(clippy::expect_used)] let writer = multiplexer.writer.take().expect("Writer should exist");
188 tokio::spawn(Self::send_loop(writer, send_rx, metrics.clone()));
189
190 #[allow(clippy::expect_used)] let reader = multiplexer.reader.take().expect("Reader should exist");
193 tokio::spawn(Self::receive_loop(reader, pending.clone(), metrics.clone()));
194
195 let pending_clone = pending.clone();
197 let timeout = config.request_timeout;
198 let metrics_clone = metrics.clone();
199 tokio::spawn(async move {
200 let mut interval = tokio::time::interval(Duration::from_secs(5));
201 loop {
202 interval.tick().await;
203 Self::cleanup_stale_requests(&pending_clone, timeout, &metrics_clone);
204 }
205 });
206
207 multiplexer
208 }
209
210 pub async fn request(&self, payload: Vec<u8>) -> Result<Vec<u8>> {
212 let _permit = self
214 .backpressure
215 .acquire()
216 .await
217 .map_err(|_| ProtocolError::PoolExhausted)?;
218
219 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
221
222 let (response_tx, response_rx) = oneshot::channel();
224
225 self.pending.insert(
227 request_id,
228 PendingRequest {
229 response_tx,
230 created_at: Instant::now(),
231 },
232 );
233
234 self.metrics.in_flight.fetch_add(1, Ordering::Relaxed);
235
236 let frame = MultiplexFrame {
238 request_id,
239 payload,
240 };
241
242 self.send_tx
243 .send(frame)
244 .await
245 .map_err(|_| ProtocolError::ConnectionClosed)?;
246
247 self.metrics.requests_sent.fetch_add(1, Ordering::Relaxed);
248
249 tokio::time::timeout(self.config.request_timeout, response_rx)
251 .await
252 .map_err(|_| {
253 self.pending.remove(&request_id);
254 self.metrics.timeouts.fetch_add(1, Ordering::Relaxed);
255 self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
256 ProtocolError::Timeout
257 })?
258 .map_err(|_| {
259 self.metrics.errors.fetch_add(1, Ordering::Relaxed);
260 self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
261 ProtocolError::ConnectionClosed
262 })
263 }
264
265 async fn send_loop(
267 mut writer: W,
268 mut send_rx: mpsc::Receiver<MultiplexFrame>,
269 _metrics: Arc<MultiplexMetrics>,
270 ) {
271 while let Some(frame) = send_rx.recv().await {
272 let payload_len = frame.payload.len() as u32;
274
275 if let Err(e) = writer.write_u64(frame.request_id).await {
276 error!("Failed to write request ID: {}", e);
277 break;
278 }
279
280 if let Err(e) = writer.write_u32(payload_len).await {
281 error!("Failed to write payload length: {}", e);
282 break;
283 }
284
285 if let Err(e) = writer.write_all(&frame.payload).await {
286 error!("Failed to write payload: {}", e);
287 break;
288 }
289
290 if let Err(e) = writer.flush().await {
291 error!("Failed to flush writer: {}", e);
292 break;
293 }
294
295 debug!("Sent multiplexed request {}", frame.request_id);
296 }
297 }
298
299 async fn receive_loop(
301 mut reader: R,
302 pending: Arc<DashMap<RequestId, PendingRequest>>,
303 metrics: Arc<MultiplexMetrics>,
304 ) {
305 loop {
306 let request_id = match reader.read_u64().await {
308 Ok(id) => id,
309 Err(e) => {
310 error!("Failed to read request ID: {}", e);
311 break;
312 }
313 };
314
315 let payload_len = match reader.read_u32().await {
316 Ok(len) => len as usize,
317 Err(e) => {
318 error!("Failed to read payload length: {}", e);
319 break;
320 }
321 };
322
323 let mut payload = vec![0u8; payload_len];
324 if let Err(e) = reader.read_exact(&mut payload).await {
325 error!("Failed to read payload: {}", e);
326 break;
327 }
328
329 debug!("Received multiplexed response {}", request_id);
330
331 if let Some((_, pending_req)) = pending.remove(&request_id) {
333 metrics.responses_received.fetch_add(1, Ordering::Relaxed);
334 metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
335
336 if pending_req.response_tx.send(payload).is_err() {
337 warn!("Failed to send response to waiting request {}", request_id);
338 }
339 } else {
340 warn!("Received response for unknown request {}", request_id);
341 }
342 }
343 }
344
345 fn cleanup_stale_requests(
347 pending: &Arc<DashMap<RequestId, PendingRequest>>,
348 timeout: Duration,
349 metrics: &Arc<MultiplexMetrics>,
350 ) {
351 let now = Instant::now();
352 let mut stale_count = 0;
353
354 pending.retain(|_id, req| {
355 let is_stale = now.duration_since(req.created_at) > timeout;
356 if is_stale {
357 stale_count += 1;
358 metrics.timeouts.fetch_add(1, Ordering::Relaxed);
359 metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
360 }
361 !is_stale
362 });
363
364 if stale_count > 0 {
365 warn!("Cleaned up {} stale requests", stale_count);
366 }
367 }
368
369 pub fn metrics(&self) -> Arc<MultiplexMetrics> {
371 self.metrics.clone()
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[tokio::test]
380 #[allow(clippy::unwrap_used)] async fn test_multiplex_single_request() {
382 let (client_stream, server_stream) = tokio::io::duplex(1024);
383 let (client_reader, client_writer) = tokio::io::split(client_stream);
384
385 let config = MultiplexConfig::default();
386 let multiplexer = Multiplexer::new(client_reader, client_writer, config);
387
388 tokio::spawn(async move {
390 let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
391 #[allow(clippy::while_let_loop)] loop {
393 let request_id = match server_reader.read_u64().await {
394 Ok(id) => id,
395 Err(_) => break,
396 };
397 let payload_len = match server_reader.read_u32().await {
398 Ok(len) => len,
399 Err(_) => break,
400 };
401 let mut payload = vec![0u8; payload_len as usize];
402 if server_reader.read_exact(&mut payload).await.is_err() {
403 break;
404 }
405
406 if server_writer.write_u64(request_id).await.is_err() {
408 break;
409 }
410 if server_writer.write_u32(payload_len).await.is_err() {
411 break;
412 }
413 if server_writer.write_all(&payload).await.is_err() {
414 break;
415 }
416 if server_writer.flush().await.is_err() {
417 break;
418 }
419 }
420 });
421
422 let response = multiplexer.request(b"hello".to_vec()).await.unwrap();
423 assert_eq!(response, b"hello");
424
425 let metrics = multiplexer.metrics();
426 assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 1);
427 assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 1);
428 }
429
430 #[tokio::test]
431 #[allow(clippy::unwrap_used)] async fn test_multiplex_concurrent_requests() {
433 let (client_stream, server_stream) = tokio::io::duplex(8192);
434 let (client_reader, client_writer) = tokio::io::split(client_stream);
435
436 let config = MultiplexConfig::default();
437 let multiplexer = Arc::new(Multiplexer::new(client_reader, client_writer, config));
438
439 tokio::spawn(async move {
441 let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
442 #[allow(clippy::while_let_loop)] loop {
444 let request_id = match server_reader.read_u64().await {
445 Ok(id) => id,
446 Err(_) => break,
447 };
448 let payload_len = match server_reader.read_u32().await {
449 Ok(len) => len,
450 Err(_) => break,
451 };
452 let mut payload = vec![0u8; payload_len as usize];
453 if server_reader.read_exact(&mut payload).await.is_err() {
454 break;
455 }
456
457 if server_writer.write_u64(request_id).await.is_err() {
459 break;
460 }
461 if server_writer.write_u32(payload_len).await.is_err() {
462 break;
463 }
464 if server_writer.write_all(&payload).await.is_err() {
465 break;
466 }
467 if server_writer.flush().await.is_err() {
468 break;
469 }
470 }
471 });
472
473 let mut tasks = vec![];
475 for i in 0..10 {
476 let multiplexer_clone = multiplexer.clone();
477 tasks.push(tokio::spawn(async move {
478 let payload = format!("request_{}", i).into_bytes();
479 multiplexer_clone.request(payload.clone()).await.unwrap()
480 }));
481 }
482
483 for task in tasks {
485 task.await.unwrap();
486 }
487
488 let metrics = multiplexer.metrics();
489 assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 10);
490 assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 10);
491 }
492
493 #[tokio::test]
494 async fn test_multiplex_config_validation() {
495 let config = MultiplexConfig::default();
496 assert!(config.validate().is_ok());
497 }
498
499 #[tokio::test]
500 async fn test_multiplex_config_validation_zero_in_flight() {
501 let config = MultiplexConfig {
502 max_in_flight: 0,
503 ..Default::default()
504 };
505 assert!(config.validate().is_err());
506 }
507
508 #[tokio::test]
509 async fn test_multiplex_config_validation_zero_timeout() {
510 let config = MultiplexConfig {
511 request_timeout: Duration::from_secs(0),
512 ..Default::default()
513 };
514 assert!(config.validate().is_err());
515 }
516
517 #[tokio::test]
518 async fn test_multiplex_config_validation_short_timeout() {
519 let config = MultiplexConfig {
520 request_timeout: Duration::from_millis(50),
521 ..Default::default()
522 };
523 assert!(config.validate().is_err());
524 }
525
526 #[tokio::test]
527 async fn test_multiplex_config_validation_zero_buffer() {
528 let config = MultiplexConfig {
529 send_buffer_size: 0,
530 ..Default::default()
531 };
532 assert!(config.validate().is_err());
533 }
534}