use crate::Runtime;
use crate::circuit::circuit_builder::StreamId;
use crate::circuit::metadata::{MEMORY_ALLOCATIONS_COUNT, SIZE_DISTRIBUTION, STATE_RECORDS_COUNT};
use crate::{
Error, NumEntries,
algebra::HasZero,
circuit::checkpointer::Checkpoint,
circuit::{
Circuit, ExportId, ExportStream, FeedbackConnector, GlobalNodeId, OwnershipPreference,
Scope, Stream,
metadata::{
ALLOCATED_MEMORY_BYTES, MetaItem, OperatorMeta, SHARED_MEMORY_BYTES, USED_MEMORY_BYTES,
},
operator_traits::{Operator, StrictOperator, StrictUnaryOperator, UnaryOperator},
},
circuit_cache_key,
storage::file::to_bytes,
};
use feldera_storage::{FileCommitter, StoragePath};
use size_of::{Context, SizeOf};
use std::sync::Arc;
use std::{borrow::Cow, mem::replace};
use super::require_persistent_id;
circuit_cache_key!(DelayedId<C, D>(StreamId => Stream<C, D>));
circuit_cache_key!(NestedDelayedId<C, D>(StreamId => Stream<C, D>));
pub struct DelayedFeedback<C, D>
where
C: Circuit,
{
feedback: FeedbackConnector<C, D, D, Z1<D>>,
output: Stream<C, D>,
export: Stream<C::Parent, D>,
}
impl<C, D> DelayedFeedback<C, D>
where
C: Circuit,
D: Checkpoint + Eq + SizeOf + NumEntries + Clone + HasZero + 'static,
{
pub fn new(circuit: &C) -> Self {
let (ExportStream { local, export }, feedback) =
circuit.add_feedback_with_export(Z1::new(D::zero()));
Self {
feedback,
output: local,
export,
}
}
}
impl<C, D> DelayedFeedback<C, D>
where
C: Circuit,
D: Checkpoint + Eq + SizeOf + NumEntries + Clone + 'static,
{
pub fn with_default(circuit: &C, default: D) -> Self {
let (ExportStream { local, export }, feedback) =
circuit.add_feedback_with_export(Z1::new(default));
Self {
feedback,
output: local,
export,
}
}
pub fn stream(&self) -> &Stream<C, D> {
&self.output
}
pub fn connect(self, input: &Stream<C, D>) {
let Self {
feedback,
output,
export,
} = self;
let circuit = output.circuit().clone();
feedback.connect_with_preference(input, OwnershipPreference::STRONGLY_PREFER_OWNED);
circuit.cache_insert(DelayedId::new(input.stream_id()), output);
circuit.cache_insert(ExportId::new(input.stream_id()), export);
}
}
pub struct DelayedNestedFeedback<C, D> {
feedback: FeedbackConnector<C, D, D, Z1Nested<D>>,
output: Stream<C, D>,
}
impl<C, D> DelayedNestedFeedback<C, D>
where
C: Circuit,
D: Checkpoint + Eq + SizeOf + NumEntries + Clone + 'static,
{
pub fn new(circuit: &C) -> Self
where
D: HasZero,
{
let (output, feedback) = circuit.add_feedback(Z1Nested::new(D::zero()));
Self { feedback, output }
}
pub fn stream(&self) -> &Stream<C, D> {
&self.output
}
pub fn connect(self, input: &Stream<C, D>) {
let Self { feedback, output } = self;
let circuit = output.circuit().clone();
feedback.connect_with_preference(input, OwnershipPreference::STRONGLY_PREFER_OWNED);
circuit.cache_insert(NestedDelayedId::new(input.stream_id()), output);
}
}
impl<C, D> Stream<C, D>
where
C: Circuit,
{
#[track_caller]
pub fn delay(&self) -> Stream<C, D>
where
D: Checkpoint + Eq + SizeOf + NumEntries + Clone + HasZero + 'static,
{
self.circuit()
.cache_get_or_insert_with(DelayedId::new(self.stream_id()), || {
let delay_pid = self.get_persistent_id().map(|pid| format!("{pid}.delay"));
self.circuit()
.add_unary_operator(Z1::new(D::zero()), self)
.set_persistent_id(delay_pid.as_deref())
})
.clone()
}
#[track_caller]
pub fn delay_with_initial_value(&self, initial: D) -> Stream<C, D>
where
D: Checkpoint + Eq + SizeOf + NumEntries + Clone + 'static,
{
self.circuit()
.cache_get_or_insert_with(DelayedId::new(self.stream_id()), move || {
let delay_pid = self.get_persistent_id().map(|pid| format!("{pid}.delay"));
self.circuit()
.add_unary_operator(Z1::new(initial.clone()), self)
.set_persistent_id(delay_pid.as_deref())
})
.clone()
}
#[track_caller]
pub fn delay_nested(&self) -> Stream<C, D>
where
D: Eq + Clone + HasZero + SizeOf + NumEntries + 'static,
{
self.circuit()
.cache_get_or_insert_with(NestedDelayedId::new(self.stream_id()), || {
let delay_pid = self.get_persistent_id().map(|pid| format!("{pid}.delay"));
self.circuit()
.add_unary_operator(Z1Nested::new(D::zero()), self)
.set_persistent_id(delay_pid.as_deref())
})
.clone()
}
}
pub struct Z1<T> {
zero: T,
global_id: GlobalNodeId,
empty_output: bool,
values: T,
}
#[derive(rkyv::Serialize, rkyv::Deserialize, rkyv::Archive)]
pub struct CommittedZ1 {
values: Vec<u8>,
}
impl<T> TryFrom<&Z1<T>> for CommittedZ1
where
T: Checkpoint + Clone,
{
type Error = Error;
fn try_from(z1: &Z1<T>) -> Result<CommittedZ1, Error> {
Ok(CommittedZ1 {
values: z1.values.checkpoint()?,
})
}
}
impl<T> Z1<T>
where
T: Checkpoint + Clone,
{
pub fn new(zero: T) -> Self {
Self {
empty_output: false,
global_id: GlobalNodeId::root(),
zero: zero.clone(),
values: zero,
}
}
fn checkpoint_file<P: AsRef<str>>(base: &StoragePath, persistent_id: P) -> StoragePath {
base.child(format!("z1-{}.dat", persistent_id.as_ref()))
}
}
impl<T> Operator for Z1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("Z^-1")
}
fn clock_start(&mut self, _scope: Scope) {}
fn clock_end(&mut self, _scope: Scope) {
self.empty_output = false;
self.values = self.zero.clone();
}
fn init(&mut self, global_id: &GlobalNodeId) {
self.global_id = global_id.clone();
}
fn metadata(&self, meta: &mut OperatorMeta) {
let bytes = self.values.size_of();
meta.extend(metadata! {
STATE_RECORDS_COUNT => MetaItem::Count(self.values.num_entries_deep()),
ALLOCATED_MEMORY_BYTES => MetaItem::bytes(bytes.total_bytes()),
USED_MEMORY_BYTES => MetaItem::bytes(bytes.used_bytes()),
MEMORY_ALLOCATIONS_COUNT => MetaItem::Count(bytes.distinct_allocations()),
SHARED_MEMORY_BYTES => MetaItem::bytes(bytes.shared_bytes()),
});
}
fn fixedpoint(&self, scope: Scope) -> bool {
if scope == 0 {
self.values.num_entries_shallow() == 0 && self.empty_output
} else {
true
}
}
fn checkpoint(
&mut self,
base: &StoragePath,
persistent_id: Option<&str>,
files: &mut Vec<Arc<dyn FileCommitter>>,
) -> Result<(), Error> {
let persistent_id = require_persistent_id(persistent_id, &self.global_id)?;
let committed: CommittedZ1 = (self as &Self).try_into()?;
let as_bytes = to_bytes(&committed).expect("Serializing CommittedZ1 should work.");
files.push(
Runtime::storage_backend()
.unwrap()
.write(&Self::checkpoint_file(base, persistent_id), as_bytes)?,
);
Ok(())
}
fn restore(&mut self, base: &StoragePath, persistent_id: Option<&str>) -> Result<(), Error> {
let persistent_id = require_persistent_id(persistent_id, &self.global_id)?;
let z1_path = Self::checkpoint_file(base, persistent_id);
let content = Runtime::storage_backend().unwrap().read(&z1_path)?;
let committed = unsafe { rkyv::archived_root::<CommittedZ1>(&content) };
let mut values = self.zero.clone();
values.restore(committed.values.as_slice())?;
self.empty_output = false;
self.values = values;
Ok(())
}
fn clear_state(&mut self) -> Result<(), Error> {
self.empty_output = false;
self.values = self.zero.clone();
Ok(())
}
}
impl<T> UnaryOperator<T, T> for Z1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
async fn eval(&mut self, i: &T) -> T {
replace(&mut self.values, i.clone())
}
async fn eval_owned(&mut self, i: T) -> T {
replace(&mut self.values, i)
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
impl<T> StrictOperator<T> for Z1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
fn get_output(&mut self) -> T {
self.empty_output = self.values.num_entries_shallow() == 0;
replace(&mut self.values, self.zero.clone())
}
fn get_final_output(&mut self) -> T {
self.get_output()
}
}
impl<T> StrictUnaryOperator<T, T> for Z1<T>
where
T: Checkpoint + SizeOf + NumEntries + Clone + 'static,
{
async fn eval_strict(&mut self, i: &T) {
self.values = i.clone();
}
async fn eval_strict_owned(&mut self, i: T) {
self.values = i;
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
pub struct Z1Nested<T> {
zero: T,
timestamp: usize,
values: Vec<T>,
}
impl<T> Z1Nested<T> {
const fn new(zero: T) -> Self {
Self {
zero,
timestamp: 0,
values: Vec::new(),
}
}
fn reset(&mut self) {
self.timestamp = 0;
self.values.clear();
}
}
impl<T> Operator for Z1Nested<T>
where
T: Eq + SizeOf + NumEntries + Clone + 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("Z^-1 (nested)")
}
fn clock_start(&mut self, scope: Scope) {
if scope == 0 {
self.values.truncate(self.timestamp);
}
self.timestamp = 0;
}
fn clock_end(&mut self, scope: Scope) {
if scope > 0 {
self.reset();
}
}
fn metadata(&self, meta: &mut OperatorMeta) {
let total_size: usize = self
.values
.iter()
.map(|batch| batch.num_entries_deep())
.sum();
let batch_sizes = self
.values
.iter()
.map(|batch| MetaItem::Int(batch.num_entries_deep()))
.collect();
let total_bytes = {
let mut context = Context::new();
for value in &self.values {
value.size_of_with_context(&mut context);
}
context.total_size()
};
meta.extend(metadata! {
STATE_RECORDS_COUNT => MetaItem::Count(total_size),
SIZE_DISTRIBUTION => MetaItem::Array(batch_sizes),
ALLOCATED_MEMORY_BYTES => MetaItem::bytes(total_bytes.total_bytes()),
USED_MEMORY_BYTES => MetaItem::bytes(total_bytes.used_bytes()),
MEMORY_ALLOCATIONS_COUNT => MetaItem::Count(total_bytes.distinct_allocations()),
SHARED_MEMORY_BYTES => MetaItem::bytes(total_bytes.shared_bytes()),
});
}
fn fixedpoint(&self, scope: Scope) -> bool {
if scope == 0 {
self.values
.iter()
.skip(self.timestamp - 1)
.all(|v| *v == self.zero)
} else {
false
}
}
}
impl<T> UnaryOperator<T, T> for Z1Nested<T>
where
T: Eq + SizeOf + NumEntries + Clone + 'static,
{
async fn eval(&mut self, i: &T) -> T {
debug_assert!(self.timestamp <= self.values.len());
if self.timestamp == self.values.len() {
self.values.push(self.zero.clone());
} else if self.timestamp == self.values.len() - 1 {
self.values.push(self.values.last().unwrap().clone());
}
let result = replace(&mut self.values[self.timestamp], i.clone());
self.timestamp += 1;
result
}
async fn eval_owned(&mut self, i: T) -> T {
debug_assert!(self.timestamp <= self.values.len());
if self.timestamp == self.values.len() {
self.values.push(self.zero.clone());
} else if self.timestamp == self.values.len() - 1 {
self.values.push(self.values.last().unwrap().clone());
}
let result = replace(&mut self.values[self.timestamp], i);
self.timestamp += 1;
result
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
impl<T> StrictOperator<T> for Z1Nested<T>
where
T: Eq + SizeOf + NumEntries + Clone + 'static,
{
fn get_output(&mut self) -> T {
if self.timestamp >= self.values.len() {
assert_eq!(self.timestamp, self.values.len());
self.values.push(self.zero.clone());
} else if self.timestamp == self.values.len() - 1 {
self.values.push(self.values.last().unwrap().clone())
}
replace(
unsafe { self.values.get_unchecked_mut(self.timestamp) },
self.zero.clone(),
)
}
fn get_final_output(&mut self) -> T {
self.get_output()
}
}
impl<T> StrictUnaryOperator<T, T> for Z1Nested<T>
where
T: Eq + SizeOf + NumEntries + Clone + 'static,
{
async fn eval_strict(&mut self, i: &T) {
debug_assert!(self.timestamp < self.values.len());
self.values[self.timestamp] = i.clone();
self.timestamp += 1;
}
async fn eval_strict_owned(&mut self, i: T) {
debug_assert!(self.timestamp < self.values.len());
self.values[self.timestamp] = i;
self.timestamp += 1;
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
#[cfg(test)]
mod test {
use crate::{
circuit::operator_traits::{Operator, StrictOperator, StrictUnaryOperator, UnaryOperator},
operator::{Z1, Z1Nested},
};
#[tokio::test]
async fn z1_test() {
let mut z1 = Z1::new(0);
let expected_result = vec![0, 1, 2, 0, 4, 5];
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.eval(&1).await);
res.push(z1.eval(&2).await);
res.push(z1.eval(&3).await);
z1.clock_end(0);
z1.clock_start(0);
res.push(z1.eval_owned(4).await);
res.push(z1.eval_owned(5).await);
res.push(z1.eval_owned(6).await);
z1.clock_end(0);
assert_eq!(res, expected_result);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.get_output());
z1.eval_strict(&1).await;
res.push(z1.get_output());
z1.eval_strict(&2).await;
res.push(z1.get_output());
z1.eval_strict(&3).await;
z1.clock_end(0);
z1.clock_start(0);
res.push(z1.get_output());
z1.eval_strict_owned(4).await;
res.push(z1.get_output());
z1.eval_strict_owned(5).await;
res.push(z1.get_output());
z1.eval_strict_owned(6).await;
z1.clock_end(0);
assert_eq!(res, expected_result);
}
#[tokio::test]
async fn z1_nested_test() {
let mut z1 = Z1Nested::new(0);
z1.clock_start(1);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.eval_owned(1).await);
res.push(z1.eval(&2).await);
res.push(z1.eval(&3).await);
z1.clock_end(0);
assert_eq!(res.as_slice(), &[0, 0, 0]);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.eval_owned(4).await);
res.push(z1.eval_owned(5).await);
z1.clock_end(0);
assert_eq!(res.as_slice(), &[1, 2]);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.eval_owned(6).await);
res.push(z1.eval_owned(7).await);
res.push(z1.eval(&8).await);
res.push(z1.eval(&9).await);
z1.clock_end(0);
assert_eq!(res.as_slice(), &[4, 5, 5, 5]);
z1.clock_end(1);
z1.clock_start(1);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.get_output());
z1.eval_strict(&1).await;
res.push(z1.get_output());
z1.eval_strict(&2).await;
res.push(z1.get_output());
z1.eval_strict(&3).await;
z1.clock_end(0);
assert_eq!(res.as_slice(), &[0, 0, 0]);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.get_output());
z1.eval_strict_owned(4).await;
res.push(z1.get_output());
z1.eval_strict_owned(5).await;
z1.clock_end(0);
assert_eq!(res.as_slice(), &[1, 2]);
let mut res = Vec::new();
z1.clock_start(0);
res.push(z1.get_output());
z1.eval_strict_owned(6).await;
res.push(z1.get_output());
z1.eval_strict_owned(7).await;
res.push(z1.get_output());
z1.eval_strict_owned(8).await;
res.push(z1.get_output());
z1.eval_strict_owned(9).await;
z1.clock_end(0);
assert_eq!(res.as_slice(), &[4, 5, 5, 5]);
z1.clock_end(1);
}
}