1use bytes::Bytes;
15use pingora::http::RequestHeader;
16use pingora::proxy::Session;
17use rand::Rng;
18use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
19use sentinel_common::errors::{SentinelError, SentinelResult};
20use sentinel_common::observability::RequestMetrics;
21use sentinel_config::routes::ShadowConfig;
22use std::collections::HashMap;
23use std::str::FromStr;
24use std::sync::Arc;
25use std::time::Instant;
26use tokio::time::Duration;
27use tracing::{debug, error, trace, warn};
28
29use crate::{RequestContext, UpstreamPool};
30
31#[derive(Clone)]
36pub struct ShadowManager {
37 upstream_pools: Arc<HashMap<String, Arc<UpstreamPool>>>,
39
40 config: ShadowConfig,
42
43 metrics: Option<Arc<RequestMetrics>>,
45
46 route_id: String,
48
49 client: reqwest::Client,
51}
52
53impl ShadowManager {
54 pub fn new(
56 upstream_pools: Arc<HashMap<String, Arc<UpstreamPool>>>,
57 config: ShadowConfig,
58 metrics: Option<Arc<RequestMetrics>>,
59 route_id: String,
60 ) -> Self {
61 let client = reqwest::Client::builder()
63 .timeout(Duration::from_millis(config.timeout_ms))
64 .pool_max_idle_per_host(10)
65 .pool_idle_timeout(Duration::from_secs(30))
66 .danger_accept_invalid_certs(true)
68 .build()
69 .unwrap_or_else(|_| reqwest::Client::new());
70
71 Self {
72 upstream_pools,
73 config,
74 metrics,
75 route_id,
76 client,
77 }
78 }
79
80 pub fn should_shadow(&self, headers: &RequestHeader) -> bool {
82 if let Some((header_name, header_value)) = &self.config.sample_header {
84 if let Some(actual_value) = headers.headers.get(header_name) {
85 if actual_value.to_str().ok() != Some(header_value.as_str()) {
86 trace!("Shadow skipped: sample-header mismatch");
87 return false;
88 }
89 } else {
90 trace!("Shadow skipped: sample-header not present");
91 return false;
92 }
93 }
94
95 if self.config.percentage < 100.0 {
97 let mut rng = rand::thread_rng();
98 let roll: f64 = rng.gen_range(0.0..100.0);
99 if roll > self.config.percentage {
100 trace!(
101 roll = roll,
102 threshold = self.config.percentage,
103 "Shadow skipped: sampling"
104 );
105 return false;
106 }
107 }
108
109 true
110 }
111
112 pub fn shadow_request(
124 &self,
125 original_headers: RequestHeader,
126 body: Option<Vec<u8>>,
127 ctx: RequestContext,
128 ) {
129 if !self.upstream_pools.contains_key(&self.config.upstream) {
131 warn!(
132 upstream = %self.config.upstream,
133 "Shadow upstream not found in pools"
134 );
135 if let Some(ref metrics) = self.metrics {
137 metrics.record_shadow_error(&self.route_id, &self.config.upstream, "upstream_not_found");
138 }
139 return;
140 }
141
142 let config = self.config.clone();
143 let upstream_id = self.config.upstream.clone();
144 let upstream_pools = Arc::clone(&self.upstream_pools);
145 let metrics = self.metrics.clone();
146 let route_id = self.route_id.clone();
147 let client = self.client.clone();
148
149 tokio::spawn(async move {
151 let start = Instant::now();
152
153 let upstream_pool = match upstream_pools.get(&upstream_id) {
155 Some(pool) => pool,
156 None => {
157 warn!(upstream = %upstream_id, "Shadow upstream disappeared");
159 return;
160 }
161 };
162
163 let result = tokio::time::timeout(
165 Duration::from_millis(config.timeout_ms),
166 Self::execute_shadow_request(
167 &client,
168 upstream_pool,
169 original_headers,
170 body,
171 ctx.clone(),
172 ),
173 )
174 .await;
175
176 let latency = start.elapsed();
177
178 match result {
179 Ok(Ok(())) => {
180 debug!(
181 upstream = %upstream_id,
182 latency_ms = latency.as_millis(),
183 path = %ctx.path,
184 method = %ctx.method,
185 "Shadow request completed successfully"
186 );
187 if let Some(ref metrics) = metrics {
189 metrics.record_shadow_success(&route_id, &upstream_id, latency);
190 }
191 }
192 Ok(Err(e)) => {
193 error!(
194 upstream = %upstream_id,
195 error = %e,
196 latency_ms = latency.as_millis(),
197 path = %ctx.path,
198 method = %ctx.method,
199 "Shadow request failed"
200 );
201 if let Some(ref metrics) = metrics {
203 metrics.record_shadow_error(&route_id, &upstream_id, "request_failed");
204 }
205 }
206 Err(_) => {
207 warn!(
208 upstream = %upstream_id,
209 timeout_ms = config.timeout_ms,
210 path = %ctx.path,
211 method = %ctx.method,
212 "Shadow request timed out"
213 );
214 if let Some(ref metrics) = metrics {
216 metrics.record_shadow_timeout(&route_id, &upstream_id, latency);
217 }
218 }
219 }
220 });
221 }
222
223 async fn execute_shadow_request(
228 client: &reqwest::Client,
229 upstream_pool: &UpstreamPool,
230 headers: RequestHeader,
231 body: Option<Vec<u8>>,
232 ctx: RequestContext,
233 ) -> SentinelResult<()> {
234 let target = upstream_pool.select_shadow_target(Some(&ctx)).await?;
236
237 let url = target.build_url(&ctx.path);
239
240 trace!(
241 url = %url,
242 method = %ctx.method,
243 body_size = body.as_ref().map(|b| b.len()).unwrap_or(0),
244 "Executing shadow request"
245 );
246
247 let mut reqwest_headers = HeaderMap::new();
249 for (name, value) in headers.headers.iter() {
250 let name_str = name.as_str().to_lowercase();
252 if matches!(
253 name_str.as_str(),
254 "connection"
255 | "keep-alive"
256 | "proxy-authenticate"
257 | "proxy-authorization"
258 | "te"
259 | "trailers"
260 | "transfer-encoding"
261 | "upgrade"
262 ) {
263 continue;
264 }
265
266 if let (Ok(header_name), Ok(header_value)) = (
267 HeaderName::from_str(name.as_str()),
268 HeaderValue::from_bytes(value.as_bytes()),
269 ) {
270 reqwest_headers.insert(header_name, header_value);
271 }
272 }
273
274 reqwest_headers.insert("x-shadow-request", HeaderValue::from_static("true"));
276
277 if let Ok(host_value) = HeaderValue::from_str(&target.host) {
279 reqwest_headers.insert("host", host_value);
280 }
281
282 let method = reqwest::Method::from_bytes(ctx.method.as_bytes())
284 .unwrap_or(reqwest::Method::GET);
285
286 let mut request_builder = client.request(method, &url).headers(reqwest_headers);
287
288 if let Some(body_bytes) = body {
290 request_builder = request_builder.body(body_bytes);
291 }
292
293 let response = request_builder.send().await.map_err(|e| {
295 SentinelError::upstream(
296 upstream_pool.id().to_string(),
297 format!("Shadow request failed: {}", e),
298 )
299 })?;
300
301 let status = response.status();
302 trace!(
303 url = %url,
304 status = %status,
305 "Shadow request completed"
306 );
307
308 drop(response);
311
312 Ok(())
313 }
314}
315
316pub fn should_buffer_method(method: &str) -> bool {
318 matches!(method.to_uppercase().as_str(), "POST" | "PUT" | "PATCH")
319}
320
321pub async fn buffer_request_body(
332 session: &mut Session,
333 max_bytes: usize,
334) -> SentinelResult<Vec<u8>> {
335 if max_bytes == 0 {
336 return Err(SentinelError::LimitExceeded {
337 limit_type: sentinel_common::errors::LimitType::BodySize,
338 message: "max_body_bytes must be > 0".to_string(),
339 current_value: 0,
340 limit: 0,
341 });
342 }
343
344 let mut buffer = Vec::with_capacity(max_bytes.min(65536)); let mut total_read = 0;
346
347 loop {
348 let chunk = session.read_request_body().await.map_err(|e| {
350 SentinelError::Internal {
351 message: format!("Failed to read request body for shadow: {}", e),
352 correlation_id: None,
353 source: None,
354 }
355 })?;
356
357 match chunk {
358 Some(data) => {
359 let chunk_len: usize = data.len();
360
361 if total_read + chunk_len > max_bytes {
363 return Err(SentinelError::LimitExceeded {
364 limit_type: sentinel_common::errors::LimitType::BodySize,
365 message: format!(
366 "Request body exceeds maximum shadow buffer size of {} bytes",
367 max_bytes
368 ),
369 current_value: total_read + chunk_len,
370 limit: max_bytes,
371 });
372 }
373
374 buffer.extend_from_slice(&data);
375 total_read += chunk_len;
376
377 trace!(
378 chunk_size = chunk_len,
379 total_buffered = total_read,
380 max_bytes = max_bytes,
381 "Buffered request body chunk for shadow"
382 );
383 }
384 None => {
385 break;
387 }
388 }
389 }
390
391 debug!(
392 total_bytes = total_read,
393 "Finished buffering request body for shadow"
394 );
395
396 Ok(buffer)
397}
398
399pub fn clone_body_for_shadow(body: &Option<Bytes>) -> Option<Vec<u8>> {
404 body.as_ref().map(|b| b.to_vec())
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use pingora::http::RequestHeader as PingoraRequestHeader;
411
412 #[test]
413 fn test_should_buffer_method() {
414 assert!(should_buffer_method("POST"));
415 assert!(should_buffer_method("PUT"));
416 assert!(should_buffer_method("PATCH"));
417 assert!(should_buffer_method("post")); assert!(!should_buffer_method("GET"));
419 assert!(!should_buffer_method("HEAD"));
420 assert!(!should_buffer_method("DELETE"));
421 }
422
423 #[test]
424 fn test_shadow_sampling_percentage() {
425 let pools = Arc::new(HashMap::new());
426 let config = ShadowConfig {
427 upstream: "shadow".to_string(),
428 percentage: 0.0, sample_header: None,
430 timeout_ms: 5000,
431 buffer_body: false,
432 max_body_bytes: 1048576,
433 };
434
435 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
436 let headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
437
438 for _ in 0..100 {
440 assert!(!manager.should_shadow(&headers));
441 }
442 }
443
444 #[test]
445 fn test_shadow_sampling_always() {
446 let pools = Arc::new(HashMap::new());
447 let config = ShadowConfig {
448 upstream: "shadow".to_string(),
449 percentage: 100.0, sample_header: None,
451 timeout_ms: 5000,
452 buffer_body: false,
453 max_body_bytes: 1048576,
454 };
455
456 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
457 let headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
458
459 for _ in 0..100 {
461 assert!(manager.should_shadow(&headers));
462 }
463 }
464
465 #[test]
466 fn test_shadow_sample_header_match() {
467 let pools = Arc::new(HashMap::new());
468 let config = ShadowConfig {
469 upstream: "shadow".to_string(),
470 percentage: 100.0,
471 sample_header: Some(("X-Shadow".to_string(), "true".to_string())),
472 timeout_ms: 5000,
473 buffer_body: false,
474 max_body_bytes: 1048576,
475 };
476
477 let manager = ShadowManager::new(pools, config, None, "test-route".to_string());
478
479 let mut headers = PingoraRequestHeader::build("GET", b"/", None).unwrap();
481 headers.insert_header("X-Shadow", "true").unwrap();
482 assert!(manager.should_shadow(&headers));
483
484 let headers_no_match = PingoraRequestHeader::build("GET", b"/", None).unwrap();
486 assert!(!manager.should_shadow(&headers_no_match));
487
488 let mut headers_wrong = PingoraRequestHeader::build("GET", b"/", None).unwrap();
490 headers_wrong.insert_header("X-Shadow", "false").unwrap();
491 assert!(!manager.should_shadow(&headers_wrong));
492 }
493
494 #[test]
495 fn test_clone_body_for_shadow() {
496 let body = Some(Bytes::from("test body content"));
498 let cloned = clone_body_for_shadow(&body);
499 assert!(cloned.is_some());
500 assert_eq!(cloned.unwrap(), b"test body content");
501
502 let no_body: Option<Bytes> = None;
504 let cloned_none = clone_body_for_shadow(&no_body);
505 assert!(cloned_none.is_none());
506 }
507}