use async_trait::async_trait;
use bamboo_rs_core_ed25519_yasmf::entry::is_lipmaa_required;
use crate::entry::LogId;
use crate::entry::SeqNum;
use crate::hash::Hash;
use crate::identity::Author;
use crate::schema::SchemaId;
use crate::storage_provider::errors::EntryStorageError;
use crate::storage_provider::traits::AsStorageEntry;
#[async_trait]
pub trait EntryStore<StorageEntry: AsStorageEntry> {
async fn insert_entry(&self, value: StorageEntry) -> Result<(), EntryStorageError>;
async fn get_entry_at_seq_num(
&self,
author: &Author,
log_id: &LogId,
seq_num: &SeqNum,
) -> Result<Option<StorageEntry>, EntryStorageError>;
async fn get_entry_by_hash(
&self,
hash: &Hash,
) -> Result<Option<StorageEntry>, EntryStorageError>;
async fn try_get_backlink(
&self,
entry: &StorageEntry,
) -> Result<Option<StorageEntry>, EntryStorageError> {
if entry.seq_num().is_first() {
return Ok(None);
};
let backlink_seq_num = SeqNum::new(entry.seq_num().as_u64() - 1).unwrap();
let expected_backlink = self
.get_entry_at_seq_num(&entry.author(), &entry.log_id(), &backlink_seq_num)
.await?
.ok_or_else(|| EntryStorageError::ExpectedBacklinkMissing(entry.hash()))?;
if expected_backlink.hash() != entry.backlink_hash().unwrap() {
return Err(EntryStorageError::InvalidBacklinkPassed(entry.hash()));
}
Ok(Some(expected_backlink))
}
async fn try_get_skiplink(
&self,
entry: &StorageEntry,
) -> Result<Option<StorageEntry>, EntryStorageError> {
if !is_lipmaa_required(entry.seq_num().as_u64()) && entry.skiplink_hash().is_none() {
return Ok(None);
};
let expected_skiplink = match entry.seq_num().skiplink_seq_num() {
Some(seq_num) => {
let expected_skiplink_entry = self
.get_entry_at_seq_num(&entry.author(), &entry.log_id(), &seq_num)
.await?
.ok_or_else(|| EntryStorageError::ExpectedSkiplinkMissing(entry.hash()))?;
Some(expected_skiplink_entry)
}
None => None,
};
if expected_skiplink.clone().map(|entry| entry.hash()) != entry.skiplink_hash() {
return Err(EntryStorageError::InvalidSkiplinkPassed(entry.hash()));
}
Ok(expected_skiplink)
}
async fn get_latest_entry(
&self,
author: &Author,
log_id: &LogId,
) -> Result<Option<StorageEntry>, EntryStorageError>;
async fn get_entries_by_schema(
&self,
schema: &SchemaId,
) -> Result<Vec<StorageEntry>, EntryStorageError>;
async fn get_paginated_log_entries(
&self,
author: &Author,
log_id: &LogId,
seq_num: &SeqNum,
max_number_of_entries: usize,
) -> Result<Vec<StorageEntry>, EntryStorageError>;
async fn determine_next_skiplink(
&self,
entry: &StorageEntry,
) -> Result<Option<Hash>, EntryStorageError> {
let next_seq_num = entry.seq_num().clone().next().unwrap();
let skiplink_seq_num = next_seq_num.skiplink_seq_num().unwrap();
let entry_skiplink_hash = if is_lipmaa_required(next_seq_num.as_u64()) {
let skiplink_entry = match self
.get_entry_at_seq_num(&entry.author(), &entry.log_id(), &skiplink_seq_num)
.await?
{
Some(entry) => Ok(entry),
None => Err(EntryStorageError::ExpectedNextSkiplinkMissing),
}?;
Ok(Some(skiplink_entry.hash()))
} else {
Ok(None)
};
entry_skiplink_hash
}
async fn get_certificate_pool(
&self,
author_id: &Author,
log_id: &LogId,
seq_num: &SeqNum,
) -> Result<Vec<StorageEntry>, EntryStorageError>;
}
#[cfg(test)]
pub mod tests {
use std::convert::TryFrom;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use lipmaa_link::get_lipmaa_links_back_to;
use rstest::rstest;
use crate::entry::{sign_and_encode, Entry, EntrySigned, LogId, SeqNum};
use crate::hash::Hash;
use crate::identity::{Author, KeyPair};
use crate::operation::{AsOperation, Operation, OperationEncoded};
use crate::schema::SchemaId;
use crate::storage_provider::errors::EntryStorageError;
use crate::storage_provider::traits::test_utils::{
test_db, SimplestStorageProvider, StorageEntry, TestStore, SKIPLINK_ENTRIES,
};
use crate::storage_provider::traits::{AsStorageEntry, EntryStore};
use crate::test_utils::fixtures::{
entry, entry_signed_encoded, key_pair, operation_encoded, random_key_pair, schema,
};
#[async_trait]
impl EntryStore<StorageEntry> for SimplestStorageProvider {
async fn insert_entry(&self, entry: StorageEntry) -> Result<(), EntryStorageError> {
self.db_insert_entry(entry);
Ok(())
}
async fn get_entry_by_hash(
&self,
hash: &Hash,
) -> Result<Option<StorageEntry>, EntryStorageError> {
let entries = self.entries.lock().unwrap();
let entry = entries.iter().find(|entry| entry.hash() == *hash);
Ok(entry.cloned())
}
async fn get_entry_at_seq_num(
&self,
author: &Author,
log_id: &LogId,
seq_num: &SeqNum,
) -> Result<Option<StorageEntry>, EntryStorageError> {
let entries = self.entries.lock().unwrap();
let entry = entries.iter().find(|entry| {
entry.author() == *author
&& entry.log_id() == *log_id
&& entry.seq_num() == *seq_num
});
Ok(entry.cloned())
}
async fn get_latest_entry(
&self,
author: &Author,
log_id: &LogId,
) -> Result<Option<StorageEntry>, EntryStorageError> {
let entries = self.entries.lock().unwrap();
let latest_entry = entries
.iter()
.filter(|entry| entry.author() == *author && entry.log_id() == *log_id)
.max_by_key(|entry| entry.seq_num().as_u64());
Ok(latest_entry.cloned())
}
async fn get_paginated_log_entries(
&self,
author: &Author,
log_id: &LogId,
seq_num: &SeqNum,
max_number_of_entries: usize,
) -> Result<Vec<StorageEntry>, EntryStorageError> {
let mut entries: Vec<StorageEntry> = Vec::new();
let mut seq_num = *seq_num;
while entries.len() < max_number_of_entries {
match self.get_entry_at_seq_num(author, log_id, &seq_num).await? {
Some(next_entry) => entries.push(next_entry),
None => break,
};
match seq_num.next() {
Some(next_seq_num) => seq_num = next_seq_num,
None => break,
};
}
Ok(entries)
}
async fn get_entries_by_schema(
&self,
schema: &SchemaId,
) -> Result<Vec<StorageEntry>, EntryStorageError> {
let entries = self.entries.lock().unwrap();
let entries: Vec<StorageEntry> = entries
.iter()
.filter(|entry| entry.operation().schema() == *schema)
.map(|e| e.to_owned())
.collect();
Ok(entries)
}
async fn get_certificate_pool(
&self,
author: &Author,
log_id: &LogId,
initial_seq_num: &SeqNum,
) -> Result<Vec<StorageEntry>, EntryStorageError> {
let seq_num = initial_seq_num.as_u64();
let cert_pool_seq_nums: Vec<SeqNum> = get_lipmaa_links_back_to(seq_num, 1)
.iter()
.map(|seq_num| SeqNum::new(*seq_num).unwrap())
.collect();
let mut cert_pool: Vec<StorageEntry> = Vec::new();
for seq_num in cert_pool_seq_nums {
let entry = match self.get_entry_at_seq_num(author, log_id, &seq_num).await? {
Some(entry) => Ok(entry),
None => Err(EntryStorageError::CertPoolEntryMissing(seq_num.as_u64())),
}?;
cert_pool.push(entry);
}
Ok(cert_pool)
}
}
#[rstest]
#[async_std::test]
async fn insert_get_entry(
entry_signed_encoded: EntrySigned,
operation_encoded: OperationEncoded,
) {
let store = SimplestStorageProvider {
logs: Arc::new(Mutex::new(Vec::new())),
entries: Arc::new(Mutex::new(Vec::new())),
operations: Arc::new(Mutex::new(Vec::new())),
};
let storage_entry = StorageEntry::new(&entry_signed_encoded, &operation_encoded).unwrap();
assert!(store.insert_entry(storage_entry.clone()).await.is_ok());
let entry_at_seq_num = store
.get_entry_at_seq_num(
&storage_entry.author(),
&storage_entry.log_id(),
&storage_entry.seq_num(),
)
.await;
assert!(entry_at_seq_num.is_ok());
assert_eq!(entry_at_seq_num.unwrap().unwrap(), storage_entry)
}
#[rstest]
#[async_std::test]
async fn get_latest_entry(
entry_signed_encoded: EntrySigned,
operation_encoded: OperationEncoded,
) {
let store = SimplestStorageProvider {
logs: Arc::new(Mutex::new(Vec::new())),
entries: Arc::new(Mutex::new(Vec::new())),
operations: Arc::new(Mutex::new(Vec::new())),
};
let storage_entry = StorageEntry::new(&entry_signed_encoded, &operation_encoded).unwrap();
assert!(store
.get_latest_entry(&storage_entry.author(), &LogId::default())
.await
.unwrap()
.is_none());
assert!(store.insert_entry(storage_entry.clone()).await.is_ok());
assert_eq!(
store
.get_latest_entry(&storage_entry.author(), &LogId::default())
.await
.unwrap()
.unwrap(),
storage_entry
);
}
#[rstest]
#[async_std::test]
async fn get_by_schema(
#[from(random_key_pair)] key_pair_1: KeyPair,
#[from(random_key_pair)] key_pair_2: KeyPair,
entry: Entry,
operation_encoded: OperationEncoded,
schema: SchemaId,
) {
let store = SimplestStorageProvider {
logs: Arc::new(Mutex::new(Vec::new())),
entries: Arc::new(Mutex::new(Vec::new())),
operations: Arc::new(Mutex::new(Vec::new())),
};
let author_1_entry = sign_and_encode(&entry, &key_pair_1).unwrap();
let author_2_entry = sign_and_encode(&entry, &key_pair_2).unwrap();
let author_1_entry = StorageEntry::new(&author_1_entry, &operation_encoded).unwrap();
let author_2_entry = StorageEntry::new(&author_2_entry, &operation_encoded).unwrap();
assert!(store
.get_entries_by_schema(&schema)
.await
.unwrap()
.is_empty());
store.insert_entry(author_1_entry).await.unwrap();
store.insert_entry(author_2_entry).await.unwrap();
assert_eq!(store.get_entries_by_schema(&schema).await.unwrap().len(), 2);
}
#[rstest]
#[async_std::test]
async fn get_entry_by_hash(
#[from(test_db)]
#[with(3, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
assert_eq!(
entries.get(0).cloned(),
db.store
.get_entry_by_hash(&entries[0].hash())
.await
.unwrap()
);
assert_eq!(
entries.get(1).cloned(),
db.store
.get_entry_by_hash(&entries[1].hash())
.await
.unwrap()
);
assert_eq!(
entries.get(2).cloned(),
db.store
.get_entry_by_hash(&entries[2].hash())
.await
.unwrap()
);
}
#[rstest]
#[async_std::test]
async fn try_get_backlink(
#[values[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15 ,16]] seq_num: usize,
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let backlink = if seq_num == 1 {
None
} else {
entries.get(seq_num - 2).cloned()
};
assert_eq!(
backlink,
db.store
.try_get_backlink(&entries[seq_num - 1])
.await
.unwrap()
);
}
#[rstest]
#[async_std::test]
async fn try_get_backlink_entry_missing(
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let entry_at_seq_num_two = entries.get(1).unwrap();
{
db.store.entries.lock().unwrap().remove(0);
}
assert_eq!(
db.store
.try_get_backlink(entry_at_seq_num_two)
.await
.unwrap_err()
.to_string(),
format!(
"Could not find expected backlink in database for entry with id: {}",
entry_at_seq_num_two.hash()
)
);
}
#[rstest]
#[async_std::test]
async fn try_get_backlink_invalid_skiplink(
key_pair: KeyPair,
operation_encoded: OperationEncoded,
#[from(test_db)]
#[with(4, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let entry_at_seq_num_four = entries.get(3).unwrap();
let entry_at_seq_num_four_with_wrong_backlink = Entry::new(
&entry_at_seq_num_four.log_id(),
Some(&Operation::from(&operation_encoded)),
entry_at_seq_num_four.skiplink_hash().as_ref(),
Some(&Hash::new_from_bytes(vec![1, 2, 3]).unwrap()),
&entry_at_seq_num_four.seq_num(),
)
.unwrap();
let entry_at_seq_num_four_with_wrong_backlink =
sign_and_encode(&entry_at_seq_num_four_with_wrong_backlink, &key_pair).unwrap();
let entry_at_seq_num_four_with_wrong_backlink = StorageEntry::new(
&entry_at_seq_num_four_with_wrong_backlink,
&operation_encoded,
)
.unwrap();
assert_eq!(
db.store
.try_get_backlink(&entry_at_seq_num_four_with_wrong_backlink)
.await
.unwrap_err()
.to_string(),
format!(
"The backlink hash encoded in the entry: {} did not match the expected backlink hash",
entry_at_seq_num_four_with_wrong_backlink.hash()
)
);
}
#[rstest(
case(1, None),
case(2, None),
case(3, None),
case(4, Some(1)),
case(5, None),
case(6, None),
case(7, None),
case(8, Some(4)),
case(9, None),
case(10, None),
case(11, None),
case(12, Some(8)),
case(13, Some(4)),
case(14, None),
case(15, None),
case(16, None)
)]
#[async_std::test]
async fn try_get_skiplink(
#[case] seq_num: usize,
#[case] expected_skiplink_seq_num: Option<usize>,
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let expected_skiplink =
expected_skiplink_seq_num.map(|seq_num| entries.get(seq_num - 1).cloned().unwrap());
assert_eq!(
expected_skiplink,
db.store
.try_get_skiplink(&entries[seq_num - 1])
.await
.unwrap()
);
}
#[rstest]
#[async_std::test]
async fn try_get_skiplink_entry_missing(
#[from(test_db)]
#[with(4, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let entry_at_seq_num_four = entries.get(3).unwrap();
{
db.store.entries.lock().unwrap().remove(0);
}
assert_eq!(
db.store
.try_get_skiplink(entry_at_seq_num_four)
.await
.unwrap_err()
.to_string(),
format!(
"Could not find expected skiplink in database for entry with id: {}",
entry_at_seq_num_four.hash()
)
);
}
#[rstest]
#[async_std::test]
async fn try_get_skiplink_invalid_skiplink(
key_pair: KeyPair,
operation_encoded: OperationEncoded,
#[from(test_db)]
#[with(4, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let entry_at_seq_num_four = entries.get(3).unwrap();
let entry_at_seq_num_four_with_wrong_skiplink = Entry::new(
&entry_at_seq_num_four.log_id(),
Some(&Operation::from(&operation_encoded)),
Some(&Hash::new_from_bytes(vec![1, 2, 3]).unwrap()),
entry_at_seq_num_four.backlink_hash().as_ref(),
&entry_at_seq_num_four.seq_num(),
)
.unwrap();
let entry_at_seq_num_four_with_wrong_skiplink =
sign_and_encode(&entry_at_seq_num_four_with_wrong_skiplink, &key_pair).unwrap();
let entry_at_seq_num_four_with_wrong_skiplink = StorageEntry::new(
&entry_at_seq_num_four_with_wrong_skiplink,
&operation_encoded,
)
.unwrap();
assert_eq!(
db.store
.try_get_skiplink(&entry_at_seq_num_four_with_wrong_skiplink)
.await
.unwrap_err()
.to_string(),
format!(
"The skiplink hash encoded in the entry: {} did not match the known hash of the skiplink target",
entry_at_seq_num_four_with_wrong_skiplink.hash()
)
);
}
#[rstest]
#[async_std::test]
async fn can_determine_next_skiplink(
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
for seq_num in 1..10 {
let current_entry = entries.get(seq_num - 1).unwrap();
let next_entry_skiplink = db.store.determine_next_skiplink(current_entry).await;
assert!(next_entry_skiplink.is_ok());
if SKIPLINK_ENTRIES.contains(&((seq_num + 1) as u64)) {
assert!(next_entry_skiplink.unwrap().is_some());
} else {
assert!(next_entry_skiplink.unwrap().is_none())
}
}
}
#[rstest]
#[async_std::test]
async fn skiplink_does_not_exist(
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let entries = db.store.entries.lock().unwrap().clone();
let logs = db.store.logs.lock().unwrap().clone();
let log_entries_with_skiplink_missing = vec![
entries.get(0).unwrap().clone(),
entries.get(1).unwrap().clone(),
entries.get(2).unwrap().clone(),
entries.get(4).unwrap().clone(),
entries.get(5).unwrap().clone(),
];
let new_db = SimplestStorageProvider {
logs: Arc::new(Mutex::new(logs)),
entries: Arc::new(Mutex::new(log_entries_with_skiplink_missing)),
operations: Arc::new(Mutex::new(Vec::new())),
};
let error_response = new_db
.determine_next_skiplink(entries.get(6).unwrap())
.await;
assert_eq!(
format!("{}", error_response.unwrap_err()),
"Could not find expected skiplink entry in database"
)
}
#[rstest]
#[async_std::test]
async fn get_n_entries(
key_pair: KeyPair,
#[from(test_db)]
#[with(16, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let author = Author::try_from(*key_pair.public_key()).unwrap();
let log_id = LogId::default();
let five_entries = db
.store
.get_paginated_log_entries(&author, &log_id, &SeqNum::new(1).unwrap(), 5)
.await
.unwrap();
assert_eq!(five_entries.len(), 5);
let end_of_log_reached = db
.store
.get_paginated_log_entries(&author, &log_id, &SeqNum::new(1).unwrap(), 1000)
.await
.unwrap();
assert_eq!(end_of_log_reached.len(), 16);
let first_entry_not_found = db
.store
.get_paginated_log_entries(&author, &log_id, &SeqNum::new(10000).unwrap(), 1)
.await
.unwrap();
assert!(first_entry_not_found.is_empty());
}
#[rstest]
#[async_std::test]
async fn get_cert_pool(
key_pair: KeyPair,
#[from(test_db)]
#[with(17, 1)]
#[future]
db: TestStore,
) {
let db = db.await;
let author = Author::try_from(*key_pair.public_key()).unwrap();
let log_id = LogId::default();
let cert_pool = db
.store
.get_certificate_pool(&author, &log_id, &SeqNum::new(16).unwrap())
.await
.unwrap();
let seq_nums: Vec<u64> = cert_pool
.iter()
.map(|entry| entry.seq_num().as_u64())
.collect();
assert_eq!(seq_nums, vec![15, 14, 13, 4, 1]);
}
}