use crate::mock::{Mock, MockCall};
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc};
use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
#[derive(Debug, Clone)]
pub struct DatabaseQuery {
pub query_name: String,
pub params: Vec<String>,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone)]
pub enum DatabaseResult<T> {
Return(T),
Error(String),
NamedReturns(HashMap<String, T>),
}
impl<T: Clone> DatabaseResult<T> {
pub fn return_value(value: T) -> Self {
DatabaseResult::Return(value)
}
pub fn error(msg: impl Into<String>) -> Self {
DatabaseResult::Error(msg.into())
}
pub fn named_returns(map: HashMap<String, T>) -> Self {
DatabaseResult::NamedReturns(map)
}
}
#[derive(Debug, Default)]
pub struct DatabaseExpectation {
pub expected_queries: HashMap<String, usize>,
pub description: Option<String>,
}
impl DatabaseExpectation {
pub fn new() -> Self {
Self::default()
}
pub fn expect_query(mut self, query_name: impl Into<String>, count: usize) -> Self {
self.expected_queries.insert(query_name.into(), count);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
pub struct MockDatabaseBuilder<T> {
result: Option<DatabaseResult<T>>,
expectation: DatabaseExpectation,
queries: Arc<RwLock<Vec<DatabaseQuery>>>,
}
impl<T: Clone + Send + Sync + 'static> MockDatabaseBuilder<T> {
pub fn new() -> Self {
Self { result: None, expectation: DatabaseExpectation::default(), queries: Arc::new(RwLock::new(Vec::new())) }
}
pub fn return_value(mut self, value: T) -> Self {
self.result = Some(DatabaseResult::return_value(value));
self
}
pub fn error(mut self, msg: impl Into<String>) -> Self {
self.result = Some(DatabaseResult::error(msg));
self
}
pub fn named_returns(mut self, map: HashMap<String, T>) -> Self {
self.result = Some(DatabaseResult::named_returns(map));
self
}
pub fn expect(mut self, expectation: DatabaseExpectation) -> Self {
self.expectation = expectation;
self
}
pub fn build(self) -> MockDatabase<T> {
MockDatabase { result: self.result, expectation: self.expectation, queries: self.queries }
}
}
impl<T: Clone + Send + Sync + 'static> Default for MockDatabaseBuilder<T> {
fn default() -> Self {
Self::new()
}
}
pub struct MockDatabase<T> {
result: Option<DatabaseResult<T>>,
expectation: DatabaseExpectation,
queries: Arc<RwLock<Vec<DatabaseQuery>>>,
}
impl<T: Clone> MockDatabase<T> {
pub fn query(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
let query_name = query_name.into();
{
let mut queries = self.queries.write();
queries.push(DatabaseQuery { query_name: query_name.clone(), params, timestamp: std::time::Instant::now() });
}
match &self.result {
Some(DatabaseResult::Return(v)) => Ok(v.clone()),
Some(DatabaseResult::Error(e)) => Err(WaeError::new(WaeErrorKind::MockError { reason: e.clone() })),
Some(DatabaseResult::NamedReturns(map)) => {
if let Some(value) = map.get(&query_name) {
Ok(value.clone())
}
else {
Err(WaeError::new(WaeErrorKind::MockError { reason: format!("No mock result for query: {}", query_name) }))
}
}
None => Err(WaeError::new(WaeErrorKind::MockError { reason: "No mock result configured".to_string() })),
}
}
pub async fn query_async(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
self.query(query_name, params)
}
pub fn queries(&self) -> Vec<DatabaseQuery> {
self.queries.read().clone()
}
pub fn query_count(&self) -> usize {
self.queries.read().len()
}
pub fn query_count_by_name(&self, query_name: &str) -> usize {
self.queries.read().iter().filter(|q| q.query_name == query_name).count()
}
}
impl<T: Clone + Send + Sync + 'static> Mock for MockDatabase<T> {
fn calls(&self) -> Vec<MockCall> {
self.queries
.read()
.iter()
.map(|q| MockCall {
args: vec![q.query_name.clone()].into_iter().chain(q.params.clone()).collect(),
timestamp: q.timestamp,
})
.collect()
}
fn call_count(&self) -> usize {
self.query_count()
}
fn verify(&self) -> TestingResult<()> {
for (query_name, expected) in &self.expectation.expected_queries {
let actual = self.query_count_by_name(query_name);
if actual != *expected {
return Err(WaeError::new(WaeErrorKind::AssertionFailed {
message: format!("Expected {} calls for query '{}', but got {}", expected, query_name, actual),
}));
}
}
Ok(())
}
fn reset(&self) {
let mut queries = self.queries.write();
queries.clear();
}
}