use std::{borrow::Cow, sync::Arc};
use feldera_storage::{FileCommitter, StoragePath};
use size_of::SizeOf;
use crate::{
Circuit, Error, NumEntries, Runtime, Scope, Stream,
circuit::{
GlobalNodeId, OwnershipPreference,
checkpointer::Checkpoint,
operator_traits::{Operator, UnaryOperator},
},
operator::require_persistent_id,
storage::file::to_bytes,
};
impl<C, D> Stream<C, D>
where
C: Circuit,
{
#[track_caller]
pub fn transaction_delay_with_initial_value(&self, initial: D) -> Stream<C, D>
where
D: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
let delay_pid = self
.get_persistent_id()
.map(|pid| format!("{pid}.transaction_delay"));
self.circuit()
.add_unary_operator(TransactionZ1::new(initial.clone()), self)
.set_persistent_id(delay_pid.as_deref())
}
}
pub struct TransactionZ1<T> {
zero: T,
global_id: GlobalNodeId,
empty_output: bool,
flush: bool,
old_value: T,
new_value: T,
}
#[derive(rkyv::Serialize, rkyv::Deserialize, rkyv::Archive)]
pub struct CommittedTransactionZ1 {
old_value: Vec<u8>,
new_value: Vec<u8>,
}
impl<T> TryFrom<&TransactionZ1<T>> for CommittedTransactionZ1
where
T: Checkpoint,
{
type Error = Error;
fn try_from(z1: &TransactionZ1<T>) -> Result<CommittedTransactionZ1, Error> {
Ok(CommittedTransactionZ1 {
old_value: z1.old_value.checkpoint()?,
new_value: z1.new_value.checkpoint()?,
})
}
}
impl<T> TransactionZ1<T>
where
T: Checkpoint + Clone,
{
pub fn new(zero: T) -> Self {
Self {
empty_output: false,
flush: false,
global_id: GlobalNodeId::root(),
zero: zero.clone(),
old_value: zero.clone(),
new_value: zero.clone(),
}
}
fn checkpoint_file<P: AsRef<str>>(base: &StoragePath, persistent_id: P) -> StoragePath {
base.child(format!("transaction-z1-{}.dat", persistent_id.as_ref()))
}
}
impl<T> Operator for TransactionZ1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("Transaction Z^-1")
}
fn clock_start(&mut self, _scope: Scope) {}
fn clock_end(&mut self, _scope: Scope) {
self.empty_output = false;
self.new_value = self.zero.clone();
self.old_value = self.zero.clone();
}
fn init(&mut self, global_id: &GlobalNodeId) {
self.global_id = global_id.clone();
}
fn fixedpoint(&self, scope: Scope) -> bool {
if scope == 0 {
self.new_value.num_entries_shallow() == 0
&& self.old_value.num_entries_shallow() == 0
&& self.empty_output
} else {
true
}
}
fn checkpoint(
&mut self,
base: &StoragePath,
persistent_id: Option<&str>,
files: &mut Vec<Arc<dyn FileCommitter>>,
) -> Result<(), Error> {
let persistent_id = require_persistent_id(persistent_id, &self.global_id)?;
let committed: CommittedTransactionZ1 = (self as &Self).try_into()?;
let as_bytes =
to_bytes(&committed).expect("Serializing CommittedTransactionZ1 should work.");
files.push(
Runtime::storage_backend()
.unwrap()
.write(&Self::checkpoint_file(base, persistent_id), as_bytes)?,
);
Ok(())
}
fn restore(&mut self, base: &StoragePath, persistent_id: Option<&str>) -> Result<(), Error> {
let persistent_id = require_persistent_id(persistent_id, &self.global_id)?;
let z1_path = Self::checkpoint_file(base, persistent_id);
let content = Runtime::storage_backend().unwrap().read(&z1_path)?;
let committed = unsafe { rkyv::archived_root::<CommittedTransactionZ1>(&content) };
let mut old_value = self.zero.clone();
let mut new_value = self.zero.clone();
old_value.restore(committed.old_value.as_slice())?;
new_value.restore(committed.new_value.as_slice())?;
self.empty_output = false;
self.old_value = old_value;
self.new_value = new_value;
Ok(())
}
fn clear_state(&mut self) -> Result<(), Error> {
self.empty_output = false;
self.old_value = self.zero.clone();
self.new_value = self.zero.clone();
Ok(())
}
fn flush(&mut self) {
self.flush = true;
}
}
impl<T> UnaryOperator<T, T> for TransactionZ1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
async fn eval(&mut self, i: &T) -> T {
if self.flush {
self.flush = false;
self.old_value = self.new_value.clone();
self.new_value = i.clone();
}
self.old_value.clone()
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
#[cfg(test)]
mod test {
use crate::{
circuit::operator_traits::{Operator, UnaryOperator},
operator::TransactionZ1,
};
#[tokio::test]
async fn transaction_z1_test() {
let mut z1 = TransactionZ1::new(0);
z1.clock_start(0);
assert_eq!(z1.eval(&1).await, 0);
assert_eq!(z1.eval(&2).await, 0);
assert_eq!(z1.eval(&3).await, 0);
z1.flush();
assert_eq!(z1.eval(&4).await, 0);
assert_eq!(z1.eval(&5).await, 0);
assert_eq!(z1.eval(&6).await, 0);
z1.flush();
assert_eq!(z1.eval(&7).await, 4);
assert_eq!(z1.eval(&8).await, 4);
assert_eq!(z1.eval(&9).await, 4);
z1.clock_end(0);
assert_eq!(z1.eval(&10).await, 0);
}
}