use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use crate::entity::{EntityId, SINGLETON_ENTITY_ID, Summary};
use crate::traits::{StateCollection, StateEntity};
use crate::{Error, Result};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Collection<T: StateEntity> {
#[serde(bound(deserialize = "T: StateEntity"))]
inner: HashMap<EntityId, T>,
}
impl<T: StateEntity> Default for Collection<T> {
fn default() -> Self { Self { inner: HashMap::default() } }
}
impl<T: StateEntity> Collection<T> {
pub fn new() -> Self { Self::default() }
pub fn inner(&self) -> &HashMap<EntityId, T> { &self.inner }
pub fn get_by_id(&self, id: &EntityId) -> Option<&T> { self.inner.get(id) }
pub fn get_by_name(&self, name: &str) -> Option<(&EntityId, &T)> {
self.inner.iter().find(|(_, entity)| entity.name() == name)
}
pub fn len(&self) -> usize { self.inner.len() }
pub fn is_empty(&self) -> bool { self.inner.is_empty() }
pub fn iter(&self) -> impl Iterator<Item = (&EntityId, &T)> { self.inner.iter() }
}
impl<T: StateEntity> StateCollection for Collection<T> {
type Entity = T;
const STATE_ENTRY: <T as StateEntity>::Entry = T::STATE_ENTRY;
fn load<I>(entities: I) -> Self
where
I: IntoIterator<Item = (EntityId, Self::Entity)>,
{
Self { inner: entities.into_iter().collect() }
}
fn get_entity(&self, id: &str) -> Option<(&EntityId, &Self::Entity)> {
let entity_id = EntityId::from(id);
if let Some((key, entity)) = self.inner.iter().find(|(k, _)| **k == entity_id) {
return Some((key, entity));
}
self.get_by_name(id)
}
fn get_entities(&self) -> Vec<(&EntityId, &Self::Entity)> { self.inner.iter().collect() }
fn search_entities(&self, needle: &str) -> Vec<(&EntityId, &Self::Entity)> {
let needle_lower = needle.to_lowercase();
self.inner
.iter()
.filter(|(_, entity)| {
needle.is_empty()
|| entity.name().to_lowercase().contains(&needle_lower)
|| entity
.description()
.is_some_and(|d| d.to_lowercase().contains(&needle_lower))
})
.collect()
}
fn create(&mut self, entity: Self::Entity) -> EntityId {
let id = EntityId::new();
drop(self.inner.insert(id.clone(), entity));
id
}
fn update(&mut self, id: &str, entity: Self::Entity) -> Result<()> {
let Some(e) = self.inner.get_mut(id) else {
return Err(Error::NotFound(format!("Entity not found: {id}")));
};
drop(std::mem::replace(e, entity));
Ok(())
}
fn remove(&mut self, id: &str) -> Result<Self::Entity> {
self.inner.remove(id).ok_or_else(|| Error::NotFound(format!("Entity not found: {id}")))
}
fn list(&self) -> Vec<Summary> {
self.inner.iter().map(|(id, entity)| entity.summary(id.clone())).collect()
}
fn is_empty(&self) -> bool { self.inner.is_empty() }
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Singleton<T: StateEntity> {
#[serde(bound(deserialize = "T: StateEntity"))]
inner: T,
}
impl<T: StateEntity> Singleton<T> {
pub fn new(entity: T) -> Self { Self { inner: entity } }
pub fn get(&self) -> &T { &self.inner }
pub fn get_mut(&mut self) -> &mut T { &mut self.inner }
pub fn set(&mut self, entity: T) { self.inner = entity; }
}
impl<T: StateEntity + Default> StateCollection for Singleton<T> {
type Entity = T;
const STATE_ENTRY: <T as StateEntity>::Entry = T::STATE_ENTRY;
fn load<I>(entities: I) -> Self
where
I: IntoIterator<Item = (EntityId, Self::Entity)>,
{
let mut iter = entities.into_iter();
if let Some((_, entity)) = iter.next() {
Self::new(entity)
} else {
Self::new(T::default())
}
}
fn get_entity(&self, _id: &str) -> Option<(&EntityId, &Self::Entity)> {
Some((&SINGLETON_ENTITY_ID, &self.inner))
}
fn get_entities(&self) -> Vec<(&EntityId, &Self::Entity)> {
vec![(&SINGLETON_ENTITY_ID, &self.inner)]
}
fn search_entities(&self, needle: &str) -> Vec<(&EntityId, &Self::Entity)> {
let needle_lower = needle.to_lowercase();
if needle.is_empty()
|| self.inner.name().to_lowercase().contains(&needle_lower)
|| self.inner.description().is_some_and(|d| d.to_lowercase().contains(&needle_lower))
{
vec![(&SINGLETON_ENTITY_ID, &self.inner)]
} else {
vec![]
}
}
fn create(&mut self, entity: Self::Entity) -> EntityId {
self.inner = entity;
EntityId::singleton()
}
fn update(&mut self, _id: &str, entity: Self::Entity) -> Result<()> {
drop(std::mem::replace(&mut self.inner, entity));
Ok(())
}
fn remove(&mut self, _id: &str) -> Result<Self::Entity> {
Err(Error::IllegalOperation("Cannot remove singleton entity".to_string()))
}
fn list(&self) -> Vec<Summary> { vec![self.inner.summary(EntityId::singleton())] }
fn is_empty(&self) -> bool { false }
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestEntity {
name: String,
value: i32,
}
#[derive(Debug, Copy, Clone)]
enum TestStateEntry {
TestEntity,
}
impl AsRef<str> for TestStateEntry {
fn as_ref(&self) -> &str {
match self {
TestStateEntry::TestEntity => "test_entity",
}
}
}
impl crate::HasName for TestEntity {
fn name(&self) -> &str { &self.name }
}
impl StateEntity for TestEntity {
type Entry = TestStateEntry;
const STATE_ENTRY: TestStateEntry = TestStateEntry::TestEntity;
}
#[test]
fn test_singleton_serde_roundtrip() {
let entity = TestEntity { name: "test".to_string(), value: 42 };
let singleton = Singleton::new(entity.clone());
let json = serde_json::to_string(&singleton).unwrap();
let deserialized: Singleton<TestEntity> = serde_json::from_str(&json).unwrap();
assert_eq!(singleton, deserialized);
assert_eq!(deserialized.get().name, "test");
assert_eq!(deserialized.get().value, 42);
}
#[test]
fn test_singleton_deserialize_from_entity_json() {
let json = r#"{"name":"direct","value":100}"#;
let singleton: Singleton<TestEntity> = serde_json::from_str(json).unwrap();
assert_eq!(singleton.get().name, "direct");
assert_eq!(singleton.get().value, 100);
}
#[test]
fn test_collection_serde_roundtrip() {
let mut collection = Collection::<TestEntity>::new();
let id1 = collection.create(TestEntity { name: "entity1".to_string(), value: 10 });
let id2 = collection.create(TestEntity { name: "entity2".to_string(), value: 20 });
let json = serde_json::to_string(&collection).unwrap();
let deserialized: Collection<TestEntity> = serde_json::from_str(&json).unwrap();
assert_eq!(collection, deserialized);
assert_eq!(deserialized.len(), 2);
assert_eq!(deserialized.get_by_id(&id1).unwrap().value, 10);
assert_eq!(deserialized.get_by_id(&id2).unwrap().value, 20);
}
#[test]
fn test_box_wrapper() {
let id1 = EntityId::new();
let entities = vec![
(id1.clone(), TestEntity { name: "entity1".to_string(), value: 10 }),
(EntityId::new(), TestEntity { name: "entity2".to_string(), value: 20 }),
];
let mut collections: Box<Collection<TestEntity>> =
Box::<Collection<TestEntity>>::load(entities);
assert!(collections.iter().collect::<Vec<_>>().len() == 2);
assert!(collections.len() == 2);
assert!(collections.get_entities().len() == 2);
assert!(collections.list().len() == 2);
assert!(!collections.is_empty());
let new_entity = TestEntity { name: "entity3".to_string(), value: 10 };
let id3 = collections.create(new_entity);
assert!(collections.get_by_id(&id3).unwrap().value == 10);
let removed = collections.remove(&id3).unwrap();
assert!(removed.value == 10);
let new_entity = TestEntity { name: "entity4".to_string(), value: 10 };
let result = collections.update(&id1, new_entity);
assert!(result.is_ok());
}
}