#![allow(dead_code)]
use std::fmt::Display;
use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::atomic::{AtomicU16, AtomicUsize, Ordering};
use std::sync::mpsc::RecvTimeoutError;
use std::sync::Arc;
use std::time::Duration;
use itertools::{process_results, Itertools};
use noir_compute::config::{HostConfig, RemoteConfig, RuntimeConfig};
use noir_compute::operator::{Data, Operator, StreamElement, Timestamp};
use noir_compute::structure::BlockStructure;
use noir_compute::CoordUInt;
use noir_compute::ExecutionMetadata;
use noir_compute::StreamContext;
const TEST_BASE_PORT: u16 = 17666;
static TEST_INDEX: AtomicU16 = AtomicU16::new(0);
#[derive(Clone)]
pub struct WatermarkChecker<Out: Data, PreviousOperator>
where
PreviousOperator: Operator<Out = Out>,
{
last_watermark: Option<Timestamp>,
prev: PreviousOperator,
received_watermarks: Arc<AtomicUsize>,
_out: PhantomData<Out>,
}
impl<Out: Data, PreviousOperator> Display for WatermarkChecker<Out, PreviousOperator>
where
PreviousOperator: Operator<Out = Out>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WatermarkChecker")
}
}
impl<Out: Data, PreviousOperator: Operator<Out = Out>> WatermarkChecker<Out, PreviousOperator> {
pub fn new(prev: PreviousOperator, received_watermarks: Arc<AtomicUsize>) -> Self {
Self {
last_watermark: None,
prev,
received_watermarks,
_out: Default::default(),
}
}
}
impl<Out: Data, PreviousOperator: Operator<Out = Out>> Operator
for WatermarkChecker<Out, PreviousOperator>
{
type Out = Out;
fn setup(&mut self, metadata: &mut ExecutionMetadata) {
self.prev.setup(metadata);
}
fn next(&mut self) -> StreamElement<Out> {
let item = self.prev.next();
match &item {
StreamElement::Timestamped(_, ts) => {
if let Some(w) = &self.last_watermark {
assert!(ts > w);
}
}
StreamElement::Watermark(ts) => {
if let Some(w) = &self.last_watermark {
assert!(ts > w);
}
self.last_watermark = Some(*ts);
self.received_watermarks.fetch_add(1, Ordering::Release);
}
_ => {}
}
item
}
fn structure(&self) -> BlockStructure {
Default::default()
}
}
pub struct TestHelper;
impl TestHelper {
fn setup() {
let _ = env_logger::Builder::new()
.is_test(true)
.parse_default_env()
.try_init();
}
pub fn env_with_config(config: RuntimeConfig, body: Arc<dyn Fn(StreamContext) + Send + Sync>) {
let timeout_sec = Self::parse_int_from_env("RSTREAM_TEST_TIMEOUT").unwrap_or(10);
let timeout = Duration::from_secs(timeout_sec);
let (sender, receiver) = std::sync::mpsc::channel();
let worker = std::thread::Builder::new()
.name("Worker".into())
.spawn(move || {
let env = StreamContext::new(config);
body(env);
sender.send(()).unwrap();
})
.unwrap();
match receiver.recv_timeout(timeout) {
Ok(_) => {}
Err(RecvTimeoutError::Timeout) => {
panic!("Worker thread didn't complete before the timeout of {timeout:?}");
}
Err(RecvTimeoutError::Disconnected) => {
panic!("Worker thread has panicked!");
}
}
worker.join().expect("Worker thread has panicked!");
}
pub fn local_env(body: Arc<dyn Fn(StreamContext) + Send + Sync>, num_cores: CoordUInt) {
Self::setup();
let config = RuntimeConfig::local(num_cores);
log::debug!("Running test with env: {:?}", config);
Self::env_with_config(config, body)
}
pub fn remote_env(
body: Arc<dyn Fn(StreamContext) + Send + Sync>,
num_hosts: CoordUInt,
cores_per_host: CoordUInt,
) {
Self::setup();
let mut hosts = vec![];
for host_id in 0..num_hosts {
let test_id = TEST_INDEX.fetch_add(1, Ordering::SeqCst) + 1;
let high_part = (test_id & 0xff00) >> 8;
let low_part = test_id & 0xff;
let address = format!("127.{high_part}.{low_part}.{host_id}");
hosts.push(HostConfig {
address,
base_port: TEST_BASE_PORT,
num_cores: cores_per_host,
ssh: Default::default(),
perf_path: None,
});
}
let mut join_handles = vec![];
for host_id in 0..num_hosts {
let config = RuntimeConfig::Remote(RemoteConfig {
host_id: Some(host_id),
hosts: hosts.clone(),
tracing_dir: None,
cleanup_executable: true,
});
let body = body.clone();
join_handles.push(
std::thread::Builder::new()
.name(format!("Test host{host_id}"))
.spawn(move || Self::env_with_config(config, body))
.unwrap(),
)
}
for (host_id, handle) in join_handles.into_iter().enumerate() {
handle
.join()
.unwrap_or_else(|e| panic!("Remote worker for host {host_id} crashed: {e:?}"));
}
}
pub fn local_remote_env<F>(body: F)
where
F: Fn(StreamContext) + Send + Sync + 'static,
{
let body = Arc::new(body);
let local_cores =
Self::parse_list_from_env("RSTREAM_TEST_LOCAL_CORES").unwrap_or_else(|| vec![4]);
for num_cores in local_cores {
Self::local_env(body.clone(), num_cores);
}
let remote_hosts =
Self::parse_list_from_env("RSTREAM_TEST_REMOTE_HOSTS").unwrap_or_else(|| vec![4]);
let remote_cores =
Self::parse_list_from_env("RSTREAM_TEST_REMOTE_CORES").unwrap_or_else(|| vec![4]);
for num_hosts in remote_hosts {
for &num_cores in &remote_cores {
Self::remote_env(body.clone(), num_hosts, num_cores);
}
}
}
fn parse_list_from_env(var_name: &str) -> Option<Vec<CoordUInt>> {
let content = std::env::var(var_name).ok()?;
if content.is_empty() {
return Some(Vec::new());
}
let values = content.split(',').map(CoordUInt::from_str).collect_vec();
process_results(values, |values| values.collect_vec()).ok()
}
fn parse_int_from_env(var_name: &str) -> Option<u64> {
let content = std::env::var(var_name).ok()?;
u64::from_str(&content).ok()
}
}