use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use crate::entity::SqlEntity;
use crate::error::SqlError;
use crate::params::{OdbcParam, PkValue};
use crate::repository::SqlRepository;
use crate::row::OdbcRow;
pub struct MockRepository<T: SqlEntity + Clone> {
store: Arc<Mutex<HashMap<String, T>>>,
next_id: Arc<AtomicI64>,
insert_calls: Arc<AtomicI64>,
update_calls: Arc<AtomicI64>,
delete_calls: Arc<AtomicI64>,
upsert_calls: Arc<AtomicI64>,
}
impl<T: SqlEntity + Clone> Default for MockRepository<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: SqlEntity + Clone> MockRepository<T> {
pub fn new() -> Self {
Self {
store: Arc::new(Mutex::new(HashMap::new())),
next_id: Arc::new(AtomicI64::new(1)),
insert_calls: Arc::new(AtomicI64::new(0)),
update_calls: Arc::new(AtomicI64::new(0)),
delete_calls: Arc::new(AtomicI64::new(0)),
upsert_calls: Arc::new(AtomicI64::new(0)),
}
}
pub fn with_data(items: impl IntoIterator<Item = T>) -> Self {
let map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.pk_value().to_string(), item))
.collect();
Self {
store: Arc::new(Mutex::new(map)),
next_id: Arc::new(AtomicI64::new(1)),
insert_calls: Arc::new(AtomicI64::new(0)),
update_calls: Arc::new(AtomicI64::new(0)),
delete_calls: Arc::new(AtomicI64::new(0)),
upsert_calls: Arc::new(AtomicI64::new(0)),
}
}
pub fn insert_call_count(&self) -> i64 {
self.insert_calls.load(Ordering::Relaxed)
}
pub fn update_call_count(&self) -> i64 {
self.update_calls.load(Ordering::Relaxed)
}
pub fn delete_call_count(&self) -> i64 {
self.delete_calls.load(Ordering::Relaxed)
}
pub fn upsert_call_count(&self) -> i64 {
self.upsert_calls.load(Ordering::Relaxed)
}
pub async fn len(&self) -> usize {
self.store.lock().await.len()
}
pub async fn all_items(&self) -> Vec<T> {
self.store.lock().await.values().cloned().collect()
}
pub async fn seed(&self, item: T) {
self.store.lock().await.insert(item.pk_value().to_string(), item);
}
pub async fn clear(&self) {
self.store.lock().await.clear();
}
}
#[async_trait]
impl<T: SqlEntity + Clone + Send + Sync + 'static> SqlRepository<T> for MockRepository<T> {
async fn get_by_id(
&self,
id: impl Into<PkValue> + Send,
_token: &CancellationToken,
) -> Result<Option<T>, SqlError> {
let key = id.into().to_string();
Ok(self.store.lock().await.get(&key).cloned())
}
async fn get_all(&self, _token: &CancellationToken) -> Result<Vec<T>, SqlError> {
Ok(self.store.lock().await.values().cloned().collect())
}
async fn get_where(
&self,
_filter: &str,
_params: &[OdbcParam],
_token: &CancellationToken,
) -> Result<Vec<T>, SqlError> {
Ok(self.store.lock().await.values().cloned().collect())
}
async fn get_paged(
&self,
page: i64,
page_size: i64,
_token: &CancellationToken,
) -> Result<Vec<T>, SqlError> {
let all: Vec<T> = self.store.lock().await.values().cloned().collect();
let page = page.max(1) as usize;
let size = page_size.max(1) as usize;
let start = (page - 1) * size;
Ok(all.into_iter().skip(start).take(size).collect())
}
async fn count(&self, _token: &CancellationToken) -> Result<i64, SqlError> {
Ok(self.store.lock().await.len() as i64)
}
async fn exists(
&self,
id: impl Into<PkValue> + Send,
_token: &CancellationToken,
) -> Result<bool, SqlError> {
let key = id.into().to_string();
Ok(self.store.lock().await.contains_key(&key))
}
async fn insert(&self, entity: &T, _token: &CancellationToken) -> Result<i64, SqlError> {
self.insert_calls.fetch_add(1, Ordering::Relaxed);
let id = if T::PK_IS_IDENTITY {
self.next_id.fetch_add(1, Ordering::Relaxed)
} else {
match entity.pk_value() {
PkValue::I32(v) => v as i64,
PkValue::I64(v) => v,
_ => 0,
}
};
self.store.lock().await.insert(entity.pk_value().to_string(), entity.clone());
Ok(id)
}
async fn update(&self, entity: &T, _token: &CancellationToken) -> Result<(), SqlError> {
self.update_calls.fetch_add(1, Ordering::Relaxed);
let key = entity.pk_value().to_string();
let mut store = self.store.lock().await;
if store.contains_key(&key) {
store.insert(key, entity.clone());
Ok(())
} else {
Err(SqlError::NotFound { table: T::TABLE_NAME, pk: key })
}
}
async fn delete(
&self,
id: impl Into<PkValue> + Send,
_token: &CancellationToken,
) -> Result<(), SqlError> {
self.delete_calls.fetch_add(1, Ordering::Relaxed);
let key = id.into().to_string();
self.store.lock().await.remove(&key);
Ok(())
}
async fn upsert(&self, entity: &T, _token: &CancellationToken) -> Result<(), SqlError> {
self.upsert_calls.fetch_add(1, Ordering::Relaxed);
self.store.lock().await.insert(entity.pk_value().to_string(), entity.clone());
Ok(())
}
async fn batch_insert(
&self,
entities: &[T],
token: &CancellationToken,
) -> Result<Vec<i64>, SqlError> {
let mut ids = Vec::with_capacity(entities.len());
for e in entities {
ids.push(self.insert(e, token).await?);
}
Ok(ids)
}
async fn batch_update(
&self,
entities: &[T],
token: &CancellationToken,
) -> Result<(), SqlError> {
for e in entities {
self.update(e, token).await?;
}
Ok(())
}
async fn batch_delete(
&self,
ids: &[PkValue],
token: &CancellationToken,
) -> Result<(), SqlError> {
for id in ids {
self.delete(id.clone(), token).await?;
}
Ok(())
}
async fn query_raw(
&self,
_sql: &str,
_params: &[OdbcParam],
_token: &CancellationToken,
) -> Result<Vec<OdbcRow>, SqlError> {
Ok(Vec::new())
}
async fn execute_raw(
&self,
_sql: &str,
_params: &[OdbcParam],
_token: &CancellationToken,
) -> Result<usize, SqlError> {
Ok(0)
}
async fn scalar<S: TryFrom<String> + Send>(
&self,
_sql: &str,
_params: &[OdbcParam],
_token: &CancellationToken,
) -> Result<S, SqlError>
where
<S as TryFrom<String>>::Error: std::fmt::Display,
{
Err(SqlError::config(
"MockRepository::scalar is not supported — use a custom mock for scalar queries",
))
}
}