use super::forwarding::{error_response, forward_request_with_body, forward_with_recording};
use super::headers::{
RiftHeadersExt, VALUE_ERROR, VALUE_LATENCY, VALUE_TCP, VALUE_TRUE, X_RIFT_BEHAVIOR_COPY,
X_RIFT_BEHAVIOR_DECORATE, X_RIFT_BEHAVIOR_LOOKUP, X_RIFT_BEHAVIOR_SHELL, X_RIFT_BEHAVIOR_WAIT,
X_RIFT_FAULT, X_RIFT_LATENCY_MS, X_RIFT_RULE_ID, X_RIFT_SCRIPT, X_RIFT_TCP_FAULT,
};
use super::response_ext::ResponseExt;
use crate::behaviors::{
apply_copy_behaviors, apply_decorate, apply_lookup_behaviors, apply_shell_transform,
RequestContext,
};
use crate::config::TcpFault;
use crate::extensions::fault::{apply_latency, create_error_response, decide_fault, FaultDecision};
use crate::extensions::matcher::CompiledRule;
use crate::extensions::metrics;
use crate::extensions::routing::Router;
use crate::extensions::template::{has_template_variables, process_template, RequestData};
use crate::proxy::context::{
ForwardingContext, RequestHandlerContext, RequestInfo, ScriptingContext, UpstreamService,
};
use crate::scripting::{CacheKey, FaultDecision as ScriptFaultDecision, ScriptRequest};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::body::Bytes;
use hyper::{Request, Response};
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use tracing::{debug, error, info, warn};
pub async fn handle_request(
ctx: &RequestHandlerContext<'_>,
req: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Infallible> {
let start_time = std::time::Instant::now();
let method = req.method().clone();
let uri = req.uri().clone();
let headers = req.headers().clone();
debug!("Received request: {} {}", method, uri);
let upstream = select_upstream(ctx.router, ctx.upstreams, &req)
.map(|(url, name)| UpstreamService {
url: Some(url),
name: Some(name),
})
.unwrap_or_default();
let req = if let (Some(compiled_scripts), Some(script_pool), Some(decision_cache)) =
(ctx.compiled_scripts, ctx.script_pool, ctx.decision_cache)
{
let scripting = ScriptingContext {
compiled_scripts,
script_pool,
decision_cache,
};
match handle_script_rules(ctx, &scripting, req, &upstream, start_time).await {
RuleHandlingResult::Response(response) => return Ok(response),
RuleHandlingResult::NoFault(req) => req,
}
} else {
req
};
let matched_rule_index = ctx
.compiled_rules
.iter()
.enumerate()
.find(|(idx, rule)| {
rule.matches(&method, &uri, &headers)
&& rule_applies_to_upstream(&ctx.rule_upstreams[*idx], upstream.name.as_deref())
})
.map(|(idx, _)| idx);
if let Some(rule_idx) = matched_rule_index {
let rule = &ctx.compiled_rules[rule_idx];
info!("Request matched rule: {}", rule.id);
match handle_yaml_rule(ctx, rule, req, upstream.url.as_deref(), start_time).await {
RuleHandlingResult::Response(response) => return Ok(response),
RuleHandlingResult::NoFault(r) => {
let upstream_url = upstream.url.as_deref().unwrap_or(ctx.upstream_uri);
let response = forward_with_recording(
ctx.http_client,
ctx.recording_store,
ctx.recording_signature_headers,
r,
upstream_url,
)
.await;
let status = response.status().as_u16();
let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(method.as_str(), duration_ms, "none");
metrics::record_request(method.as_str(), status);
return Ok(response);
}
}
}
let upstream_url = upstream.url.as_deref().unwrap_or(ctx.upstream_uri);
let response = forward_with_recording(
ctx.http_client,
ctx.recording_store,
ctx.recording_signature_headers,
req,
upstream_url,
)
.await;
let status = response.status().as_u16();
let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(method.as_str(), duration_ms, "none");
metrics::record_request(method.as_str(), status);
Ok(response)
}
pub enum RuleHandlingResult {
Response(Response<BoxBody<Bytes, hyper::Error>>),
NoFault(Request<hyper::body::Incoming>),
}
async fn handle_script_rules(
ctx: &RequestHandlerContext<'_>,
scripting: &ScriptingContext<'_>,
req: Request<hyper::body::Incoming>,
upstream: &UpstreamService,
start_time: std::time::Instant,
) -> RuleHandlingResult {
let request_info = RequestInfo::from_request(&req);
let matching_script =
scripting
.compiled_scripts
.iter()
.find(|(_, compiled_rule, rule_upstream)| {
compiled_rule.matches(
&request_info.method,
&request_info.uri,
&request_info.headers,
) && rule_applies_to_upstream(rule_upstream, upstream.name.as_deref())
});
let (compiled_script, compiled_rule, _) = match matching_script {
Some(m) => m,
None => return RuleHandlingResult::NoFault(req),
};
info!("Request matched script rule: {}", compiled_rule.id);
let body_bytes = match req.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!("Failed to collect request body: {}", e);
return RuleHandlingResult::Response(
error_response(500, "Failed to read request body").into_boxed(),
);
}
};
let mut headers_map = HashMap::new();
for (k, v) in request_info.headers.iter() {
if let Ok(value_str) = v.to_str() {
headers_map.insert(k.as_str().to_string(), value_str.to_string());
}
}
let body_json: serde_json::Value =
serde_json::from_slice(&body_bytes).unwrap_or(serde_json::Value::Null);
let query_params = crate::predicate::parse_query_string(request_info.uri.query());
let script_request = ScriptRequest {
method: request_info.method.to_string(),
path: request_info.uri.path().to_string(),
headers: headers_map.clone(),
body: body_json.clone(),
query: query_params,
path_params: HashMap::new(),
};
let cache_key = CacheKey::new(
request_info.method.to_string(),
request_info.uri.path().to_string(),
headers_map.into_iter().collect(),
&body_json,
compiled_rule.id.clone(),
);
let use_cache = !ctx.flow_state_configured;
let script_start = std::time::Instant::now();
let result = if use_cache {
if let Some(cached_decision) = scripting.decision_cache.get(&cache_key) {
debug!("Cache hit for rule: {} (stateless)", compiled_rule.id);
Ok(cached_decision)
} else {
debug!("Cache miss for rule: {}", compiled_rule.id);
let pool_result = scripting
.script_pool
.execute(
compiled_script.clone(),
script_request,
Arc::clone(ctx.flow_store),
)
.await;
if let Ok(ref decision) = pool_result {
let _ = scripting.decision_cache.insert(cache_key, decision.clone());
}
pool_result
}
} else {
debug!("Executing stateful script (no cache): {}", compiled_rule.id);
scripting
.script_pool
.execute(
compiled_script.clone(),
script_request,
Arc::clone(ctx.flow_store),
)
.await
};
let script_duration = script_start.elapsed().as_secs_f64() * 1000.0;
let forwarding_ctx = ForwardingContext {
info: request_info,
body_bytes,
start_time,
upstream_service: upstream.to_owned(),
};
RuleHandlingResult::Response(
handle_script_result(
ctx,
result.map_err(|e| e.to_string()),
compiled_rule,
&forwarding_ctx,
script_duration,
)
.await,
)
}
async fn handle_script_result(
ctx: &RequestHandlerContext<'_>,
result: Result<ScriptFaultDecision, String>,
compiled_rule: &CompiledRule,
forwarding_ctx: &ForwardingContext,
script_duration: f64,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let request_info = &forwarding_ctx.info;
match result {
Ok(ScriptFaultDecision::Error {
status,
body,
rule_id,
headers: script_headers,
}) => {
warn!(
"Script injecting error fault: status={}, rule={}",
status, rule_id
);
metrics::record_script_execution(&rule_id, script_duration, "inject");
metrics::record_script_fault("error", &rule_id, None);
metrics::record_error_injection(&rule_id, status);
let duration_ms = forwarding_ctx.start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), duration_ms, "script");
metrics::record_request(request_info.method.as_str(), status);
let fixed_headers = ctx
.compiled_rules
.iter()
.enumerate()
.find(|(idx, rule)| {
rule.matches(
&request_info.method,
&request_info.uri,
&request_info.headers,
) && rule_applies_to_upstream(
&ctx.rule_upstreams[*idx],
forwarding_ctx.upstream_service.name.as_deref(),
) && rule.rule.fault.error.is_some()
})
.and_then(|(_, rule)| rule.rule.fault.error.as_ref().map(|e| e.headers.clone()));
let mut response =
create_error_response(status, body, fixed_headers.as_ref(), Some(&script_headers))
.unwrap();
response.set_header(&X_RIFT_FAULT, &VALUE_ERROR);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
response.set_header(&X_RIFT_SCRIPT, &VALUE_TRUE);
response.into_boxed()
}
Ok(ScriptFaultDecision::Latency {
duration_ms,
rule_id,
}) => {
info!(
"Script injecting latency fault: {}ms, rule={}",
duration_ms, rule_id
);
metrics::record_script_execution(&rule_id, script_duration, "inject");
metrics::record_script_fault("latency", &rule_id, Some(duration_ms));
apply_latency(duration_ms).await;
let upstream_url = forwarding_ctx
.upstream_service
.url
.as_deref()
.unwrap_or(ctx.upstream_uri);
let mut response = forward_request_with_body(
ctx.http_client,
request_info.method.clone(),
request_info.uri.clone(),
request_info.headers.clone(),
forwarding_ctx.body_bytes.clone(),
upstream_url,
)
.await;
let status = response.status().as_u16();
let total_duration = forwarding_ctx.start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), total_duration, "script");
metrics::record_request(request_info.method.as_str(), status);
response.set_header(&X_RIFT_FAULT, &VALUE_LATENCY);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
response.set_header(&X_RIFT_SCRIPT, &VALUE_TRUE);
response.set_header_value(&X_RIFT_LATENCY_MS, &duration_ms.to_string());
response.into_boxed()
}
Ok(ScriptFaultDecision::None) => {
debug!(
"Script decided not to inject fault for rule: {}",
compiled_rule.id
);
metrics::record_script_execution(&compiled_rule.id, script_duration, "pass");
let upstream_url = forwarding_ctx
.upstream_service
.url
.as_deref()
.unwrap_or(ctx.upstream_uri);
let response = forward_request_with_body(
ctx.http_client,
request_info.method.clone(),
request_info.uri.clone(),
request_info.headers.clone(),
forwarding_ctx.body_bytes.clone(),
upstream_url,
)
.await;
let status = response.status().as_u16();
let duration_ms = forwarding_ctx.start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), duration_ms, "none");
metrics::record_request(request_info.method.as_str(), status);
response.into_boxed()
}
Err(e) => {
error!(
"Script execution error for rule {}: {}",
compiled_rule.id, e
);
metrics::record_script_execution(&compiled_rule.id, script_duration, "error");
metrics::record_script_error(&compiled_rule.id, "runtime");
let upstream_url = forwarding_ctx
.upstream_service
.url
.as_deref()
.unwrap_or(ctx.upstream_uri);
let response = forward_request_with_body(
ctx.http_client,
request_info.method.clone(),
request_info.uri.clone(),
request_info.headers.clone(),
forwarding_ctx.body_bytes.clone(),
upstream_url,
)
.await;
let status = response.status().as_u16();
let duration_ms = forwarding_ctx.start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), duration_ms, "none");
metrics::record_request(request_info.method.as_str(), status);
response.into_boxed()
}
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_yaml_rule(
ctx: &RequestHandlerContext<'_>,
rule: &CompiledRule,
req: Request<hyper::body::Incoming>,
selected_upstream_url: Option<&str>,
start_time: std::time::Instant,
) -> RuleHandlingResult {
let fault_decision = decide_fault(&rule.rule.fault, &rule.id);
let request_info = RequestInfo::from_request(&req);
match fault_decision {
FaultDecision::TcpFault {
fault_type,
rule_id,
} => {
warn!("Injecting TCP fault: {:?}, rule={}", fault_type, rule_id);
metrics::record_error_injection(&rule_id, 0);
let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), duration_ms, "tcp_fault");
let (status, body) = match fault_type {
TcpFault::ConnectionResetByPeer => {
(502, r#"{"error": "Connection reset by peer"}"#.to_string())
}
TcpFault::RandomDataThenClose => (
502,
r#"{"error": "Connection closed unexpectedly"}"#.to_string(),
),
};
let mut response = create_error_response(status, body, None, None).unwrap();
response.set_header(&X_RIFT_FAULT, &VALUE_TCP);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
response.set_header_value(&X_RIFT_TCP_FAULT, &format!("{fault_type:?}").to_lowercase());
RuleHandlingResult::Response(response.into_boxed())
}
FaultDecision::Error {
status,
body,
rule_id,
headers: fault_headers,
behaviors,
} => {
warn!("Injecting error fault: status={}, rule={}", status, rule_id);
if let Some(ref bhvs) = behaviors {
if let Some(ref wait) = bhvs.wait {
let wait_ms = wait.get_duration_ms();
debug!("Applying wait behavior: {}ms", wait_ms);
apply_latency(wait_ms).await;
}
}
metrics::record_error_injection(&rule_id, status);
let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), duration_ms, "error");
metrics::record_request(request_info.method.as_str(), status);
let request_context = RequestContext::from_request(
request_info.method.as_str(),
&request_info.uri,
&request_info.headers,
None, );
let mut processed_body = if has_template_variables(&body) {
let request_data = RequestData::new(
request_info.method.as_str(),
request_info.uri.path(),
request_info.uri.query(),
&request_info.headers,
None,
);
process_template(&body, &request_data)
} else {
body
};
let mut response_headers = fault_headers.clone();
if let Some(ref bhvs) = behaviors {
if !bhvs.copy.is_empty() {
debug!("Applying {} copy behaviors", bhvs.copy.len());
processed_body = apply_copy_behaviors(
&processed_body,
&mut response_headers,
&bhvs.copy,
&request_context,
);
}
}
if let Some(ref bhvs) = behaviors {
if !bhvs.lookup.is_empty() {
debug!("Applying {} lookup behaviors", bhvs.lookup.len());
processed_body = apply_lookup_behaviors(
&processed_body,
&mut response_headers,
&bhvs.lookup,
&request_context,
ctx.csv_cache,
);
}
}
if let Some(ref bhvs) = behaviors {
for cmd in &bhvs.shell_transform {
debug!("Applying shell transform: {}", cmd);
match apply_shell_transform(cmd, &request_context, &processed_body, status) {
Ok(transformed) => {
processed_body = transformed;
}
Err(e) => {
warn!("Shell transform failed: {}", e);
}
}
}
}
let mut final_status = status;
if let Some(ref bhvs) = behaviors {
if let Some(ref script) = bhvs.decorate {
debug!("Applying decorate behavior");
match apply_decorate(
script,
&request_context,
&processed_body,
status,
&mut response_headers,
) {
Ok((new_body, new_status)) => {
processed_body = new_body;
final_status = new_status;
}
Err(e) => {
warn!("Decorate behavior failed: {}", e);
}
}
}
}
let mut response =
create_error_response(final_status, processed_body, Some(&response_headers), None)
.unwrap();
response.set_header(&X_RIFT_FAULT, &VALUE_ERROR);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
if let Some(ref bhvs) = behaviors {
if bhvs.wait.is_some() {
response.set_header(&X_RIFT_BEHAVIOR_WAIT, &VALUE_TRUE);
}
if !bhvs.copy.is_empty() {
response.set_header(&X_RIFT_BEHAVIOR_COPY, &VALUE_TRUE);
}
if !bhvs.lookup.is_empty() {
response.set_header(&X_RIFT_BEHAVIOR_LOOKUP, &VALUE_TRUE);
}
if !bhvs.shell_transform.is_empty() {
response.set_header(&X_RIFT_BEHAVIOR_SHELL, &VALUE_TRUE);
}
if bhvs.decorate.is_some() {
response.set_header(&X_RIFT_BEHAVIOR_DECORATE, &VALUE_TRUE);
}
}
RuleHandlingResult::Response(response.into_boxed())
}
FaultDecision::Latency {
duration_ms,
rule_id,
} => {
info!(
"Injecting latency fault: {}ms, rule={}",
duration_ms, rule_id
);
metrics::record_latency_injection(&rule_id, duration_ms);
apply_latency(duration_ms).await;
let body_bytes = match req.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!("Failed to collect request body: {}", e);
let mut response = error_response(500, "Failed to read request body");
response.set_header(&X_RIFT_FAULT, &VALUE_LATENCY);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
return RuleHandlingResult::Response(response.into_boxed());
}
};
let upstream_url = selected_upstream_url.unwrap_or(ctx.upstream_uri);
let mut response = forward_request_with_body(
ctx.http_client,
request_info.method.clone(),
request_info.uri.clone(),
request_info.headers.clone(),
body_bytes,
upstream_url,
)
.await;
let status = response.status().as_u16();
let total_duration = start_time.elapsed().as_secs_f64() * 1000.0;
metrics::record_proxy_duration(request_info.method.as_str(), total_duration, "latency");
metrics::record_request(request_info.method.as_str(), status);
response.set_header(&X_RIFT_FAULT, &VALUE_LATENCY);
response.set_header_value(&X_RIFT_RULE_ID, &rule_id);
response.set_header_value(&X_RIFT_LATENCY_MS, &duration_ms.to_string());
RuleHandlingResult::Response(response.into_boxed())
}
FaultDecision::None => {
debug!("No fault injected for matched rule: {}", rule.id);
RuleHandlingResult::NoFault(req)
}
}
}
fn select_upstream<B>(
router: Option<&Router>,
upstreams: &[crate::config::Upstream],
req: &Request<B>,
) -> Option<(String, String)> {
let router = router?;
let upstream_name = router.match_request(req)?;
let upstream = upstreams.iter().find(|u| u.name == upstream_name)?;
debug!("Routed to upstream: {} ({})", upstream_name, upstream.url);
Some((upstream.url.clone(), upstream_name.to_string()))
}
pub fn rule_applies_to_upstream(
rule_upstream_filter: &Option<String>,
selected_upstream_name: Option<&str>,
) -> bool {
match (rule_upstream_filter, selected_upstream_name) {
(None, _) => true,
(Some(_), None) => true,
(Some(rule_upstream), Some(selected)) => rule_upstream == selected,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_applies_to_upstream_no_filter() {
assert!(rule_applies_to_upstream(&None, None));
assert!(rule_applies_to_upstream(&None, Some("backend-a")));
assert!(rule_applies_to_upstream(&None, Some("backend-b")));
}
#[test]
fn test_rule_applies_to_upstream_sidecar_mode() {
assert!(rule_applies_to_upstream(
&Some("backend-a".to_string()),
None
));
assert!(rule_applies_to_upstream(
&Some("backend-b".to_string()),
None
));
}
#[test]
fn test_rule_applies_to_upstream_matching() {
assert!(rule_applies_to_upstream(
&Some("backend-a".to_string()),
Some("backend-a")
));
}
#[test]
fn test_rule_applies_to_upstream_non_matching() {
assert!(!rule_applies_to_upstream(
&Some("backend-a".to_string()),
Some("backend-b")
));
assert!(!rule_applies_to_upstream(
&Some("backend-x".to_string()),
Some("backend-y")
));
}
#[test]
fn test_rule_applies_to_upstream_empty_strings() {
assert!(rule_applies_to_upstream(&Some("".to_string()), Some("")));
assert!(!rule_applies_to_upstream(
&Some("backend".to_string()),
Some("")
));
}
}