Skip to main content

dynamo_runtime/pipeline/network/egress/
http_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! HTTP/2 client for request plane
5
6use super::unified_client::{Headers, RequestPlaneClient};
7use crate::Result;
8use async_trait::async_trait;
9use bytes::Bytes;
10use std::sync::Arc;
11use std::time::Duration;
12
13/// Default timeout for HTTP requests (ack only, not full response)
14const DEFAULT_HTTP_REQUEST_TIMEOUT_SECS: u64 = 5;
15
16/// HTTP/2 Performance Configuration Constants
17const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 1024; // 1MB frame size for better throughput
18const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 1000; // Allow more concurrent streams
19const DEFAULT_POOL_MAX_IDLE_PER_HOST: usize = 100; // Increased connection pool
20const DEFAULT_POOL_IDLE_TIMEOUT_SECS: u64 = 90; // Keep connections alive longer
21const DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS: u64 = 30; // Send pings every 30s
22const DEFAULT_HTTP2_KEEP_ALIVE_TIMEOUT_SECS: u64 = 10; // Timeout for ping responses
23const DEFAULT_HTTP2_ADAPTIVE_WINDOW: bool = true; // Enable adaptive flow control
24
25/// HTTP/2 Performance Configuration
26#[derive(Debug, Clone)]
27pub struct Http2Config {
28    pub max_frame_size: u32,
29    pub max_concurrent_streams: u32,
30    pub pool_max_idle_per_host: usize,
31    pub pool_idle_timeout: Duration,
32    pub keep_alive_interval: Duration,
33    pub keep_alive_timeout: Duration,
34    pub adaptive_window: bool,
35    pub request_timeout: Duration,
36}
37
38impl Default for Http2Config {
39    fn default() -> Self {
40        Self {
41            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
42            max_concurrent_streams: DEFAULT_MAX_CONCURRENT_STREAMS,
43            pool_max_idle_per_host: DEFAULT_POOL_MAX_IDLE_PER_HOST,
44            pool_idle_timeout: Duration::from_secs(DEFAULT_POOL_IDLE_TIMEOUT_SECS),
45            keep_alive_interval: Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS),
46            keep_alive_timeout: Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_TIMEOUT_SECS),
47            adaptive_window: DEFAULT_HTTP2_ADAPTIVE_WINDOW,
48            request_timeout: Duration::from_secs(DEFAULT_HTTP_REQUEST_TIMEOUT_SECS),
49        }
50    }
51}
52
53impl Http2Config {
54    /// Create configuration from environment variables
55    pub fn from_env() -> Self {
56        let mut config = Self::default();
57
58        if let Ok(val) = std::env::var("DYN_HTTP2_MAX_FRAME_SIZE")
59            && let Ok(size) = val.parse::<u32>()
60        {
61            config.max_frame_size = size;
62        }
63
64        if let Ok(val) = std::env::var("DYN_HTTP2_MAX_CONCURRENT_STREAMS")
65            && let Ok(streams) = val.parse::<u32>()
66        {
67            config.max_concurrent_streams = streams;
68        }
69
70        if let Ok(val) = std::env::var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST")
71            && let Ok(pool_size) = val.parse::<usize>()
72        {
73            config.pool_max_idle_per_host = pool_size;
74        }
75
76        if let Ok(val) = std::env::var("DYN_HTTP2_POOL_IDLE_TIMEOUT_SECS")
77            && let Ok(timeout) = val.parse::<u64>()
78        {
79            config.pool_idle_timeout = Duration::from_secs(timeout);
80        }
81
82        if let Ok(val) = std::env::var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS")
83            && let Ok(interval) = val.parse::<u64>()
84        {
85            config.keep_alive_interval = Duration::from_secs(interval);
86        }
87
88        if let Ok(val) = std::env::var("DYN_HTTP2_KEEP_ALIVE_TIMEOUT_SECS")
89            && let Ok(timeout) = val.parse::<u64>()
90        {
91            config.keep_alive_timeout = Duration::from_secs(timeout);
92        }
93
94        if let Ok(val) = std::env::var("DYN_HTTP2_ADAPTIVE_WINDOW") {
95            config.adaptive_window = val.parse().unwrap_or(DEFAULT_HTTP2_ADAPTIVE_WINDOW);
96        }
97
98        if let Ok(val) = std::env::var("DYN_HTTP_REQUEST_TIMEOUT")
99            && let Ok(timeout) = val.parse::<u64>()
100        {
101            config.request_timeout = Duration::from_secs(timeout);
102        }
103
104        config
105    }
106}
107
108/// HTTP/2 request plane client
109pub struct HttpRequestClient {
110    client: reqwest::Client,
111    config: Http2Config,
112}
113
114impl HttpRequestClient {
115    /// Create a new HTTP request client with HTTP/2 and default configuration
116    pub fn new() -> Result<Self> {
117        Self::with_config(Http2Config::default())
118    }
119
120    /// Create a new HTTP request client with custom timeout (legacy method)
121    /// Uses HTTP/2 with prior knowledge to avoid ALPN negotiation overhead
122    pub fn with_timeout(timeout: Duration) -> Result<Self> {
123        let config = Http2Config {
124            request_timeout: timeout,
125            ..Http2Config::default()
126        };
127        Self::with_config(config)
128    }
129
130    /// Create a new HTTP request client with basic configuration
131    ///
132    /// Note: Advanced HTTP/2 configuration methods may not be available in all versions of reqwest.
133    /// This implementation uses only the stable, widely-supported configuration options.
134    pub fn with_config(config: Http2Config) -> Result<Self> {
135        let builder = reqwest::Client::builder()
136            .pool_max_idle_per_host(config.pool_max_idle_per_host)
137            .pool_idle_timeout(config.pool_idle_timeout)
138            .timeout(config.request_timeout);
139        // HTTP/2 is automatically negotiated by reqwest when available
140
141        let client = builder.build()?;
142
143        Ok(Self { client, config })
144    }
145
146    /// Create from environment configuration
147    pub fn from_env() -> Result<Self> {
148        Self::with_config(Http2Config::from_env())
149    }
150
151    /// Get the current HTTP/2 configuration
152    pub fn config(&self) -> &Http2Config {
153        &self.config
154    }
155}
156
157impl Default for HttpRequestClient {
158    fn default() -> Self {
159        Self::new().expect("Failed to create HTTP request client")
160    }
161}
162
163#[async_trait]
164impl RequestPlaneClient for HttpRequestClient {
165    async fn send_request(
166        &self,
167        address: String,
168        payload: Bytes,
169        headers: Headers,
170    ) -> Result<Bytes> {
171        let mut req = self
172            .client
173            .post(&address)
174            .header("Content-Type", "application/octet-stream")
175            .body(payload);
176
177        // Add custom headers
178        for (key, value) in headers {
179            req = req.header(key, value);
180        }
181
182        let response = req.send().await?;
183
184        if !response.status().is_success() {
185            anyhow::bail!(
186                "HTTP request failed with status {}: {}",
187                response.status(),
188                response.text().await.unwrap_or_default()
189            );
190        }
191
192        let body = response.bytes().await?;
193        Ok(body)
194    }
195
196    fn transport_name(&self) -> &'static str {
197        "http2"
198    }
199
200    fn is_healthy(&self) -> bool {
201        // HTTP client is stateless and always healthy if created successfully
202        true
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use axum::{Router, body::Bytes as AxumBytes, extract::State as AxumState, routing::post};
210    use std::sync::Arc;
211    use tokio::sync::Mutex as TokioMutex;
212
213    #[test]
214    fn test_http_client_creation() {
215        let client = HttpRequestClient::new();
216        assert!(client.is_ok());
217    }
218
219    #[test]
220    fn test_http_client_with_custom_timeout() {
221        let client = HttpRequestClient::with_timeout(Duration::from_secs(10));
222        assert!(client.is_ok());
223        assert_eq!(
224            client.unwrap().config.request_timeout,
225            Duration::from_secs(10)
226        );
227    }
228
229    #[test]
230    fn test_http2_config_from_env() {
231        // Set environment variables
232        unsafe {
233            std::env::set_var("DYN_HTTP2_MAX_FRAME_SIZE", "2097152"); // 2MB
234            std::env::set_var("DYN_HTTP2_MAX_CONCURRENT_STREAMS", "2000");
235            std::env::set_var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST", "200");
236            std::env::set_var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS", "60");
237            std::env::set_var("DYN_HTTP2_ADAPTIVE_WINDOW", "false");
238        }
239
240        let config = Http2Config::from_env();
241
242        assert_eq!(config.max_frame_size, 2097152);
243        assert_eq!(config.max_concurrent_streams, 2000);
244        assert_eq!(config.pool_max_idle_per_host, 200);
245        assert_eq!(config.keep_alive_interval, Duration::from_secs(60));
246        assert!(!config.adaptive_window);
247
248        // Clean up
249        unsafe {
250            std::env::remove_var("DYN_HTTP2_MAX_FRAME_SIZE");
251            std::env::remove_var("DYN_HTTP2_MAX_CONCURRENT_STREAMS");
252            std::env::remove_var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST");
253            std::env::remove_var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS");
254            std::env::remove_var("DYN_HTTP2_ADAPTIVE_WINDOW");
255        }
256    }
257
258    #[test]
259    fn test_http_client_with_custom_config() {
260        let config = Http2Config {
261            max_frame_size: 512 * 1024, // 512KB
262            max_concurrent_streams: 500,
263            pool_max_idle_per_host: 75,
264            pool_idle_timeout: Duration::from_secs(60),
265            keep_alive_interval: Duration::from_secs(45),
266            keep_alive_timeout: Duration::from_secs(15),
267            adaptive_window: false,
268            request_timeout: Duration::from_secs(8),
269        };
270
271        let client = HttpRequestClient::with_config(config.clone());
272        assert!(client.is_ok());
273
274        let client = client.unwrap();
275        assert_eq!(client.config.max_frame_size, 512 * 1024);
276        assert_eq!(client.config.max_concurrent_streams, 500);
277        assert_eq!(client.config.pool_max_idle_per_host, 75);
278        assert_eq!(client.config.request_timeout, Duration::from_secs(8));
279    }
280
281    #[tokio::test]
282    async fn test_http_client_send_request_invalid_url() {
283        let client = HttpRequestClient::new().unwrap();
284        let result = client
285            .send_request(
286                "http://invalid-host-that-does-not-exist:9999/test".to_string(),
287                Bytes::from("test"),
288                std::collections::HashMap::new(),
289            )
290            .await;
291        assert!(result.is_err());
292    }
293
294    #[tokio::test]
295    async fn test_http2_client_server_integration() {
296        use hyper_util::rt::{TokioExecutor, TokioIo};
297        use hyper_util::server::conn::auto::Builder as ConnBuilder;
298        use hyper_util::service::TowerToHyperService;
299
300        // Create a test server that accepts HTTP/2
301        #[derive(Clone)]
302        struct TestState {
303            received: Arc<TokioMutex<Vec<Bytes>>>,
304            protocol_version: Arc<TokioMutex<Option<String>>>,
305        }
306
307        async fn test_handler(
308            AxumState(state): AxumState<TestState>,
309            body: AxumBytes,
310        ) -> &'static str {
311            state.received.lock().await.push(body);
312            "OK"
313        }
314
315        let state = TestState {
316            received: Arc::new(TokioMutex::new(Vec::new())),
317            protocol_version: Arc::new(TokioMutex::new(None)),
318        };
319
320        let app = Router::new()
321            .route("/test", post(test_handler))
322            .with_state(state.clone());
323
324        // Bind to a random port
325        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
326        let addr = listener.local_addr().unwrap();
327
328        // Start HTTP/2 server
329        let server_handle = tokio::spawn(async move {
330            loop {
331                let Ok((stream, _)) = listener.accept().await else {
332                    break;
333                };
334
335                let app = app.clone();
336                tokio::spawn(async move {
337                    let conn_builder = ConnBuilder::new(TokioExecutor::new());
338                    let io = TokioIo::new(stream);
339                    let tower_service = app.into_service();
340                    let hyper_service = TowerToHyperService::new(tower_service);
341
342                    let _ = conn_builder.serve_connection(io, hyper_service).await;
343                });
344            }
345        });
346
347        // Give server time to start
348        tokio::time::sleep(Duration::from_millis(100)).await;
349
350        // Create HTTP/2 client with prior knowledge
351        let client = HttpRequestClient::new().unwrap();
352
353        // Send request
354        let test_data = Bytes::from("test_payload");
355        let result = client
356            .send_request(
357                format!("http://{}/test", addr),
358                test_data.clone(),
359                std::collections::HashMap::new(),
360            )
361            .await;
362
363        // Verify request succeeded
364        assert!(result.is_ok(), "Request failed: {:?}", result.err());
365
366        // Verify server received the data
367        tokio::time::sleep(Duration::from_millis(100)).await;
368        let received = state.received.lock().await;
369        assert_eq!(received.len(), 1);
370        assert_eq!(received[0], test_data);
371
372        // Cleanup
373        server_handle.abort();
374    }
375
376    #[tokio::test]
377    async fn test_http2_headers_propagation() {
378        use hyper_util::rt::{TokioExecutor, TokioIo};
379        use hyper_util::server::conn::auto::Builder as ConnBuilder;
380        use hyper_util::service::TowerToHyperService;
381
382        // Create a test server that captures headers
383        #[derive(Clone)]
384        struct HeaderState {
385            headers: Arc<TokioMutex<Vec<(String, String)>>>,
386        }
387
388        async fn header_handler(
389            AxumState(state): AxumState<HeaderState>,
390            headers: axum::http::HeaderMap,
391        ) -> &'static str {
392            let mut captured = state.headers.lock().await;
393            for (name, value) in headers.iter() {
394                if let Ok(val_str) = value.to_str() {
395                    captured.push((name.to_string(), val_str.to_string()));
396                }
397            }
398            "OK"
399        }
400
401        let state = HeaderState {
402            headers: Arc::new(TokioMutex::new(Vec::new())),
403        };
404
405        let app = Router::new()
406            .route("/test", post(header_handler))
407            .with_state(state.clone());
408
409        // Bind to a random port
410        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
411        let addr = listener.local_addr().unwrap();
412
413        // Start HTTP/2 server
414        let server_handle = tokio::spawn(async move {
415            loop {
416                let Ok((stream, _)) = listener.accept().await else {
417                    break;
418                };
419
420                let app = app.clone();
421                tokio::spawn(async move {
422                    let conn_builder = ConnBuilder::new(TokioExecutor::new());
423                    let io = TokioIo::new(stream);
424                    let tower_service = app.into_service();
425                    let hyper_service = TowerToHyperService::new(tower_service);
426
427                    let _ = conn_builder.serve_connection(io, hyper_service).await;
428                });
429            }
430        });
431
432        // Give server time to start
433        tokio::time::sleep(Duration::from_millis(100)).await;
434
435        // Create HTTP/2 client
436        let client = HttpRequestClient::new().unwrap();
437
438        // Send request with custom headers
439        let mut headers = std::collections::HashMap::new();
440        headers.insert("x-test-header".to_string(), "test-value".to_string());
441        headers.insert("x-request-id".to_string(), "req-123".to_string());
442
443        let result = client
444            .send_request(
445                format!("http://{}/test", addr),
446                Bytes::from("test"),
447                headers,
448            )
449            .await;
450
451        // Verify request succeeded
452        assert!(result.is_ok());
453
454        // Verify headers were received
455        tokio::time::sleep(Duration::from_millis(100)).await;
456        let received_headers = state.headers.lock().await;
457
458        let header_map: std::collections::HashMap<_, _> = received_headers
459            .iter()
460            .map(|(k, v)| (k.as_str(), v.as_str()))
461            .collect();
462
463        assert!(header_map.contains_key("x-test-header"));
464        assert_eq!(header_map.get("x-test-header"), Some(&"test-value"));
465        assert!(header_map.contains_key("x-request-id"));
466        assert_eq!(header_map.get("x-request-id"), Some(&"req-123"));
467
468        // Cleanup
469        server_handle.abort();
470    }
471
472    #[tokio::test]
473    async fn test_http2_concurrent_requests() {
474        use hyper_util::rt::{TokioExecutor, TokioIo};
475        use hyper_util::server::conn::auto::Builder as ConnBuilder;
476        use hyper_util::service::TowerToHyperService;
477        use std::sync::atomic::{AtomicU64, Ordering};
478
479        // Create a test server that counts requests
480        #[derive(Clone)]
481        struct CounterState {
482            count: Arc<AtomicU64>,
483        }
484
485        async fn counter_handler(AxumState(state): AxumState<CounterState>) -> String {
486            let count = state.count.fetch_add(1, Ordering::SeqCst);
487            format!("{}", count)
488        }
489
490        let state = CounterState {
491            count: Arc::new(AtomicU64::new(0)),
492        };
493
494        let app = Router::new()
495            .route("/test", post(counter_handler))
496            .with_state(state.clone());
497
498        // Bind to a random port
499        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
500        let addr = listener.local_addr().unwrap();
501
502        // Start HTTP/2 server
503        let server_handle = tokio::spawn(async move {
504            loop {
505                let Ok((stream, _)) = listener.accept().await else {
506                    break;
507                };
508
509                let app = app.clone();
510                tokio::spawn(async move {
511                    let conn_builder = ConnBuilder::new(TokioExecutor::new());
512                    let io = TokioIo::new(stream);
513                    let tower_service = app.into_service();
514                    let hyper_service = TowerToHyperService::new(tower_service);
515
516                    let _ = conn_builder.serve_connection(io, hyper_service).await;
517                });
518            }
519        });
520
521        // Give server time to start
522        tokio::time::sleep(Duration::from_millis(100)).await;
523
524        // Create HTTP/2 client
525        let client = Arc::new(HttpRequestClient::new().unwrap());
526
527        // Send multiple concurrent requests (HTTP/2 multiplexing)
528        let mut handles = vec![];
529        for _ in 0..10 {
530            let client = client.clone();
531            let handle = tokio::spawn(async move {
532                client
533                    .send_request(
534                        format!("http://{}/test", addr),
535                        Bytes::from("test"),
536                        std::collections::HashMap::new(),
537                    )
538                    .await
539            });
540            handles.push(handle);
541        }
542
543        // Wait for all requests to complete
544        let mut success_count = 0;
545        for handle in handles {
546            if let Ok(Ok(_)) = handle.await {
547                success_count += 1;
548            }
549        }
550
551        // Verify all requests succeeded
552        assert_eq!(success_count, 10);
553
554        // Verify server received all requests
555        assert_eq!(state.count.load(Ordering::SeqCst), 10);
556
557        // Cleanup
558        server_handle.abort();
559    }
560
561    #[tokio::test]
562    async fn test_http2_performance_benchmark() {
563        use hyper_util::rt::{TokioExecutor, TokioIo};
564        use hyper_util::server::conn::auto::Builder as ConnBuilder;
565        use hyper_util::service::TowerToHyperService;
566        use std::sync::atomic::{AtomicU64, Ordering};
567        use std::time::Instant;
568
569        // Create a test server that measures performance
570        #[derive(Clone)]
571        struct PerfState {
572            request_count: Arc<AtomicU64>,
573            total_bytes: Arc<AtomicU64>,
574        }
575
576        async fn perf_handler(
577            AxumState(state): AxumState<PerfState>,
578            body: AxumBytes,
579        ) -> &'static str {
580            state.request_count.fetch_add(1, Ordering::Relaxed);
581            state
582                .total_bytes
583                .fetch_add(body.len() as u64, Ordering::Relaxed);
584            "OK"
585        }
586
587        let state = PerfState {
588            request_count: Arc::new(AtomicU64::new(0)),
589            total_bytes: Arc::new(AtomicU64::new(0)),
590        };
591
592        let app = Router::new()
593            .route("/perf", post(perf_handler))
594            .with_state(state.clone());
595
596        // Bind to a random port
597        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
598        let addr = listener.local_addr().unwrap();
599
600        // Start HTTP/2 server
601        let server_handle = tokio::spawn(async move {
602            loop {
603                let Ok((stream, _)) = listener.accept().await else {
604                    break;
605                };
606
607                let app = app.clone();
608                tokio::spawn(async move {
609                    let conn_builder = ConnBuilder::new(TokioExecutor::new());
610                    let io = TokioIo::new(stream);
611                    let tower_service = app.into_service();
612                    let hyper_service = TowerToHyperService::new(tower_service);
613
614                    let _ = conn_builder.serve_connection(io, hyper_service).await;
615                });
616            }
617        });
618
619        // Give server time to start
620        tokio::time::sleep(Duration::from_millis(100)).await;
621
622        // Create optimized HTTP/2 client
623        let optimized_config = Http2Config {
624            max_frame_size: 1024 * 1024, // 1MB frames
625            max_concurrent_streams: 1000,
626            pool_max_idle_per_host: 100,
627            pool_idle_timeout: Duration::from_secs(90),
628            keep_alive_interval: Duration::from_secs(30),
629            keep_alive_timeout: Duration::from_secs(10),
630            adaptive_window: true,
631            request_timeout: Duration::from_secs(30),
632        };
633
634        let client = Arc::new(HttpRequestClient::with_config(optimized_config).unwrap());
635
636        // Performance test: Send many concurrent requests
637        let num_requests = 100;
638        let payload_size = 64 * 1024; // 64KB payload
639        let payload = Bytes::from(vec![0u8; payload_size]);
640
641        let start_time = Instant::now();
642        let mut handles = vec![];
643
644        for _ in 0..num_requests {
645            let client = client.clone();
646            let payload = payload.clone();
647
648            let handle = tokio::spawn(async move {
649                let headers = std::collections::HashMap::new();
650                client
651                    .send_request(format!("http://{}/perf", addr), payload, headers)
652                    .await
653            });
654            handles.push(handle);
655        }
656
657        // Wait for all requests to complete
658        let mut successful_requests = 0;
659        for handle in handles {
660            if handle.await.unwrap().is_ok() {
661                successful_requests += 1;
662            }
663        }
664
665        let elapsed = start_time.elapsed();
666        let requests_per_sec = successful_requests as f64 / elapsed.as_secs_f64();
667        let throughput_mbps =
668            (successful_requests * payload_size) as f64 / elapsed.as_secs_f64() / (1024.0 * 1024.0);
669
670        println!("Performance Results:");
671        println!(
672            "  Successful requests: {}/{}",
673            successful_requests, num_requests
674        );
675        println!("  Total time: {:?}", elapsed);
676        println!("  Requests/sec: {:.2}", requests_per_sec);
677        println!("  Throughput: {:.2} MB/s", throughput_mbps);
678
679        // Verify server received all requests
680        let server_count = state.request_count.load(Ordering::Relaxed);
681        let server_bytes = state.total_bytes.load(Ordering::Relaxed);
682
683        assert_eq!(server_count, successful_requests as u64);
684        assert_eq!(server_bytes, (successful_requests * payload_size) as u64);
685
686        // Performance assertions (adjust based on your requirements)
687        assert!(successful_requests >= num_requests * 95 / 100); // At least 95% success rate
688        assert!(requests_per_sec > 50.0); // At least 50 requests per second
689        assert!(throughput_mbps > 10.0); // At least 10 MB/s throughput
690
691        // Cleanup
692        server_handle.abort();
693    }
694}