Skip to main content

network_protocol/service/
multiplex.rs

1//! # Request Multiplexing
2//!
3//! High-performance request multiplexing over a single connection using request ID tagging.
4//! This is the key to beating Oracle's OLTP performance:
5//! - Eliminates connection pool exhaustion under high concurrency
6//! - Allows thousands of concurrent requests over a handful of connections
7//! - Sub-millisecond request routing via lockless hash map
8//! - Zero-copy frame processing with per-request channels
9//!
10//! ## Architecture
11//! - Each request gets a unique u64 ID (atomic counter)
12//! - Sender tags outgoing requests with ID + payload
13//! - Background task demuxes incoming responses by ID
14//! - Per-request oneshot channels for response delivery
15//!
16//! ## Performance Characteristics
17//! - O(1) request routing (DashMap lockless concurrent hashmap)
18//! - Zero heap allocations per request (pre-sized buffers)
19//! - Automatic cleanup of stale requests (timeout + memory pressure)
20//! - Backpressure via in-flight limit (prevents OOM)
21
22use crate::error::{ProtocolError, Result};
23use dashmap::DashMap;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot, Semaphore};
29use tracing::{debug, error, warn};
30
31/// Unique request identifier (64-bit for collision-free namespace)
32pub type RequestId = u64;
33
34/// Multiplexed request frame
35#[derive(Debug, Clone)]
36pub struct MultiplexFrame {
37    /// Request ID for correlation
38    pub request_id: RequestId,
39    /// Request/response payload
40    pub payload: Vec<u8>,
41}
42
43/// Configuration for multiplexer
44#[derive(Debug, Clone)]
45pub struct MultiplexConfig {
46    /// Maximum concurrent in-flight requests (backpressure)
47    pub max_in_flight: usize,
48    /// Request timeout (detect stale/abandoned requests)
49    pub request_timeout: Duration,
50    /// Channel buffer size for outgoing requests
51    pub send_buffer_size: usize,
52}
53
54impl MultiplexConfig {
55    /// Validate configuration parameters
56    pub fn validate(&self) -> Result<()> {
57        let mut errors = Vec::new();
58
59        // Validate max in-flight
60        if self.max_in_flight == 0 {
61            errors.push("max_in_flight must be greater than 0".to_string());
62        }
63
64        if self.max_in_flight > 1_000_000 {
65            errors.push(format!(
66                "max_in_flight ({}) exceeds recommended limit (1,000,000)",
67                self.max_in_flight
68            ));
69        }
70
71        // Validate request timeout
72        if self.request_timeout.is_zero() {
73            errors.push("request_timeout must be greater than 0".to_string());
74        }
75
76        if self.request_timeout.as_millis() < 100 {
77            errors.push(format!(
78                "request_timeout ({} ms) is too short (minimum: 100ms)",
79                self.request_timeout.as_millis()
80            ));
81        }
82
83        if self.request_timeout.as_secs() > 300 {
84            errors.push(format!(
85                "request_timeout ({} seconds) is unusually long (recommended: < 5 minutes)",
86                self.request_timeout.as_secs()
87            ));
88        }
89
90        // Validate send buffer size
91        if self.send_buffer_size == 0 {
92            errors.push("send_buffer_size must be greater than 0".to_string());
93        }
94
95        if self.send_buffer_size > 10_000 {
96            errors.push(format!(
97                "send_buffer_size ({}) is unusually large (recommended: < 10,000)",
98                self.send_buffer_size
99            ));
100        }
101
102        // Return aggregated errors
103        if errors.is_empty() {
104            Ok(())
105        } else {
106            Err(ProtocolError::ConfigError(format!(
107                "Multiplex configuration validation failed:\n  - {}",
108                errors.join("\n  - ")
109            )))
110        }
111    }
112}
113
114impl Default for MultiplexConfig {
115    fn default() -> Self {
116        Self {
117            max_in_flight: 10_000, // Oracle-scale concurrency
118            request_timeout: Duration::from_secs(30),
119            send_buffer_size: 100,
120        }
121    }
122}
123
124/// Pending request awaiting response
125struct PendingRequest {
126    response_tx: oneshot::Sender<Vec<u8>>,
127    created_at: Instant,
128}
129
130/// Multiplexer metrics
131#[derive(Debug, Default)]
132pub struct MultiplexMetrics {
133    /// Total requests sent
134    pub requests_sent: AtomicU64,
135    /// Total responses received
136    pub responses_received: AtomicU64,
137    /// Total timeouts
138    pub timeouts: AtomicU64,
139    /// Total errors
140    pub errors: AtomicU64,
141    /// Current in-flight requests
142    pub in_flight: AtomicU64,
143}
144
145/// Request multiplexer for a single connection
146pub struct Multiplexer<R, W>
147where
148    R: AsyncReadExt + Send + Unpin + 'static,
149    W: AsyncWriteExt + Send + Unpin + 'static,
150{
151    config: MultiplexConfig,
152    next_request_id: Arc<AtomicU64>,
153    pending: Arc<DashMap<RequestId, PendingRequest>>,
154    send_tx: mpsc::Sender<MultiplexFrame>,
155    backpressure: Arc<Semaphore>,
156    metrics: Arc<MultiplexMetrics>,
157    reader: Option<R>,
158    writer: Option<W>,
159}
160
161impl<R, W> Multiplexer<R, W>
162where
163    R: AsyncReadExt + Send + Unpin + 'static,
164    W: AsyncWriteExt + Send + Unpin + 'static,
165{
166    /// Create a new multiplexer over a connection
167    pub fn new(reader: R, writer: W, config: MultiplexConfig) -> Self {
168        let (send_tx, send_rx) = mpsc::channel(config.send_buffer_size);
169
170        let pending = Arc::new(DashMap::new());
171        let metrics = Arc::new(MultiplexMetrics::default());
172        let backpressure = Arc::new(Semaphore::new(config.max_in_flight));
173
174        let mut multiplexer = Self {
175            config: config.clone(),
176            next_request_id: Arc::new(AtomicU64::new(1)),
177            pending: pending.clone(),
178            send_tx,
179            backpressure,
180            metrics: metrics.clone(),
181            reader: Some(reader),
182            writer: Some(writer),
183        };
184
185        // Spawn send task
186        #[allow(clippy::expect_used)] // Writer guaranteed to exist during initialization
187        let writer = multiplexer.writer.take().expect("Writer should exist");
188        tokio::spawn(Self::send_loop(writer, send_rx, metrics.clone()));
189
190        // Spawn receive task
191        #[allow(clippy::expect_used)] // Reader guaranteed to exist during initialization
192        let reader = multiplexer.reader.take().expect("Reader should exist");
193        tokio::spawn(Self::receive_loop(reader, pending.clone(), metrics.clone()));
194
195        // Spawn cleanup task for stale requests
196        let pending_clone = pending.clone();
197        let timeout = config.request_timeout;
198        let metrics_clone = metrics.clone();
199        tokio::spawn(async move {
200            let mut interval = tokio::time::interval(Duration::from_secs(5));
201            loop {
202                interval.tick().await;
203                Self::cleanup_stale_requests(&pending_clone, timeout, &metrics_clone);
204            }
205        });
206
207        multiplexer
208    }
209
210    /// Send a request and wait for response
211    pub async fn request(&self, payload: Vec<u8>) -> Result<Vec<u8>> {
212        // Enforce backpressure
213        let _permit = self
214            .backpressure
215            .acquire()
216            .await
217            .map_err(|_| ProtocolError::PoolExhausted)?;
218
219        // Generate unique request ID
220        let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
221
222        // Create oneshot channel for response
223        let (response_tx, response_rx) = oneshot::channel();
224
225        // Register pending request
226        self.pending.insert(
227            request_id,
228            PendingRequest {
229                response_tx,
230                created_at: Instant::now(),
231            },
232        );
233
234        self.metrics.in_flight.fetch_add(1, Ordering::Relaxed);
235
236        // Send request frame
237        let frame = MultiplexFrame {
238            request_id,
239            payload,
240        };
241
242        self.send_tx
243            .send(frame)
244            .await
245            .map_err(|_| ProtocolError::ConnectionClosed)?;
246
247        self.metrics.requests_sent.fetch_add(1, Ordering::Relaxed);
248
249        // Wait for response with timeout
250        tokio::time::timeout(self.config.request_timeout, response_rx)
251            .await
252            .map_err(|_| {
253                self.pending.remove(&request_id);
254                self.metrics.timeouts.fetch_add(1, Ordering::Relaxed);
255                self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
256                ProtocolError::Timeout
257            })?
258            .map_err(|_| {
259                self.metrics.errors.fetch_add(1, Ordering::Relaxed);
260                self.metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
261                ProtocolError::ConnectionClosed
262            })
263    }
264
265    /// Send loop: writes outgoing frames to connection
266    async fn send_loop(
267        mut writer: W,
268        mut send_rx: mpsc::Receiver<MultiplexFrame>,
269        _metrics: Arc<MultiplexMetrics>,
270    ) {
271        while let Some(frame) = send_rx.recv().await {
272            // Frame format: [request_id: u64][payload_len: u32][payload: bytes]
273            let payload_len = frame.payload.len() as u32;
274
275            if let Err(e) = writer.write_u64(frame.request_id).await {
276                error!("Failed to write request ID: {}", e);
277                break;
278            }
279
280            if let Err(e) = writer.write_u32(payload_len).await {
281                error!("Failed to write payload length: {}", e);
282                break;
283            }
284
285            if let Err(e) = writer.write_all(&frame.payload).await {
286                error!("Failed to write payload: {}", e);
287                break;
288            }
289
290            if let Err(e) = writer.flush().await {
291                error!("Failed to flush writer: {}", e);
292                break;
293            }
294
295            debug!("Sent multiplexed request {}", frame.request_id);
296        }
297    }
298
299    /// Receive loop: reads incoming frames and routes to waiting requests
300    async fn receive_loop(
301        mut reader: R,
302        pending: Arc<DashMap<RequestId, PendingRequest>>,
303        metrics: Arc<MultiplexMetrics>,
304    ) {
305        loop {
306            // Read frame: [request_id: u64][payload_len: u32][payload: bytes]
307            let request_id = match reader.read_u64().await {
308                Ok(id) => id,
309                Err(e) => {
310                    error!("Failed to read request ID: {}", e);
311                    break;
312                }
313            };
314
315            let payload_len = match reader.read_u32().await {
316                Ok(len) => len as usize,
317                Err(e) => {
318                    error!("Failed to read payload length: {}", e);
319                    break;
320                }
321            };
322
323            let mut payload = vec![0u8; payload_len];
324            if let Err(e) = reader.read_exact(&mut payload).await {
325                error!("Failed to read payload: {}", e);
326                break;
327            }
328
329            debug!("Received multiplexed response {}", request_id);
330
331            // Route to waiting request
332            if let Some((_, pending_req)) = pending.remove(&request_id) {
333                metrics.responses_received.fetch_add(1, Ordering::Relaxed);
334                metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
335
336                if pending_req.response_tx.send(payload).is_err() {
337                    warn!("Failed to send response to waiting request {}", request_id);
338                }
339            } else {
340                warn!("Received response for unknown request {}", request_id);
341            }
342        }
343    }
344
345    /// Cleanup stale requests that exceeded timeout
346    fn cleanup_stale_requests(
347        pending: &Arc<DashMap<RequestId, PendingRequest>>,
348        timeout: Duration,
349        metrics: &Arc<MultiplexMetrics>,
350    ) {
351        let now = Instant::now();
352        let mut stale_count = 0;
353
354        pending.retain(|_id, req| {
355            let is_stale = now.duration_since(req.created_at) > timeout;
356            if is_stale {
357                stale_count += 1;
358                metrics.timeouts.fetch_add(1, Ordering::Relaxed);
359                metrics.in_flight.fetch_sub(1, Ordering::Relaxed);
360            }
361            !is_stale
362        });
363
364        if stale_count > 0 {
365            warn!("Cleaned up {} stale requests", stale_count);
366        }
367    }
368
369    /// Get current metrics
370    pub fn metrics(&self) -> Arc<MultiplexMetrics> {
371        self.metrics.clone()
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[tokio::test]
380    #[allow(clippy::unwrap_used)] // Test code
381    async fn test_multiplex_single_request() {
382        let (client_stream, server_stream) = tokio::io::duplex(1024);
383        let (client_reader, client_writer) = tokio::io::split(client_stream);
384
385        let config = MultiplexConfig::default();
386        let multiplexer = Multiplexer::new(client_reader, client_writer, config);
387
388        // Spawn server echo handler
389        tokio::spawn(async move {
390            let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
391            #[allow(clippy::while_let_loop)] // More readable in this context
392            loop {
393                let request_id = match server_reader.read_u64().await {
394                    Ok(id) => id,
395                    Err(_) => break,
396                };
397                let payload_len = match server_reader.read_u32().await {
398                    Ok(len) => len,
399                    Err(_) => break,
400                };
401                let mut payload = vec![0u8; payload_len as usize];
402                if server_reader.read_exact(&mut payload).await.is_err() {
403                    break;
404                }
405
406                // Echo back
407                if server_writer.write_u64(request_id).await.is_err() {
408                    break;
409                }
410                if server_writer.write_u32(payload_len).await.is_err() {
411                    break;
412                }
413                if server_writer.write_all(&payload).await.is_err() {
414                    break;
415                }
416                if server_writer.flush().await.is_err() {
417                    break;
418                }
419            }
420        });
421
422        let response = multiplexer.request(b"hello".to_vec()).await.unwrap();
423        assert_eq!(response, b"hello");
424
425        let metrics = multiplexer.metrics();
426        assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 1);
427        assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 1);
428    }
429
430    #[tokio::test]
431    #[allow(clippy::unwrap_used)] // Test code
432    async fn test_multiplex_concurrent_requests() {
433        let (client_stream, server_stream) = tokio::io::duplex(8192);
434        let (client_reader, client_writer) = tokio::io::split(client_stream);
435
436        let config = MultiplexConfig::default();
437        let multiplexer = Arc::new(Multiplexer::new(client_reader, client_writer, config));
438
439        // Spawn server echo handler
440        tokio::spawn(async move {
441            let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);
442            #[allow(clippy::while_let_loop)] // More readable in this context
443            loop {
444                let request_id = match server_reader.read_u64().await {
445                    Ok(id) => id,
446                    Err(_) => break,
447                };
448                let payload_len = match server_reader.read_u32().await {
449                    Ok(len) => len,
450                    Err(_) => break,
451                };
452                let mut payload = vec![0u8; payload_len as usize];
453                if server_reader.read_exact(&mut payload).await.is_err() {
454                    break;
455                }
456
457                // Echo back
458                if server_writer.write_u64(request_id).await.is_err() {
459                    break;
460                }
461                if server_writer.write_u32(payload_len).await.is_err() {
462                    break;
463                }
464                if server_writer.write_all(&payload).await.is_err() {
465                    break;
466                }
467                if server_writer.flush().await.is_err() {
468                    break;
469                }
470            }
471        });
472
473        // Send 10 concurrent requests
474        let mut tasks = vec![];
475        for i in 0..10 {
476            let multiplexer_clone = multiplexer.clone();
477            tasks.push(tokio::spawn(async move {
478                let payload = format!("request_{}", i).into_bytes();
479                multiplexer_clone.request(payload.clone()).await.unwrap()
480            }));
481        }
482
483        // Wait for all responses
484        for task in tasks {
485            task.await.unwrap();
486        }
487
488        let metrics = multiplexer.metrics();
489        assert_eq!(metrics.requests_sent.load(Ordering::Relaxed), 10);
490        assert_eq!(metrics.responses_received.load(Ordering::Relaxed), 10);
491    }
492
493    #[tokio::test]
494    async fn test_multiplex_config_validation() {
495        let config = MultiplexConfig::default();
496        assert!(config.validate().is_ok());
497    }
498
499    #[tokio::test]
500    async fn test_multiplex_config_validation_zero_in_flight() {
501        let config = MultiplexConfig {
502            max_in_flight: 0,
503            ..Default::default()
504        };
505        assert!(config.validate().is_err());
506    }
507
508    #[tokio::test]
509    async fn test_multiplex_config_validation_zero_timeout() {
510        let config = MultiplexConfig {
511            request_timeout: Duration::from_secs(0),
512            ..Default::default()
513        };
514        assert!(config.validate().is_err());
515    }
516
517    #[tokio::test]
518    async fn test_multiplex_config_validation_short_timeout() {
519        let config = MultiplexConfig {
520            request_timeout: Duration::from_millis(50),
521            ..Default::default()
522        };
523        assert!(config.validate().is_err());
524    }
525
526    #[tokio::test]
527    async fn test_multiplex_config_validation_zero_buffer() {
528        let config = MultiplexConfig {
529            send_buffer_size: 0,
530            ..Default::default()
531        };
532        assert!(config.validate().is_err());
533    }
534}