use std::any::Any;
use std::cell::RefCell;
use std::marker::PhantomData;
use std::rc::Rc;
use std::sync::Arc;
use crate::dsl::builder::InternalStreamsBuilder;
use crate::dsl::config::Materialized;
use crate::dsl::graph::{GraphNodeKind, LowerState, NodeId};
use crate::dsl::kgrouped::{KGroupedStream, RepartitionLowerFn, mint_store_name};
use crate::dsl::ktable::KTable;
use crate::dsl::names;
use crate::dsl::processors::aggregate::KStreamAggregateProcessor;
use crate::dsl::processors::change::Change;
use crate::dsl::processors::cogroup_merge::KStreamPassThrough;
use crate::dsl::processors::tuple_forwarder::TupleForwarder;
use crate::processor::serde::{DefaultSerde, Serde};
use crate::topology::NodeHandle;
#[derive(Clone)]
pub(crate) enum CogroupKind {
NonWindowed,
Time(crate::dsl::windows::TimeWindows),
Sliding(crate::dsl::windows::SlidingWindows),
Session(crate::dsl::windows::SessionWindows),
}
#[allow(dead_code, clippy::type_complexity)]
pub(crate) struct CogroupSpec<K, VOut> {
pub kind: CogroupKind,
pub init: Arc<dyn Fn() -> VOut + Send + Sync>,
pub merger: Option<Arc<dyn Fn(&K, VOut, VOut) -> VOut + Send + Sync>>,
}
impl<K, VOut> Clone for CogroupSpec<K, VOut> {
fn clone(&self) -> Self {
Self {
kind: self.kind.clone(),
init: self.init.clone(),
merger: self.merger.clone(),
}
}
}
type AggNodeThunk = Box<dyn FnOnce(&mut LowerState, String, String, String) -> String + Send>;
pub(crate) type MakeAggFn<K, VOut> = Box<dyn FnOnce(CogroupSpec<K, VOut>) -> AggNodeThunk + Send>;
#[allow(dead_code)]
pub(crate) struct CogroupInput<K, VOut> {
pub parent: NodeId,
pub key_changing_upstream: bool,
pub repartition_lower: Option<RepartitionLowerFn>,
pub make_agg: MakeAggFn<K, VOut>,
pub source_topic: Option<String>,
}
pub struct CogroupedKStream<K, VOut> {
pub(crate) builder: Rc<RefCell<InternalStreamsBuilder>>,
pub(crate) inputs: Vec<CogroupInput<K, VOut>>,
_pd: PhantomData<fn() -> (K, VOut)>,
}
impl<K, VOut> CogroupedKStream<K, VOut> {
pub(crate) fn new(
builder: Rc<RefCell<InternalStreamsBuilder>>,
inputs: Vec<CogroupInput<K, VOut>>,
) -> Self {
Self {
builder,
inputs,
_pd: PhantomData,
}
}
}
#[allow(clippy::too_many_lines)]
pub(crate) fn make_agg_for_input<K, Vn, VOut, A>(agg: A) -> MakeAggFn<K, VOut>
where
K: Any + Send + Sync + Clone,
Vn: Any + Send + Sync + Clone,
VOut: Any + Send + Sync + Clone,
A: Fn(&K, &Vn, VOut) -> VOut + Send + Sync + 'static,
{
let agg = Arc::new(agg);
Box::new(move |spec: CogroupSpec<K, VOut>| -> AggNodeThunk {
Box::new(
move |state: &mut LowerState,
parent_name: String,
proc_name: String,
store_name: String|
-> String {
let parent = NodeHandle::<K, Vn>::from_name(parent_name);
let init = spec.init.clone();
match spec.kind {
CogroupKind::NonWindowed => {
let agg = agg.clone();
let store = store_name.clone();
let h = state
.topology
.add_processor::<K, Vn, K, Change<VOut>, _, _, _>(
proc_name,
move || KStreamAggregateProcessor {
store_name: store.clone(),
init: {
let i = init.clone();
move || i()
},
agg: {
let a = agg.clone();
move |k: &K, v: &Vn, acc: VOut| a(k, v, acc)
},
forwarder: TupleForwarder::default(),
_pd: PhantomData,
},
[parent],
);
h.name().to_string()
}
CogroupKind::Time(w) => {
use crate::dsl::processors::window_aggregate::KStreamWindowAggregateProcessor;
use crate::dsl::windows::Windowed;
let agg = agg.clone();
let store = store_name.clone();
let h = state
.topology
.add_processor::<K, Vn, Windowed<K>, Change<VOut>, _, _, _>(
proc_name,
move || KStreamWindowAggregateProcessor {
store_name: store.clone(),
windows: w,
init: {
let i = init.clone();
move || i()
},
agg: {
let a = agg.clone();
move |k: &K, v: &Vn, acc: VOut| a(k, v, acc)
},
emit: crate::dsl::emit::EmitStrategy::default(),
stream_time: i64::MIN,
last_emitted_close: i64::MIN,
forwarder: TupleForwarder::default(),
_pd: PhantomData,
},
[parent],
);
h.name().to_string()
}
CogroupKind::Sliding(w) => {
use crate::dsl::processors::sliding_window_aggregate::KStreamSlidingWindowAggregateProcessor;
use crate::dsl::windows::Windowed;
let agg = agg.clone();
let store = store_name.clone();
let h = state
.topology
.add_processor::<K, Vn, Windowed<K>, Change<VOut>, _, _, _>(
proc_name,
move || KStreamSlidingWindowAggregateProcessor {
store_name: store.clone(),
windows: w,
init: {
let i = init.clone();
move || i()
},
agg: {
let a = agg.clone();
move |k: &K, v: &Vn, acc: VOut| a(k, v, acc)
},
stream_time: i64::MIN,
emit: crate::dsl::emit::EmitStrategy::default(),
last_emitted_close: i64::MIN,
forwarder: TupleForwarder::default(),
_pd: PhantomData,
},
[parent],
);
h.name().to_string()
}
CogroupKind::Session(w) => {
use crate::dsl::processors::session_aggregate::KStreamSessionAggregateProcessor;
use crate::dsl::windows::Windowed;
let agg = agg.clone();
let store = store_name.clone();
let merger = spec
.merger
.clone()
.expect("session cogroup requires a merger");
let h = state
.topology
.add_processor::<K, Vn, Windowed<K>, Change<VOut>, _, _, _>(
proc_name,
move || KStreamSessionAggregateProcessor {
store_name: store.clone(),
gap_ms: w.gap_ms,
init: {
let i = init.clone();
move || i()
},
agg: {
let a = agg.clone();
move |k: &K, v: &Vn, acc: VOut| a(k, v, acc)
},
merger: {
let m = merger.clone();
move |k: &K, a: VOut, b: VOut| m(k, a, b)
},
emit: crate::dsl::emit::EmitStrategy::default(),
grace_ms: w.grace_ms,
stream_time: i64::MIN,
last_emitted_close: i64::MIN,
forwarder: TupleForwarder::default(),
_pd: PhantomData,
},
[parent],
);
h.name().to_string()
}
}
},
)
})
}
impl<K, VOut> CogroupedKStream<K, VOut>
where
K: Any + Send + Sync + Clone,
VOut: Any + Send + Sync + Clone,
{
#[must_use]
pub fn cogroup<Vn, A>(mut self, grouped: KGroupedStream<K, Vn>, agg: A) -> Self
where
Vn: Any + Send + Sync + Clone,
A: Fn(&K, &Vn, VOut) -> VOut + Send + Sync + 'static,
{
let (parent, key_changing, rp_lower, source_topic) = grouped.into_cogroup_parts();
self.inputs.push(CogroupInput {
parent,
key_changing_upstream: key_changing,
repartition_lower: rp_lower,
make_agg: make_agg_for_input::<K, Vn, VOut, A>(agg),
source_topic,
});
self
}
pub fn aggregate_explicit<KS, VS, I>(
self,
init: I,
materialized: impl Into<Materialized<KS, VS>>,
) -> KTable<K, VOut, KS, VS>
where
KS: Serde<K> + Clone + 'static,
VS: Serde<VOut> + Clone + 'static,
I: Fn() -> VOut + Send + Sync + 'static,
{
let materialized = materialized.into();
let store_name = mint_store_name(&self.builder, &materialized, names::AGGREGATE_STORE);
let Materialized {
key_serde,
value_serde,
logging,
caching,
..
} = materialized;
let spec = CogroupSpec::<K, VOut> {
kind: CogroupKind::NonWindowed,
init: Arc::new(init),
merger: None,
};
let ks = key_serde.clone();
let vs = value_serde.clone();
let store_for_reg = store_name.clone();
let registrar: StoreRegistrarFn = Box::new(move |state, procs| {
if logging {
state.topology.add_state_store::<K, VOut, KS, VS>(
store_for_reg.clone(),
ks.clone(),
vs.clone(),
procs,
);
} else {
state
.topology
.add_state_store_no_changelog::<K, VOut, KS, VS>(
store_for_reg.clone(),
ks.clone(),
vs.clone(),
);
}
state.topology.mark_store_caching(&store_for_reg, caching);
});
let suppress = crate::dsl::ktable::kv_suppress_factory::<K, VOut, KS, VS>(
key_serde.clone(),
value_serde.clone(),
);
let merge_id = lower_cogroup::<K, VOut, K>(
&self.builder,
self.inputs,
store_name.clone(),
spec,
logging,
registrar,
);
KTable::new(
Rc::clone(&self.builder),
merge_id,
Some(store_name),
None,
key_serde,
value_serde,
)
.with_suppress_factory(Some(suppress))
}
pub fn aggregate<I>(
self,
init: I,
store_name: impl Into<String>,
) -> KTable<K, VOut, <K as DefaultSerde>::Serde, <VOut as DefaultSerde>::Serde>
where
K: DefaultSerde,
VOut: DefaultSerde,
<K as DefaultSerde>::Serde: Serde<K> + Clone,
<VOut as DefaultSerde>::Serde: Serde<VOut> + Clone,
I: Fn() -> VOut + Send + Sync + 'static,
{
self.aggregate_explicit(
init,
Materialized::with(
<K as DefaultSerde>::Serde::default(),
<VOut as DefaultSerde>::Serde::default(),
)
.as_store(store_name),
)
}
}
pub(crate) type StoreRegistrarFn = Box<dyn FnOnce(&mut LowerState, Vec<String>) + Send>;
#[allow(clippy::too_many_lines, clippy::needless_pass_by_value)]
pub(crate) fn lower_cogroup<K, VOut, KOut>(
builder: &Rc<RefCell<InternalStreamsBuilder>>,
inputs: Vec<CogroupInput<K, VOut>>,
store_name: String,
spec: CogroupSpec<K, VOut>,
logging: bool,
registrar: StoreRegistrarFn,
) -> NodeId
where
K: Any + Send + Sync + Clone,
VOut: Any + Send + Sync + Clone,
KOut: Any + Send + Clone,
{
let mut g = builder.borrow_mut();
let mut agg_ids: Vec<NodeId> = Vec::with_capacity(inputs.len());
let copartition_sources: Vec<String> = inputs
.iter()
.filter_map(|i| i.source_topic.clone())
.collect();
for input in inputs {
let CogroupInput {
parent,
key_changing_upstream,
repartition_lower,
make_agg,
source_topic: _,
} = input;
let agg_parent = KGroupedStream::<K, ()>::record_repartition(
&mut g,
&store_name,
parent,
key_changing_upstream,
repartition_lower,
);
let proc_name = g.new_processor_name(names::COGROUP_AGGREGATE);
let agg_id = g.graph.add(
proc_name.clone(),
GraphNodeKind::Aggregate {
store_name: store_name.clone(),
changelog: false,
},
vec![agg_parent],
);
let thunk = make_agg(spec.clone());
let store_for = store_name.clone();
let pn = proc_name.clone();
g.graph.nodes[agg_id].lower = Some(Box::new(move |state: &mut LowerState| {
let parent_name = state.handle_name[&agg_parent].clone();
let handle = thunk(state, parent_name, pn, store_for);
state.handle_name.insert(agg_id, handle);
}));
agg_ids.push(agg_id);
}
let merge_name = g.new_processor_name(names::COGROUP_MERGE);
let merge_id = g.graph.add(
merge_name.clone(),
GraphNodeKind::Aggregate {
store_name: store_name.clone(),
changelog: logging,
},
agg_ids.clone(),
);
g.graph.nodes[merge_id].lower = Some(Box::new(move |state: &mut LowerState| {
let parents: Vec<NodeHandle<KOut, Change<VOut>>> = agg_ids
.iter()
.map(|id| NodeHandle::<KOut, Change<VOut>>::from_name(state.handle_name[id].clone()))
.collect();
let h = state
.topology
.add_processor::<KOut, Change<VOut>, KOut, Change<VOut>, _, _, _>(
merge_name.clone(),
|| KStreamPassThrough::<KOut, Change<VOut>> { _pd: PhantomData },
parents,
);
let proc_names: Vec<String> = agg_ids
.iter()
.map(|id| state.handle_name[id].clone())
.collect();
registrar(state, proc_names);
if copartition_sources.len() >= 2 {
state
.topology
.add_copartition_group(copartition_sources.clone());
}
state.handle_name.insert(merge_id, h.name().to_string());
}));
drop(g);
merge_id
}
#[cfg(test)]
mod tests {
use assert2::check;
use crate::dsl::StreamsBuilder;
#[test]
fn cogroup_store_is_cached_with_positive_budget() {
let b = StreamsBuilder::new();
let g1 = b.stream::<String, String>(["in1"]).group_by_key();
let g2 = b.stream::<String, String>(["in2"]).group_by_key();
g1.cogroup::<i64, _>(|_k, v: &String, acc| {
acc + i64::try_from(v.len()).unwrap_or(i64::MAX)
})
.cogroup(g2, |_k, _v: &String, acc| acc + 1)
.aggregate(|| 0i64, "co");
let built = b.build("app").unwrap();
let g = pollster::block_on(built.instantiate(
&crate::store::backend::StoreBackend::InMemory,
"app",
10_485_760,
))
.unwrap();
check!(
g.cache_owner.contains_key("co"),
"cogroup store must be cached when budget > 0 and caching enabled, \
cache_owner = {:?}",
g.cache_owner
);
}
}
#[cfg(test)]
mod cogroup_caching_tests {
use assert2::check;
use crate::dsl::StreamsBuilder;
use crate::store::backend::StoreBackend;
use crate::{I64Serde, Materialized, Produced, StringSerde};
#[test]
fn cogroup_caches_marks_and_dedups_cross_input() {
let b = StreamsBuilder::new();
let g1 = b.stream::<String, String>(["in1"]).group_by_key();
let g2 = b.stream::<String, String>(["in2"]).group_by_key();
g1.cogroup::<i64, _>(|_k, v: &String, acc| {
acc + i64::try_from(v.len()).unwrap_or(i64::MAX)
})
.cogroup(g2, |_k, _v: &String, acc| acc + 1)
.aggregate_explicit(
|| 0i64,
Materialized::with(StringSerde, I64Serde).as_store("cg"),
)
.to_stream()
.to_explicit("out", Produced::with(StringSerde, I64Serde));
let built = b.build("app").unwrap();
let mut g =
pollster::block_on(built.instantiate(&StoreBackend::InMemory, "app", 1024)).unwrap();
check!(g.cache_owner.contains_key("cg"));
pollster::block_on(g.init_processors()).unwrap();
pollster::block_on(g.pipe("in1", Some(b"a"), b"xx", 0)).unwrap();
pollster::block_on(g.pipe("in2", Some(b"a"), b"z", 1)).unwrap();
check!(g.take_output().is_empty());
pollster::block_on(g.flush_caches()).unwrap();
let out = g.take_output();
check!(out.len() == 1);
check!(out[0].topic == "out");
check!(out[0].value.as_ref().unwrap().as_ref() == 3i64.to_be_bytes());
}
#[test]
fn cogroup_uncached_when_caching_off() {
let b = StreamsBuilder::new();
let g1 = b.stream::<String, String>(["in1"]).group_by_key();
let g2 = b.stream::<String, String>(["in2"]).group_by_key();
g1.cogroup::<i64, _>(|_k, v: &String, acc| {
acc + i64::try_from(v.len()).unwrap_or(i64::MAX)
})
.cogroup(g2, |_k, _v: &String, acc| acc + 1)
.aggregate_explicit(
|| 0i64,
Materialized::with(StringSerde, I64Serde)
.as_store("cg")
.with_caching(false),
)
.to_stream()
.to_explicit("out", Produced::with(StringSerde, I64Serde));
let built = b.build("app").unwrap();
let g =
pollster::block_on(built.instantiate(&StoreBackend::InMemory, "app", 1024)).unwrap();
check!(!g.cache_owner.contains_key("cg"));
}
}