use std::sync::atomic::Ordering;
use std::sync::Arc;
use synapse_pingora::admin_server::{
register_evaluate_callback, register_profiles_getter, register_schemas_getter, EvaluationResult,
};
use synapse_pingora::{MetricsRegistry, ProfilingMetrics};
#[test]
fn test_bandwidth_api_returns_nonzero_after_request_recording() {
let registry = MetricsRegistry::new();
let stats_before = registry.get_bandwidth_stats();
assert_eq!(stats_before.total_bytes, 0);
assert_eq!(stats_before.total_bytes_in, 0);
assert_eq!(stats_before.request_count, 0);
registry.record_request_bandwidth(1024);
registry.record_request_bandwidth(2048);
let stats_after = registry.get_bandwidth_stats();
assert!(
stats_after.total_bytes > 0,
"total_bytes should be non-zero"
);
assert!(
stats_after.total_bytes_in > 0,
"total_bytes_in should be non-zero"
);
assert_eq!(stats_after.total_bytes_in, 3072);
assert_eq!(stats_after.request_count, 2);
}
#[test]
fn test_bandwidth_api_returns_nonzero_after_response_recording() {
let registry = MetricsRegistry::new();
let stats_before = registry.get_bandwidth_stats();
assert_eq!(stats_before.total_bytes_out, 0);
registry.record_response_bandwidth(4096);
registry.record_response_bandwidth(8192);
let stats_after = registry.get_bandwidth_stats();
assert!(
stats_after.total_bytes_out > 0,
"total_bytes_out should be non-zero"
);
assert_eq!(stats_after.total_bytes_out, 12288);
assert!(
stats_after.total_bytes > 0,
"total_bytes should be non-zero"
);
}
#[test]
fn test_bandwidth_api_aggregation() {
let registry = MetricsRegistry::new();
registry.record_request_bandwidth(100);
registry.record_response_bandwidth(500);
registry.record_request_bandwidth(200);
registry.record_response_bandwidth(1000);
registry.record_request_bandwidth(300);
registry.record_response_bandwidth(1500);
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.total_bytes_in, 600);
assert_eq!(stats.total_bytes_out, 3000);
assert_eq!(stats.total_bytes, 3600);
assert_eq!(stats.request_count, 3);
assert_eq!(stats.avg_bytes_per_request, 1200); }
#[test]
fn test_bandwidth_api_max_size_tracking() {
let registry = MetricsRegistry::new();
registry.record_request_bandwidth(100);
registry.record_request_bandwidth(500);
registry.record_request_bandwidth(250);
registry.record_response_bandwidth(1000);
registry.record_response_bandwidth(3000);
registry.record_response_bandwidth(2000);
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.max_request_size, 500);
assert_eq!(stats.max_response_size, 3000);
}
#[test]
fn test_bandwidth_api_after_reset() {
let registry = MetricsRegistry::new();
registry.record_request_bandwidth(1000);
registry.record_response_bandwidth(2000);
let stats_before = registry.get_bandwidth_stats();
assert!(stats_before.total_bytes > 0);
registry.reset();
let endpoint_stats = registry.get_endpoint_stats();
assert!(endpoint_stats.is_empty());
}
#[test]
fn test_bandwidth_api_concurrent_recording() {
use std::thread;
let registry = Arc::new(MetricsRegistry::new());
let mut handles = vec![];
for i in 0..10 {
let reg = Arc::clone(®istry);
handles.push(thread::spawn(move || {
for _ in 0..100 {
reg.record_request_bandwidth(i as u64 * 10 + 1);
reg.record_response_bandwidth(i as u64 * 10 + 2);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.request_count, 1000); assert!(stats.total_bytes_in > 0);
assert!(stats.total_bytes_out > 0);
}
#[test]
fn test_endpoint_recording_integration() {
let registry = MetricsRegistry::new();
registry.record_endpoint("/api/users", "GET");
registry.record_endpoint("/api/users", "POST");
registry.record_endpoint("/api/products", "GET");
registry.record_endpoint("/api/users/123", "GET");
let stats = registry.get_endpoint_stats();
assert_eq!(stats.len(), 3);
let users_stat = stats
.iter()
.find(|(path, _)| path == "/api/users")
.map(|(_, s)| s);
assert!(users_stat.is_some());
let users = users_stat.unwrap();
assert_eq!(users.hit_count, 2);
assert_eq!(users.methods.len(), 2);
}
#[test]
fn test_endpoint_stats_cleared_on_reset() {
let registry = MetricsRegistry::new();
registry.record_endpoint("/test/path", "GET");
assert!(!registry.get_endpoint_stats().is_empty());
registry.reset();
assert!(registry.get_endpoint_stats().is_empty());
}
#[test]
fn test_evaluate_callback_registration() {
register_evaluate_callback(|_method, uri, _headers, _body, _client_ip| {
let is_sqli = uri.contains("'") || uri.contains("OR") || uri.contains("--");
EvaluationResult {
blocked: is_sqli,
risk_score: if is_sqli { 85 } else { 10 },
matched_rules: if is_sqli { vec![942100] } else { vec![] },
block_reason: if is_sqli {
Some("SQL Injection detected".to_string())
} else {
None
},
detection_time_us: 100,
}
});
}
#[test]
fn test_evaluate_callback_reregistration() {
register_evaluate_callback(
|_method, _uri, _headers, _body, _client_ip| EvaluationResult {
blocked: false,
risk_score: 0,
matched_rules: vec![],
block_reason: None,
detection_time_us: 50,
},
);
register_evaluate_callback(
|_method, _uri, _headers, _body, _client_ip| EvaluationResult {
blocked: true,
risk_score: 100,
matched_rules: vec![999],
block_reason: Some("Always block".to_string()),
detection_time_us: 10,
},
);
}
#[test]
fn test_profiles_getter_registration() {
register_profiles_getter(|| {
vec![]
});
}
#[test]
fn test_schemas_getter_registration() {
register_schemas_getter(|| {
vec![]
});
}
#[test]
fn test_full_profiling_workflow() {
let registry = MetricsRegistry::new();
let requests = vec![
("/api/users", "GET", 50, 1500),
("/api/users", "POST", 200, 50),
("/api/products", "GET", 30, 5000),
("/api/products/123", "GET", 20, 2000),
("/api/orders", "POST", 500, 100),
];
for (path, method, req_bytes, resp_bytes) in &requests {
registry.record_endpoint(path, method);
registry.record_request_bandwidth(*req_bytes);
registry.record_response_bandwidth(*resp_bytes);
}
let endpoint_stats = registry.get_endpoint_stats();
assert_eq!(endpoint_stats.len(), 4);
let bw_stats = registry.get_bandwidth_stats();
assert_eq!(bw_stats.request_count, 5);
assert_eq!(bw_stats.total_bytes_in, 800); assert_eq!(bw_stats.total_bytes_out, 8650); assert_eq!(bw_stats.max_request_size, 500);
assert_eq!(bw_stats.max_response_size, 5000);
}
#[test]
fn test_profiling_under_load() {
let registry = MetricsRegistry::new();
for i in 0..1000 {
let path = match i % 5 {
0 => "/api/users",
1 => "/api/products",
2 => "/api/orders",
3 => "/api/auth/login",
_ => "/api/health",
};
let method = if i % 3 == 0 { "POST" } else { "GET" };
registry.record_endpoint(path, method);
registry.record_request_bandwidth((i % 100 + 1) as u64);
registry.record_response_bandwidth((i % 500 + 1) as u64);
}
let endpoint_stats = registry.get_endpoint_stats();
assert_eq!(endpoint_stats.len(), 5);
let bw_stats = registry.get_bandwidth_stats();
assert_eq!(bw_stats.request_count, 1000);
assert!(bw_stats.total_bytes > 0);
assert!(bw_stats.avg_bytes_per_request > 0);
}
#[test]
fn test_profiling_metrics_thread_safety() {
use std::thread;
let metrics = Arc::new(ProfilingMetrics::default());
let mut handles = vec![];
{
let m = Arc::clone(&metrics);
handles.push(thread::spawn(move || {
for i in 0..500 {
m.record_request_bytes(i * 10);
}
}));
}
{
let m = Arc::clone(&metrics);
handles.push(thread::spawn(move || {
for i in 0..500 {
m.record_response_bytes(i * 20);
}
}));
}
{
let m = Arc::clone(&metrics);
handles.push(thread::spawn(move || {
for i in 0..500 {
let path = format!("/api/endpoint{}", i % 10);
m.record_endpoint(&path, "GET");
}
}));
}
{
let m = Arc::clone(&metrics);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _ = m.get_bandwidth_stats();
let _ = m.get_endpoint_stats();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(metrics.bandwidth_request_count.load(Ordering::Relaxed), 500);
assert!(metrics.total_bytes_in.load(Ordering::Relaxed) > 0);
assert!(metrics.total_bytes_out.load(Ordering::Relaxed) > 0);
}
#[test]
fn test_anomaly_recording_integration() {
let registry = MetricsRegistry::new();
let anomalies = vec![
("sql_injection", 8.5),
("xss_attempt", 6.0),
("rate_limit_exceeded", 4.0),
("sql_injection", 9.0), ];
registry.record_profile_metrics(
10, &anomalies
.iter()
.map(|(t, s)| (t.to_string(), *s))
.collect::<Vec<_>>(),
);
let prometheus = registry.render_prometheus();
assert!(prometheus.contains("synapse_profiles_active_count 10"));
assert!(prometheus.contains("synapse_anomalies_detected_total"));
}
#[test]
fn test_zero_byte_requests() {
let registry = MetricsRegistry::new();
registry.record_request_bandwidth(0);
registry.record_response_bandwidth(0);
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.request_count, 1);
assert_eq!(stats.avg_bytes_per_request, 0);
}
#[test]
fn test_large_bandwidth_values() {
let registry = MetricsRegistry::new();
let one_gb = 1024 * 1024 * 1024u64;
registry.record_request_bandwidth(one_gb);
registry.record_response_bandwidth(one_gb);
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.total_bytes_in, one_gb);
assert_eq!(stats.total_bytes_out, one_gb);
assert_eq!(stats.total_bytes, 2 * one_gb);
}
#[test]
fn test_single_request_stats() {
let registry = MetricsRegistry::new();
registry.record_request_bandwidth(100);
registry.record_response_bandwidth(500);
let stats = registry.get_bandwidth_stats();
assert_eq!(stats.request_count, 1);
assert_eq!(stats.total_bytes, 600);
assert_eq!(stats.avg_bytes_per_request, 600);
}
#[test]
fn test_endpoint_path_handling() {
let registry = MetricsRegistry::new();
registry.record_endpoint("/", "GET");
registry.record_endpoint("/api", "GET");
registry.record_endpoint("/api/", "GET"); registry.record_endpoint("/api/users/123/profile", "GET");
registry.record_endpoint("/api/users?id=1", "GET");
let stats = registry.get_endpoint_stats();
assert_eq!(stats.len(), 5); }