Skip to main content

grapsus_proxy/proxy/
context.rs

1//! Request context for the proxy request lifecycle.
2//!
3//! The `RequestContext` struct maintains state throughout a single request,
4//! including timing, routing decisions, and metadata for logging.
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use grapsus_config::{BodyStreamingMode, Config, RouteConfig, ServiceType};
10
11use crate::inference::StreamingTokenCounter;
12use crate::websocket::WebSocketHandler;
13
14/// Reason why fallback routing was triggered
15#[derive(Debug, Clone)]
16pub enum FallbackReason {
17    /// Primary upstream health check failed
18    HealthCheckFailed,
19    /// Token budget exhausted for the request
20    BudgetExhausted,
21    /// Response latency exceeded threshold
22    LatencyThreshold { observed_ms: u64, threshold_ms: u64 },
23    /// Upstream returned an error code that triggers fallback
24    ErrorCode(u16),
25    /// Connection to upstream failed
26    ConnectionError(String),
27}
28
29impl std::fmt::Display for FallbackReason {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            FallbackReason::HealthCheckFailed => write!(f, "health_check_failed"),
33            FallbackReason::BudgetExhausted => write!(f, "budget_exhausted"),
34            FallbackReason::LatencyThreshold {
35                observed_ms,
36                threshold_ms,
37            } => write!(
38                f,
39                "latency_threshold_{}ms_exceeded_{}ms",
40                observed_ms, threshold_ms
41            ),
42            FallbackReason::ErrorCode(code) => write!(f, "error_code_{}", code),
43            FallbackReason::ConnectionError(msg) => write!(f, "connection_error_{}", msg),
44        }
45    }
46}
47
48/// Cache status for the Cache-Status response header (RFC 9211)
49#[derive(Debug, Clone)]
50pub(crate) enum CacheStatus {
51    /// Cache hit from memory tier (hybrid cache)
52    HitMemory,
53    /// Cache hit from disk tier, promoted to memory (hybrid cache)
54    HitDisk,
55    /// Cache hit (non-hybrid / tier unknown)
56    Hit,
57    /// Cache hit but response was stale (revalidation needed)
58    HitStale,
59    /// Cache miss (response fetched from upstream)
60    Miss,
61    /// Cache bypassed (not eligible for caching)
62    Bypass(&'static str),
63}
64
65/// Rate limit header information for response headers
66#[derive(Debug, Clone)]
67pub struct RateLimitHeaderInfo {
68    /// Maximum requests allowed per window
69    pub limit: u32,
70    /// Remaining requests in current window
71    pub remaining: u32,
72    /// Unix timestamp (seconds) when the window resets
73    pub reset_at: u64,
74}
75
76/// Request context maintained throughout the request lifecycle.
77///
78/// This struct uses a hybrid approach:
79/// - Immutable fields (start_time) are private with getters
80/// - Mutable fields are public(crate) for efficient access within the proxy module
81pub struct RequestContext {
82    /// Request start time (immutable after creation)
83    start_time: Instant,
84
85    // === Tracing ===
86    /// Unique trace ID for request tracing (also used as correlation_id)
87    pub(crate) trace_id: String,
88
89    // === Global config (cached once per request) ===
90    /// Cached global configuration snapshot for this request
91    pub(crate) config: Option<Arc<Config>>,
92
93    // === Routing ===
94    /// Selected route ID
95    pub(crate) route_id: Option<String>,
96    /// Cached route configuration (avoids duplicate route matching)
97    pub(crate) route_config: Option<Arc<RouteConfig>>,
98    /// Selected upstream pool ID
99    pub(crate) upstream: Option<String>,
100    /// Selected upstream peer address (IP:port) for feedback reporting
101    pub(crate) selected_upstream_address: Option<String>,
102    /// Number of upstream attempts
103    pub(crate) upstream_attempts: u32,
104
105    // === Scope (for namespaced configurations) ===
106    /// Namespace for this request (if routed to a namespace scope)
107    pub(crate) namespace: Option<String>,
108    /// Service for this request (if routed to a service scope)
109    pub(crate) service: Option<String>,
110
111    // === Request metadata (cached for logging) ===
112    /// HTTP method
113    pub(crate) method: String,
114    /// Request path
115    pub(crate) path: String,
116    /// Query string
117    pub(crate) query: Option<String>,
118
119    // === Client info ===
120    /// Client IP address
121    pub(crate) client_ip: String,
122    /// User-Agent header
123    pub(crate) user_agent: Option<String>,
124    /// Referer header
125    pub(crate) referer: Option<String>,
126    /// Host header
127    pub(crate) host: Option<String>,
128
129    // === Body tracking ===
130    /// Request body bytes received
131    pub(crate) request_body_bytes: u64,
132    /// Response body bytes (set during response)
133    pub(crate) response_bytes: u64,
134
135    // === Connection tracking ===
136    /// Whether the upstream connection was reused
137    pub(crate) connection_reused: bool,
138    /// Whether this request is a WebSocket upgrade
139    pub(crate) is_websocket_upgrade: bool,
140
141    // === WebSocket Inspection ===
142    /// Whether WebSocket frame inspection is enabled for this connection
143    pub(crate) websocket_inspection_enabled: bool,
144    /// Whether to skip inspection (e.g., due to compression negotiation)
145    pub(crate) websocket_skip_inspection: bool,
146    /// Agent IDs for WebSocket frame inspection
147    pub(crate) websocket_inspection_agents: Vec<String>,
148    /// WebSocket frame handler (created after 101 upgrade)
149    pub(crate) websocket_handler: Option<Arc<WebSocketHandler>>,
150
151    // === Caching ===
152    /// Whether this request is eligible for caching
153    pub(crate) cache_eligible: bool,
154    /// Cache status for Cache-Status response header (RFC 9211)
155    pub(crate) cache_status: Option<CacheStatus>,
156
157    // === Body Inspection ===
158    /// Whether body inspection is enabled for this request
159    pub(crate) body_inspection_enabled: bool,
160    /// Bytes already sent to agent for inspection
161    pub(crate) body_bytes_inspected: u64,
162    /// Accumulated body buffer for agent inspection
163    pub(crate) body_buffer: Vec<u8>,
164    /// Agent IDs to use for body inspection
165    pub(crate) body_inspection_agents: Vec<String>,
166
167    // === Body Decompression ===
168    /// Whether decompression is enabled for body inspection
169    pub(crate) decompression_enabled: bool,
170    /// Content-Encoding of the request body (if compressed)
171    pub(crate) body_content_encoding: Option<String>,
172    /// Maximum decompression ratio allowed
173    pub(crate) max_decompression_ratio: f64,
174    /// Maximum decompressed size allowed
175    pub(crate) max_decompression_bytes: usize,
176    /// Whether decompression was performed
177    pub(crate) body_was_decompressed: bool,
178
179    // === Rate Limiting ===
180    /// Rate limit info for response headers (set during request_filter)
181    pub(crate) rate_limit_info: Option<RateLimitHeaderInfo>,
182
183    // === GeoIP Filtering ===
184    /// Country code from GeoIP lookup (ISO 3166-1 alpha-2)
185    pub(crate) geo_country_code: Option<String>,
186    /// Whether a geo lookup was performed for this request
187    pub(crate) geo_lookup_performed: bool,
188
189    // === Body Streaming ===
190    /// Body streaming mode for request body inspection
191    pub(crate) request_body_streaming_mode: BodyStreamingMode,
192    /// Current chunk index for request body streaming
193    pub(crate) request_body_chunk_index: u32,
194    /// Whether agent needs more data (streaming mode)
195    pub(crate) agent_needs_more: bool,
196    /// Body streaming mode for response body inspection
197    pub(crate) response_body_streaming_mode: BodyStreamingMode,
198    /// Current chunk index for response body streaming
199    pub(crate) response_body_chunk_index: u32,
200    /// Response body bytes inspected
201    pub(crate) response_body_bytes_inspected: u64,
202    /// Response body inspection enabled
203    pub(crate) response_body_inspection_enabled: bool,
204    /// Agent IDs for response body inspection
205    pub(crate) response_body_inspection_agents: Vec<String>,
206
207    // === OpenTelemetry Tracing ===
208    /// OpenTelemetry request span (if tracing enabled)
209    pub(crate) otel_span: Option<crate::otel::RequestSpan>,
210    /// W3C trace context parsed from incoming request
211    pub(crate) trace_context: Option<crate::otel::TraceContext>,
212
213    // === Inference Rate Limiting ===
214    /// Whether inference rate limiting is enabled for this route
215    pub(crate) inference_rate_limit_enabled: bool,
216    /// Estimated tokens for this request (used for rate limiting)
217    pub(crate) inference_estimated_tokens: u64,
218    /// Rate limit key used (client IP, API key, etc.)
219    pub(crate) inference_rate_limit_key: Option<String>,
220    /// Model name detected from request
221    pub(crate) inference_model: Option<String>,
222    /// Provider override from model-based routing (for cross-provider routing)
223    pub(crate) inference_provider_override: Option<grapsus_config::InferenceProvider>,
224    /// Whether model-based routing was used to select the upstream
225    pub(crate) model_routing_used: bool,
226    /// Actual tokens from response (filled in after response)
227    pub(crate) inference_actual_tokens: Option<u64>,
228
229    // === Token Budget Tracking ===
230    /// Whether budget tracking is enabled for this route
231    pub(crate) inference_budget_enabled: bool,
232    /// Budget remaining after this request (set after response)
233    pub(crate) inference_budget_remaining: Option<i64>,
234    /// Period reset timestamp (Unix seconds)
235    pub(crate) inference_budget_period_reset: Option<u64>,
236    /// Whether budget was exhausted (429 sent)
237    pub(crate) inference_budget_exhausted: bool,
238
239    // === Cost Attribution ===
240    /// Whether cost attribution is enabled for this route
241    pub(crate) inference_cost_enabled: bool,
242    /// Calculated cost for this request (set after response)
243    pub(crate) inference_request_cost: Option<f64>,
244    /// Input tokens for cost calculation
245    pub(crate) inference_input_tokens: u64,
246    /// Output tokens for cost calculation
247    pub(crate) inference_output_tokens: u64,
248
249    // === Streaming Token Counting ===
250    /// Whether this is a streaming (SSE) response
251    pub(crate) inference_streaming_response: bool,
252    /// Streaming token counter for SSE responses
253    pub(crate) inference_streaming_counter: Option<StreamingTokenCounter>,
254
255    // === Fallback Routing ===
256    /// Current fallback attempt number (0 = primary, 1+ = fallback)
257    pub(crate) fallback_attempt: u32,
258    /// List of upstream IDs that have been tried
259    pub(crate) tried_upstreams: Vec<String>,
260    /// Reason for triggering fallback (if fallback was used)
261    pub(crate) fallback_reason: Option<FallbackReason>,
262    /// Original upstream ID before fallback (primary)
263    pub(crate) original_upstream: Option<String>,
264    /// Model mapping applied: (original_model, mapped_model)
265    pub(crate) model_mapping_applied: Option<(String, String)>,
266    /// Whether fallback should be retried after response
267    pub(crate) should_retry_with_fallback: bool,
268
269    // === Semantic Guardrails ===
270    /// Whether guardrails are enabled for this route
271    pub(crate) guardrails_enabled: bool,
272    /// Prompt injection detected but allowed (add warning header)
273    pub(crate) guardrail_warning: bool,
274    /// Categories of prompt injection detected (for logging)
275    pub(crate) guardrail_detection_categories: Vec<String>,
276    /// PII categories detected in response (for logging)
277    pub(crate) pii_detection_categories: Vec<String>,
278
279    // === Shadow Traffic ===
280    /// Pending shadow request info (stored for deferred execution after body buffering)
281    pub(crate) shadow_pending: Option<ShadowPendingRequest>,
282    /// Whether shadow request was sent for this request
283    pub(crate) shadow_sent: bool,
284
285    // === Sticky Sessions ===
286    /// Whether a new sticky session assignment was made (needs Set-Cookie header)
287    pub(crate) sticky_session_new_assignment: bool,
288    /// Set-Cookie header value to include in response (full header value)
289    pub(crate) sticky_session_set_cookie: Option<String>,
290    /// Target index for sticky session (for logging)
291    pub(crate) sticky_target_index: Option<usize>,
292
293    // === Listener Overrides ===
294    /// Keepalive timeout from listener config (seconds, for response phase)
295    pub(crate) listener_keepalive_timeout_secs: Option<u64>,
296
297    // === Filter Overrides ===
298    /// Upstream connect timeout override from Timeout filter (seconds)
299    pub(crate) filter_connect_timeout_secs: Option<u64>,
300    /// Upstream read timeout override from Timeout filter (seconds)
301    pub(crate) filter_upstream_timeout_secs: Option<u64>,
302    /// CORS origin matched by a CORS filter (for response headers)
303    pub(crate) cors_origin: Option<String>,
304    /// Whether response compression is enabled by a Compress filter
305    pub(crate) compress_enabled: bool,
306
307    // === Response-Phase Agent Processing ===
308    /// Agent IDs resolved from route filters (saved in request phase for response phase)
309    pub(crate) route_agent_ids: Vec<String>,
310    /// Whether response-phase agent processing is enabled (agent subscribes to response events)
311    pub(crate) response_agent_processing_enabled: bool,
312    /// Accumulated response body buffer for agent processing (when agent needs full body)
313    pub(crate) response_agent_body_buffer: Vec<u8>,
314    /// Whether response body has been fully received by agent
315    pub(crate) response_agent_body_complete: bool,
316}
317
318/// Pending shadow request information stored in context for deferred execution
319#[derive(Clone)]
320pub struct ShadowPendingRequest {
321    /// Cloned request headers for shadow
322    pub headers: pingora::http::RequestHeader,
323    /// Shadow manager (wrapped in Arc for Clone)
324    pub manager: std::sync::Arc<crate::shadow::ShadowManager>,
325    /// Request context for shadow (client IP, path, method, etc.)
326    pub request_ctx: crate::upstream::RequestContext,
327    /// Whether body should be included
328    pub include_body: bool,
329}
330
331impl RequestContext {
332    /// Create a new empty request context with the current timestamp.
333    pub fn new() -> Self {
334        Self {
335            start_time: Instant::now(),
336            trace_id: String::new(),
337            config: None,
338            route_id: None,
339            route_config: None,
340            upstream: None,
341            selected_upstream_address: None,
342            upstream_attempts: 0,
343            namespace: None,
344            service: None,
345            method: String::new(),
346            path: String::new(),
347            query: None,
348            client_ip: String::new(),
349            user_agent: None,
350            referer: None,
351            host: None,
352            request_body_bytes: 0,
353            response_bytes: 0,
354            connection_reused: false,
355            is_websocket_upgrade: false,
356            websocket_inspection_enabled: false,
357            websocket_skip_inspection: false,
358            websocket_inspection_agents: Vec::new(),
359            websocket_handler: None,
360            cache_eligible: false,
361            cache_status: None,
362            body_inspection_enabled: false,
363            body_bytes_inspected: 0,
364            body_buffer: Vec::new(),
365            body_inspection_agents: Vec::new(),
366            decompression_enabled: false,
367            body_content_encoding: None,
368            max_decompression_ratio: 100.0,
369            max_decompression_bytes: 10 * 1024 * 1024, // 10MB
370            body_was_decompressed: false,
371            rate_limit_info: None,
372            geo_country_code: None,
373            geo_lookup_performed: false,
374            request_body_streaming_mode: BodyStreamingMode::Buffer,
375            request_body_chunk_index: 0,
376            agent_needs_more: false,
377            response_body_streaming_mode: BodyStreamingMode::Buffer,
378            response_body_chunk_index: 0,
379            response_body_bytes_inspected: 0,
380            response_body_inspection_enabled: false,
381            response_body_inspection_agents: Vec::new(),
382            otel_span: None,
383            trace_context: None,
384            inference_rate_limit_enabled: false,
385            inference_estimated_tokens: 0,
386            inference_rate_limit_key: None,
387            inference_model: None,
388            inference_provider_override: None,
389            model_routing_used: false,
390            inference_actual_tokens: None,
391            inference_budget_enabled: false,
392            inference_budget_remaining: None,
393            inference_budget_period_reset: None,
394            inference_budget_exhausted: false,
395            inference_cost_enabled: false,
396            inference_request_cost: None,
397            inference_input_tokens: 0,
398            inference_output_tokens: 0,
399            inference_streaming_response: false,
400            inference_streaming_counter: None,
401            fallback_attempt: 0,
402            tried_upstreams: Vec::new(),
403            fallback_reason: None,
404            original_upstream: None,
405            model_mapping_applied: None,
406            should_retry_with_fallback: false,
407            guardrails_enabled: false,
408            guardrail_warning: false,
409            guardrail_detection_categories: Vec::new(),
410            pii_detection_categories: Vec::new(),
411            shadow_pending: None,
412            shadow_sent: false,
413            sticky_session_new_assignment: false,
414            sticky_session_set_cookie: None,
415            sticky_target_index: None,
416            listener_keepalive_timeout_secs: None,
417            filter_connect_timeout_secs: None,
418            filter_upstream_timeout_secs: None,
419            cors_origin: None,
420            compress_enabled: false,
421            route_agent_ids: Vec::new(),
422            response_agent_processing_enabled: false,
423            response_agent_body_buffer: Vec::new(),
424            response_agent_body_complete: false,
425        }
426    }
427
428    // === Immutable field accessors ===
429
430    /// Get the request start time.
431    #[inline]
432    pub fn start_time(&self) -> Instant {
433        self.start_time
434    }
435
436    /// Get elapsed duration since request start.
437    #[inline]
438    pub fn elapsed(&self) -> std::time::Duration {
439        self.start_time.elapsed()
440    }
441
442    // === Read-only accessors ===
443
444    /// Get trace_id (alias for backwards compatibility with correlation_id usage).
445    #[inline]
446    pub fn correlation_id(&self) -> &str {
447        &self.trace_id
448    }
449
450    /// Get the trace ID.
451    #[inline]
452    pub fn trace_id(&self) -> &str {
453        &self.trace_id
454    }
455
456    /// Get the route ID, if set.
457    #[inline]
458    pub fn route_id(&self) -> Option<&str> {
459        self.route_id.as_deref()
460    }
461
462    /// Get the upstream ID, if set.
463    #[inline]
464    pub fn upstream(&self) -> Option<&str> {
465        self.upstream.as_deref()
466    }
467
468    /// Get the selected upstream peer address (IP:port), if set.
469    #[inline]
470    pub fn selected_upstream_address(&self) -> Option<&str> {
471        self.selected_upstream_address.as_deref()
472    }
473
474    /// Get the cached route configuration, if set.
475    #[inline]
476    pub fn route_config(&self) -> Option<&Arc<RouteConfig>> {
477        self.route_config.as_ref()
478    }
479
480    /// Get the cached global configuration, if set.
481    #[inline]
482    pub fn global_config(&self) -> Option<&Arc<Config>> {
483        self.config.as_ref()
484    }
485
486    /// Get the service type from cached route config.
487    #[inline]
488    pub fn service_type(&self) -> Option<ServiceType> {
489        self.route_config.as_ref().map(|c| c.service_type.clone())
490    }
491
492    /// Get the number of upstream attempts.
493    #[inline]
494    pub fn upstream_attempts(&self) -> u32 {
495        self.upstream_attempts
496    }
497
498    /// Get the HTTP method.
499    #[inline]
500    pub fn method(&self) -> &str {
501        &self.method
502    }
503
504    /// Get the request path.
505    #[inline]
506    pub fn path(&self) -> &str {
507        &self.path
508    }
509
510    /// Get the query string, if present.
511    #[inline]
512    pub fn query(&self) -> Option<&str> {
513        self.query.as_deref()
514    }
515
516    /// Get the client IP address.
517    #[inline]
518    pub fn client_ip(&self) -> &str {
519        &self.client_ip
520    }
521
522    /// Get the User-Agent header, if present.
523    #[inline]
524    pub fn user_agent(&self) -> Option<&str> {
525        self.user_agent.as_deref()
526    }
527
528    /// Get the Referer header, if present.
529    #[inline]
530    pub fn referer(&self) -> Option<&str> {
531        self.referer.as_deref()
532    }
533
534    /// Get the Host header, if present.
535    #[inline]
536    pub fn host(&self) -> Option<&str> {
537        self.host.as_deref()
538    }
539
540    /// Get the response body size in bytes.
541    #[inline]
542    pub fn response_bytes(&self) -> u64 {
543        self.response_bytes
544    }
545
546    /// Get the GeoIP country code, if determined.
547    #[inline]
548    pub fn geo_country_code(&self) -> Option<&str> {
549        self.geo_country_code.as_deref()
550    }
551
552    /// Check if a geo lookup was performed for this request.
553    #[inline]
554    pub fn geo_lookup_performed(&self) -> bool {
555        self.geo_lookup_performed
556    }
557
558    /// Get traceparent header value for distributed tracing.
559    ///
560    /// Returns the W3C Trace Context traceparent header value if tracing is enabled.
561    /// Format: `{version}-{trace-id}-{span-id}-{trace-flags}`
562    #[inline]
563    pub fn traceparent(&self) -> Option<String> {
564        self.otel_span.as_ref().map(|span| {
565            let sampled = self
566                .trace_context
567                .as_ref()
568                .map(|c| c.sampled)
569                .unwrap_or(true);
570            crate::otel::create_traceparent(&span.trace_id, &span.span_id, sampled)
571        })
572    }
573
574    // === Mutation helpers ===
575
576    /// Set the trace ID.
577    #[inline]
578    pub fn set_trace_id(&mut self, trace_id: impl Into<String>) {
579        self.trace_id = trace_id.into();
580    }
581
582    /// Set the route ID.
583    #[inline]
584    pub fn set_route_id(&mut self, route_id: impl Into<String>) {
585        self.route_id = Some(route_id.into());
586    }
587
588    /// Set the upstream ID.
589    #[inline]
590    pub fn set_upstream(&mut self, upstream: impl Into<String>) {
591        self.upstream = Some(upstream.into());
592    }
593
594    /// Set the selected upstream peer address (IP:port).
595    #[inline]
596    pub fn set_selected_upstream_address(&mut self, address: impl Into<String>) {
597        self.selected_upstream_address = Some(address.into());
598    }
599
600    /// Increment upstream attempt counter.
601    #[inline]
602    pub fn inc_upstream_attempts(&mut self) {
603        self.upstream_attempts += 1;
604    }
605
606    /// Set response bytes.
607    #[inline]
608    pub fn set_response_bytes(&mut self, bytes: u64) {
609        self.response_bytes = bytes;
610    }
611
612    // === Fallback accessors ===
613
614    /// Get the current fallback attempt number (0 = primary).
615    #[inline]
616    pub fn fallback_attempt(&self) -> u32 {
617        self.fallback_attempt
618    }
619
620    /// Get the list of upstreams that have been tried.
621    #[inline]
622    pub fn tried_upstreams(&self) -> &[String] {
623        &self.tried_upstreams
624    }
625
626    /// Get the fallback reason, if fallback was triggered.
627    #[inline]
628    pub fn fallback_reason(&self) -> Option<&FallbackReason> {
629        self.fallback_reason.as_ref()
630    }
631
632    /// Get the original upstream ID (before fallback).
633    #[inline]
634    pub fn original_upstream(&self) -> Option<&str> {
635        self.original_upstream.as_deref()
636    }
637
638    /// Get the model mapping that was applied: (original, mapped).
639    #[inline]
640    pub fn model_mapping_applied(&self) -> Option<&(String, String)> {
641        self.model_mapping_applied.as_ref()
642    }
643
644    /// Check if fallback was used for this request.
645    #[inline]
646    pub fn used_fallback(&self) -> bool {
647        self.fallback_attempt > 0
648    }
649
650    /// Record that a fallback attempt is being made.
651    #[inline]
652    pub fn record_fallback(&mut self, reason: FallbackReason, new_upstream: &str) {
653        if self.fallback_attempt == 0 {
654            // First fallback - save original upstream
655            self.original_upstream = self.upstream.clone();
656        }
657        self.fallback_attempt += 1;
658        self.fallback_reason = Some(reason);
659        if let Some(current) = &self.upstream {
660            self.tried_upstreams.push(current.clone());
661        }
662        self.upstream = Some(new_upstream.to_string());
663    }
664
665    /// Record model mapping applied during fallback.
666    #[inline]
667    pub fn record_model_mapping(&mut self, original: String, mapped: String) {
668        self.model_mapping_applied = Some((original, mapped));
669    }
670
671    // === Model Routing accessors ===
672
673    /// Check if model-based routing was used to select the upstream.
674    #[inline]
675    pub fn used_model_routing(&self) -> bool {
676        self.model_routing_used
677    }
678
679    /// Get the provider override from model-based routing (if any).
680    #[inline]
681    pub fn inference_provider_override(&self) -> Option<grapsus_config::InferenceProvider> {
682        self.inference_provider_override
683    }
684
685    /// Record model-based routing result.
686    ///
687    /// Called when model-based routing selects an upstream based on the model name.
688    #[inline]
689    pub fn record_model_routing(
690        &mut self,
691        upstream: &str,
692        model: Option<String>,
693        provider_override: Option<grapsus_config::InferenceProvider>,
694    ) {
695        self.upstream = Some(upstream.to_string());
696        self.model_routing_used = true;
697        if model.is_some() {
698            self.inference_model = model;
699        }
700        self.inference_provider_override = provider_override;
701    }
702}
703
704impl Default for RequestContext {
705    fn default() -> Self {
706        Self::new()
707    }
708}
709
710// ============================================================================
711// Tests
712// ============================================================================
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_rate_limit_header_info() {
720        let info = RateLimitHeaderInfo {
721            limit: 100,
722            remaining: 42,
723            reset_at: 1704067200,
724        };
725
726        assert_eq!(info.limit, 100);
727        assert_eq!(info.remaining, 42);
728        assert_eq!(info.reset_at, 1704067200);
729    }
730
731    #[test]
732    fn test_request_context_default() {
733        let ctx = RequestContext::new();
734
735        assert!(ctx.trace_id.is_empty());
736        assert!(ctx.rate_limit_info.is_none());
737        assert!(ctx.route_id.is_none());
738        assert!(ctx.config.is_none());
739    }
740
741    #[test]
742    fn test_request_context_rate_limit_info() {
743        let mut ctx = RequestContext::new();
744
745        // Initially no rate limit info
746        assert!(ctx.rate_limit_info.is_none());
747
748        // Set rate limit info
749        ctx.rate_limit_info = Some(RateLimitHeaderInfo {
750            limit: 50,
751            remaining: 25,
752            reset_at: 1704067300,
753        });
754
755        assert!(ctx.rate_limit_info.is_some());
756        let info = ctx.rate_limit_info.as_ref().unwrap();
757        assert_eq!(info.limit, 50);
758        assert_eq!(info.remaining, 25);
759        assert_eq!(info.reset_at, 1704067300);
760    }
761
762    #[test]
763    fn test_request_context_elapsed() {
764        let ctx = RequestContext::new();
765
766        // Elapsed time should be very small (less than 1 second)
767        let elapsed = ctx.elapsed();
768        assert!(elapsed.as_secs() < 1);
769    }
770
771    #[test]
772    fn test_request_context_setters() {
773        let mut ctx = RequestContext::new();
774
775        ctx.set_trace_id("trace-123");
776        assert_eq!(ctx.trace_id(), "trace-123");
777        assert_eq!(ctx.correlation_id(), "trace-123");
778
779        ctx.set_route_id("my-route");
780        assert_eq!(ctx.route_id(), Some("my-route"));
781
782        ctx.set_upstream("backend-pool");
783        assert_eq!(ctx.upstream(), Some("backend-pool"));
784
785        ctx.inc_upstream_attempts();
786        ctx.inc_upstream_attempts();
787        assert_eq!(ctx.upstream_attempts(), 2);
788
789        ctx.set_response_bytes(1024);
790        assert_eq!(ctx.response_bytes(), 1024);
791    }
792
793    #[test]
794    fn test_fallback_tracking() {
795        let mut ctx = RequestContext::new();
796
797        // Initially no fallback
798        assert_eq!(ctx.fallback_attempt(), 0);
799        assert!(!ctx.used_fallback());
800        assert!(ctx.tried_upstreams().is_empty());
801        assert!(ctx.fallback_reason().is_none());
802        assert!(ctx.original_upstream().is_none());
803
804        // Set initial upstream
805        ctx.set_upstream("openai-primary");
806
807        // Record first fallback
808        ctx.record_fallback(FallbackReason::HealthCheckFailed, "anthropic-fallback");
809
810        assert_eq!(ctx.fallback_attempt(), 1);
811        assert!(ctx.used_fallback());
812        assert_eq!(ctx.tried_upstreams(), &["openai-primary".to_string()]);
813        assert!(matches!(
814            ctx.fallback_reason(),
815            Some(FallbackReason::HealthCheckFailed)
816        ));
817        assert_eq!(ctx.original_upstream(), Some("openai-primary"));
818        assert_eq!(ctx.upstream(), Some("anthropic-fallback"));
819
820        // Record second fallback
821        ctx.record_fallback(FallbackReason::ErrorCode(503), "local-gpu");
822
823        assert_eq!(ctx.fallback_attempt(), 2);
824        assert_eq!(
825            ctx.tried_upstreams(),
826            &[
827                "openai-primary".to_string(),
828                "anthropic-fallback".to_string()
829            ]
830        );
831        assert!(matches!(
832            ctx.fallback_reason(),
833            Some(FallbackReason::ErrorCode(503))
834        ));
835        // Original upstream should still be the first one
836        assert_eq!(ctx.original_upstream(), Some("openai-primary"));
837        assert_eq!(ctx.upstream(), Some("local-gpu"));
838    }
839
840    #[test]
841    fn test_model_mapping_tracking() {
842        let mut ctx = RequestContext::new();
843
844        assert!(ctx.model_mapping_applied().is_none());
845
846        ctx.record_model_mapping("gpt-4".to_string(), "claude-3-opus".to_string());
847
848        let mapping = ctx.model_mapping_applied().unwrap();
849        assert_eq!(mapping.0, "gpt-4");
850        assert_eq!(mapping.1, "claude-3-opus");
851    }
852
853    #[test]
854    fn test_fallback_reason_display() {
855        assert_eq!(
856            FallbackReason::HealthCheckFailed.to_string(),
857            "health_check_failed"
858        );
859        assert_eq!(
860            FallbackReason::BudgetExhausted.to_string(),
861            "budget_exhausted"
862        );
863        assert_eq!(
864            FallbackReason::LatencyThreshold {
865                observed_ms: 5500,
866                threshold_ms: 5000
867            }
868            .to_string(),
869            "latency_threshold_5500ms_exceeded_5000ms"
870        );
871        assert_eq!(FallbackReason::ErrorCode(502).to_string(), "error_code_502");
872        assert_eq!(
873            FallbackReason::ConnectionError("timeout".to_string()).to_string(),
874            "connection_error_timeout"
875        );
876    }
877}