use std::collections::HashMap;
use std::future::Future;
use std::hash::Hash;
use super::error::RepositoryError;
use super::pagination::{FilterCondition, OrderDirection, Pagination};
pub type RepositoryResult<T> = std::result::Result<T, RepositoryError>;
pub trait Repository<Id, Entity, Create, Update>: Send + Sync {
fn find_by_id(&self, id: &Id) -> impl Future<Output = RepositoryResult<Option<Entity>>> + Send;
fn find_all(
&self,
filters: &[FilterCondition],
order_by: Option<(&str, OrderDirection)>,
pagination: Option<Pagination>,
) -> impl Future<Output = RepositoryResult<Vec<Entity>>> + Send;
fn count(
&self,
filters: &[FilterCondition],
) -> impl Future<Output = RepositoryResult<u64>> + Send;
fn exists(&self, id: &Id) -> impl Future<Output = RepositoryResult<bool>> + Send;
fn create(&self, data: Create) -> impl Future<Output = RepositoryResult<Entity>> + Send;
fn update(
&self,
id: &Id,
data: Update,
) -> impl Future<Output = RepositoryResult<Entity>> + Send;
fn delete(&self, id: &Id) -> impl Future<Output = RepositoryResult<bool>> + Send;
}
pub trait SoftDeleteRepository<Id, Entity, Create, Update>:
Repository<Id, Entity, Create, Update>
{
fn soft_delete(&self, id: &Id) -> impl Future<Output = RepositoryResult<bool>> + Send;
fn restore(&self, id: &Id) -> impl Future<Output = RepositoryResult<bool>> + Send;
fn find_with_deleted(
&self,
filters: &[FilterCondition],
order_by: Option<(&str, OrderDirection)>,
pagination: Option<Pagination>,
) -> impl Future<Output = RepositoryResult<Vec<Entity>>> + Send;
fn find_deleted(
&self,
filters: &[FilterCondition],
order_by: Option<(&str, OrderDirection)>,
pagination: Option<Pagination>,
) -> impl Future<Output = RepositoryResult<Vec<Entity>>> + Send;
fn force_delete(&self, id: &Id) -> impl Future<Output = RepositoryResult<bool>> + Send;
}
pub trait RelationLoader<Entity, RelatedId, Related>: Send + Sync
where
RelatedId: Eq + Hash,
{
fn load_one(
&self,
entity: &Entity,
) -> impl Future<Output = RepositoryResult<Option<Related>>> + Send;
fn load_many(
&self,
entity: &Entity,
) -> impl Future<Output = RepositoryResult<Vec<Related>>> + Send;
fn batch_load(
&self,
ids: &[RelatedId],
) -> impl Future<Output = RepositoryResult<HashMap<RelatedId, Related>>> + Send
where
Related: Clone,
RelatedId: Clone;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_repository_result_type() {
let ok_result: RepositoryResult<i32> = Ok(42);
assert!(ok_result.is_ok());
let err_result: RepositoryResult<i32> = Err(
super::super::error::RepositoryError::not_found("Test", "123"),
);
assert!(err_result.is_err());
}
struct MockId(String);
struct MockEntity {
id: String,
}
struct MockCreate {
name: String,
}
struct MockUpdate {
name: Option<String>,
}
struct MockRepository;
impl Repository<MockId, MockEntity, MockCreate, MockUpdate> for MockRepository {
async fn find_by_id(&self, _id: &MockId) -> RepositoryResult<Option<MockEntity>> {
Ok(None)
}
async fn find_all(
&self,
_filters: &[FilterCondition],
_order_by: Option<(&str, OrderDirection)>,
_pagination: Option<Pagination>,
) -> RepositoryResult<Vec<MockEntity>> {
Ok(vec![])
}
async fn count(&self, _filters: &[FilterCondition]) -> RepositoryResult<u64> {
Ok(0)
}
async fn exists(&self, _id: &MockId) -> RepositoryResult<bool> {
Ok(false)
}
async fn create(&self, data: MockCreate) -> RepositoryResult<MockEntity> {
Ok(MockEntity { id: data.name })
}
async fn update(&self, id: &MockId, data: MockUpdate) -> RepositoryResult<MockEntity> {
Ok(MockEntity {
id: data.name.unwrap_or_else(|| id.0.clone()),
})
}
async fn delete(&self, _id: &MockId) -> RepositoryResult<bool> {
Ok(true)
}
}
impl SoftDeleteRepository<MockId, MockEntity, MockCreate, MockUpdate> for MockRepository {
async fn soft_delete(&self, _id: &MockId) -> RepositoryResult<bool> {
Ok(true)
}
async fn restore(&self, _id: &MockId) -> RepositoryResult<bool> {
Ok(true)
}
async fn find_with_deleted(
&self,
_filters: &[FilterCondition],
_order_by: Option<(&str, OrderDirection)>,
_pagination: Option<Pagination>,
) -> RepositoryResult<Vec<MockEntity>> {
Ok(vec![])
}
async fn find_deleted(
&self,
_filters: &[FilterCondition],
_order_by: Option<(&str, OrderDirection)>,
_pagination: Option<Pagination>,
) -> RepositoryResult<Vec<MockEntity>> {
Ok(vec![])
}
async fn force_delete(&self, _id: &MockId) -> RepositoryResult<bool> {
Ok(true)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct RelatedId(String);
#[derive(Clone)]
struct RelatedEntity {
id: RelatedId,
}
struct MockRelationLoader;
impl RelationLoader<MockEntity, RelatedId, RelatedEntity> for MockRelationLoader {
async fn load_one(&self, _entity: &MockEntity) -> RepositoryResult<Option<RelatedEntity>> {
Ok(None)
}
async fn load_many(&self, _entity: &MockEntity) -> RepositoryResult<Vec<RelatedEntity>> {
Ok(vec![])
}
async fn batch_load(
&self,
ids: &[RelatedId],
) -> RepositoryResult<HashMap<RelatedId, RelatedEntity>>
where
RelatedEntity: Clone,
RelatedId: Clone,
{
Ok(ids
.iter()
.map(|id| (id.clone(), RelatedEntity { id: id.clone() }))
.collect())
}
}
#[tokio::test]
async fn test_mock_repository_find_by_id() {
let repo = MockRepository;
let result = repo.find_by_id(&MockId("test".to_string())).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_mock_repository_create() {
let repo = MockRepository;
let result = repo
.create(MockCreate {
name: "test".to_string(),
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "test");
}
#[tokio::test]
async fn test_mock_soft_delete_repository() {
let repo = MockRepository;
let result = repo.soft_delete(&MockId("test".to_string())).await;
assert!(result.is_ok());
assert!(result.unwrap());
}
#[tokio::test]
async fn test_mock_relation_loader_batch() {
let loader = MockRelationLoader;
let ids = vec![RelatedId("1".to_string()), RelatedId("2".to_string())];
let result = loader.batch_load(&ids).await;
assert!(result.is_ok());
let map = result.unwrap();
assert_eq!(map.len(), 2);
let one = RelatedId("1".to_string());
let two = RelatedId("2".to_string());
assert!(map.contains_key(&one));
assert!(map.contains_key(&two));
assert_eq!(map.get(&one).unwrap().id, one);
assert_eq!(map.get(&two).unwrap().id, two);
}
}