use crate::responder::{pending, MapResponder, ResponseStream, StreamResponse};
use crate::*;
use axum::http::header::HeaderMap;
use futures::stream::Stream;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use tokio::sync::RwLock;
pub(crate) type MockList = Arc<RwLock<Vec<Mock>>>;
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub(crate) enum MatchStatus {
Mismatch,
Potential,
Partial,
Full,
}
#[derive(Clone)]
pub struct Mock {
matcher: Vec<Arc<dyn Match + Send + Sync + 'static>>,
pub(crate) responder: Arc<dyn ResponseStream + Send + Sync + 'static>,
pub(crate) expected_calls: Arc<Times>,
pub(crate) calls: Arc<AtomicU64>,
pub(crate) name: Option<String>,
pub(crate) priority: u8,
}
impl Mock {
pub fn given(matcher: impl Match + Send + Sync + 'static) -> MockBuilder {
MockBuilder {
matcher: vec![Arc::new(matcher)],
responder: Arc::new(pending()),
name: None,
priority: 5,
}
}
pub fn verify(&self) -> bool {
let calls = self.calls.load(Ordering::SeqCst);
debug!("mock hit over {} calls", calls);
self.expected_calls.contains(calls)
}
pub(crate) fn check_request(
&self,
path: &str,
headers: &HeaderMap,
params: &HashMap<String, String>,
) -> (MatchStatus, u64) {
let values = self
.matcher
.iter()
.map(|x| x.request_match(path, headers, params))
.collect::<Vec<Option<bool>>>();
self.check_matcher_responses(&values)
}
pub(crate) fn check_message(&self, state: &mut MatchState) -> (MatchStatus, u64) {
let values = self
.matcher
.iter()
.map(|x| x.temporal_match(state))
.collect::<Vec<Option<bool>>>();
self.check_matcher_responses(&values)
}
pub(crate) fn check_matcher_responses(&self, values: &[Option<bool>]) -> (MatchStatus, u64) {
if values.iter().copied().all(can_consider) {
let contains_true = values.contains(&Some(true));
let contains_none = values.contains(&None);
if contains_true {
let mut current_mask = 0u64;
for (i, _val) in values.iter().enumerate().filter(|(_, i)| **i == Some(true)) {
current_mask |= 1 << i as u64;
}
if !contains_none {
(MatchStatus::Full, current_mask)
} else {
(MatchStatus::Partial, current_mask)
}
} else {
(MatchStatus::Potential, 0)
}
} else {
(MatchStatus::Mismatch, 0)
}
}
pub(crate) fn expected_mask(&self) -> u64 {
u64::MAX >> (64 - self.matcher.len() as u64)
}
pub(crate) fn register_hit(&self) {
self.calls.fetch_add(1, Ordering::Acquire);
}
}
pub struct MockBuilder {
matcher: Vec<Arc<dyn Match + Send + Sync + 'static>>,
pub(crate) responder: Arc<dyn ResponseStream + Send + Sync + 'static>,
pub(crate) name: Option<String>,
pub(crate) priority: u8,
}
impl MockBuilder {
pub fn named<T: Into<String>>(mut self, mock_name: T) -> Self {
self.name = Some(mock_name.into());
self
}
pub fn expect(self, times: impl Into<Times>) -> Mock {
Mock {
matcher: self.matcher,
responder: self.responder,
expected_calls: Arc::new(times.into()),
calls: Default::default(),
name: self.name,
priority: self.priority,
}
}
pub fn add_matcher(mut self, matcher: impl Match + Send + Sync + 'static) -> Self {
assert!(self.matcher.len() < 65, "Cannot have more than 65 matchers");
self.matcher.push(Arc::new(matcher));
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
assert!(priority > 0, "priority must be strictly greater than 0!");
self.priority = priority;
self
}
pub fn set_responder(mut self, responder: impl ResponseStream + Send + Sync + 'static) -> Self {
self.responder = Arc::new(responder);
self
}
pub fn response_stream<F, S>(mut self, ctor: F) -> Self
where
F: Fn() -> S + Send + Sync + 'static,
S: Stream<Item = Message> + Send + Sync + 'static,
{
self.responder = Arc::new(StreamResponse::new(ctor));
self
}
pub fn response_map<F>(mut self, map_fn: F) -> Self
where
F: Fn(Message) -> Message + Send + Sync + 'static,
{
self.responder = Arc::new(MapResponder::new(map_fn));
self
}
}