use crate::request_logger::RequestLogEntry;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct VerificationRequest {
pub method: Option<String>,
pub path: Option<String>,
pub query_params: HashMap<String, String>,
pub headers: HashMap<String, String>,
pub body_pattern: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum VerificationCount {
Exactly(usize),
AtLeast(usize),
AtMost(usize),
Never,
AtLeastOnce,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerificationResult {
pub matched: bool,
pub count: usize,
pub expected: VerificationCount,
pub matches: Vec<RequestLogEntry>,
pub error_message: Option<String>,
}
impl VerificationResult {
pub fn success(
count: usize,
expected: VerificationCount,
matches: Vec<RequestLogEntry>,
) -> Self {
Self {
matched: true,
count,
expected,
matches,
error_message: None,
}
}
pub fn failure(
count: usize,
expected: VerificationCount,
matches: Vec<RequestLogEntry>,
error_message: String,
) -> Self {
Self {
matched: false,
count,
expected,
matches,
error_message: Some(error_message),
}
}
}
pub fn matches_verification_pattern(
entry: &RequestLogEntry,
pattern: &VerificationRequest,
) -> bool {
if let Some(ref expected_method) = pattern.method {
if entry.method.to_uppercase() != expected_method.to_uppercase() {
return false;
}
}
if let Some(ref expected_path) = pattern.path {
if !matches_path_pattern(&entry.path, expected_path) {
return false;
}
}
if !pattern.query_params.is_empty() {
for (key, expected_value) in &pattern.query_params {
let found_value = entry.query_params.get(key);
if found_value != Some(expected_value) {
return false;
}
}
}
for (key, expected_value) in &pattern.headers {
let header_key_lower = key.to_lowercase();
let found = entry
.headers
.iter()
.any(|(k, v)| k.to_lowercase() == header_key_lower && v == expected_value);
if !found {
return false;
}
}
if let Some(ref body_pattern) = pattern.body_pattern {
if let Some(body_str) = entry.metadata.get("request_body") {
if !matches_body_pattern(body_str, body_pattern) {
return false;
}
} else {
}
}
true
}
fn matches_path_pattern(path: &str, pattern: &str) -> bool {
if pattern == path {
return true;
}
if pattern == "*" {
return true;
}
if pattern.contains('*') {
return matches_wildcard_pattern(path, pattern);
}
if let Ok(re) = Regex::new(pattern) {
if re.is_match(path) {
return true;
}
}
false
}
fn matches_wildcard_pattern(path: &str, pattern: &str) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
match_wildcard_segments(&pattern_parts, &path_parts, 0, 0)
}
fn match_wildcard_segments(
pattern_parts: &[&str],
path_parts: &[&str],
pattern_idx: usize,
path_idx: usize,
) -> bool {
if pattern_idx == pattern_parts.len() && path_idx == path_parts.len() {
return true;
}
if pattern_idx == pattern_parts.len() {
return false;
}
let current_pattern = pattern_parts[pattern_idx];
match current_pattern {
"*" => {
if path_idx < path_parts.len() {
if match_wildcard_segments(pattern_parts, path_parts, pattern_idx + 1, path_idx + 1)
{
return true;
}
}
false
}
"**" => {
if match_wildcard_segments(pattern_parts, path_parts, pattern_idx + 1, path_idx) {
return true;
}
if path_idx < path_parts.len()
&& match_wildcard_segments(pattern_parts, path_parts, pattern_idx, path_idx + 1)
{
return true;
}
false
}
_ => {
if path_idx < path_parts.len() && path_parts[path_idx] == current_pattern {
match_wildcard_segments(pattern_parts, path_parts, pattern_idx + 1, path_idx + 1)
} else {
false
}
}
}
}
fn matches_body_pattern(body: &str, pattern: &str) -> bool {
if let Ok(re) = Regex::new(pattern) {
re.is_match(body)
} else {
body == pattern
}
}
pub async fn verify_requests(
logger: &crate::request_logger::CentralizedRequestLogger,
pattern: &VerificationRequest,
expected: VerificationCount,
) -> VerificationResult {
let logs = logger.get_recent_logs(None).await;
let matches: Vec<RequestLogEntry> = logs
.into_iter()
.filter(|entry| matches_verification_pattern(entry, pattern))
.collect();
let count = matches.len();
let matched = match &expected {
VerificationCount::Exactly(n) => count == *n,
VerificationCount::AtLeast(n) => count >= *n,
VerificationCount::AtMost(n) => count <= *n,
VerificationCount::Never => count == 0,
VerificationCount::AtLeastOnce => count >= 1,
};
if matched {
VerificationResult::success(count, expected, matches)
} else {
let error_message = format!(
"Verification failed: expected {:?}, but found {} matching requests",
expected, count
);
VerificationResult::failure(count, expected, matches, error_message)
}
}
pub async fn verify_never(
logger: &crate::request_logger::CentralizedRequestLogger,
pattern: &VerificationRequest,
) -> VerificationResult {
verify_requests(logger, pattern, VerificationCount::Never).await
}
pub async fn verify_at_least(
logger: &crate::request_logger::CentralizedRequestLogger,
pattern: &VerificationRequest,
min: usize,
) -> VerificationResult {
verify_requests(logger, pattern, VerificationCount::AtLeast(min)).await
}
pub async fn verify_sequence(
logger: &crate::request_logger::CentralizedRequestLogger,
patterns: &[VerificationRequest],
) -> VerificationResult {
let mut logs = logger.get_recent_logs(None).await;
logs.reverse();
let mut log_idx = 0;
let mut all_matches = Vec::new();
for pattern in patterns {
let mut found = false;
while log_idx < logs.len() {
if matches_verification_pattern(&logs[log_idx], pattern) {
all_matches.push(logs[log_idx].clone());
log_idx += 1;
found = true;
break;
}
log_idx += 1;
}
if !found {
let error_message = format!(
"Sequence verification failed: pattern {:?} not found in sequence",
pattern
);
return VerificationResult::failure(
all_matches.len(),
VerificationCount::Exactly(patterns.len()),
all_matches,
error_message,
);
}
}
VerificationResult::success(
all_matches.len(),
VerificationCount::Exactly(patterns.len()),
all_matches,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::request_logger::{create_http_log_entry, CentralizedRequestLogger};
use std::collections::HashMap;
fn create_test_entry(method: &str, path: &str) -> RequestLogEntry {
create_http_log_entry(
method,
path,
200,
100,
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
HashMap::new(),
1024,
None,
)
}
#[tokio::test]
async fn test_verify_exactly() {
let logger = CentralizedRequestLogger::new(100);
logger.log_request(create_test_entry("GET", "/api/users")).await;
logger.log_request(create_test_entry("GET", "/api/users")).await;
logger.log_request(create_test_entry("GET", "/api/users")).await;
let pattern = VerificationRequest {
method: Some("GET".to_string()),
path: Some("/api/users".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
let result = verify_requests(&logger, &pattern, VerificationCount::Exactly(3)).await;
assert!(result.matched);
assert_eq!(result.count, 3);
}
#[tokio::test]
async fn test_verify_at_least() {
let logger = CentralizedRequestLogger::new(100);
logger.log_request(create_test_entry("POST", "/api/orders")).await;
logger.log_request(create_test_entry("POST", "/api/orders")).await;
let pattern = VerificationRequest {
method: Some("POST".to_string()),
path: Some("/api/orders".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
let result = verify_at_least(&logger, &pattern, 2).await;
assert!(result.matched);
assert_eq!(result.count, 2);
let result2 = verify_at_least(&logger, &pattern, 1).await;
assert!(result2.matched);
let result3 = verify_at_least(&logger, &pattern, 3).await;
assert!(!result3.matched);
}
#[tokio::test]
async fn test_verify_never() {
let logger = CentralizedRequestLogger::new(100);
logger.log_request(create_test_entry("GET", "/api/users")).await;
let pattern = VerificationRequest {
method: Some("DELETE".to_string()),
path: Some("/api/users".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
let result = verify_never(&logger, &pattern).await;
assert!(result.matched);
assert_eq!(result.count, 0);
}
#[tokio::test]
async fn test_verify_sequence() {
let logger = CentralizedRequestLogger::new(100);
logger.log_request(create_test_entry("POST", "/api/users")).await;
logger.log_request(create_test_entry("GET", "/api/users/1")).await;
logger.log_request(create_test_entry("PUT", "/api/users/1")).await;
let patterns = vec![
VerificationRequest {
method: Some("POST".to_string()),
path: Some("/api/users".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
},
VerificationRequest {
method: Some("GET".to_string()),
path: Some("/api/users/1".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
},
];
let result = verify_sequence(&logger, &patterns).await;
assert!(result.matched);
assert_eq!(result.count, 2);
}
#[test]
fn test_matches_path_pattern_exact() {
assert!(matches_path_pattern("/api/users", "/api/users"));
assert!(!matches_path_pattern("/api/users", "/api/posts"));
}
#[test]
fn test_matches_path_pattern_wildcard() {
assert!(matches_path_pattern("/api/users/1", "/api/users/*"));
assert!(matches_path_pattern("/api/users/123", "/api/users/*"));
assert!(!matches_path_pattern("/api/users/1/posts", "/api/users/*"));
}
#[test]
fn test_matches_path_pattern_double_wildcard() {
assert!(matches_path_pattern("/api/users/1", "/api/**"));
assert!(matches_path_pattern("/api/users/1/posts", "/api/**"));
assert!(matches_path_pattern("/api/users", "/api/**"));
}
#[test]
fn test_matches_path_pattern_regex() {
assert!(matches_path_pattern("/api/users/123", r"^/api/users/\d+$"));
assert!(!matches_path_pattern("/api/users/abc", r"^/api/users/\d+$"));
}
#[test]
fn test_matches_verification_pattern_method() {
let entry = create_test_entry("GET", "/api/users");
let pattern = VerificationRequest {
method: Some("GET".to_string()),
path: None,
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
assert!(matches_verification_pattern(&entry, &pattern));
let pattern2 = VerificationRequest {
method: Some("POST".to_string()),
path: None,
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
assert!(!matches_verification_pattern(&entry, &pattern2));
}
#[test]
fn test_matches_verification_pattern_path() {
let entry = create_test_entry("GET", "/api/users");
let pattern = VerificationRequest {
method: None,
path: Some("/api/users".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
assert!(matches_verification_pattern(&entry, &pattern));
let pattern2 = VerificationRequest {
method: None,
path: Some("/api/posts".to_string()),
query_params: HashMap::new(),
headers: HashMap::new(),
body_pattern: None,
};
assert!(!matches_verification_pattern(&entry, &pattern2));
}
}