use crate::behaviors::ResponseBehaviors;
use crate::config::{FaultConfig, TcpFault};
use crate::response::builder::ErrorResponseBuilder;
use http_body_util::Full;
use hyper::body::Bytes;
use hyper::header::{HeaderName, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING};
use hyper::http::HeaderValue;
use hyper::{HeaderMap, Response, StatusCode};
use rand::Rng;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone)]
pub enum FaultDecision {
None,
Latency {
duration_ms: u64,
rule_id: String,
},
Error {
status: u16,
body: String,
rule_id: String,
headers: HashMap<String, String>,
behaviors: Option<ResponseBehaviors>,
},
TcpFault {
fault_type: TcpFault,
rule_id: String,
},
}
pub fn decide_fault(fault_config: &FaultConfig, rule_id: &str) -> FaultDecision {
let mut rng = rand::thread_rng();
if let Some(tcp_fault) = &fault_config.tcp_fault {
return FaultDecision::TcpFault {
fault_type: *tcp_fault,
rule_id: rule_id.to_string(),
};
}
if let Some(error_fault) = &fault_config.error {
if should_inject(error_fault.probability, &mut rng) {
return FaultDecision::Error {
status: error_fault.status,
body: error_fault.body.clone(),
rule_id: rule_id.to_string(),
headers: error_fault.headers.clone(),
behaviors: error_fault.behaviors.clone(),
};
}
}
if let Some(latency_fault) = &fault_config.latency {
if should_inject(latency_fault.probability, &mut rng) {
let duration_ms = rng.gen_range(latency_fault.min_ms..=latency_fault.max_ms);
return FaultDecision::Latency {
duration_ms,
rule_id: rule_id.to_string(),
};
}
}
FaultDecision::None
}
fn should_inject(probability: f64, rng: &mut impl Rng) -> bool {
rng.gen::<f64>() < probability
}
pub async fn apply_latency(duration_ms: u64) {
tokio::time::sleep(Duration::from_millis(duration_ms)).await;
}
pub fn create_error_response(
status: u16,
body: String,
fixed_headers: Option<&HashMap<String, String>>,
dynamic_headers: Option<&HashMap<String, String>>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let content_length = body.len();
let mut merged = HeaderMap::new();
[fixed_headers, dynamic_headers]
.iter()
.filter_map(|&opt| opt)
.flat_map(|map| map.iter())
.for_each(|(key, value)| {
if let (Ok(name), Ok(val)) = (
HeaderName::try_from(key.as_str()),
HeaderValue::from_str(value),
) {
merged.insert(name, val);
}
});
merged.remove(TRANSFER_ENCODING);
merged
.entry(CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/json"));
merged.insert(
CONTENT_LENGTH,
HeaderValue::from_str(&content_length.to_string()).unwrap(),
);
let response = ErrorResponseBuilder::new(status_code)
.merge_headers(&merged)
.body(body)
.build_full();
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ErrorFault, LatencyFault};
#[test]
fn test_should_inject_always() {
let mut rng = rand::thread_rng();
let mut count = 0;
for _ in 0..100 {
if should_inject(1.0, &mut rng) {
count += 1;
}
}
assert_eq!(count, 100);
}
#[test]
fn test_should_inject_never() {
let mut rng = rand::thread_rng();
let mut count = 0;
for _ in 0..100 {
if should_inject(0.0, &mut rng) {
count += 1;
}
}
assert_eq!(count, 0);
}
#[test]
fn test_should_inject_probability() {
let mut rng = rand::thread_rng();
let mut count = 0;
let iterations = 10000;
let target_probability = 0.3;
for _ in 0..iterations {
if should_inject(target_probability, &mut rng) {
count += 1;
}
}
let actual_probability = count as f64 / iterations as f64;
assert!(
(actual_probability - target_probability).abs() < 0.05,
"Expected ~{target_probability}, got {actual_probability}"
);
}
#[test]
fn test_decide_fault_with_error() {
let fault_config = FaultConfig {
latency: None,
error: Some(ErrorFault {
probability: 1.0,
status: 502,
body: "error".to_string(),
headers: HashMap::new(),
behaviors: None,
}),
tcp_fault: None,
};
let decision = decide_fault(&fault_config, "test-rule");
match decision {
FaultDecision::Error {
status,
body,
rule_id,
headers,
..
} => {
assert_eq!(status, 502);
assert_eq!(body, "error");
assert_eq!(rule_id, "test-rule");
assert!(headers.is_empty()); }
_ => panic!("Expected Error decision"),
}
}
#[test]
fn test_decide_fault_with_latency() {
let fault_config = FaultConfig {
latency: Some(LatencyFault {
probability: 1.0,
min_ms: 100,
max_ms: 200,
}),
error: None,
tcp_fault: None,
};
let decision = decide_fault(&fault_config, "test-rule");
match decision {
FaultDecision::Latency {
duration_ms,
rule_id,
} => {
assert!((100..=200).contains(&duration_ms));
assert_eq!(rule_id, "test-rule");
}
_ => panic!("Expected Latency decision"),
}
}
#[test]
fn test_create_error_response() {
let response =
create_error_response(502, r#"{"error": "test"}"#.to_string(), None, None).unwrap();
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[test]
fn test_create_error_response_with_headers() {
let mut fixed_headers = HashMap::new();
fixed_headers.insert("Server".to_string(), "openresty".to_string());
fixed_headers.insert("X-Custom".to_string(), "fixed-value".to_string());
let mut dynamic_headers = HashMap::new();
dynamic_headers.insert("X-Custom".to_string(), "dynamic-value".to_string());
dynamic_headers.insert("X-Dynamic".to_string(), "new-header".to_string());
let response = create_error_response(
502,
r#"{"error": "test"}"#.to_string(),
Some(&fixed_headers),
Some(&dynamic_headers),
)
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(response.headers().get("x-custom").unwrap(), "dynamic-value");
assert_eq!(response.headers().get("server").unwrap(), "openresty");
assert_eq!(response.headers().get("x-dynamic").unwrap(), "new-header");
assert!(response.headers().get("content-length").is_some());
}
#[test]
fn test_dynamic_headers_override_fixed_headers() {
let mut fixed_headers = HashMap::new();
fixed_headers.insert("X-Override-Me".to_string(), "fixed-value".to_string());
let mut dynamic_headers = HashMap::new();
dynamic_headers.insert("X-Override-Me".to_string(), "dynamic-value".to_string());
let response = create_error_response(
500,
"test body".to_string(),
Some(&fixed_headers),
Some(&dynamic_headers),
)
.unwrap();
let header_value = response
.headers()
.get("x-override-me")
.expect("Header should exist");
assert_eq!(
header_value, "dynamic-value",
"Dynamic header should override fixed header with the same key"
);
assert_ne!(
header_value, "fixed-value",
"Fixed header value should have been overwritten"
);
}
}