use crate::{batch::Batch, views::ViewError};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{
fmt::{Debug, Display},
future::Future,
ops::{
Bound,
Bound::{Excluded, Included, Unbounded},
},
time::{Duration, Instant},
};
#[cfg(test)]
#[path = "unit_tests/common_tests.rs"]
mod common_tests;
#[doc(hidden)]
pub type HasherOutputSize = <sha3::Sha3_256 as sha3::digest::OutputSizeUser>::OutputSize;
#[doc(hidden)]
pub type HasherOutput = generic_array::GenericArray<u8, HasherOutputSize>;
#[derive(Debug)]
pub(crate) enum Update<T> {
Removed,
Set(T),
}
pub const MIN_VIEW_TAG: u8 = 1;
pub(crate) fn get_upper_bound_option(key_prefix: &[u8]) -> Option<Vec<u8>> {
let len = key_prefix.len();
for i in (0..len).rev() {
let val = key_prefix[i];
if val < u8::MAX {
let mut upper_bound = key_prefix[0..i + 1].to_vec();
upper_bound[i] += 1;
return Some(upper_bound);
}
}
None
}
pub(crate) fn get_upper_bound(key_prefix: &[u8]) -> Bound<Vec<u8>> {
match get_upper_bound_option(key_prefix) {
None => Unbounded,
Some(upper_bound) => Excluded(upper_bound),
}
}
pub fn get_interval(key_prefix: Vec<u8>) -> (Bound<Vec<u8>>, Bound<Vec<u8>>) {
let upper_bound = get_upper_bound(&key_prefix);
(Included(key_prefix), upper_bound)
}
pub(crate) fn from_bytes_opt<V: DeserializeOwned, E>(
key_opt: Option<Vec<u8>>,
) -> Result<Option<V>, E>
where
E: From<bcs::Error>,
{
match key_opt {
Some(bytes) => {
let value = bcs::from_bytes(&bytes)?;
Ok(Some(value))
}
None => Ok(None),
}
}
pub trait KeyIterable<Error> {
type Iterator<'a>: Iterator<Item = Result<&'a [u8], Error>>
where
Self: 'a;
fn iterator(&self) -> Self::Iterator<'_>;
}
pub trait KeyValueIterable<Error> {
type Iterator<'a>: Iterator<Item = Result<(&'a [u8], &'a [u8]), Error>>
where
Self: 'a;
type IteratorOwned: Iterator<Item = Result<(Vec<u8>, Vec<u8>), Error>>;
fn iterator(&self) -> Self::Iterator<'_>;
fn into_iterator_owned(self) -> Self::IteratorOwned;
}
#[async_trait]
pub trait KeyValueStoreClient {
const MAX_VALUE_SIZE: usize;
type Error: Debug;
type Keys: KeyIterable<Self::Error>;
type KeyValues: KeyValueIterable<Self::Error>;
fn max_stream_queries(&self) -> usize;
async fn read_key_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
async fn read_multi_key_bytes(
&self,
keys: Vec<Vec<u8>>,
) -> Result<Vec<Option<Vec<u8>>>, Self::Error>;
async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Self::Keys, Self::Error>;
async fn find_key_values_by_prefix(
&self,
key_prefix: &[u8],
) -> Result<Self::KeyValues, Self::Error>;
async fn write_batch(&self, batch: Batch, base_key: &[u8]) -> Result<(), Self::Error>;
async fn clear_journal(&self, base_key: &[u8]) -> Result<(), Self::Error>;
async fn read_key<V: DeserializeOwned>(&self, key: &[u8]) -> Result<Option<V>, Self::Error>
where
Self::Error: From<bcs::Error>,
{
from_bytes_opt(self.read_key_bytes(key).await?)
}
async fn read_multi_key<V: DeserializeOwned + Send>(
&self,
keys: Vec<Vec<u8>>,
) -> Result<Vec<Option<V>>, Self::Error>
where
Self::Error: From<bcs::Error>,
{
let mut values = Vec::with_capacity(keys.len());
for entry in self.read_multi_key_bytes(keys).await? {
values.push(from_bytes_opt(entry)?);
}
Ok(values)
}
}
#[doc(hidden)]
pub struct SimpleKeyIterator<'a, E> {
iter: std::slice::Iter<'a, Vec<u8>>,
_error_type: std::marker::PhantomData<E>,
}
impl<'a, E> Iterator for SimpleKeyIterator<'a, E> {
type Item = Result<&'a [u8], E>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|key| Result::Ok(key.as_ref()))
}
}
impl<E> KeyIterable<E> for Vec<Vec<u8>> {
type Iterator<'a> = SimpleKeyIterator<'a, E>;
fn iterator(&self) -> Self::Iterator<'_> {
SimpleKeyIterator {
iter: self.iter(),
_error_type: std::marker::PhantomData,
}
}
}
#[doc(hidden)]
pub struct SimpleKeyValueIterator<'a, E> {
iter: std::slice::Iter<'a, (Vec<u8>, Vec<u8>)>,
_error_type: std::marker::PhantomData<E>,
}
impl<'a, E> Iterator for SimpleKeyValueIterator<'a, E> {
type Item = Result<(&'a [u8], &'a [u8]), E>;
fn next(&mut self) -> Option<Self::Item> {
self.iter
.next()
.map(|entry| Ok((&entry.0[..], &entry.1[..])))
}
}
#[doc(hidden)]
pub struct SimpleKeyValueIteratorOwned<E> {
iter: std::vec::IntoIter<(Vec<u8>, Vec<u8>)>,
_error_type: std::marker::PhantomData<E>,
}
impl<E> Iterator for SimpleKeyValueIteratorOwned<E> {
type Item = Result<(Vec<u8>, Vec<u8>), E>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(Result::Ok)
}
}
impl<E> KeyValueIterable<E> for Vec<(Vec<u8>, Vec<u8>)> {
type Iterator<'a> = SimpleKeyValueIterator<'a, E>;
type IteratorOwned = SimpleKeyValueIteratorOwned<E>;
fn iterator(&self) -> Self::Iterator<'_> {
SimpleKeyValueIterator {
iter: self.iter(),
_error_type: std::marker::PhantomData,
}
}
fn into_iterator_owned(self) -> Self::IteratorOwned {
SimpleKeyValueIteratorOwned {
iter: self.into_iter(),
_error_type: std::marker::PhantomData,
}
}
}
#[async_trait]
pub trait Context {
const MAX_VALUE_SIZE: usize;
type Extra: Clone + Send + Sync;
type Error: std::error::Error + Debug + Send + Sync + 'static + From<bcs::Error>;
type Keys: KeyIterable<Self::Error>;
type KeyValues: KeyValueIterable<Self::Error>;
fn max_stream_queries(&self) -> usize;
async fn read_key_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
async fn read_multi_key_bytes(
&self,
keys: Vec<Vec<u8>>,
) -> Result<Vec<Option<Vec<u8>>>, Self::Error>;
async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Self::Keys, Self::Error>;
async fn find_key_values_by_prefix(
&self,
key_prefix: &[u8],
) -> Result<Self::KeyValues, Self::Error>;
async fn write_batch(&self, batch: Batch) -> Result<(), Self::Error>;
fn extra(&self) -> &Self::Extra;
fn clone_with_base_key(&self, base_key: Vec<u8>) -> Self;
fn base_key(&self) -> Vec<u8>;
fn base_tag(&self, tag: u8) -> Vec<u8> {
assert!(tag >= MIN_VIEW_TAG, "tag should be at least MIN_VIEW_TAG");
let mut key = self.base_key();
key.extend([tag]);
key
}
fn base_tag_index(&self, tag: u8, index: &[u8]) -> Vec<u8> {
assert!(tag >= MIN_VIEW_TAG, "tag should be at least MIN_VIEW_TAG");
let mut key = self.base_key();
key.extend([tag]);
key.extend_from_slice(index);
key
}
fn derive_key<I: Serialize>(&self, index: &I) -> Result<Vec<u8>, Self::Error> {
let mut key = self.base_key();
bcs::serialize_into(&mut key, index)?;
assert!(
key.len() > self.base_key().len(),
"Empty indices are not allowed"
);
Ok(key)
}
fn derive_tag_key<I: Serialize>(&self, tag: u8, index: &I) -> Result<Vec<u8>, Self::Error> {
assert!(tag >= MIN_VIEW_TAG, "tag should be at least MIN_VIEW_TAG");
let mut key = self.base_key();
key.extend([tag]);
bcs::serialize_into(&mut key, index)?;
Ok(key)
}
fn derive_short_key<I: Serialize + ?Sized>(index: &I) -> Result<Vec<u8>, Self::Error> {
Ok(bcs::to_bytes(index)?)
}
fn deserialize_value<Item: DeserializeOwned>(bytes: &[u8]) -> Result<Item, Self::Error> {
let value = bcs::from_bytes(bytes)?;
Ok(value)
}
async fn read_key<Item>(&self, key: &[u8]) -> Result<Option<Item>, Self::Error>
where
Item: DeserializeOwned,
{
from_bytes_opt(self.read_key_bytes(key).await?)
}
async fn read_multi_key<V: DeserializeOwned + Send>(
&self,
keys: Vec<Vec<u8>>,
) -> Result<Vec<Option<V>>, Self::Error>
where
Self::Error: From<bcs::Error>,
{
let mut values = Vec::with_capacity(keys.len());
for entry in self.read_multi_key_bytes(keys).await? {
values.push(from_bytes_opt(entry)?);
}
Ok(values)
}
}
#[derive(Debug, Default, Clone)]
pub struct ContextFromDb<E, DB> {
pub db: DB,
pub base_key: Vec<u8>,
pub extra: E,
}
impl<E, DB> ContextFromDb<E, DB>
where
E: Clone + Send + Sync,
DB: KeyValueStoreClient + Clone + Send + Sync,
DB::Error: From<bcs::Error> + Send + Sync + std::error::Error + 'static,
ViewError: From<DB::Error>,
{
pub async fn create(
db: DB,
base_key: Vec<u8>,
extra: E,
) -> Result<Self, <ContextFromDb<E, DB> as Context>::Error> {
db.clear_journal(&base_key).await?;
Ok(ContextFromDb {
db,
base_key,
extra,
})
}
}
async fn time_async<F, O>(f: F) -> (O, Duration)
where
F: Future<Output = O>,
{
let start = Instant::now();
let out = f.await;
let duration = start.elapsed();
(out, duration)
}
async fn log_time_async<F, D, O>(f: F, name: D) -> O
where
F: Future<Output = O>,
D: Display,
{
if cfg!(feature = "db_timings") {
let (out, duration) = time_async(f).await;
let duration = duration.as_nanos();
println!("|{name}|={duration:?}");
out
} else {
f.await
}
}
#[async_trait]
impl<E, DB> Context for ContextFromDb<E, DB>
where
E: Clone + Send + Sync,
DB: KeyValueStoreClient + Clone + Send + Sync,
DB::Error: From<bcs::Error> + Send + Sync + std::error::Error + 'static,
ViewError: From<DB::Error>,
{
const MAX_VALUE_SIZE: usize = DB::MAX_VALUE_SIZE;
type Extra = E;
type Error = DB::Error;
type Keys = DB::Keys;
type KeyValues = DB::KeyValues;
fn max_stream_queries(&self) -> usize {
self.db.max_stream_queries()
}
fn extra(&self) -> &E {
&self.extra
}
fn base_key(&self) -> Vec<u8> {
self.base_key.clone()
}
async fn read_key_bytes(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
log_time_async(self.db.read_key_bytes(key), "read_key_bytes").await
}
async fn read_multi_key_bytes(
&self,
keys: Vec<Vec<u8>>,
) -> Result<Vec<Option<Vec<u8>>>, Self::Error> {
log_time_async(self.db.read_multi_key_bytes(keys), "read_multi_key_bytes").await
}
async fn find_keys_by_prefix(&self, key_prefix: &[u8]) -> Result<Self::Keys, Self::Error> {
log_time_async(
self.db.find_keys_by_prefix(key_prefix),
"find_keys_by_prefix",
)
.await
}
async fn find_key_values_by_prefix(
&self,
key_prefix: &[u8],
) -> Result<Self::KeyValues, Self::Error> {
log_time_async(
self.db.find_key_values_by_prefix(key_prefix),
"find_key_values_by_prefix",
)
.await
}
async fn write_batch(&self, batch: Batch) -> Result<(), Self::Error> {
log_time_async(self.db.write_batch(batch, &self.base_key), "write_batch").await
}
fn clone_with_base_key(&self, base_key: Vec<u8>) -> Self {
Self {
db: self.db.clone(),
base_key,
extra: self.extra.clone(),
}
}
}
pub trait CustomSerialize: Sized {
fn to_custom_bytes(&self) -> Result<Vec<u8>, ViewError>;
fn from_custom_bytes(short_key: &[u8]) -> Result<Self, ViewError>;
}
impl CustomSerialize for u128 {
fn to_custom_bytes(&self) -> Result<Vec<u8>, ViewError> {
let mut bytes = bcs::to_bytes(&self)?;
bytes.reverse();
Ok(bytes)
}
fn from_custom_bytes(bytes: &[u8]) -> Result<Self, ViewError> {
let mut bytes = bytes.to_vec();
bytes.reverse();
let value = bcs::from_bytes(&bytes)?;
Ok(value)
}
}
#[cfg(test)]
mod tests {
use linera_views::common::CustomSerialize;
use rand::{Rng, SeedableRng};
use std::collections::BTreeSet;
#[test]
fn test_ordering_serialization() {
let mut rng = rand::rngs::StdRng::seed_from_u64(2);
let n = 1000;
let mut set = BTreeSet::new();
for _ in 0..n {
let val = rng.gen::<u128>();
set.insert(val);
}
let mut vec = Vec::new();
for val in set {
vec.push(val);
}
for i in 1..vec.len() {
let val1 = vec[i - 1];
let val2 = vec[i];
assert!(val1 < val2);
let vec1 = val1.to_custom_bytes().unwrap();
let vec2 = val2.to_custom_bytes().unwrap();
assert!(vec1 < vec2);
let val_ret1 = u128::from_custom_bytes(&vec1).unwrap();
let val_ret2 = u128::from_custom_bytes(&vec2).unwrap();
assert_eq!(val1, val_ret1);
assert_eq!(val2, val_ret2);
}
}
}