use parking_lot::RwLock;
use std::sync::Arc;
use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
#[cfg(feature = "mockall")]
pub use mockall::{self, __mock_MockObject, __mock_MockStatic, automock, mock, predicate, sequence};
#[derive(Debug, Clone)]
pub struct MockCall {
pub args: Vec<String>,
pub timestamp: std::time::Instant,
}
#[derive(Debug)]
pub enum MockResult<T> {
Return(T),
Error(String),
Sequence(Vec<T>),
}
impl<T: Clone> MockResult<T> {
pub fn return_value(value: T) -> Self {
MockResult::Return(value)
}
pub fn error(msg: impl Into<String>) -> Self {
MockResult::Error(msg.into())
}
pub fn sequence(values: Vec<T>) -> Self {
MockResult::Sequence(values)
}
}
impl<T: Clone> Clone for MockResult<T> {
fn clone(&self) -> Self {
match self {
MockResult::Return(v) => MockResult::Return(v.clone()),
MockResult::Error(e) => MockResult::Error(e.clone()),
MockResult::Sequence(v) => MockResult::Sequence(v.clone()),
}
}
}
#[derive(Debug, Default)]
pub struct MockExpectation {
pub expected_calls: Option<usize>,
pub description: Option<String>,
}
impl MockExpectation {
pub fn new() -> Self {
Self::default()
}
pub fn times(mut self, count: usize) -> Self {
self.expected_calls = Some(count);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
pub trait Mock: Send + Sync {
fn calls(&self) -> Vec<MockCall>;
fn call_count(&self) -> usize;
fn verify(&self) -> TestingResult<()>;
fn reset(&self);
}
#[allow(async_fn_in_trait)]
pub trait AsyncMock: Mock {
async fn verify_async(&self) -> TestingResult<()>;
}
pub struct MockBuilder<T> {
result: Option<MockResult<T>>,
expectation: MockExpectation,
calls: Arc<RwLock<Vec<MockCall>>>,
}
impl<T: Clone + Send + Sync + 'static> MockBuilder<T> {
pub fn new() -> Self {
Self { result: None, expectation: MockExpectation::default(), calls: Arc::new(RwLock::new(Vec::new())) }
}
pub fn return_value(mut self, value: T) -> Self {
self.result = Some(MockResult::return_value(value));
self
}
pub fn error(mut self, msg: impl Into<String>) -> Self {
self.result = Some(MockResult::error(msg));
self
}
pub fn sequence(mut self, values: Vec<T>) -> Self {
self.result = Some(MockResult::sequence(values));
self
}
pub fn expect(mut self, expectation: MockExpectation) -> Self {
self.expectation = expectation;
self
}
pub fn build(self) -> MockFn<T> {
MockFn {
result: self.result,
expectation: self.expectation,
calls: self.calls,
sequence_index: Arc::new(RwLock::new(0)),
}
}
}
impl<T: Clone + Send + Sync + 'static> Default for MockBuilder<T> {
fn default() -> Self {
Self::new()
}
}
pub struct MockFn<T> {
result: Option<MockResult<T>>,
expectation: MockExpectation,
calls: Arc<RwLock<Vec<MockCall>>>,
sequence_index: Arc<RwLock<usize>>,
}
impl<T: Clone + Send + Sync + 'static> MockFn<T> {
pub fn call(&self, args: Vec<String>) -> TestingResult<T> {
{
let mut calls = self.calls.write();
calls.push(MockCall { args, timestamp: std::time::Instant::now() });
}
match &self.result {
Some(MockResult::Return(v)) => Ok(v.clone()),
Some(MockResult::Error(e)) => Err(WaeError::new(WaeErrorKind::MockError { reason: e.clone() })),
Some(MockResult::Sequence(values)) => {
let mut idx = self.sequence_index.write();
if *idx < values.len() {
let value = values[*idx].clone();
*idx += 1;
Ok(value)
}
else {
Err(WaeError::new(WaeErrorKind::MockError { reason: "Mock sequence exhausted".to_string() }))
}
}
None => Err(WaeError::new(WaeErrorKind::MockError { reason: "No mock result configured".to_string() })),
}
}
pub async fn call_async(&self, args: Vec<String>) -> TestingResult<T> {
self.call(args)
}
}
impl<T: Clone + Send + Sync + 'static> Mock for MockFn<T> {
fn calls(&self) -> Vec<MockCall> {
self.calls.read().clone()
}
fn call_count(&self) -> usize {
self.calls.read().len()
}
fn verify(&self) -> TestingResult<()> {
let actual_calls = self.call_count();
if let Some(expected) = self.expectation.expected_calls
&& actual_calls != expected
{
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: format!("Expected {} calls, but got {}", expected, actual_calls),
}));
}
Ok(())
}
fn reset(&self) {
let mut calls = self.calls.write();
calls.clear();
let mut idx = self.sequence_index.write();
*idx = 0;
}
}
impl<T: Clone + Send + Sync + 'static> AsyncMock for MockFn<T> {
async fn verify_async(&self) -> TestingResult<()> {
self.verify()
}
}
pub fn verify<M: Mock>(mock: &M) -> TestingResult<()> {
mock.verify()
}
pub async fn verify_async<M: AsyncMock>(mock: &M) -> TestingResult<()> {
mock.verify_async().await
}