use crate::mock::{Mock, MockCall};
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc};
use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
#[derive(Debug, Clone)]
pub struct ServiceRequest {
pub method: Method,
pub path: String,
pub headers: HeaderMap,
pub body: Option<Bytes>,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct ServiceResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
}
impl ServiceResponse {
pub fn ok(body: impl Into<Bytes>) -> Self {
Self { status: StatusCode::OK, headers: HeaderMap::new(), body: body.into() }
}
pub fn error(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self { status, headers: HeaderMap::new(), body: body.into() }
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
K: TryInto<http::HeaderName>,
V: TryInto<http::HeaderValue>,
{
if let (Ok(key), Ok(value)) = (key.try_into(), value.try_into()) {
self.headers.append(key, value);
}
self
}
}
#[derive(Clone)]
pub enum ServiceMatchRule {
Path(String),
PathAndMethod(String, Method),
Custom(Arc<dyn Fn(&ServiceRequest) -> bool + Send + Sync>),
}
impl std::fmt::Debug for ServiceMatchRule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceMatchRule::Path(path) => f.debug_tuple("Path").field(path).finish(),
ServiceMatchRule::PathAndMethod(path, method) => f.debug_tuple("PathAndMethod").field(path).field(method).finish(),
ServiceMatchRule::Custom(_) => f.debug_tuple("Custom").field(&"<function>").finish(),
}
}
}
#[derive(Debug, Clone)]
pub struct ServiceResponseConfig {
pub match_rule: ServiceMatchRule,
pub response: ServiceResponse,
}
#[derive(Debug, Default)]
pub struct ServiceExpectation {
pub expected_requests: HashMap<String, usize>,
pub description: Option<String>,
}
impl ServiceExpectation {
pub fn new() -> Self {
Self::default()
}
pub fn expect_request(mut self, path: impl Into<String>, count: usize) -> Self {
self.expected_requests.insert(path.into(), count);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
pub struct MockExternalServiceBuilder {
responses: Vec<ServiceResponseConfig>,
expectation: ServiceExpectation,
requests: Arc<RwLock<Vec<ServiceRequest>>>,
}
impl MockExternalServiceBuilder {
pub fn new() -> Self {
Self { responses: Vec::new(), expectation: ServiceExpectation::default(), requests: Arc::new(RwLock::new(Vec::new())) }
}
pub fn respond_to_path(mut self, path: impl Into<String>, response: ServiceResponse) -> Self {
self.responses.push(ServiceResponseConfig { match_rule: ServiceMatchRule::Path(path.into()), response });
self
}
pub fn respond_to_path_and_method(mut self, path: impl Into<String>, method: Method, response: ServiceResponse) -> Self {
self.responses
.push(ServiceResponseConfig { match_rule: ServiceMatchRule::PathAndMethod(path.into(), method), response });
self
}
pub fn respond_with<F>(mut self, matcher: F, response: ServiceResponse) -> Self
where
F: Fn(&ServiceRequest) -> bool + Send + Sync + 'static,
{
self.responses.push(ServiceResponseConfig { match_rule: ServiceMatchRule::Custom(Arc::new(matcher)), response });
self
}
pub fn expect(mut self, expectation: ServiceExpectation) -> Self {
self.expectation = expectation;
self
}
pub fn build(self) -> MockExternalService {
MockExternalService { responses: self.responses, expectation: self.expectation, requests: self.requests }
}
}
impl Default for MockExternalServiceBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct MockExternalService {
responses: Vec<ServiceResponseConfig>,
expectation: ServiceExpectation,
requests: Arc<RwLock<Vec<ServiceRequest>>>,
}
impl MockExternalService {
pub fn handle_request(
&self,
method: Method,
path: String,
headers: HeaderMap,
body: Option<Bytes>,
) -> TestingResult<ServiceResponse> {
let request = ServiceRequest {
method: method.clone(),
path: path.clone(),
headers: headers.clone(),
body: body.clone(),
timestamp: std::time::Instant::now(),
};
{
let mut requests = self.requests.write();
requests.push(request.clone());
}
for config in &self.responses {
if Self::matches(&request, &config.match_rule) {
return Ok(config.response.clone());
}
}
Err(WaeError::new(WaeErrorKind::MockError { reason: format!("No mock response configured for {} {}", method, path) }))
}
pub async fn handle_request_async(
&self,
method: Method,
path: String,
headers: HeaderMap,
body: Option<Bytes>,
) -> TestingResult<ServiceResponse> {
self.handle_request(method, path, headers, body)
}
fn matches(request: &ServiceRequest, rule: &ServiceMatchRule) -> bool {
match rule {
ServiceMatchRule::Path(path) => request.path == *path,
ServiceMatchRule::PathAndMethod(path, method) => request.path == *path && request.method == *method,
ServiceMatchRule::Custom(matcher) => matcher(request),
}
}
pub fn requests(&self) -> Vec<ServiceRequest> {
self.requests.read().clone()
}
pub fn request_count(&self) -> usize {
self.requests.read().len()
}
pub fn request_count_by_path(&self, path: &str) -> usize {
self.requests.read().iter().filter(|r| r.path == path).count()
}
}
impl Mock for MockExternalService {
fn calls(&self) -> Vec<MockCall> {
self.requests
.read()
.iter()
.map(|r| MockCall { args: vec![r.method.to_string(), r.path.clone()], timestamp: r.timestamp })
.collect()
}
fn call_count(&self) -> usize {
self.request_count()
}
fn verify(&self) -> TestingResult<()> {
for (path, expected) in &self.expectation.expected_requests {
let actual = self.request_count_by_path(path);
if actual != *expected {
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: format!("Expected {} requests for path '{}', but got {}", expected, path, actual),
}));
}
}
Ok(())
}
fn reset(&self) {
let mut requests = self.requests.write();
requests.clear();
}
}