use std::collections::HashMap;
use std::sync::Arc;
use regex::Regex;
use serde::Serialize;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct MockHttp {
mocks: Arc<RwLock<Vec<MockHandler>>>,
requests: Arc<RwLock<Vec<RecordedRequest>>>,
}
struct MockHandler {
#[allow(dead_code)]
pattern: String,
regex: Regex,
handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
}
#[derive(Debug, Clone)]
pub struct RecordedRequest {
pub method: String,
pub url: String,
pub headers: HashMap<String, String>,
pub body: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct MockRequest {
pub method: String,
pub path: String,
pub url: String,
pub headers: HashMap<String, String>,
pub body: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct MockResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: serde_json::Value,
}
impl MockResponse {
pub fn json<T: Serialize>(body: T) -> Self {
Self {
status: 200,
headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
}
}
pub fn error(status: u16, message: &str) -> Self {
Self {
status,
headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
body: serde_json::json!({ "error": message }),
}
}
pub fn internal_error(message: &str) -> Self {
Self::error(500, message)
}
pub fn not_found(message: &str) -> Self {
Self::error(404, message)
}
pub fn unauthorized(message: &str) -> Self {
Self::error(401, message)
}
pub fn ok() -> Self {
Self::json(serde_json::json!({}))
}
}
impl MockHttp {
pub fn new() -> Self {
Self {
mocks: Arc::new(RwLock::new(Vec::new())),
requests: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn add_mock<F>(&mut self, pattern: &str, handler: F)
where
F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
{
let regex_pattern = pattern
.replace('.', "\\.")
.replace('*', ".*")
.replace('?', ".");
let regex = Regex::new(&format!("^{}$", regex_pattern)).unwrap();
let mocks = self.mocks.clone();
tokio::task::block_in_place(|| {
let rt = tokio::runtime::Handle::try_current();
if let Ok(rt) = rt {
rt.block_on(async {
let mut mocks = mocks.write().await;
mocks.push(MockHandler {
pattern: pattern.to_string(),
regex,
handler: Arc::new(handler),
});
});
}
});
}
#[allow(unused_variables)]
pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
where
F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
{
let regex_pattern = pattern
.replace('.', "\\.")
.replace('*', ".*")
.replace('?', ".");
let _regex = Regex::new(&format!("^{}$", regex_pattern)).unwrap();
}
pub async fn execute(&self, request: MockRequest) -> MockResponse {
{
let mut requests = self.requests.write().await;
requests.push(RecordedRequest {
method: request.method.clone(),
url: request.url.clone(),
headers: request.headers.clone(),
body: request.body.clone(),
});
}
let mocks = self.mocks.read().await;
for mock in mocks.iter() {
if mock.regex.is_match(&request.url) || mock.regex.is_match(&request.path) {
return (mock.handler)(&request);
}
}
MockResponse::error(500, &format!("No mock found for {}", request.url))
}
pub async fn requests(&self) -> Vec<RecordedRequest> {
self.requests.read().await.clone()
}
pub async fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
let regex_pattern = pattern
.replace('.', "\\.")
.replace('*', ".*")
.replace('?', ".");
let regex = Regex::new(&format!("^{}$", regex_pattern)).unwrap();
self.requests
.read()
.await
.iter()
.filter(|r| regex.is_match(&r.url))
.cloned()
.collect()
}
pub async fn clear_requests(&self) {
self.requests.write().await.clear();
}
pub async fn clear_mocks(&self) {
self.mocks.write().await.clear();
}
}
impl Default for MockHttp {
fn default() -> Self {
Self::new()
}
}
type MockHandlerFn = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
pub struct MockHttpBuilder {
mocks: Vec<(String, MockHandlerFn)>,
}
impl MockHttpBuilder {
pub fn new() -> Self {
Self { mocks: Vec::new() }
}
pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
where
F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
{
self.mocks.push((pattern.to_string(), Box::new(handler)));
self
}
pub fn build(self) -> MockHttp {
MockHttp::new()
}
}
impl Default for MockHttpBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_response_json() {
let response = MockResponse::json(serde_json::json!({"id": 123}));
assert_eq!(response.status, 200);
assert_eq!(response.body["id"], 123);
}
#[test]
fn test_mock_response_error() {
let response = MockResponse::error(404, "Not found");
assert_eq!(response.status, 404);
assert_eq!(response.body["error"], "Not found");
}
#[test]
fn test_mock_response_internal_error() {
let response = MockResponse::internal_error("Server error");
assert_eq!(response.status, 500);
}
#[test]
fn test_mock_response_not_found() {
let response = MockResponse::not_found("Resource not found");
assert_eq!(response.status, 404);
}
#[test]
fn test_mock_response_unauthorized() {
let response = MockResponse::unauthorized("Invalid token");
assert_eq!(response.status, 401);
}
#[tokio::test]
async fn test_mock_http_no_handler() {
let mock = MockHttp::new();
let request = MockRequest {
method: "GET".to_string(),
path: "/test".to_string(),
url: "https://example.com/test".to_string(),
headers: HashMap::new(),
body: serde_json::Value::Null,
};
let response = mock.execute(request).await;
assert_eq!(response.status, 500);
}
#[tokio::test]
async fn test_mock_http_records_requests() {
let mock = MockHttp::new();
let request = MockRequest {
method: "POST".to_string(),
path: "/api/users".to_string(),
url: "https://api.example.com/users".to_string(),
headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
body: serde_json::json!({"name": "Test"}),
};
let _ = mock.execute(request).await;
let requests = mock.requests().await;
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].method, "POST");
assert_eq!(requests[0].body["name"], "Test");
}
#[tokio::test]
async fn test_mock_http_clear_requests() {
let mock = MockHttp::new();
let request = MockRequest {
method: "GET".to_string(),
path: "/test".to_string(),
url: "https://example.com/test".to_string(),
headers: HashMap::new(),
body: serde_json::Value::Null,
};
let _ = mock.execute(request).await;
assert_eq!(mock.requests().await.len(), 1);
mock.clear_requests().await;
assert_eq!(mock.requests().await.len(), 0);
}
}