use crate::errors::StorageError;
use crate::storage::types::DbRecord;
use crate::storage::types::ValueState;
use crate::storage::types::ValueStateRetrievalFlag;
use crate::storage::Storable;
use log::{debug, error, info, trace, warn};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Default)]
struct TransactionState {
mods: HashMap<Vec<u8>, DbRecord>,
active: bool,
}
#[derive(Default)]
pub struct Transaction {
state: Arc<tokio::sync::RwLock<TransactionState>>,
num_reads: Arc<tokio::sync::RwLock<u64>>,
num_writes: Arc<tokio::sync::RwLock<u64>>,
}
unsafe impl Send for Transaction {}
unsafe impl Sync for Transaction {}
impl std::fmt::Debug for Transaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "a lone transaction")
}
}
impl Transaction {
pub fn new() -> Self {
Self {
state: Arc::new(tokio::sync::RwLock::new(TransactionState {
mods: HashMap::new(),
active: false,
})),
num_reads: Arc::new(tokio::sync::RwLock::new(0)),
num_writes: Arc::new(tokio::sync::RwLock::new(0)),
}
}
pub async fn log_metrics(&self, level: log::Level) {
let mut r = self.num_reads.write().await;
let mut w = self.num_writes.write().await;
let msg = format!("Transaction writes: {}, Transaction reads: {}", *w, *r);
*r = 0;
*w = 0;
drop(r);
drop(w);
match level {
log::Level::Trace => trace!("{}", msg),
log::Level::Debug => debug!("{}", msg),
log::Level::Info => info!("{}", msg),
log::Level::Warn => warn!("{}", msg),
_ => error!("{}", msg),
}
}
pub async fn begin_transaction(&self) -> bool {
debug!("BEGIN begin transaction");
let mut guard = self.state.write().await;
let out = if guard.active {
false
} else {
guard.active = true;
true
};
debug!("END begin transaction");
out
}
pub async fn commit_transaction(&self) -> Result<Vec<DbRecord>, StorageError> {
debug!("BEGIN commit transaction");
let mut guard = self.state.write().await;
if !guard.active {
return Err(StorageError::Transaction(
"Transaction not currently active".to_string(),
));
}
let mut records = guard.mods.values().cloned().collect::<Vec<_>>();
records.sort_by_key(|r| r.transaction_priority());
guard.mods.clear();
guard.active = false;
debug!("END commit transaction");
Ok(records)
}
pub async fn rollback_transaction(&self) -> Result<(), StorageError> {
debug!("BEGIN rollback transaction");
let mut guard = self.state.write().await;
if !guard.active {
return Err(StorageError::Transaction(
"Transaction not currently active".to_string(),
));
}
guard.mods.clear();
guard.active = false;
debug!("END rollback transaction");
Ok(())
}
pub async fn is_transaction_active(&self) -> bool {
debug!("BEGIN is transaction active");
let out = self.state.read().await.active;
debug!("END is transaction active");
out
}
pub async fn get<St: Storable>(&self, key: &St::StorageKey) -> Option<DbRecord> {
debug!("BEGIN transaction get {:?}", key);
let bin_id = St::get_full_binary_key_id(key);
let guard = self.state.read().await;
let out = guard.mods.get(&bin_id).cloned();
#[cfg(feature = "runtime_metrics")]
if out.is_some() {
*(self.num_reads.write().await) += 1;
}
debug!("END transaction get");
out
}
pub async fn batch_set(&self, records: &[DbRecord]) {
debug!("BEGIN transaction set");
let mut guard = self.state.write().await;
for record in records {
guard
.mods
.insert(record.get_full_binary_id(), record.clone());
}
#[cfg(feature = "runtime_metrics")]
{
*(self.num_writes.write().await) += 1;
}
debug!("END transaction set");
}
pub async fn set(&self, record: &DbRecord) {
debug!("BEGIN transaction set");
let bin_id = record.get_full_binary_id();
let mut guard = self.state.write().await;
guard.mods.insert(bin_id, record.clone());
#[cfg(feature = "runtime_metrics")]
{
*(self.num_writes.write().await) += 1;
}
debug!("END transaction set");
}
pub async fn get_users_data(
&self,
usernames: &[crate::AkdLabel],
) -> HashMap<crate::AkdLabel, Vec<ValueState>> {
debug!("BEGIN transaction user version scan");
let mut results: HashMap<crate::AkdLabel, Vec<ValueState>> = HashMap::new();
let mut set = std::collections::HashSet::with_capacity(usernames.len());
for username in usernames.iter() {
if !set.contains(username) {
set.insert(username.clone());
}
}
let guard = self.state.read().await;
for (_key, record) in guard.mods.iter() {
if let DbRecord::ValueState(value_state) = record {
if set.contains(&value_state.username) {
if results.contains_key(&value_state.username) {
if let Some(item) = results.get_mut(&value_state.username) {
item.push(value_state.clone())
}
} else {
results.insert(value_state.username.clone(), vec![value_state.clone()]);
}
}
}
}
for (_k, v) in results.iter_mut() {
v.sort_unstable_by(|a, b| a.epoch.cmp(&b.epoch));
}
debug!("END transaction user version scan");
results
}
pub async fn get_user_state(
&self,
username: &crate::AkdLabel,
flag: ValueStateRetrievalFlag,
) -> Option<ValueState> {
let intermediate = self
.get_users_data(&[username.clone()])
.await
.remove(username)
.unwrap_or_default();
let out = Self::find_appropriate_item(intermediate, flag);
#[cfg(feature = "runtime_metrics")]
if out.is_some() {
*(self.num_reads.write().await) += 1;
}
out
}
pub async fn get_users_states(
&self,
usernames: &[crate::AkdLabel],
flag: ValueStateRetrievalFlag,
) -> HashMap<crate::AkdLabel, ValueState> {
let mut result_map = HashMap::new();
let intermediate = self.get_users_data(usernames).await;
for (key, value_list) in intermediate.into_iter() {
if let Some(found) = Self::find_appropriate_item(value_list, flag) {
result_map.insert(key, found);
}
}
#[cfg(feature = "runtime_metrics")]
{
*(self.num_reads.write().await) += 1;
}
result_map
}
fn find_appropriate_item(
intermediate: Vec<ValueState>,
flag: ValueStateRetrievalFlag,
) -> Option<ValueState> {
match flag {
ValueStateRetrievalFlag::SpecificVersion(version) => intermediate
.into_iter()
.find(|item| item.version == version),
ValueStateRetrievalFlag::SpecificEpoch(epoch) => {
intermediate.into_iter().find(|item| item.epoch == epoch)
}
ValueStateRetrievalFlag::LeqEpoch(epoch) => intermediate
.into_iter()
.rev()
.find(|item| item.epoch <= epoch),
ValueStateRetrievalFlag::MaxEpoch => intermediate.into_iter().last(),
ValueStateRetrievalFlag::MinEpoch => intermediate.into_iter().next(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::append_only_zks::*;
use crate::node_label::*;
use crate::storage::types::*;
use crate::tree_node::*;
use rand::{rngs::OsRng, seq::SliceRandom};
#[tokio::test]
async fn test_commit_order() -> Result<(), StorageError> {
let azks = DbRecord::Azks(Azks {
num_nodes: 0,
latest_epoch: 0,
});
let node1 = DbRecord::TreeNode(TreeNodeWithPreviousValue::from_tree_node(TreeNode {
label: NodeLabel::new(byte_arr_from_u64(0), 0),
last_epoch: 1,
least_descendant_ep: 1,
parent: NodeLabel::new(byte_arr_from_u64(0), 0),
node_type: NodeType::Root,
left_child: None,
right_child: None,
hash: [0u8; 32],
}));
let node2 = DbRecord::TreeNode(TreeNodeWithPreviousValue::from_tree_node(TreeNode {
label: NodeLabel::new(byte_arr_from_u64(1), 1),
last_epoch: 1,
least_descendant_ep: 1,
parent: NodeLabel::new(byte_arr_from_u64(0), 0),
node_type: NodeType::Leaf,
left_child: None,
right_child: None,
hash: [0u8; 32],
}));
let value1 = DbRecord::ValueState(ValueState {
username: AkdLabel::from_utf8_str("test"),
epoch: 1,
label: NodeLabel::new(byte_arr_from_u64(1), 1),
version: 1,
plaintext_val: AkdValue::from_utf8_str("abc123"),
});
let value2 = DbRecord::ValueState(ValueState {
username: AkdLabel::from_utf8_str("test"),
epoch: 2,
label: NodeLabel::new(byte_arr_from_u64(1), 1),
version: 2,
plaintext_val: AkdValue::from_utf8_str("abc1234"),
});
let records = vec![azks, node1, node2, value1, value2];
let mut rng = OsRng;
for _ in 1..10 {
let txn = Transaction::new();
txn.begin_transaction().await;
let mut shuffled = records.clone();
shuffled.shuffle(&mut rng);
for record in shuffled {
txn.set(&record).await;
}
let mut running_priority = 0;
for record in txn.commit_transaction().await? {
let priority = record.transaction_priority();
if priority > running_priority {
running_priority = priority;
} else if priority < running_priority {
panic!("Transaction did not obey record priority when committing");
}
}
}
Ok(())
}
}