use std::collections::{HashMap, HashSet, VecDeque};
use crate::processor::erased::{OutputRecord, ProcessorError};
use crate::processor::graph::Graph;
use crate::processor::serde::{Consumed, Produced, Serde, SerdeAssociate};
use crate::topology::BuiltTopology;
type PendingRecord = (String, Option<Vec<u8>>, Vec<u8>, i64);
pub struct TopologyTestDriver {
graph: Graph,
source_topics: HashSet<String>,
output: HashMap<String, VecDeque<OutputRecord>>,
mock_wall_ms: i64,
}
impl TopologyTestDriver {
pub fn new(built: &BuiltTopology) -> Result<Self, ProcessorError> {
let source_topics: HashSet<String> = built.list_source_topics().into_iter().collect();
let backend = crate::store::backend::StoreBackend::InMemory;
let mut graph = pollster::block_on(built.instantiate(&backend, "app"))?;
graph.globals = pollster::block_on(crate::runtime::global::GlobalStateManager::build(
built.global_store_factories(),
built.global_store_topics(),
&backend,
"app",
));
pollster::block_on(graph.init_processors())?;
Ok(Self {
graph,
source_topics,
output: HashMap::new(),
mock_wall_ms: 0,
})
}
#[allow(clippy::needless_pass_by_value)] pub fn pipe_input<KS, VS>(
&mut self,
topic: &str,
consumed: impl Into<Consumed<KS, VS>>,
key: Option<KS::Target>,
value: VS::Target,
timestamp: i64,
) where
KS: SerdeAssociate + Serde<KS::Target>,
VS: SerdeAssociate + Serde<VS::Target>,
{
let consumed = consumed.into();
let kb = key.as_ref().map(|k| consumed.key_serde.serialize(topic, k));
let vb = consumed.value_serde.serialize(topic, &value);
self.pipe_bytes(topic, kb.as_deref(), &vb, timestamp);
}
fn pipe_bytes(&mut self, topic: &str, key: Option<&[u8]>, value: &[u8], timestamp: i64) {
let mut queue: VecDeque<PendingRecord> = VecDeque::from([(
topic.to_string(),
key.map(<[u8]>::to_vec),
value.to_vec(),
timestamp,
)]);
while let Some((t, k, v, ts)) = queue.pop_front() {
let _ = pollster::block_on(self.graph.pipe(&t, k.as_deref(), &v, ts));
self.route_outputs(&mut queue);
let _ = pollster::block_on(self.graph.punctuate_stream_time(self.graph.stream_time));
self.route_outputs(&mut queue);
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn advance_wall_clock_time(&mut self, by: std::time::Duration) {
self.mock_wall_ms += i64::try_from(by.as_millis()).unwrap_or(i64::MAX);
let mut queue: VecDeque<PendingRecord> = VecDeque::new();
let _ = pollster::block_on(self.graph.punctuate_wall_clock(self.mock_wall_ms));
self.route_outputs(&mut queue);
while let Some((t, k, v, ts)) = queue.pop_front() {
let _ = pollster::block_on(self.graph.pipe(&t, k.as_deref(), &v, ts));
self.route_outputs(&mut queue);
let _ = pollster::block_on(self.graph.punctuate_stream_time(self.graph.stream_time));
self.route_outputs(&mut queue);
}
}
fn route_outputs(&mut self, queue: &mut VecDeque<PendingRecord>) {
for out in self.graph.take_output() {
if self.source_topics.contains(&out.topic) {
let vv = out.value.clone().unwrap_or_default().to_vec();
queue.push_back((
out.topic.clone(),
out.key.as_ref().map(|b| b.to_vec()),
vv,
out.timestamp,
));
} else {
self.output
.entry(out.topic.clone())
.or_default()
.push_back(out);
}
}
let _ = self
.graph
.drain_changelogs(&std::collections::HashSet::new());
}
pub fn get_key_value_store<K: Send + Sync + 'static, V: Send + 'static>(
&mut self,
name: &str,
) -> Option<&mut dyn crate::store::api::KeyValueStore<K, V>> {
self.graph.stores.get_kv::<K, V>(name)
}
pub fn store_get<K: Send + Sync + 'static, V: Send + 'static>(
&mut self,
store: &str,
key: &K,
) -> Option<V> {
let s = self.graph.stores.get_kv::<K, V>(store)?;
pollster::block_on(s.get(key))
}
#[allow(clippy::needless_pass_by_value)] pub fn pipe_global<K, V>(&mut self, store_name: &str, key: K, value: V)
where
K: Send + Sync + 'static,
V: Send + 'static,
{
pollster::block_on(self.graph.globals.put(store_name, key, value));
}
pub async fn iq_kv_get<K: 'static, V: 'static>(
&self,
store: &str,
key: &K,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Option<V> {
let q = self.graph.stores.iq_get(store)?;
let kb = ks.serialize(store, key);
let vb = q.iq_kv_get(&kb).await?;
Some(vs.deserialize(store, &vb).expect("iq deserialize"))
}
pub async fn iq_kv_range<K: 'static, V: 'static>(
&self,
store: &str,
lo: &K,
hi: &K,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Vec<(K, V)> {
let Some(q) = self.graph.stores.iq_get(store) else {
return Vec::new();
};
let lob = ks.serialize(store, lo);
let hib = ks.serialize(store, hi);
q.iq_kv_range(&lob, &hib)
.await
.into_iter()
.map(|(k, v)| {
(
ks.deserialize(store, &k).expect("iq deserialize"),
vs.deserialize(store, &v).expect("iq deserialize"),
)
})
.collect()
}
pub async fn iq_kv_all<K: 'static, V: 'static>(
&self,
store: &str,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Vec<(K, V)> {
let Some(q) = self.graph.stores.iq_get(store) else {
return Vec::new();
};
q.iq_kv_all()
.await
.into_iter()
.map(|(k, v)| {
(
ks.deserialize(store, &k).expect("iq deserialize"),
vs.deserialize(store, &v).expect("iq deserialize"),
)
})
.collect()
}
pub async fn iq_kv_count(&self, store: &str) -> u64 {
match self.graph.stores.iq_get(store) {
Some(q) => q.iq_kv_approx_count().await,
None => 0,
}
}
pub async fn iq_window_fetch<K: 'static, V: 'static>(
&self,
store: &str,
key: &K,
from: i64,
to: i64,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Vec<(i64, V)> {
let Some(q) = self.graph.stores.iq_get(store) else {
return Vec::new();
};
let kb = ks.serialize(store, key);
q.iq_window_fetch(&kb, from, to)
.await
.into_iter()
.map(|(start, v)| (start, vs.deserialize(store, &v).expect("iq deserialize")))
.collect()
}
pub async fn iq_window_fetch_single<K: 'static, V: 'static>(
&self,
store: &str,
key: &K,
window_start: i64,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Option<V> {
let q = self.graph.stores.iq_get(store)?;
let kb = ks.serialize(store, key);
let vb = q.iq_window_fetch_single(&kb, window_start).await?;
Some(vs.deserialize(store, &vb).expect("iq deserialize"))
}
pub async fn iq_session_fetch<K: 'static, V: 'static>(
&self,
store: &str,
key: &K,
ks: &dyn Serde<K>,
vs: &dyn Serde<V>,
) -> Vec<((i64, i64), V)> {
let Some(q) = self.graph.stores.iq_get(store) else {
return Vec::new();
};
let kb = ks.serialize(store, key);
q.iq_session_fetch_key(&kb)
.await
.into_iter()
.map(|(win, v)| (win, vs.deserialize(store, &v).expect("iq deserialize")))
.collect()
}
#[allow(clippy::needless_pass_by_value)] pub fn read_output<KS, VS>(
&mut self,
topic: &str,
produced: impl Into<Produced<KS, VS>>,
) -> Option<(Option<KS::Target>, VS::Target)>
where
KS: SerdeAssociate + Serde<KS::Target>,
VS: SerdeAssociate + Serde<VS::Target>,
{
let produced = produced.into();
let out = self.output.get_mut(topic)?.pop_front()?;
let key = out.key.map(|b| {
produced
.key_serde
.deserialize(topic, &b)
.expect("test: deserialize output key")
});
let value = produced
.value_serde
.deserialize(topic, &out.value.unwrap_or_default())
.expect("test: deserialize output value");
Some((key, value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processor::api::{Processor, ProcessorContext};
use crate::processor::record::Record;
use crate::processor::serde::StringSerde;
use crate::topology::{NodeHandle, Topology};
use assert2::check;
use async_trait::async_trait;
struct Upper;
#[async_trait]
impl Processor<String, String, String, String> for Upper {
async fn process(
&mut self,
ctx: &mut ProcessorContext<'_, '_, String, String>,
r: Record<String, String>,
) {
ctx.forward(Record::new(r.key, r.value.to_uppercase(), r.timestamp));
}
}
struct DropEmpty;
#[async_trait]
impl Processor<String, String, String, String> for DropEmpty {
async fn process(
&mut self,
ctx: &mut ProcessorContext<'_, '_, String, String>,
r: Record<String, String>,
) {
if !r.value.is_empty() {
ctx.forward(r);
}
}
}
struct Identity;
#[async_trait]
impl Processor<String, String, String, String> for Identity {
async fn process(
&mut self,
ctx: &mut ProcessorContext<'_, '_, String, String>,
r: Record<String, String>,
) {
ctx.forward(r);
}
}
fn map_filter() -> crate::topology::BuiltTopology {
let mut t = Topology::new();
let src: NodeHandle<String, String> = t.add_source("src", ["in"]);
let up = t.add_processor("up", || Upper, [&src]);
let flt = t.add_processor("flt", || DropEmpty, [&up]);
t.add_sink("out", "out", [&flt]);
t.build("app").unwrap()
}
#[test]
fn map_filter_through() {
let built = map_filter();
let mut d = TopologyTestDriver::new(&built).unwrap();
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
Some("k".to_string()),
"hello".to_string(),
0,
);
check!(
d.read_output("out", Produced::with(StringSerde, StringSerde))
== Some((Some("k".to_string()), "HELLO".to_string()))
);
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
Some("k2".to_string()),
String::new(),
1,
);
check!(
d.read_output("out", Produced::with(StringSerde, StringSerde))
== None::<(Option<String>, String)>
);
}
#[test]
fn repartition_loops_through() {
let mut t = Topology::new();
t.add_repartition_topic("rp");
let s1: NodeHandle<String, String> = t.add_source("s1", ["in"]);
let id = t.add_processor("id", || Identity, [&s1]);
t.add_sink("to_rp", "rp", [&id]);
let s2: NodeHandle<String, String> = t.add_source("s2", ["rp"]);
let up = t.add_processor("up", || Upper, [&s2]);
t.add_sink("out", "out", [&up]);
let built = t.build("app").unwrap();
let mut d = TopologyTestDriver::new(&built).unwrap();
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
None,
"hi".to_string(),
0,
);
check!(
d.read_output("out", Produced::with(StringSerde, StringSerde))
== Some((None, "HI".to_string()))
);
}
#[test]
fn branch_to_two_sinks() {
let mut t = Topology::new();
let src: NodeHandle<String, String> = t.add_source("src", ["in"]);
let up = t.add_processor("up", || Upper, [&src]);
t.add_sink("a", "out-a", [&up]);
t.add_sink("b", "out-b", [&up]);
let built = t.build("app").unwrap();
let mut d = TopologyTestDriver::new(&built).unwrap();
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
None,
"x".to_string(),
0,
);
check!(
d.read_output("out-a", Produced::with(StringSerde, StringSerde))
== Some((None, "X".to_string()))
);
check!(
d.read_output("out-b", Produced::with(StringSerde, StringSerde))
== Some((None, "X".to_string()))
);
}
#[test]
fn stateful_count_and_store_inspection() {
use crate::processor::serde::I64Serde;
struct Counter;
#[async_trait]
impl Processor<String, String, String, i64> for Counter {
async fn process(
&mut self,
ctx: &mut ProcessorContext<'_, '_, String, i64>,
r: Record<String, String>,
) {
let n = {
let s = ctx.get_state_store::<String, i64>("counts").unwrap();
let n = s.get(&r.value).await.unwrap_or(0) + 1;
s.put(r.value.clone(), n).await;
n
};
ctx.forward(Record::new(Some(r.value), n, r.timestamp));
}
}
let mut t = Topology::new();
let src: NodeHandle<String, String> = t.add_source("src", ["in"]);
let c = t.add_processor("c", || Counter, [&src]);
t.add_state_store("counts", StringSerde, I64Serde, [c.name()]);
t.add_sink("out", "out", [&c]);
let mut d = TopologyTestDriver::new(&t.build("app").unwrap()).unwrap();
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
None,
"a".to_string(),
0,
);
d.pipe_input(
"in",
Consumed::with(StringSerde, StringSerde),
None,
"a".to_string(),
1,
);
check!(
d.read_output("out", Produced::with(StringSerde, I64Serde))
== Some((Some("a".to_string()), 1))
);
check!(
d.read_output("out", Produced::with(StringSerde, I64Serde))
== Some((Some("a".to_string()), 2))
);
check!(d.store_get::<String, i64>("counts", &"a".to_string()) == Some(2));
}
}