use std::marker::PhantomData;
use std::sync::Arc;
use noxu_db::{
Cursor, Database, DatabaseEntry, Mutex, OperationStatus, Transaction,
};
use crate::entity::{Entity, PrimaryKey};
use crate::entity_serializer::EntitySerializer;
use crate::error::{PersistError, Result};
use crate::evolve::envelope;
use crate::evolve::mutations::Mutations;
pub struct PrimaryIndex<K: PrimaryKey, E: Entity<PrimaryKey = K>> {
db: Arc<Mutex<Database>>,
mutations: Arc<Mutations>,
_phantom: PhantomData<(K, E)>,
}
impl<K, E> PrimaryIndex<K, E>
where
K: PrimaryKey + Ord + Send + Sync + 'static,
E: Entity<PrimaryKey = K> + Clone + Send + Sync + 'static,
{
pub fn new(db: Arc<Mutex<Database>>) -> Self {
Self::with_mutations(db, Arc::new(Mutations::new()))
}
pub fn with_mutations(
db: Arc<Mutex<Database>>,
mutations: Arc<Mutations>,
) -> Self {
Self { db, mutations, _phantom: PhantomData }
}
pub fn mutations(&self) -> &Arc<Mutations> {
&self.mutations
}
pub fn database_shared(&self) -> &Arc<Mutex<Database>> {
&self.db
}
pub fn get<S: EntitySerializer<E>>(
&self,
txn: Option<&Transaction>,
serializer: &S,
key: &K,
) -> Result<Option<E>> {
let key_bytes = key.to_bytes();
let db = self.db.lock();
let found = match txn {
Some(t) => db.get_in(t, &key_bytes)?,
None => db.get(&key_bytes)?,
};
match found {
Some(bytes) => {
let entity = self.decode_record(&bytes, serializer)?;
Ok(Some(entity))
}
None => Ok(None),
}
}
fn decode_record<S: EntitySerializer<E>>(
&self,
bytes: &[u8],
serializer: &S,
) -> Result<E> {
let dec = envelope::decode(bytes)?;
let expected_tag = E::entity_name();
if dec.class_tag != expected_tag {
let renamed = self.mutations.renamers().any(|r| {
r.field_name().is_none()
&& r.class_name() == dec.class_tag
&& r.new_name() == expected_tag
});
if !renamed {
return Err(PersistError::SerializationError(format!(
"entity class tag mismatch: on-disk '{}' != \
expected '{}' (no Renamer registered)",
dec.class_tag, expected_tag,
)));
}
}
serializer.deserialize_versioned(
dec.payload,
dec.class_version,
self.mutations.as_ref(),
)
}
pub fn put<S: EntitySerializer<E>>(
&self,
txn: Option<&Transaction>,
serializer: &S,
entity: &E,
) -> Result<()> {
let key_bytes = entity.primary_key().to_bytes();
let payload = serializer.serialize(entity)?;
let envelope_bytes =
envelope::encode(E::class_version(), E::entity_name(), &payload)?;
let db = self.db.lock();
match txn {
Some(t) => db.put_in(t, &key_bytes, &envelope_bytes)?,
None => db.put(&key_bytes, &envelope_bytes)?,
}
Ok(())
}
pub fn put_no_overwrite<S: EntitySerializer<E>>(
&self,
txn: Option<&Transaction>,
serializer: &S,
entity: &E,
) -> Result<bool> {
let key_bytes = entity.primary_key().to_bytes();
let payload = serializer.serialize(entity)?;
let envelope_bytes =
envelope::encode(E::class_version(), E::entity_name(), &payload)?;
let db = self.db.lock();
let inserted = match txn {
Some(t) => {
db.put_no_overwrite_in(t, &key_bytes, &envelope_bytes)?
}
None => db.put_no_overwrite(&key_bytes, &envelope_bytes)?,
};
Ok(inserted)
}
pub fn delete(&self, txn: Option<&Transaction>, key: &K) -> Result<bool> {
let key_bytes = key.to_bytes();
let db = self.db.lock();
let deleted = match txn {
Some(t) => db.delete_in(t, &key_bytes)?,
None => db.delete(&key_bytes)?,
};
Ok(deleted)
}
pub fn delete_with_entity<S: EntitySerializer<E>>(
&self,
txn: Option<&Transaction>,
_serializer: &S,
key: &K,
) -> Result<bool> {
self.delete(txn, key)
}
pub fn contains(&self, txn: Option<&Transaction>, key: &K) -> Result<bool> {
let key_bytes = key.to_bytes();
let db = self.db.lock();
let found = match txn {
Some(t) => db.get_in(t, &key_bytes)?,
None => db.get(&key_bytes)?,
};
Ok(found.is_some())
}
pub fn count(&self) -> Result<u64> {
Ok(self.db.lock().count()?)
}
pub fn entities<'a, S: EntitySerializer<E>>(
&'a self,
txn: Option<&'a Transaction>,
serializer: &'a S,
) -> Result<EntityIterator<'a, K, E, S>> {
let cursor = {
let db = self.db.lock();
match txn {
Some(t) => db.open_cursor_in(t, None)?,
None => db.open_cursor(None)?,
}
};
Ok(EntityIterator {
cursor,
serializer,
mutations: Arc::clone(&self.mutations),
started: false,
done: false,
_phantom: PhantomData,
})
}
pub fn keys<'a>(
&'a self,
txn: Option<&'a Transaction>,
) -> Result<KeyIterator<'a, K>> {
let cursor = {
let db = self.db.lock();
match txn {
Some(t) => db.open_cursor_in(t, None)?,
None => db.open_cursor(None)?,
}
};
Ok(KeyIterator {
cursor,
started: false,
done: false,
_phantom: PhantomData,
})
}
pub fn database(&self) -> Arc<Mutex<Database>> {
Arc::clone(&self.db)
}
}
pub struct EntityIterator<'a, K, E, S> {
cursor: Cursor<'a>,
serializer: &'a S,
mutations: Arc<Mutations>,
started: bool,
done: bool,
_phantom: PhantomData<(K, E)>,
}
impl<'a, K: PrimaryKey, E: Entity<PrimaryKey = K>, S: EntitySerializer<E>>
Iterator for EntityIterator<'a, K, E, S>
{
type Item = Result<E>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let mut key_entry = DatabaseEntry::new();
let mut data_entry = DatabaseEntry::new();
let get_type = if self.started {
noxu_db::Get::Next
} else {
self.started = true;
noxu_db::Get::First
};
match self.cursor.get(&mut key_entry, &mut data_entry, get_type, None) {
Ok(OperationStatus::Success) => {
let bytes = match data_entry.data_opt() {
Some(b) => b,
None => {
self.done = true;
return Some(Err(PersistError::SerializationError(
"empty data from cursor".to_string(),
)));
}
};
Some(decode_iter_record::<E, S>(
bytes,
self.serializer,
self.mutations.as_ref(),
))
}
Ok(_) => {
self.done = true;
None
}
Err(e) => {
self.done = true;
Some(Err(e.into()))
}
}
}
}
fn decode_iter_record<E, S>(
bytes: &[u8],
serializer: &S,
mutations: &Mutations,
) -> Result<E>
where
E: Entity,
S: EntitySerializer<E>,
{
let dec = envelope::decode(bytes)?;
let expected_tag = E::entity_name();
if dec.class_tag != expected_tag {
let renamed = mutations.renamers().any(|r| {
r.field_name().is_none()
&& r.class_name() == dec.class_tag
&& r.new_name() == expected_tag
});
if !renamed {
return Err(PersistError::SerializationError(format!(
"entity class tag mismatch: on-disk '{}' != expected '{}' \
(no Renamer registered)",
dec.class_tag, expected_tag,
)));
}
}
serializer.deserialize_versioned(dec.payload, dec.class_version, mutations)
}
impl<K, E, S> Drop for EntityIterator<'_, K, E, S> {
fn drop(&mut self) {
let _ = self.cursor.close();
}
}
pub struct KeyIterator<'a, K> {
cursor: Cursor<'a>,
started: bool,
done: bool,
_phantom: PhantomData<&'a K>,
}
impl<K: PrimaryKey> Iterator for KeyIterator<'_, K> {
type Item = Result<K>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let mut key_entry = DatabaseEntry::new();
let mut data_entry = DatabaseEntry::new();
let get_type = if self.started {
noxu_db::Get::Next
} else {
self.started = true;
noxu_db::Get::First
};
match self.cursor.get(&mut key_entry, &mut data_entry, get_type, None) {
Ok(OperationStatus::Success) => {
match key_entry.data_opt() {
Some(key_bytes) => Some(K::from_bytes(key_bytes)),
None => {
self.done = true;
None
}
}
}
Ok(_) => {
self.done = true;
None
}
Err(e) => {
self.done = true;
Some(Err(e.into()))
}
}
}
}
impl<K> Drop for KeyIterator<'_, K> {
fn drop(&mut self) {
let _ = self.cursor.close();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entity::Entity;
use crate::entity_serializer::EntitySerializer;
use noxu_db::{DatabaseConfig, Environment, EnvironmentConfig};
use tempfile::TempDir;
#[derive(Clone, Debug, PartialEq)]
struct User {
id: u64,
name: String,
email: String,
}
impl Entity for User {
type PrimaryKey = u64;
fn primary_key(&self) -> &u64 {
&self.id
}
fn entity_name() -> &'static str {
"User"
}
}
struct UserSerializer;
impl EntitySerializer<User> for UserSerializer {
fn serialize(&self, entity: &User) -> Result<Vec<u8>> {
let mut buf = Vec::new();
buf.extend_from_slice(&entity.id.to_be_bytes());
let name_bytes = entity.name.as_bytes();
buf.extend_from_slice(&(name_bytes.len() as u32).to_be_bytes());
buf.extend_from_slice(name_bytes);
let email_bytes = entity.email.as_bytes();
buf.extend_from_slice(&(email_bytes.len() as u32).to_be_bytes());
buf.extend_from_slice(email_bytes);
Ok(buf)
}
fn deserialize(&self, bytes: &[u8]) -> Result<User> {
if bytes.len() < 12 {
return Err(PersistError::SerializationError(
"not enough bytes for User".to_string(),
));
}
let id = u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5],
bytes[6], bytes[7],
]);
let name_len =
u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]])
as usize;
let name_start = 12;
let name_end = name_start + name_len;
if bytes.len() < name_end + 4 {
return Err(PersistError::SerializationError(
"not enough bytes for User name/email".to_string(),
));
}
let name = String::from_utf8(bytes[name_start..name_end].to_vec())
.map_err(|e| {
PersistError::SerializationError(format!("bad name: {}", e))
})?;
let email_len = u32::from_be_bytes([
bytes[name_end],
bytes[name_end + 1],
bytes[name_end + 2],
bytes[name_end + 3],
]) as usize;
let email_start = name_end + 4;
let email_end = email_start + email_len;
if bytes.len() < email_end {
return Err(PersistError::SerializationError(
"not enough bytes for User email".to_string(),
));
}
let email =
String::from_utf8(bytes[email_start..email_end].to_vec())
.map_err(|e| {
PersistError::SerializationError(format!(
"bad email: {}",
e
))
})?;
Ok(User { id, name, email })
}
}
fn setup() -> (TempDir, Environment, Arc<Mutex<Database>>) {
let temp_dir = TempDir::new().unwrap();
let env_config = EnvironmentConfig::new(temp_dir.path().to_path_buf())
.with_allow_create(true);
let env = Environment::open(env_config).unwrap();
let db_config = DatabaseConfig::new()
.with_allow_create(true)
.with_transactional(true);
let db = env.open_database(None, "users", &db_config).unwrap();
(temp_dir, env, Arc::new(Mutex::new(db)))
}
fn test_user(id: u64) -> User {
User {
id,
name: format!("User{}", id),
email: format!("user{}@example.com", id),
}
}
#[test]
fn test_put_and_get() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
let found = index.get(None, &ser, &1u64).unwrap();
assert_eq!(found, Some(user));
}
#[test]
fn test_get_not_found() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let found = index.get(None, &ser, &999u64).unwrap();
assert_eq!(found, None);
}
#[test]
fn test_put_overwrites() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user1 = test_user(1);
index.put(None, &ser, &user1).unwrap();
let user1_updated = User {
id: 1,
name: "Updated".to_string(),
email: "updated@example.com".to_string(),
};
index.put(None, &ser, &user1_updated).unwrap();
let found = index.get(None, &ser, &1u64).unwrap().unwrap();
assert_eq!(found.name, "Updated");
assert_eq!(found.email, "updated@example.com");
}
#[test]
fn test_put_no_overwrite_success() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
let inserted = index.put_no_overwrite(None, &ser, &user).unwrap();
assert!(inserted);
}
#[test]
fn test_put_no_overwrite_fails_on_existing() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
let user2 = User {
id: 1,
name: "Other".to_string(),
email: "other@example.com".to_string(),
};
let inserted = index.put_no_overwrite(None, &ser, &user2).unwrap();
assert!(!inserted);
let found = index.get(None, &ser, &1u64).unwrap().unwrap();
assert_eq!(found.name, "User1");
}
#[test]
fn test_delete() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
let deleted = index.delete(None, &1u64).unwrap();
assert!(deleted);
let found = index.get(None, &ser, &1u64).unwrap();
assert_eq!(found, None);
}
#[test]
fn test_delete_not_found() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let deleted = index.delete(None, &999u64).unwrap();
assert!(!deleted);
}
#[test]
fn test_delete_with_entity() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
let deleted = index.delete_with_entity(None, &ser, &1u64).unwrap();
assert!(deleted);
assert_eq!(index.get(None, &ser, &1u64).unwrap(), None);
}
#[test]
fn test_delete_with_entity_not_found() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let deleted = index.delete_with_entity(None, &ser, &999u64).unwrap();
assert!(!deleted);
}
#[test]
fn test_contains() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
assert!(!index.contains(None, &1u64).unwrap());
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
assert!(index.contains(None, &1u64).unwrap());
}
#[test]
fn test_count_empty() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
assert_eq!(index.count().unwrap(), 0);
}
#[test]
fn test_count_with_entities() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
for i in 1..=5 {
index.put(None, &ser, &test_user(i)).unwrap();
}
assert_eq!(index.count().unwrap(), 5);
}
#[test]
fn test_entities_iterator() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
for i in 1..=3 {
index.put(None, &ser, &test_user(i)).unwrap();
}
let entities: Vec<User> = index
.entities(None, &ser)
.unwrap()
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap();
assert_eq!(entities.len(), 3);
assert_eq!(entities[0].id, 1);
assert_eq!(entities[1].id, 2);
assert_eq!(entities[2].id, 3);
}
#[test]
fn test_entities_iterator_empty() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let entities: Vec<User> = index
.entities(None, &ser)
.unwrap()
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_multiple_put_delete_cycles() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(1);
index.put(None, &ser, &user).unwrap();
index.delete(None, &1u64).unwrap();
assert_eq!(index.count().unwrap(), 0);
let user2 = User {
id: 1,
name: "Reinserted".to_string(),
email: "new@example.com".to_string(),
};
index.put(None, &ser, &user2).unwrap();
let found = index.get(None, &ser, &1u64).unwrap().unwrap();
assert_eq!(found.name, "Reinserted");
}
#[test]
fn test_database_reference() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
assert_eq!(index.database().lock().name(), "users");
}
#[test]
fn test_entity_iterator_done_returns_none() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let mut iter = index.entities(None, &ser).unwrap();
assert!(iter.next().is_none()); assert!(iter.next().is_none()); }
#[test]
fn test_entity_iterator_exhausted() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
index.put(None, &ser, &test_user(1)).unwrap();
let mut iter = index.entities(None, &ser).unwrap();
assert!(iter.next().is_some());
assert!(iter.next().is_none());
assert!(iter.next().is_none());
}
#[test]
fn test_key_iterator_first_call_done() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
index.put(None, &ser, &test_user(1)).unwrap();
let mut iter = index.keys(None).unwrap();
let first = iter.next();
assert!(iter.next().is_none());
let _ = first;
}
#[test]
fn test_key_iterator_empty_db() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let mut iter = index.keys(None).unwrap();
assert!(iter.next().is_none());
}
#[test]
fn test_key_iterator_done_branch() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
for i in 1u64..=3 {
index.put(None, &ser, &test_user(i)).unwrap();
}
let mut iter = index.keys(None).unwrap();
let k1 = iter.next();
let k2 = iter.next();
let k3 = iter.next();
assert!(k1.is_some(), "expected first key");
assert!(k2.is_some(), "expected second key");
assert!(k3.is_some(), "expected third key");
assert!(iter.next().is_none());
assert!(iter.next().is_none());
}
#[test]
fn test_delete_with_entity_removes_record() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
index.put(None, &ser, &test_user(10)).unwrap();
let deleted = index.delete_with_entity(None, &ser, &10u64).unwrap();
assert!(deleted);
assert_eq!(index.get(None, &ser, &10u64).unwrap(), None);
}
#[test]
fn test_delete_with_entity_missing_key_returns_false() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let deleted = index.delete_with_entity(None, &ser, &999u64).unwrap();
assert!(!deleted);
}
#[test]
fn test_put_no_overwrite_insert_then_skip() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
let user = test_user(5);
assert!(index.put_no_overwrite(None, &ser, &user).unwrap());
assert!(!index.put_no_overwrite(None, &ser, &user).unwrap());
}
#[test]
fn test_contains_after_delete() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
index.put(None, &ser, &test_user(7)).unwrap();
assert!(index.contains(None, &7u64).unwrap());
index.delete(None, &7u64).unwrap();
assert!(!index.contains(None, &7u64).unwrap());
}
#[test]
fn test_count_after_delete() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
for i in 1u64..=4 {
index.put(None, &ser, &test_user(i)).unwrap();
}
assert_eq!(index.count().unwrap(), 4);
index.delete(None, &2u64).unwrap();
assert_eq!(index.count().unwrap(), 3);
}
#[test]
fn test_entities_iterator_many() {
let (_td, _env, db) = setup();
let index: PrimaryIndex<u64, User> = PrimaryIndex::new(Arc::clone(&db));
let ser = UserSerializer;
for i in 1u64..=10 {
index.put(None, &ser, &test_user(i)).unwrap();
}
let entities: Vec<User> = index
.entities(None, &ser)
.unwrap()
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap();
assert_eq!(entities.len(), 10);
for (i, user) in entities.iter().enumerate() {
assert_eq!(user.id, (i + 1) as u64);
}
}
}