1use bytes::Bytes;
15use pingora::http::RequestHeader;
16use pingora::proxy::Session;
17use rand::RngExt;
18use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
19use std::collections::HashMap;
20use std::str::FromStr;
21use std::sync::Arc;
22use std::time::Instant;
23use tokio::time::Duration;
24use tracing::{debug, error, trace, warn};
25use grapsus_common::errors::{GrapsusError, GrapsusResult};
26use grapsus_common::observability::RequestMetrics;
27use grapsus_config::routes::ShadowConfig;
28
29use crate::{RequestContext, UpstreamPool};
30
31#[derive(Clone)]
36pub struct ShadowManager {
37 upstream_pools: Arc<HashMap<String, Arc<UpstreamPool>>>,
39
40 config: ShadowConfig,
42
43 metrics: Option<Arc<RequestMetrics>>,
45
46 route_id: String,
48
49 client: reqwest::Client,
51}
52
53impl ShadowManager {
54 pub fn new(
56 upstream_pools: Arc<HashMap<String, Arc<UpstreamPool>>>,
57 config: ShadowConfig,
58 metrics: Option<Arc<RequestMetrics>>,
59 route_id: String,
60 ) -> Self {
61 let client = reqwest::Client::builder()
63 .timeout(Duration::from_millis(config.timeout_ms))
64 .pool_max_idle_per_host(10)
65 .pool_idle_timeout(Duration::from_secs(30))
66 .danger_accept_invalid_certs(true)
68 .build()
69 .unwrap_or_else(|_| reqwest::Client::new());
70
71 Self {
72 upstream_pools,
73 config,
74 metrics,
75 route_id,
76 client,
77 }
78 }
79
80 pub fn should_shadow(&self, headers: &RequestHeader) -> bool {
82 if let Some((header_name, header_value)) = &self.config.sample_header {
84 if let Some(actual_value) = headers.headers.get(header_name) {
85 if actual_value.to_str().ok() != Some(header_value.as_str()) {
86 trace!("Shadow skipped: sample-header mismatch");
87 return false;
88 }
89 } else {
90 trace!("Shadow skipped: sample-header not present");
91 return false;
92 }
93 }
94
95 if self.config.percentage < 100.0 {
97 let mut rng = rand::rng();
98 let roll: f64 = rng.random_range(0.0..100.0);
99 if roll > self.config.percentage {
100 trace!(
101 roll = roll,
102 threshold = self.config.percentage,
103 "Shadow skipped: sampling"
104 );
105 return false;
106 }
107 }
108
109 true
110 }
111
112 pub fn shadow_request(
124 &self,
125 original_headers: RequestHeader,
126 body: Option<Vec<u8>>,
127 ctx: RequestContext,
128 ) {
129 if !self.upstream_pools.contains_key(&self.config.upstream) {
131 warn!(
132 upstream = %self.config.upstream,
133 "Shadow upstream not found in pools"
134 );
135 if let Some(ref metrics) = self.metrics {
137 metrics.record_shadow_error(
138 &self.route_id,
139 &self.config.upstream,
140 "upstream_not_found",
141 );
142 }
143 return;
144 }
145
146 let config = self.config.clone();
147 let upstream_id = self.config.upstream.clone();
148 let upstream_pools = Arc::clone(&self.upstream_pools);
149 let metrics = self.metrics.clone();
150 let route_id = self.route_id.clone();
151 let client = self.client.clone();
152
153 tokio::spawn(async move {
155 let start = Instant::now();
156
157 let upstream_pool = match upstream_pools.get(&upstream_id) {
159 Some(pool) => pool,
160 None => {
161 warn!(upstream = %upstream_id, "Shadow upstream disappeared");
163 return;
164 }
165 };
166
167 let result = tokio::time::timeout(
169 Duration::from_millis(config.timeout_ms),
170 Self::execute_shadow_request(
171 &client,
172 upstream_pool,
173 original_headers,
174 body,
175 ctx.clone(),
176 ),
177 )
178 .await;
179
180 let latency = start.elapsed();
181
182 match result {
183 Ok(Ok(())) => {
184 debug!(
185 upstream = %upstream_id,
186 latency_ms = latency.as_millis(),
187 path = %ctx.path,
188 method = %ctx.method,
189 "Shadow request completed successfully"
190 );
191 if let Some(ref metrics) = metrics {
193 metrics.record_shadow_success(&route_id, &upstream_id, latency);
194 }
195 }
196 Ok(Err(e)) => {
197 error!(
198 upstream = %upstream_id,
199 error = %e,
200 latency_ms = latency.as_millis(),
201 path = %ctx.path,
202 method = %ctx.method,
203 "Shadow request failed"
204 );
205 if let Some(ref metrics) = metrics {
207 metrics.record_shadow_error(&route_id, &upstream_id, "request_failed");
208 }
209 }
210 Err(_) => {
211 warn!(
212 upstream = %upstream_id,
213 timeout_ms = config.timeout_ms,
214 path = %ctx.path,
215 method = %ctx.method,
216 "Shadow request timed out"
217 );
218 if let Some(ref metrics) = metrics {
220 metrics.record_shadow_timeout(&route_id, &upstream_id, latency);
221 }
222 }
223 }
224 });
225 }
226
227 async fn execute_shadow_request(
232 client: &reqwest::Client,
233 upstream_pool: &UpstreamPool,
234 headers: RequestHeader,
235 body: Option<Vec<u8>>,
236 ctx: RequestContext,
237 ) -> GrapsusResult<()> {
238 let target = upstream_pool.select_shadow_target(Some(&ctx)).await?;
240
241 let url = target.build_url(&ctx.path);
243
244 trace!(
245 url = %url,
246 method = %ctx.method,
247 body_size = body.as_ref().map(|b| b.len()).unwrap_or(0),
248 "Executing shadow request"
249 );
250
251 let mut reqwest_headers = HeaderMap::new();
253 for (name, value) in headers.headers.iter() {
254 let name_str = name.as_str().to_lowercase();
256 if matches!(
257 name_str.as_str(),
258 "connection"
259 | "keep-alive"
260 | "proxy-authenticate"
261 | "proxy-authorization"
262 | "te"
263 | "trailers"
264 | "transfer-encoding"
265 | "upgrade"
266 ) {
267 continue;
268 }
269
270 if let (Ok(header_name), Ok(header_value)) = (
271 HeaderName::from_str(name.as_str()),
272 HeaderValue::from_bytes(value.as_bytes()),
273 ) {
274 reqwest_headers.insert(header_name, header_value);
275 }
276 }
277
278 reqwest_headers.insert("x-shadow-request", HeaderValue::from_static("true"));
280
281 if let Ok(host_value) = HeaderValue::from_str(&target.host) {
283 reqwest_headers.insert("host", host_value);
284 }
285
286 let method =
288 reqwest::Method::from_bytes(ctx.method.as_bytes()).unwrap_or(reqwest::Method::GET);
289
290 let mut request_builder = client.request(method, &url).headers(reqwest_headers);
291
292 if let Some(body_bytes) = body {
294 request_builder = request_builder.body(body_bytes);
295 }
296
297 let response = request_builder.send().await.map_err(|e| {
299 GrapsusError::upstream(
300 upstream_pool.id().to_string(),
301 format!("Shadow request failed: {}", e),
302 )
303 })?;
304
305 let status = response.status();
306 trace!(
307 url = %url,
308 status = %status,
309 "Shadow request completed"
310 );
311
312 drop(response);
315
316 Ok(())
317 }
318}
319
320pub fn should_buffer_method(method: &str) -> bool {
322 matches!(method.to_uppercase().as_str(), "POST" | "PUT" | "PATCH")
323}
324
325pub async fn buffer_request_body(
336 session: &mut Session,
337 max_bytes: usize,
338) -> GrapsusResult<Vec<u8>> {
339 if max_bytes == 0 {
340 return Err(GrapsusError::LimitExceeded {
341 limit_type: grapsus_common::errors::LimitType::BodySize,
342 message: "max_body_bytes must be > 0".to_string(),
343 current_value: 0,
344 limit: 0,
345 });
346 }
347
348 let mut buffer = Vec::with_capacity(max_bytes.min(65536)); let mut total_read = 0;
350
351 loop {
352 let chunk = session
354 .read_request_body()
355 .await
356 .map_err(|e| GrapsusError::Internal {
357 message: format!("Failed to read request body for shadow: {}", e),
358 correlation_id: None,
359 source: None,
360 })?;
361
362 match chunk {
363 Some(data) => {
364 let chunk_len: usize = data.len();
365
366 if total_read + chunk_len > max_bytes {
368 return Err(GrapsusError::LimitExceeded {
369 limit_type: grapsus_common::errors::LimitType::BodySize,
370 message: format!(
371 "Request body exceeds maximum shadow buffer size of {} bytes",
372 max_bytes
373 ),
374 current_value: total_read + chunk_len,
375 limit: max_bytes,
376 });
377 }
378
379 buffer.extend_from_slice(&data);
380 total_read += chunk_len;
381
382 trace!(
383 chunk_size = chunk_len,
384 total_buffered = total_read,
385 max_bytes = max_bytes,
386 "Buffered request body chunk for shadow"
387 );
388 }
389 None => {
390 break;
392 }
393 }
394 }
395
396 debug!(
397 total_bytes = total_read,
398 "Finished buffering request body for shadow"
399 );
400
401 Ok(buffer)
402}
403
404pub fn clone_body_for_shadow(body: &Option<Bytes>) -> Option<Vec<u8>> {
409 body.as_ref().map(|b| b.to_vec())
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use pingora::http::RequestHeader as PingoraRequestHeader;
416
417 #[test]
418 fn test_should_buffer_method() {
419 assert!(should_buffer_method("POST"));
420 assert!(should_buffer_method("PUT"));
421 assert!(should_buffer_method("PATCH"));
422 assert!(should_buffer_method("post")); assert!(!should_buffer_method("GET"));
424 assert!(!should_buffer_method("HEAD"));
425 assert!(!should_buffer_method("DELETE"));
426 }
427
428 #[test]
429 fn test_shadow_sampling_percentage() {
430 let pools = Arc::new(HashMap::new());
431 let config = ShadowConfig {
432 upstream: "shadow".to_string(),
433 percentage: 0.0, sample_header: None,
435 timeout_ms: 5000,
436 buffer_body: false,
437 max_body_bytes: 1048576,
438 };
439
440 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
441 let headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
442
443 for _ in 0..100 {
445 assert!(!manager.should_shadow(&headers));
446 }
447 }
448
449 #[test]
450 fn test_shadow_sampling_always() {
451 let pools = Arc::new(HashMap::new());
452 let config = ShadowConfig {
453 upstream: "shadow".to_string(),
454 percentage: 100.0, sample_header: None,
456 timeout_ms: 5000,
457 buffer_body: false,
458 max_body_bytes: 1048576,
459 };
460
461 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
462 let headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
463
464 for _ in 0..100 {
466 assert!(manager.should_shadow(&headers));
467 }
468 }
469
470 #[test]
471 fn test_shadow_sample_header_match() {
472 let pools = Arc::new(HashMap::new());
473 let config = ShadowConfig {
474 upstream: "shadow".to_string(),
475 percentage: 100.0,
476 sample_header: Some(("X-Shadow".to_string(), "true".to_string())),
477 timeout_ms: 5000,
478 buffer_body: false,
479 max_body_bytes: 1048576,
480 };
481
482 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
483
484 let mut headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
486 headers.insert_header("X-Shadow", "true").unwrap();
487 assert!(manager.should_shadow(&headers));
488
489 let headers_no_match = PingoraRequestHeader::build("GET", b"/", None).unwrap();
491 assert!(!manager.should_shadow(&headers_no_match));
492
493 let mut headers_wrong = PingoraRequestHeader::build("GET", b"/", None).unwrap();
495 headers_wrong.insert_header("X-Shadow", "false").unwrap();
496 assert!(!manager.should_shadow(&headers_wrong));
497 }
498
499 #[test]
500 fn test_clone_body_for_shadow() {
501 let body = Some(Bytes::from("test body content"));
503 let cloned = clone_body_for_shadow(&body);
504 assert!(cloned.is_some());
505 assert_eq!(cloned.unwrap(), b"test body content");
506
507 let no_body: Option<Bytes> = None;
509 let cloned_none = clone_body_for_shadow(&no_body);
510 assert!(cloned_none.is_none());
511 }
512}