use crate::dataflow::{stream::StreamId, time::Timestamp, State};
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use tokio::time::Duration;
pub type DeadlineId = crate::Uuid;
pub trait CondFn: Fn(&[StreamId], &ConditionContext, &Timestamp) -> bool + Send + Sync {}
impl<F: Fn(&[StreamId], &ConditionContext, &Timestamp) -> bool + Send + Sync> CondFn for F {}
pub trait DeadlineFn<S>: FnMut(&S, &Timestamp) -> Duration + Send + Sync {}
impl<S, F: FnMut(&S, &Timestamp) -> Duration + Send + Sync> DeadlineFn<S> for F {}
pub trait HandlerFn<S>: FnMut(&S, &Timestamp) + Send + Sync {}
impl<S, F: FnMut(&S, &Timestamp) + Send + Sync> HandlerFn<S> for F {}
pub trait DeadlineT<S>: Send + Sync {
fn get_constrained_read_stream_ids(&self) -> &HashSet<StreamId>;
fn get_constrained_write_stream_ids(&self) -> &HashSet<StreamId>;
fn invoke_start_condition(
&self,
read_stream_ids: &[StreamId],
condition_context: &ConditionContext,
timestamp: &Timestamp,
) -> bool;
fn calculate_deadline(&self, state: &S, timestamp: &Timestamp) -> Duration;
fn get_end_condition_fn(&self) -> Arc<dyn CondFn>;
fn id(&self) -> DeadlineId;
fn invoke_handler(&self, state: &S, timestamp: &Timestamp);
}
pub struct TimestampDeadline<S>
where
S: State,
{
start_condition_fn: Arc<dyn CondFn>,
end_condition_fn: Arc<dyn CondFn>,
deadline_fn: Arc<Mutex<dyn DeadlineFn<S>>>,
handler_fn: Arc<Mutex<dyn HandlerFn<S>>>,
read_stream_ids: HashSet<StreamId>,
write_stream_ids: HashSet<StreamId>,
id: DeadlineId,
}
#[allow(dead_code)]
impl<S> TimestampDeadline<S>
where
S: State,
{
pub fn new(
deadline_fn: impl DeadlineFn<S> + 'static,
handler_fn: impl HandlerFn<S> + 'static,
) -> Self {
Self {
start_condition_fn: Arc::new(TimestampDeadline::<S>::default_start_condition),
end_condition_fn: Arc::new(TimestampDeadline::<S>::default_end_condition),
deadline_fn: Arc::new(Mutex::new(deadline_fn)),
handler_fn: Arc::new(Mutex::new(handler_fn)),
read_stream_ids: HashSet::new(),
write_stream_ids: HashSet::new(),
id: DeadlineId::new_deterministic(),
}
}
pub fn on_read_stream(mut self, read_stream_id: StreamId) -> Self {
self.read_stream_ids.insert(read_stream_id);
self
}
pub fn on_write_stream(mut self, write_stream_id: StreamId) -> Self {
self.write_stream_ids.insert(write_stream_id);
self
}
pub fn with_start_condition(mut self, condition: impl 'static + CondFn) -> Self {
self.start_condition_fn = Arc::new(condition);
self
}
pub fn with_end_condition(mut self, condition: impl 'static + CondFn) -> Self {
self.end_condition_fn = Arc::new(condition);
self
}
pub(crate) fn get_start_condition_fn(&self) -> Arc<dyn CondFn> {
Arc::clone(&self.start_condition_fn)
}
pub(crate) fn get_end_condition_fn(&self) -> Arc<dyn CondFn> {
Arc::clone(&self.end_condition_fn)
}
fn default_start_condition(
stream_ids: &[StreamId],
condition_context: &ConditionContext,
current_timestamp: &Timestamp,
) -> bool {
tracing::debug!(
"Executed default start condition for the streams {:?} and the \
timestamp: {:?} with the context: {:?}",
stream_ids,
current_timestamp,
condition_context,
);
for stream_id in stream_ids {
if condition_context.get_message_count(*stream_id, current_timestamp.clone()) == 1 {
return true;
}
}
false
}
fn default_end_condition(
stream_ids: &[StreamId],
condition_context: &ConditionContext,
current_timestamp: &Timestamp,
) -> bool {
tracing::debug!(
"Executed default end condition for the streams {:?} and the \
timestamp: {:?} with the context: {:?}",
stream_ids,
current_timestamp,
condition_context,
);
for stream_id in stream_ids {
if !condition_context.get_watermark_status(*stream_id, current_timestamp.clone()) {
return false;
}
}
true
}
}
impl<S> DeadlineT<S> for TimestampDeadline<S>
where
S: State,
{
fn get_constrained_read_stream_ids(&self) -> &HashSet<StreamId> {
&self.read_stream_ids
}
fn get_constrained_write_stream_ids(&self) -> &HashSet<StreamId> {
&self.write_stream_ids
}
fn invoke_start_condition(
&self,
read_stream_ids: &[StreamId],
condition_context: &ConditionContext,
timestamp: &Timestamp,
) -> bool {
(self.start_condition_fn)(read_stream_ids, condition_context, timestamp)
}
fn calculate_deadline(&self, state: &S, timestamp: &Timestamp) -> Duration {
(self.deadline_fn.lock().unwrap())(state, timestamp)
}
fn get_end_condition_fn(&self) -> Arc<dyn CondFn> {
Arc::clone(&self.end_condition_fn)
}
fn id(&self) -> DeadlineId {
self.id
}
fn invoke_handler(&self, state: &S, timestamp: &Timestamp) {
(self.handler_fn.lock().unwrap())(state, timestamp)
}
}
pub struct DeadlineEvent {
pub read_stream_ids: HashSet<StreamId>,
pub write_stream_ids: HashSet<StreamId>,
pub timestamp: Timestamp,
pub duration: Duration,
pub end_condition: Arc<dyn CondFn>,
pub id: DeadlineId,
}
impl DeadlineEvent {
pub fn new(
read_stream_ids: HashSet<StreamId>,
write_stream_ids: HashSet<StreamId>,
timestamp: Timestamp,
duration: Duration,
end_condition: Arc<dyn CondFn>,
id: DeadlineId,
) -> Self {
Self {
read_stream_ids,
write_stream_ids,
timestamp,
duration,
end_condition,
id,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ConditionContext {
message_count: HashMap<(StreamId, Timestamp), usize>,
watermark_status: HashMap<(StreamId, Timestamp), bool>,
}
impl ConditionContext {
pub fn new() -> Self {
ConditionContext {
message_count: HashMap::new(),
watermark_status: HashMap::new(),
}
}
pub fn increment_msg_count(&mut self, stream_id: StreamId, timestamp: Timestamp) {
let count = self
.message_count
.entry((stream_id, timestamp))
.or_insert(0);
*count += 1;
}
pub fn notify_watermark_arrival(&mut self, stream_id: StreamId, timestamp: Timestamp) {
let watermark = self
.watermark_status
.entry((stream_id, timestamp))
.or_insert(false);
*watermark = true;
}
pub fn clear_state(&mut self, stream_id: StreamId, timestamp: Timestamp) {
self.message_count.remove(&(stream_id, timestamp.clone()));
self.watermark_status.remove(&(stream_id, timestamp));
}
pub fn get_message_count(&self, stream_id: StreamId, timestamp: Timestamp) -> usize {
match self.message_count.get(&(stream_id, timestamp)) {
Some(v) => *v,
None => 0,
}
}
pub fn get_watermark_status(&self, stream_id: StreamId, timestamp: Timestamp) -> bool {
match self.watermark_status.get(&(stream_id, timestamp)) {
Some(v) => *v,
None => false,
}
}
pub(crate) fn merge(&self, other: &ConditionContext) -> Self {
ConditionContext {
message_count: self
.message_count
.clone()
.into_iter()
.chain(other.message_count.clone())
.collect(),
watermark_status: self
.watermark_status
.clone()
.into_iter()
.chain(other.watermark_status.clone())
.collect(),
}
}
}