use bincode::config::Configuration;
use bincode::serde::decode_from_slice;
use serde::de::DeserializeOwned;
use crate::errors::DatabaseError;
use crate::traits::{DatabaseEntry, Index, Storage};
use crate::transaction::DatabaseTransaction;
use crate::wrap::{decode_value, empty_wrap, wrap, Subtable, Wrap, WrapPrelude};
use crate::{DeriveKey, Incrementable, Manifest, Manifests, RecordKey};
use std::ops::Range;
type DatabaseIteratorItem<R, S> =
Result<<R as DatabaseEntry>::Key, DatabaseError<<S as Storage>::StoreError>>;
pub struct Database<S: Storage, M: Manifest> {
pub(crate) store: S,
fallback: Option<Box<dyn Storage<StoreError = S::StoreError>>>,
pub(crate) manifest: M,
serialization_config: Configuration,
}
impl<S: Default + Storage, M: Manifest> Default for Database<S, M> {
fn default() -> Self {
Self::new(S::default())
}
}
impl<S: Storage, M: Manifest> Database<S, M> {
pub fn new(store: S) -> Self {
let mut db = Database {
store,
fallback: None,
manifest: M::default(),
serialization_config: Configuration::default(),
};
let mut manifest = M::default();
manifest.load(&mut db).unwrap();
db.manifest = manifest;
db
}
pub fn with_serialization_config(&mut self, config: Configuration) {
self.serialization_config = config;
}
pub fn set_fallback(&mut self, fallback: Box<dyn Storage<StoreError = S::StoreError>>) {
self.fallback = Some(fallback);
}
pub fn put<R: DatabaseEntry>(
&mut self,
record: R,
) -> Result<R::Key, DatabaseError<<S as Storage>::StoreError>>
where
R::Key: RecordKey<Record = R> + Incrementable + Ord,
M: Manifests<R>,
{
let mut transaction = DatabaseTransaction::new(self);
let inserted_key = transaction.put(record, self)?;
self.commit(transaction)?;
Ok(inserted_key)
}
pub fn insert<K: RecordKey<Record = R>, R>(
&mut self,
record: R,
) -> Result<K, DatabaseError<<S as Storage>::StoreError>>
where
R: DeriveKey<Key = K> + DatabaseEntry<Key = K>,
M: Manifests<R>,
{
let mut transaction = DatabaseTransaction::new(self);
let inserted_key = transaction.insert::<K, R>(record)?;
self.commit(transaction)?;
Ok(inserted_key)
}
pub fn create_transaction(&self) -> DatabaseTransaction<M> {
DatabaseTransaction::new(self)
}
pub fn commit(
&mut self,
transaction: DatabaseTransaction<M>,
) -> Result<(), DatabaseError<S::StoreError>> {
let (writes, deletes) = transaction.consume();
for (key, value) in writes {
if let Some(fallback) = &mut self.fallback {
fallback
.insert(key.clone(), value.clone())
.map_err(DatabaseError::Io)?;
}
self.store.insert(key, value).map_err(DatabaseError::Io)?;
}
for key in deletes {
if let Some(fallback) = &mut self.fallback {
fallback.remove(key.clone()).map_err(DatabaseError::Io)?;
}
self.store.remove(key).map_err(DatabaseError::Io)?;
}
Ok(())
}
pub fn get<K: RecordKey>(
&self,
key: &K,
) -> Result<Option<K::Record>, DatabaseError<S::StoreError>>
where
K::Record: DatabaseEntry<Key = K>,
M: Manifests<K::Record>,
{
let serialized_key = wrap::<K::Record>(key, self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let value =
if let Some(value) = self.store.get(serialized_key).map_err(DatabaseError::Io)? {
value
} else {
let Some(fallback) = &self.fallback else {
return Ok(None);
};
let key = wrap::<K::Record>(key, self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let Some(value) = fallback.get(key).map_err(DatabaseError::Io)? else {
return Ok(None);
};
value
};
Ok(Some(
decode_value(&value, self.serialization_config)
.map_err(DatabaseError::Deserialization)?,
))
}
pub fn remove<K: RecordKey>(
&mut self,
key: &K,
) -> Result<Option<K::Record>, DatabaseError<S::StoreError>>
where
K::Record: DatabaseEntry<Key = K>,
M: Manifests<K::Record>,
{
let key = wrap::<K::Record>(key, self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let value = if let Some(fallback) = &mut self.fallback {
let fallback_value = fallback.remove(key.clone()).map_err(DatabaseError::Io)?;
self.store.remove(key).map_err(DatabaseError::Io)?;
fallback_value
} else {
self.store.remove(key).map_err(DatabaseError::Io)?
};
Ok(if let Some(ref value) = value {
Some(
decode_value(value, self.serialization_config)
.map_err(DatabaseError::Deserialization)?,
)
} else {
None
})
}
pub fn iter_keys<K: RecordKey + Ord>(
&self,
range: Range<K>,
) -> Result<
impl Iterator<Item = DatabaseIteratorItem<K::Record, S>> + use<'_, K, S, M>,
DatabaseError<S::StoreError>,
>
where
K::Record: DatabaseEntry<Key = K>,
M: Manifests<K::Record>,
{
let start = wrap::<K::Record>(&range.start, self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let end = wrap::<K::Record>(&range.end, self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let raw_iter = self
.store
.iter_keys(start..end)
.map_err(DatabaseError::Io)?;
Ok(
raw_iter.map(|elem: Result<Vec<u8>, <S as Storage>::StoreError>| {
let value = match elem {
Ok(value) => value,
Err(e) => return Err(DatabaseError::Io(e)),
};
let deserialized: Wrap<K> =
match decode_from_slice(&value, self.serialization_config) {
Ok((deserialized, _)) => deserialized,
Err(e) => return Err(DatabaseError::Deserialization(e)),
};
Ok(deserialized.key)
}),
)
}
pub fn iter_all_keys<K: RecordKey + Ord>(
&self,
) -> Result<
impl Iterator<Item = DatabaseIteratorItem<K::Record, S>> + use<'_, K, S, M>,
DatabaseError<S::StoreError>,
>
where
K::Record: DatabaseEntry<Key = K>,
M: Manifests<K::Record>,
{
let (start, end) = empty_wrap::<K::Record>(self.serialization_config)
.map_err(DatabaseError::Serialization)?;
let raw_iter = self
.store
.iter_keys(start..end)
.map_err(DatabaseError::Io)?;
Ok(
raw_iter.map(|elem: Result<Vec<u8>, <S as Storage>::StoreError>| {
let value = match elem {
Ok(value) => value,
Err(e) => return Err(DatabaseError::Io(e)),
};
let deserialized: Wrap<K> =
match decode_from_slice(&value, self.serialization_config) {
Ok((deserialized, _)) => deserialized,
Err(e) => return Err(DatabaseError::Deserialization(e)),
};
Ok(deserialized.key)
}),
)
}
pub fn last_id<K: RecordKey + Ord + Default>(&self) -> Result<K, DatabaseError<S::StoreError>>
where
K::Record: DatabaseEntry<Key = K>,
M: Manifests<K::Record>,
{
let mut first = self.iter_all_keys::<K>()?;
Ok(first.next().transpose()?.unwrap_or_default())
}
pub fn iter_by_index<I: Index + Ord>(
&mut self,
range: Range<I>,
) -> Result<
impl Iterator<Item = DatabaseIteratorItem<I::Record, S>> + use<'_, I, S, M>,
DatabaseError<S::StoreError>,
> {
let index_prelude = WrapPrelude::new::<I::Record>(Subtable::Index(I::INDEX));
let mut start = index_prelude.to_bytes(self.serialization_config);
let mut end = start.clone();
start.extend(range.start.to_bytes(self.serialization_config));
end.extend(range.end.to_bytes(self.serialization_config));
let raw_iter = self
.store
.iter_keys(start..end)
.map_err(DatabaseError::Io)?;
Ok(raw_iter.map(|elem| self.process_iter_result(elem)))
}
pub fn iter_by_index_exact<I: Index + Ord>(
&mut self,
index_key: &I,
) -> Result<
impl Iterator<Item = DatabaseIteratorItem<I::Record, S>> + use<'_, I, S, M>,
DatabaseError<S::StoreError>,
> {
let index_prelude = WrapPrelude::new::<I::Record>(Subtable::Index(I::INDEX));
let mut start = index_prelude.to_bytes(self.serialization_config);
let mut end = index_prelude.to_bytes(self.serialization_config);
let start_bytes = index_key.to_bytes(self.serialization_config);
let end_bytes = {
let mut end_bytes = start_bytes.clone();
bytes_next(&mut end_bytes);
end_bytes
};
start.extend(start_bytes);
end.extend(end_bytes);
let raw_iter = self
.store
.iter_keys(start..end)
.map_err(DatabaseError::Io)?;
Ok(raw_iter.map(|elem| self.process_iter_result(elem)))
}
pub fn dissolve(self) -> S {
self.store
}
pub fn serialization_config(&self) -> Configuration {
self.serialization_config
}
fn process_iter_result<T: DeserializeOwned>(
&self,
result: Result<Vec<u8>, S::StoreError>,
) -> Result<T, DatabaseError<S::StoreError>> {
let key = result.map_err(DatabaseError::Io)?;
let value: Vec<u8> = match self.store.get(key) {
Ok(Some(data)) => data,
Ok(None) => {
return Err(DatabaseError::Internal(
crate::InternalDatabaseError::MissingIndexEntry,
));
}
Err(e) => return Err(DatabaseError::Io(e)),
};
decode_from_slice(&value, self.serialization_config)
.map_err(DatabaseError::Deserialization)
.map(|(v, _)| v)
}
}
fn bytes_next(bytes: &mut Vec<u8>) {
for i in (0..bytes.len()).rev() {
if bytes[i] < 255 {
bytes[i] += 1;
return;
} else {
bytes[i] = 0;
}
}
bytes.push(0);
}