use std::{
borrow::Cow,
panic::Location,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use size_of::SizeOf;
use typedmap::TypedMapKey;
use crate::{
Circuit, Error, NumEntries, Runtime, Scope, Stream,
circuit::{
LocalStoreMarker,
circuit_builder::StreamId,
metadata::{
ALLOCATED_MEMORY_BYTES, BatchSizeStats, INPUT_BATCHES_STATS, MEMORY_ALLOCATIONS_COUNT,
MetaItem, OUTPUT_BATCHES_STATS, OperatorLocation, OperatorMeta, SHARED_MEMORY_BYTES,
STATE_RECORDS_COUNT, USED_MEMORY_BYTES,
},
operator_traits::{Operator, UnaryOperator},
},
circuit_cache_key,
trace::{Batch, BatchReader, Spine, Trace},
};
circuit_cache_key!(AccumulatorId<C, B: Batch>(StreamId => (Stream<C, Option<Spine<B>>>, Arc<AtomicUsize>)));
#[derive(Hash, PartialEq, Eq)]
struct EnableCountId {
id: usize,
}
impl EnableCountId {
fn new(id: usize) -> Self {
Self { id }
}
}
impl TypedMapKey<LocalStoreMarker> for EnableCountId {
type Value = Arc<AtomicUsize>;
}
impl<C, B> Stream<C, B>
where
C: Circuit,
B: Batch,
{
pub fn dyn_accumulate(&self, factories: &B::Factories) -> Stream<C, Option<Spine<B>>> {
let (stream, enable_count) = self.dyn_accumulate_with_enable_count(factories);
enable_count.fetch_add(1, Ordering::AcqRel);
stream
}
pub fn dyn_accumulate_with_enable_count(
&self,
factories: &B::Factories,
) -> (Stream<C, Option<Spine<B>>>, Arc<AtomicUsize>) {
self.circuit()
.cache_get_or_insert_with(AccumulatorId::new(self.stream_id()), || {
let accumulator = Accumulator::<B>::new(factories, Location::caller());
let enable_count = accumulator.enable_count.clone();
let stream = self
.circuit()
.add_unary_operator(accumulator, &self.try_sharded_version());
stream.mark_sharded_if(self);
(stream, enable_count)
})
.clone()
}
}
pub struct Accumulator<B>
where
B: Batch,
{
factories: B::Factories,
state: Spine<B>,
flush: bool,
location: &'static Location<'static>,
input_batch_stats: BatchSizeStats,
output_batch_stats: BatchSizeStats,
enable_count: Arc<AtomicUsize>,
enabled_during_current_transaction: Option<bool>,
}
impl<B> Accumulator<B>
where
B: Batch,
{
pub fn new(factories: &B::Factories, location: &'static Location<'static>) -> Self {
let enable_count = match Runtime::runtime() {
None => Arc::new(AtomicUsize::new(0)),
Some(runtime) => {
let accumulator_id = runtime.sequence_next();
runtime
.local_store()
.entry(EnableCountId::new(accumulator_id))
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
.value()
.clone()
}
};
Self {
factories: factories.clone(),
state: Spine::new(factories),
flush: false,
location,
input_batch_stats: BatchSizeStats::new(),
output_batch_stats: BatchSizeStats::new(),
enable_count,
enabled_during_current_transaction: None,
}
}
}
impl<B> Operator for Accumulator<B>
where
B: Batch,
{
fn name(&self) -> std::borrow::Cow<'static, str> {
Cow::Borrowed("Accumulator")
}
fn location(&self) -> OperatorLocation {
Some(self.location)
}
fn metadata(&self, meta: &mut OperatorMeta) {
let total_size = self.state.num_entries_deep();
let bytes = self.state.size_of();
meta.extend(metadata! {
STATE_RECORDS_COUNT => MetaItem::Count(total_size),
ALLOCATED_MEMORY_BYTES => MetaItem::bytes(bytes.total_bytes()),
USED_MEMORY_BYTES => MetaItem::bytes(bytes.used_bytes()),
MEMORY_ALLOCATIONS_COUNT => MetaItem::Count(bytes.distinct_allocations()),
SHARED_MEMORY_BYTES => MetaItem::bytes(bytes.shared_bytes()),
INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
});
self.state.metadata(meta);
}
fn clock_start(&mut self, _scope: Scope) {
debug_assert!(self.state.is_empty());
}
fn clock_end(&mut self, _scope: Scope) {
debug_assert!(self.state.is_empty());
}
fn fixedpoint(&self, _scope: Scope) -> bool {
self.state.is_empty()
}
fn clear_state(&mut self) -> Result<(), Error> {
self.state = Spine::new(&self.factories);
self.flush = false;
Ok(())
}
fn flush(&mut self) {
self.flush = true;
}
fn is_flush_complete(&self) -> bool {
!self.flush
}
}
impl<B> UnaryOperator<B, Option<Spine<B>>> for Accumulator<B>
where
B: Batch,
{
async fn eval(&mut self, batch: &B) -> Option<Spine<B>> {
let len = batch.len();
if len > 0 {
if self.enabled_during_current_transaction.is_none() {
self.enabled_during_current_transaction =
Some(self.enable_count.load(Ordering::Acquire) > 0);
}
if self.enabled_during_current_transaction == Some(true) {
self.input_batch_stats.add_batch(len);
self.state.insert(batch.clone());
}
}
if self.flush {
self.flush = false;
self.enabled_during_current_transaction = None;
let mut spine = Spine::<B>::new(&self.factories);
std::mem::swap(&mut self.state, &mut spine);
self.output_batch_stats.add_batch(spine.len());
Some(spine)
} else {
None
}
}
async fn eval_owned(&mut self, batch: B) -> Option<Spine<B>> {
let len = batch.len();
if len > 0 {
if self.enabled_during_current_transaction.is_none() {
self.enabled_during_current_transaction =
Some(self.enable_count.load(Ordering::Acquire) > 0);
}
if self.enabled_during_current_transaction == Some(true) {
self.input_batch_stats.add_batch(len);
self.state.insert(batch);
}
}
if self.flush {
self.flush = false;
self.enabled_during_current_transaction = None;
let mut spine = Spine::<B>::new(&self.factories);
std::mem::swap(&mut self.state, &mut spine);
self.output_batch_stats.add_batch(spine.len());
Some(spine)
} else {
None
}
}
}