use parking_lot::RwLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
impl HttpResponse {
pub fn new(status: u16, body: impl Into<Vec<u8>>) -> Self {
Self {
status,
headers: HashMap::new(),
body: body.into(),
}
}
pub fn json(status: u16, value: &serde_json::Value) -> Self {
let body = serde_json::to_vec(value).expect("Failed to serialize JSON");
let mut headers = HashMap::new();
headers.insert("content-type".to_string(), "application/json".to_string());
Self {
status,
headers,
body,
}
}
pub fn body_string(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn body_json(&self) -> Result<serde_json::Value, serde_json::Error> {
serde_json::from_slice(&self.body)
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpRequest {
pub method: String,
pub url: String,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum HttpError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Request timed out")]
Timeout,
#[error("No mock rule matched for {method} {url}")]
NoMockMatch { method: String, url: String },
#[error("{0}")]
Other(String),
}
pub trait HttpProvider: Send + Sync {
fn request(
&self,
method: &str,
url: &str,
headers: HashMap<String, String>,
body: Option<Vec<u8>>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
>;
fn is_mock(&self) -> bool;
}
#[derive(Debug, Clone)]
pub struct RealHttp {
timeout: Duration,
}
impl RealHttp {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(30),
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self { timeout }
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}
impl Default for RealHttp {
fn default() -> Self {
Self::new()
}
}
impl HttpProvider for RealHttp {
fn request(
&self,
_method: &str,
_url: &str,
_headers: HashMap<String, String>,
_body: Option<Vec<u8>>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
> {
let timeout = self.timeout;
Box::pin(async move {
let _ = timeout; Err(HttpError::Other(
"Real HTTP not implemented - use MockHttp for testing".to_string(),
))
})
}
fn is_mock(&self) -> bool {
false
}
}
#[derive(Clone)]
pub struct MockHttpRule {
pub method: Option<String>,
pub url_pattern: Regex,
pub response: HttpResponse,
pub latency: Option<Duration>,
pub times: Option<usize>,
matched_count: usize,
}
impl MockHttpRule {
pub fn new(url_pattern: &str, response: HttpResponse) -> Self {
Self {
method: None,
url_pattern: Regex::new(url_pattern).expect("Invalid URL regex pattern"),
response,
latency: None,
times: None,
matched_count: 0,
}
}
pub fn with_method(mut self, method: &str) -> Self {
self.method = Some(method.to_uppercase());
self
}
pub fn with_latency(mut self, latency: Duration) -> Self {
self.latency = Some(latency);
self
}
pub fn times(mut self, n: usize) -> Self {
self.times = Some(n);
self
}
fn matches(&self, method: &str, url: &str) -> bool {
if let Some(ref expected_method) = self.method {
if expected_method != method.to_uppercase().as_str() {
return false;
}
}
if let Some(limit) = self.times {
if self.matched_count >= limit {
return false;
}
}
self.url_pattern.is_match(url)
}
}
impl std::fmt::Debug for MockHttpRule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockHttpRule")
.field("method", &self.method)
.field("url_pattern", &self.url_pattern.as_str())
.field("response_status", &self.response.status)
.field("latency", &self.latency)
.field("times", &self.times)
.finish()
}
}
pub struct MockHttp {
rules: RwLock<Vec<MockHttpRule>>,
requests: RwLock<Vec<HttpRequest>>,
fail_on_unmatched: bool,
}
impl MockHttp {
pub fn new() -> Self {
Self {
rules: RwLock::new(Vec::new()),
requests: RwLock::new(Vec::new()),
fail_on_unmatched: true,
}
}
pub fn rule(self, rule: MockHttpRule) -> Self {
self.rules.write().push(rule);
self
}
pub fn fail_on_unmatched(mut self, fail: bool) -> Self {
self.fail_on_unmatched = fail;
self
}
pub fn on_get(self, url_pattern: &str) -> MockHttpBuilder {
MockHttpBuilder {
mock: self,
method: Some("GET".to_string()),
url_pattern: url_pattern.to_string(),
latency: None,
times: None,
}
}
pub fn on_post(self, url_pattern: &str) -> MockHttpBuilder {
MockHttpBuilder {
mock: self,
method: Some("POST".to_string()),
url_pattern: url_pattern.to_string(),
latency: None,
times: None,
}
}
pub fn on_put(self, url_pattern: &str) -> MockHttpBuilder {
MockHttpBuilder {
mock: self,
method: Some("PUT".to_string()),
url_pattern: url_pattern.to_string(),
latency: None,
times: None,
}
}
pub fn on_delete(self, url_pattern: &str) -> MockHttpBuilder {
MockHttpBuilder {
mock: self,
method: Some("DELETE".to_string()),
url_pattern: url_pattern.to_string(),
latency: None,
times: None,
}
}
pub fn on_any(self, url_pattern: &str) -> MockHttpBuilder {
MockHttpBuilder {
mock: self,
method: None,
url_pattern: url_pattern.to_string(),
latency: None,
times: None,
}
}
pub fn requests(&self) -> Vec<HttpRequest> {
self.requests.read().clone()
}
pub fn clear_requests(&self) {
self.requests.write().clear();
}
pub fn assert_request_made(&self, method: &str, url_pattern: &str) -> bool {
let re = Regex::new(url_pattern).expect("Invalid URL pattern");
let requests = self.requests.read();
requests
.iter()
.any(|r| r.method.eq_ignore_ascii_case(method) && re.is_match(&r.url))
}
pub fn request_count(&self) -> usize {
self.requests.read().len()
}
}
impl Default for MockHttp {
fn default() -> Self {
Self::new()
}
}
impl HttpProvider for MockHttp {
fn request(
&self,
method: &str,
url: &str,
headers: HashMap<String, String>,
body: Option<Vec<u8>>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<HttpResponse, HttpError>> + Send + '_>,
> {
self.requests.write().push(HttpRequest {
method: method.to_string(),
url: url.to_string(),
headers: headers.clone(),
body: body.clone().unwrap_or_default(),
});
let mut rules = self.rules.write();
let matched = rules.iter_mut().find(|rule| rule.matches(method, url));
match matched {
Some(rule) => {
rule.matched_count += 1;
let response = rule.response.clone();
let latency = rule.latency;
Box::pin(async move {
if let Some(delay) = latency {
tokio::time::sleep(delay).await;
}
Ok(response)
})
}
None => {
if self.fail_on_unmatched {
let method = method.to_string();
let url = url.to_string();
Box::pin(async move { Err(HttpError::NoMockMatch { method, url }) })
} else {
Box::pin(async move { Ok(HttpResponse::new(404, b"Not Found".to_vec())) })
}
}
}
}
fn is_mock(&self) -> bool {
true
}
}
pub struct MockHttpBuilder {
mock: MockHttp,
method: Option<String>,
url_pattern: String,
latency: Option<Duration>,
times: Option<usize>,
}
impl MockHttpBuilder {
pub fn with_latency(mut self, latency: Duration) -> Self {
self.latency = latency.into();
self
}
pub fn times(mut self, n: usize) -> Self {
self.times = Some(n);
self
}
pub fn respond(self, response: HttpResponse) -> MockHttp {
let mut rule = MockHttpRule::new(&self.url_pattern, response);
rule.method = self.method;
rule.latency = self.latency;
rule.times = self.times;
self.mock.rule(rule)
}
pub fn respond_json(self, status: u16, value: serde_json::Value) -> MockHttp {
self.respond(HttpResponse::json(status, &value))
}
pub fn respond_text(self, status: u16, text: &str) -> MockHttp {
let mut response = HttpResponse::new(status, text.as_bytes().to_vec());
response
.headers
.insert("content-type".to_string(), "text/plain".to_string());
self.respond(response)
}
pub fn respond_error(self, status: u16, message: &str) -> MockHttp {
self.respond_json(status, serde_json::json!({"error": message}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn mock_http_matches_get() {
let mock = MockHttp::new()
.on_get(r"^https://api\.example\.com/users/\d+$")
.respond_json(200, json!({"name": "Alice"}));
let response = mock
.request(
"GET",
"https://api.example.com/users/123",
HashMap::new(),
None,
)
.await
.unwrap();
assert_eq!(response.status, 200);
let body: serde_json::Value = response.body_json().unwrap();
assert_eq!(body["name"], "Alice");
}
#[tokio::test]
async fn mock_http_matches_post() {
let mock = MockHttp::new()
.on_post(r"^https://api\.example\.com/users$")
.respond_json(201, json!({"id": 42}));
let response = mock
.request(
"POST",
"https://api.example.com/users",
HashMap::new(),
Some(b"{}".to_vec()),
)
.await
.unwrap();
assert_eq!(response.status, 201);
}
#[tokio::test]
async fn mock_http_fails_on_unmatched() {
let mock = MockHttp::new()
.on_get(r"^https://api\.example\.com/users$")
.respond_json(200, json!([]));
let result = mock
.request("GET", "https://api.example.com/other", HashMap::new(), None)
.await;
assert!(matches!(result, Err(HttpError::NoMockMatch { .. })));
}
#[tokio::test]
async fn mock_http_records_requests() {
let mock = MockHttp::new().on_get(r".*").respond_json(200, json!({}));
mock.request("GET", "https://example.com/a", HashMap::new(), None)
.await
.unwrap();
mock.request("GET", "https://example.com/b", HashMap::new(), None)
.await
.unwrap();
assert_eq!(mock.request_count(), 2);
assert!(mock.assert_request_made("GET", r"example\.com/a"));
assert!(mock.assert_request_made("GET", r"example\.com/b"));
}
#[tokio::test]
async fn mock_http_times_limit() {
let mock = MockHttp::new()
.on_get(r"^https://api\.example\.com/users$")
.times(2)
.respond_json(200, json!([]));
mock.request("GET", "https://api.example.com/users", HashMap::new(), None)
.await
.unwrap();
mock.request("GET", "https://api.example.com/users", HashMap::new(), None)
.await
.unwrap();
let result = mock
.request("GET", "https://api.example.com/users", HashMap::new(), None)
.await;
assert!(matches!(result, Err(HttpError::NoMockMatch { .. })));
}
}