use crate::api::{PreprocessorPlugin, SourcePlugin, TransformPlugin};
use crate::r#async::PluginAsyncRuntimeObj;
use crate::ffi::SafeArrowArray;
use crate::{PluginChannels, PluginError, PluginMsg, SinkPlugin};
use abi_stable::derive_macro_reexports::NonExhaustive;
use abi_stable::std_types::RString;
use arrow::array::RecordBatch;
use async_ffi::FutureExt;
use crossbeam_channel::TryRecvError;
use std::sync::Arc;
use std::time::Duration;
use tracing::error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InitOutcome {
Init,
Terminate,
}
fn wait_for_initialization(channels: &PluginChannels) -> Result<InitOutcome, PluginError> {
match channels.input.receiver.recv().map(|m| m.into_enum()) {
Ok(Ok(PluginMsg::Init)) => Ok(InitOutcome::Init),
Ok(Ok(PluginMsg::Terminate)) => Ok(InitOutcome::Terminate),
Ok(Ok(_other)) => Err(PluginError::Execution(
"Expected Init message as first message".to_string(),
)),
Ok(Err(_unwrap_err)) => Err(PluginError::Execution(
"Malformed message wrapper during initialization".to_string(),
)),
Err(_recv_err) => Err(PluginError::Execution(
"Channel disconnected during initialization".to_string(),
)),
}
}
use crate::ffi::{PluginCheckpointEpoch, PluginMetricsRecorder};
async fn handle_checkpoint_marker(
channels: &PluginChannels,
epoch: PluginCheckpointEpoch,
runtime: &PluginAsyncRuntimeObj,
) -> Result<(), PluginError> {
channels
.output
.send_with_retry(runtime, "Checkpoint marker", || {
NonExhaustive::new(PluginMsg::CheckpointMarker { epoch })
})
.await
}
async fn handle_checkpoint_finalizer(
channels: &PluginChannels,
epoch: PluginCheckpointEpoch,
runtime: &PluginAsyncRuntimeObj,
) -> Result<(), PluginError> {
channels
.output
.send_with_retry(runtime, "Checkpoint finalizer", || {
NonExhaustive::new(PluginMsg::CheckpointFinalizer { epoch })
})
.await
}
async fn handle_checkpoint_ack(
channels: &PluginChannels,
epoch: PluginCheckpointEpoch,
runtime: &PluginAsyncRuntimeObj,
) -> Result<(), PluginError> {
channels
.output
.send_with_retry(runtime, "Checkpoint ack", || {
NonExhaustive::new(PluginMsg::CheckpointAck { epoch })
})
.await
}
async fn handle_control_messages(
channels: &PluginChannels,
source_plugin: &Arc<dyn SourcePlugin>,
runtime: &PluginAsyncRuntimeObj,
) -> Result<(), PluginError> {
while !channels.input.receiver.is_empty() {
match channels.input.receiver.recv().map(|m| m.into_enum()) {
Ok(Ok(PluginMsg::Init)) => {
return Err(PluginError::Execution(
"Received Init message after plugin was initialized".to_string(),
));
}
Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
source_plugin
.process_checkpoint_marker(epoch.into())
.await?;
handle_checkpoint_marker(channels, epoch, runtime).await?;
}
Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
source_plugin
.process_checkpoint_finalizer(epoch.into())
.await?;
handle_checkpoint_finalizer(channels, epoch, runtime).await?;
}
Ok(Ok(PluginMsg::Terminate)) => {
source_plugin.terminate().await?;
}
Err(e) => {
return Err(PluginError::Execution(format!(
"Error receiving message from input channel: {e}"
)));
}
_ => {}
}
}
Ok(())
}
pub struct SourcePluginDispatcher {
channels: PluginChannels,
source_plugin: Arc<dyn SourcePlugin>,
}
impl SourcePluginDispatcher {
pub fn new(channels: PluginChannels, source_plugin: Arc<dyn SourcePlugin>) -> Self {
SourcePluginDispatcher {
channels,
source_plugin,
}
}
pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
match wait_for_initialization(&self.channels)? {
InitOutcome::Terminate => {
self.source_plugin.terminate().await?;
return Ok(());
}
InitOutcome::Init => {}
}
if !self.source_plugin.is_running() {
return Ok(());
}
self.source_plugin.initialize().await?;
loop {
let source_plugin = self.source_plugin.clone();
if !source_plugin.is_running() {
break;
}
let runtime_clone = runtime.clone();
let channels_clone = self.channels.clone();
let source_plugin_clone = self.source_plugin.clone();
let generate_batch_future = async move {
match source_plugin.generate_batch().await {
Ok(batch) => {
let retry_callback = || -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
let channels = channels_clone.clone();
let source_plugin = source_plugin_clone.clone();
let runtime = runtime_clone.clone();
Box::pin(async move {
let _ = handle_control_messages(&channels, &source_plugin, &runtime).await;
source_plugin.is_running()
})
};
let _ = channels_clone.output.send_with_retry_callback(
&runtime_clone,
"Source plugin",
|| {
let batch_data: SafeArrowArray = batch.clone().into();
NonExhaustive::new(PluginMsg::NextBatch { data: batch_data })
},
Some(retry_callback),
Duration::from_millis(50),
)
.await;
}
Err(e) => {
error!("Error generating batch: {:?}", e);
}
}
}
.into_ffi();
runtime.spawn(generate_batch_future).await;
handle_control_messages(&self.channels, &self.source_plugin, &runtime).await?;
}
Ok(())
}
}
pub struct TransformPluginDispatcher {
channels: PluginChannels,
transform_plugin: Arc<dyn TransformPlugin>,
}
impl TransformPluginDispatcher {
pub fn new(channels: PluginChannels, transform_plugin: Arc<dyn TransformPlugin>) -> Self {
TransformPluginDispatcher {
channels,
transform_plugin,
}
}
pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
match wait_for_initialization(&self.channels)? {
InitOutcome::Terminate => {
self.transform_plugin.terminate().await?;
return Ok(());
}
InitOutcome::Init => {}
}
if !self.transform_plugin.is_running() {
return Ok(());
}
self.transform_plugin.initialize().await?;
loop {
if !self.transform_plugin.is_running() {
break;
}
match self
.channels
.input
.receiver
.try_recv()
.map(|m| m.into_enum())
{
Ok(Ok(PluginMsg::NextBatch { data })) => {
let batch: RecordBatch = data.into();
let processed_batch = self.transform_plugin.process_batch(batch).await?;
let transform_plugin = self.transform_plugin.clone();
let retry_callback =
|| -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
let plugin = transform_plugin.clone();
Box::pin(async move {
plugin.is_running()
})
};
self.channels
.output
.send_with_retry_callback(
&runtime,
"Transform plugin",
|| {
let batch_data: SafeArrowArray = processed_batch.clone().into();
NonExhaustive::new(PluginMsg::NextBatch { data: batch_data })
},
Some(retry_callback),
Duration::from_millis(50),
)
.await?;
}
Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
self.transform_plugin
.process_checkpoint_marker(epoch.into())
.await?;
handle_checkpoint_marker(&self.channels, epoch, &runtime).await?;
}
Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
self.transform_plugin
.process_checkpoint_finalizer(epoch.into())
.await?;
handle_checkpoint_finalizer(&self.channels, epoch, &runtime).await?;
}
Ok(Ok(PluginMsg::Terminate)) => {
self.transform_plugin.terminate().await?;
}
Err(TryRecvError::Empty) => {
runtime.yield_now().await;
}
Err(TryRecvError::Disconnected) => {
break;
}
_ => {}
}
}
Ok(())
}
}
pub struct SinkPluginDispatcher {
channels: PluginChannels,
sink_plugin: Arc<dyn SinkPlugin>,
plugin_metrics_recorder: PluginMetricsRecorder,
}
impl SinkPluginDispatcher {
pub fn new(channels: PluginChannels, sink_plugin: Arc<dyn SinkPlugin>) -> Self {
let metrics_sender = channels.metrics.sender.clone();
SinkPluginDispatcher {
channels,
sink_plugin,
plugin_metrics_recorder: PluginMetricsRecorder::new(metrics_sender),
}
}
pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
match wait_for_initialization(&self.channels)? {
InitOutcome::Terminate => {
self.sink_plugin.terminate().await?;
return Ok(());
}
InitOutcome::Init => {}
}
if !self.sink_plugin.is_running() {
return Ok(());
}
self.sink_plugin.initialize().await?;
loop {
if !self.sink_plugin.is_running() {
break;
}
match self
.channels
.input
.receiver
.try_recv()
.map(|m| m.into_enum())
{
Ok(Ok(PluginMsg::NextBatch { data })) => {
let batch: RecordBatch = data.into();
let num_rows = batch.num_rows();
let plugin_process_batch = std::time::Instant::now();
let result = self.sink_plugin.process_batch(batch).await;
let duration = plugin_process_batch.elapsed();
match result {
Ok(()) => {
self.plugin_metrics_recorder
.record_count("output_rows", num_rows as u64);
self.plugin_metrics_recorder
.record_latency("elapsed_compute", duration);
}
Err(e) => {
return Err(e);
}
}
}
Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
self.sink_plugin
.process_checkpoint_marker(epoch.into())
.await?;
handle_checkpoint_ack(&self.channels, epoch, &runtime).await?;
}
Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
self.sink_plugin
.process_checkpoint_finalizer(epoch.into())
.await?
}
Ok(Ok(PluginMsg::Terminate)) => {
self.sink_plugin.terminate().await?;
}
Err(TryRecvError::Empty) => {
runtime.yield_now().await;
}
Err(TryRecvError::Disconnected) => {
break;
}
_ => {}
}
}
Ok(())
}
}
pub struct PreprocessorPluginDispatcher {
channels: PluginChannels,
preprocessor_plugin: Arc<dyn PreprocessorPlugin>,
}
impl PreprocessorPluginDispatcher {
pub fn new(channels: PluginChannels, preprocessor_plugin: Arc<dyn PreprocessorPlugin>) -> Self {
PreprocessorPluginDispatcher {
channels,
preprocessor_plugin,
}
}
pub async fn start(&self) -> Result<(), PluginError> {
match self.channels.input.receiver.recv().map(|m| m.into_enum()) {
Ok(Ok(PluginMsg::Topology { config })) => {
match self
.preprocessor_plugin
.preprocess_topology(config.into_string())
.await
{
Ok(result) => {
self.channels
.output
.sender
.send(NonExhaustive::new(PluginMsg::Topology {
config: RString::from(result),
}))
.map_err(|e| {
PluginError::Execution(format!(
"Failed to send topology response: {}",
e
))
})?;
}
Err(e) => {
let error_msg = e.to_string();
if let Err(send_err) =
self.channels
.output
.sender
.send(NonExhaustive::new(PluginMsg::Error {
message: RString::from(error_msg),
}))
{
tracing::error!(
"Failed to send error message through plugin channel: {}",
send_err
);
}
return Err(e);
}
}
}
Ok(Ok(PluginMsg::Terminate)) => return Ok(()),
Ok(Ok(other)) => {
return Err(PluginError::Execution(format!(
"Expected Topology message, got: {:?}",
other
)));
}
Ok(Err(_)) => {
return Err(PluginError::Execution(
"Malformed message wrapper".to_string(),
));
}
Err(e) => {
return Err(PluginError::Execution(format!(
"Channel disconnected: {}",
e
)));
}
}
match self.channels.input.receiver.recv().map(|m| m.into_enum()) {
Ok(Ok(PluginMsg::Terminate)) => Ok(()),
_ => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ffi::{PluginChannel, PluginChannels, PluginMetricsChannel, PluginMsg};
use abi_stable::external_types::crossbeam_channel;
use async_trait::async_trait;
fn make_channels() -> PluginChannels {
PluginChannels {
input: PluginChannel::new(crossbeam_channel::bounded(8)),
output: PluginChannel::new(crossbeam_channel::bounded(8)),
metrics: PluginMetricsChannel::new(crossbeam_channel::bounded(8)),
}
}
struct FailingPreprocessor {
error_msg: String,
}
#[async_trait]
impl PreprocessorPlugin for FailingPreprocessor {
async fn preprocess_topology(&self, _config: String) -> Result<String, PluginError> {
Err(PluginError::Execution(self.error_msg.clone()))
}
}
struct SuccessPreprocessor {
result: String,
}
#[async_trait]
impl PreprocessorPlugin for SuccessPreprocessor {
async fn preprocess_topology(&self, _config: String) -> Result<String, PluginError> {
Ok(self.result.clone())
}
}
#[tokio::test]
async fn preprocessor_start_sends_error_on_preprocess_failure() {
let channels = make_channels();
let error_msg = "transform 'foo' missing required field 'type'";
let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
error_msg: error_msg.to_string(),
});
let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Topology {
config: RString::from("some_config"),
}))
.unwrap();
let result = dispatcher.start().await;
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains(error_msg),
"start() should propagate the preprocessor error"
);
let output_msg = channels
.output
.receiver
.try_recv()
.expect("output channel should contain an Error message");
match output_msg.into_enum() {
Ok(PluginMsg::Error { message }) => {
assert_eq!(message.as_str(), error_msg);
}
other => panic!("expected PluginMsg::Error, got: {:?}", other),
}
}
#[tokio::test]
async fn preprocessor_start_returns_ok_on_terminate_before_topology() {
let channels = make_channels();
let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
error_msg: "should not be called".to_string(),
});
let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Terminate))
.unwrap();
let result = dispatcher.start().await;
assert!(result.is_ok(), "Terminate before Topology should succeed");
}
#[tokio::test]
async fn preprocessor_start_errors_on_unexpected_message() {
let channels = make_channels();
let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
error_msg: "should not be called".to_string(),
});
let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Init))
.unwrap();
let result = dispatcher.start().await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Expected Topology message"),
);
}
#[tokio::test]
async fn preprocessor_start_sends_topology_response_on_success() {
let channels = make_channels();
let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(SuccessPreprocessor {
result: "processed_config".to_string(),
});
let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Topology {
config: RString::from("input_config"),
}))
.unwrap();
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Terminate))
.unwrap();
let result = dispatcher.start().await;
assert!(result.is_ok());
let output_msg = channels
.output
.receiver
.try_recv()
.expect("output channel should contain a Topology response");
match output_msg.into_enum() {
Ok(PluginMsg::Topology { config }) => {
assert_eq!(config.as_str(), "processed_config");
}
other => panic!("expected PluginMsg::Topology, got: {:?}", other),
}
}
use crate::api::{SinkPlugin, SourcePlugin, SupportsGracefulShutdown, TransformPlugin};
use crate::r#async::DirectTokioProxy;
use arrow::datatypes::{Schema, SchemaRef};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[derive(Default)]
struct LifecycleRecorder {
initialized: AtomicBool,
terminated: AtomicUsize,
}
impl LifecycleRecorder {
fn was_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
fn terminate_count(&self) -> usize {
self.terminated.load(Ordering::SeqCst)
}
}
fn empty_schema() -> SchemaRef {
Arc::new(Schema::empty())
}
struct RecordingSource {
recorder: Arc<LifecycleRecorder>,
running: AtomicBool,
}
impl RecordingSource {
fn new(recorder: Arc<LifecycleRecorder>) -> Self {
Self {
recorder,
running: AtomicBool::new(true),
}
}
}
#[async_trait]
impl SupportsGracefulShutdown for RecordingSource {
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn terminate(&self) -> Result<(), PluginError> {
self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
self.running.store(false, Ordering::SeqCst);
Ok(())
}
}
#[async_trait]
impl SourcePlugin for RecordingSource {
async fn initialize(&self) -> Result<(), PluginError> {
self.recorder.initialized.store(true, Ordering::SeqCst);
Ok(())
}
fn output_schema(&self) -> Result<SchemaRef, PluginError> {
Ok(empty_schema())
}
async fn generate_batch(&self) -> Result<RecordBatch, PluginError> {
Ok(RecordBatch::new_empty(empty_schema()))
}
async fn process_checkpoint_marker(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
async fn process_checkpoint_finalizer(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
}
struct RecordingTransform {
recorder: Arc<LifecycleRecorder>,
running: AtomicBool,
}
impl RecordingTransform {
fn new(recorder: Arc<LifecycleRecorder>) -> Self {
Self {
recorder,
running: AtomicBool::new(true),
}
}
}
#[async_trait]
impl SupportsGracefulShutdown for RecordingTransform {
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn terminate(&self) -> Result<(), PluginError> {
self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
self.running.store(false, Ordering::SeqCst);
Ok(())
}
}
#[async_trait]
impl TransformPlugin for RecordingTransform {
async fn initialize(&self) -> Result<(), PluginError> {
self.recorder.initialized.store(true, Ordering::SeqCst);
Ok(())
}
fn output_schema(&self) -> Result<SchemaRef, PluginError> {
Ok(empty_schema())
}
async fn process_batch(&self, data: RecordBatch) -> Result<RecordBatch, PluginError> {
Ok(data)
}
async fn process_checkpoint_marker(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
async fn process_checkpoint_finalizer(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
}
struct RecordingSink {
recorder: Arc<LifecycleRecorder>,
running: AtomicBool,
}
impl RecordingSink {
fn new(recorder: Arc<LifecycleRecorder>) -> Self {
Self {
recorder,
running: AtomicBool::new(true),
}
}
}
#[async_trait]
impl SupportsGracefulShutdown for RecordingSink {
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn terminate(&self) -> Result<(), PluginError> {
self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
self.running.store(false, Ordering::SeqCst);
Ok(())
}
}
#[async_trait]
impl SinkPlugin for RecordingSink {
async fn initialize(&self) -> Result<(), PluginError> {
self.recorder.initialized.store(true, Ordering::SeqCst);
Ok(())
}
async fn process_batch(&self, _data: RecordBatch) -> Result<(), PluginError> {
Ok(())
}
async fn process_checkpoint_marker(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
async fn process_checkpoint_finalizer(
&self,
_epoch: crate::api::CheckpointEpoch,
) -> Result<(), PluginError> {
Ok(())
}
}
#[tokio::test]
async fn source_start_skips_initialize_on_terminate_before_init() {
let channels = make_channels();
let recorder = Arc::new(LifecycleRecorder::default());
let plugin: Arc<dyn SourcePlugin> = Arc::new(RecordingSource::new(recorder.clone()));
let dispatcher = SourcePluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Terminate))
.unwrap();
let runtime = DirectTokioProxy::new().into_async_runtime_obj();
let result = dispatcher.start(runtime).await;
assert!(result.is_ok(), "Terminate-before-Init should return Ok");
assert!(
!recorder.was_initialized(),
"initialize() must not run when host terminates first"
);
assert_eq!(
recorder.terminate_count(),
1,
"terminate() must be called exactly once on Terminate-before-Init"
);
}
#[tokio::test]
async fn transform_start_skips_initialize_on_terminate_before_init() {
let channels = make_channels();
let recorder = Arc::new(LifecycleRecorder::default());
let plugin: Arc<dyn TransformPlugin> = Arc::new(RecordingTransform::new(recorder.clone()));
let dispatcher = TransformPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Terminate))
.unwrap();
let runtime = DirectTokioProxy::new().into_async_runtime_obj();
let result = dispatcher.start(runtime).await;
assert!(result.is_ok(), "Terminate-before-Init should return Ok");
assert!(
!recorder.was_initialized(),
"initialize() must not run when host terminates first"
);
assert_eq!(
recorder.terminate_count(),
1,
"terminate() must be called exactly once on Terminate-before-Init"
);
}
#[tokio::test]
async fn sink_start_skips_initialize_on_terminate_before_init() {
let channels = make_channels();
let recorder = Arc::new(LifecycleRecorder::default());
let plugin: Arc<dyn SinkPlugin> = Arc::new(RecordingSink::new(recorder.clone()));
let dispatcher = SinkPluginDispatcher::new(channels.clone(), plugin);
channels
.input
.sender
.send(NonExhaustive::new(PluginMsg::Terminate))
.unwrap();
let runtime = DirectTokioProxy::new().into_async_runtime_obj();
let result = dispatcher.start(runtime).await;
assert!(result.is_ok(), "Terminate-before-Init should return Ok");
assert!(
!recorder.was_initialized(),
"initialize() must not run when host terminates first"
);
assert_eq!(
recorder.terminate_count(),
1,
"terminate() must be called exactly once on Terminate-before-Init"
);
}
}