Skip to main content

streamling_plugin/
dispatch.rs

1//! This modules provides optional dispatching logic to that connects the channel-based FFI
2//! functionality with the clean Rust API.
3//! Users may choose to implement this dispatching logic in their plugins if needed.
4
5use crate::api::{PreprocessorPlugin, SourcePlugin, TransformPlugin};
6use crate::r#async::PluginAsyncRuntimeObj;
7use crate::ffi::SafeArrowArray;
8use crate::{PluginChannels, PluginError, PluginMsg, SinkPlugin};
9use abi_stable::derive_macro_reexports::NonExhaustive;
10use abi_stable::std_types::RString;
11use arrow::array::RecordBatch;
12use async_ffi::FutureExt;
13use crossbeam_channel::TryRecvError;
14use std::sync::Arc;
15use std::time::Duration;
16use tracing::error;
17
18/// Outcome of [`wait_for_initialization`]: distinguishes the two messages the
19/// caller can legitimately receive on the input channel before any data flows.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21enum InitOutcome {
22    /// The host has signaled the plugin to begin its main loop. The plugin
23    /// should call its own `initialize()` and proceed to process data.
24    Init,
25    /// The host sent `Terminate` before `Init`. This happens whenever the
26    /// pipeline tears down before any source's `ExecutionPlan::execute` runs
27    /// (e.g. under `--validate` / `--dry-run`, or when topology setup fails
28    /// after plugins have been constructed). The plugin's `initialize()`
29    /// should NOT be called: it's the only place plugins are allowed to open
30    /// runtime resources (network sockets, ClickHouse connections, etc.), and
31    /// running it just to immediately tear down would do real I/O against
32    /// environments that may not exist (e.g. a validator pod).
33    Terminate,
34}
35
36/// Block until the host sends the first control message. Returns whether it
37/// was `Init` (proceed to initialize) or `Terminate` (skip initialize and
38/// shut down cleanly). Any other message, malformed wrapper, or channel
39/// disconnect is treated as an error.
40fn wait_for_initialization(channels: &PluginChannels) -> Result<InitOutcome, PluginError> {
41    match channels.input.receiver.recv().map(|m| m.into_enum()) {
42        Ok(Ok(PluginMsg::Init)) => Ok(InitOutcome::Init),
43        Ok(Ok(PluginMsg::Terminate)) => Ok(InitOutcome::Terminate),
44        Ok(Ok(_other)) => Err(PluginError::Execution(
45            "Expected Init message as first message".to_string(),
46        )),
47        Ok(Err(_unwrap_err)) => Err(PluginError::Execution(
48            "Malformed message wrapper during initialization".to_string(),
49        )),
50        Err(_recv_err) => Err(PluginError::Execution(
51            "Channel disconnected during initialization".to_string(),
52        )),
53    }
54}
55
56use crate::ffi::{PluginCheckpointEpoch, PluginMetricsRecorder};
57
58/// Handle checkpoint marker message for any plugin type
59async fn handle_checkpoint_marker(
60    channels: &PluginChannels,
61    epoch: PluginCheckpointEpoch,
62    runtime: &PluginAsyncRuntimeObj,
63) -> Result<(), PluginError> {
64    channels
65        .output
66        .send_with_retry(runtime, "Checkpoint marker", || {
67            NonExhaustive::new(PluginMsg::CheckpointMarker { epoch })
68        })
69        .await
70}
71
72/// Handle checkpoint finalizer message for any plugin type
73async fn handle_checkpoint_finalizer(
74    channels: &PluginChannels,
75    epoch: PluginCheckpointEpoch,
76    runtime: &PluginAsyncRuntimeObj,
77) -> Result<(), PluginError> {
78    channels
79        .output
80        .send_with_retry(runtime, "Checkpoint finalizer", || {
81            NonExhaustive::new(PluginMsg::CheckpointFinalizer { epoch })
82        })
83        .await
84}
85
86/// Handle checkpoint ack message for sink plugins
87async fn handle_checkpoint_ack(
88    channels: &PluginChannels,
89    epoch: PluginCheckpointEpoch,
90    runtime: &PluginAsyncRuntimeObj,
91) -> Result<(), PluginError> {
92    channels
93        .output
94        .send_with_retry(runtime, "Checkpoint ack", || {
95            NonExhaustive::new(PluginMsg::CheckpointAck { epoch })
96        })
97        .await
98}
99
100async fn handle_control_messages(
101    channels: &PluginChannels,
102    source_plugin: &Arc<dyn SourcePlugin>,
103    runtime: &PluginAsyncRuntimeObj,
104) -> Result<(), PluginError> {
105    while !channels.input.receiver.is_empty() {
106        match channels.input.receiver.recv().map(|m| m.into_enum()) {
107            Ok(Ok(PluginMsg::Init)) => {
108                return Err(PluginError::Execution(
109                    "Received Init message after plugin was initialized".to_string(),
110                ));
111            }
112            Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
113                source_plugin
114                    .process_checkpoint_marker(epoch.into())
115                    .await?;
116                handle_checkpoint_marker(channels, epoch, runtime).await?;
117            }
118            Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
119                source_plugin
120                    .process_checkpoint_finalizer(epoch.into())
121                    .await?;
122                handle_checkpoint_finalizer(channels, epoch, runtime).await?;
123            }
124            Ok(Ok(PluginMsg::Terminate)) => {
125                source_plugin.terminate().await?;
126            }
127            Err(e) => {
128                return Err(PluginError::Execution(format!(
129                    "Error receiving message from input channel: {e}"
130                )));
131            }
132            _ => {}
133        }
134    }
135    Ok(())
136}
137
138pub struct SourcePluginDispatcher {
139    channels: PluginChannels,
140    source_plugin: Arc<dyn SourcePlugin>,
141}
142
143impl SourcePluginDispatcher {
144    pub fn new(channels: PluginChannels, source_plugin: Arc<dyn SourcePlugin>) -> Self {
145        SourcePluginDispatcher {
146            channels,
147            source_plugin,
148        }
149    }
150
151    pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
152        // If the host sends `Terminate` before `Init` (validation / early-teardown
153        // path), short-circuit before `initialize()` runs. This is the contract
154        // plugin authors rely on when keeping runtime I/O out of `new()`: the
155        // host guarantees `initialize()` does not run when termination comes
156        // first, so plugins can open network connections, DB clients, etc. there
157        // without worrying about hermetic-validation environments.
158        match wait_for_initialization(&self.channels)? {
159            InitOutcome::Terminate => {
160                self.source_plugin.terminate().await?;
161                return Ok(());
162            }
163            InitOutcome::Init => {}
164        }
165        if !self.source_plugin.is_running() {
166            return Ok(());
167        }
168        self.source_plugin.initialize().await?;
169
170        loop {
171            // Generation loop
172            // The idea is to continuously generate batches from the source plugin
173            // and send them to the output channel, BUT it needs to occasionally check
174            // the input channel for control messages (checkpoint markers, etc.). So there is
175            // a timeout that's used to exit the generation loop and check the input channel.
176            let source_plugin = self.source_plugin.clone();
177
178            if !source_plugin.is_running() {
179                break;
180            }
181
182            let runtime_clone = runtime.clone();
183            let channels_clone = self.channels.clone();
184            let source_plugin_clone = self.source_plugin.clone();
185            let generate_batch_future = async move {
186                match source_plugin.generate_batch().await {
187                    Ok(batch) => {
188                        let retry_callback = || -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
189                            let channels = channels_clone.clone();
190                            let source_plugin = source_plugin_clone.clone();
191                            let runtime = runtime_clone.clone();
192                            Box::pin(async move {
193                                // Handle control messages in case the output channel is full
194                                let _ = handle_control_messages(&channels, &source_plugin, &runtime).await;
195                                // Check if plugin is still running
196                                source_plugin.is_running()
197                            })
198                        };
199
200                        let _ = channels_clone.output.send_with_retry_callback(
201                            &runtime_clone,
202                            "Source plugin",
203                            || {
204                                let batch_data: SafeArrowArray = batch.clone().into();
205                                NonExhaustive::new(PluginMsg::NextBatch { data: batch_data })
206                            },
207                            Some(retry_callback),
208                            Duration::from_millis(50),
209                        )
210                        .await;
211                        // Ignore errors - source plugin doesn't propagate them
212                    }
213                    Err(e) => {
214                        error!("Error generating batch: {:?}", e);
215                    }
216                }
217            }
218            .into_ffi();
219
220            runtime.spawn(generate_batch_future).await;
221
222            handle_control_messages(&self.channels, &self.source_plugin, &runtime).await?;
223        }
224
225        Ok(())
226    }
227}
228
229pub struct TransformPluginDispatcher {
230    channels: PluginChannels,
231    transform_plugin: Arc<dyn TransformPlugin>,
232}
233
234impl TransformPluginDispatcher {
235    pub fn new(channels: PluginChannels, transform_plugin: Arc<dyn TransformPlugin>) -> Self {
236        TransformPluginDispatcher {
237            channels,
238            transform_plugin,
239        }
240    }
241
242    pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
243        // See SourcePluginDispatcher::start for the rationale behind short-
244        // circuiting on Terminate before calling `initialize()`.
245        match wait_for_initialization(&self.channels)? {
246            InitOutcome::Terminate => {
247                self.transform_plugin.terminate().await?;
248                return Ok(());
249            }
250            InitOutcome::Init => {}
251        }
252        if !self.transform_plugin.is_running() {
253            return Ok(());
254        }
255        self.transform_plugin.initialize().await?;
256
257        loop {
258            if !self.transform_plugin.is_running() {
259                break;
260            }
261
262            match self
263                .channels
264                .input
265                .receiver
266                .try_recv()
267                .map(|m| m.into_enum())
268            {
269                Ok(Ok(PluginMsg::NextBatch { data })) => {
270                    let batch: RecordBatch = data.into();
271
272                    let processed_batch = self.transform_plugin.process_batch(batch).await?;
273
274                    let transform_plugin = self.transform_plugin.clone();
275                    let retry_callback =
276                        || -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
277                            let plugin = transform_plugin.clone();
278                            Box::pin(async move {
279                                // Check if plugin is still running
280                                plugin.is_running()
281                            })
282                        };
283
284                    self.channels
285                        .output
286                        .send_with_retry_callback(
287                            &runtime,
288                            "Transform plugin",
289                            || {
290                                let batch_data: SafeArrowArray = processed_batch.clone().into();
291                                NonExhaustive::new(PluginMsg::NextBatch { data: batch_data })
292                            },
293                            Some(retry_callback),
294                            Duration::from_millis(50),
295                        )
296                        .await?;
297                }
298                Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
299                    self.transform_plugin
300                        .process_checkpoint_marker(epoch.into())
301                        .await?;
302                    handle_checkpoint_marker(&self.channels, epoch, &runtime).await?;
303                }
304                Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
305                    self.transform_plugin
306                        .process_checkpoint_finalizer(epoch.into())
307                        .await?;
308                    handle_checkpoint_finalizer(&self.channels, epoch, &runtime).await?;
309                }
310                Ok(Ok(PluginMsg::Terminate)) => {
311                    self.transform_plugin.terminate().await?;
312                }
313                Err(TryRecvError::Empty) => {
314                    runtime.yield_now().await;
315                }
316                Err(TryRecvError::Disconnected) => {
317                    break;
318                }
319                _ => {}
320            }
321        }
322
323        Ok(())
324    }
325}
326
327pub struct SinkPluginDispatcher {
328    channels: PluginChannels,
329    sink_plugin: Arc<dyn SinkPlugin>,
330    plugin_metrics_recorder: PluginMetricsRecorder,
331}
332
333impl SinkPluginDispatcher {
334    pub fn new(channels: PluginChannels, sink_plugin: Arc<dyn SinkPlugin>) -> Self {
335        let metrics_sender = channels.metrics.sender.clone();
336        SinkPluginDispatcher {
337            channels,
338            sink_plugin,
339            plugin_metrics_recorder: PluginMetricsRecorder::new(metrics_sender),
340        }
341    }
342
343    pub async fn start(&self, runtime: PluginAsyncRuntimeObj) -> Result<(), PluginError> {
344        // See SourcePluginDispatcher::start for the rationale behind short-
345        // circuiting on Terminate before calling `initialize()`.
346        match wait_for_initialization(&self.channels)? {
347            InitOutcome::Terminate => {
348                self.sink_plugin.terminate().await?;
349                return Ok(());
350            }
351            InitOutcome::Init => {}
352        }
353        if !self.sink_plugin.is_running() {
354            return Ok(());
355        }
356        self.sink_plugin.initialize().await?;
357
358        loop {
359            if !self.sink_plugin.is_running() {
360                break;
361            }
362
363            match self
364                .channels
365                .input
366                .receiver
367                .try_recv()
368                .map(|m| m.into_enum())
369            {
370                Ok(Ok(PluginMsg::NextBatch { data })) => {
371                    let batch: RecordBatch = data.into();
372                    let num_rows = batch.num_rows();
373                    let plugin_process_batch = std::time::Instant::now();
374                    let result = self.sink_plugin.process_batch(batch).await;
375                    let duration = plugin_process_batch.elapsed();
376                    match result {
377                        Ok(()) => {
378                            self.plugin_metrics_recorder
379                                .record_count("output_rows", num_rows as u64);
380                            self.plugin_metrics_recorder
381                                .record_latency("elapsed_compute", duration);
382                        }
383                        Err(e) => {
384                            // Propagate error to cause pipeline failure
385                            // Any retry mechanism should be handled by the plugin itself
386                            return Err(e);
387                        }
388                    }
389                }
390                Ok(Ok(PluginMsg::CheckpointMarker { epoch })) => {
391                    self.sink_plugin
392                        .process_checkpoint_marker(epoch.into())
393                        .await?;
394                    handle_checkpoint_ack(&self.channels, epoch, &runtime).await?;
395                }
396                Ok(Ok(PluginMsg::CheckpointFinalizer { epoch })) => {
397                    self.sink_plugin
398                        .process_checkpoint_finalizer(epoch.into())
399                        .await?
400                }
401                Ok(Ok(PluginMsg::Terminate)) => {
402                    self.sink_plugin.terminate().await?;
403                }
404                Err(TryRecvError::Empty) => {
405                    runtime.yield_now().await;
406                }
407                Err(TryRecvError::Disconnected) => {
408                    break;
409                }
410                _ => {}
411            }
412        }
413
414        Ok(())
415    }
416}
417
418pub struct PreprocessorPluginDispatcher {
419    channels: PluginChannels,
420    preprocessor_plugin: Arc<dyn PreprocessorPlugin>,
421}
422
423impl PreprocessorPluginDispatcher {
424    pub fn new(channels: PluginChannels, preprocessor_plugin: Arc<dyn PreprocessorPlugin>) -> Self {
425        PreprocessorPluginDispatcher {
426            channels,
427            preprocessor_plugin,
428        }
429    }
430
431    pub async fn start(&self) -> Result<(), PluginError> {
432        match self.channels.input.receiver.recv().map(|m| m.into_enum()) {
433            Ok(Ok(PluginMsg::Topology { config })) => {
434                match self
435                    .preprocessor_plugin
436                    .preprocess_topology(config.into_string())
437                    .await
438                {
439                    Ok(result) => {
440                        self.channels
441                            .output
442                            .sender
443                            .send(NonExhaustive::new(PluginMsg::Topology {
444                                config: RString::from(result),
445                            }))
446                            .map_err(|e| {
447                                PluginError::Execution(format!(
448                                    "Failed to send topology response: {}",
449                                    e
450                                ))
451                            })?;
452                    }
453                    Err(e) => {
454                        let error_msg = e.to_string();
455                        if let Err(send_err) =
456                            self.channels
457                                .output
458                                .sender
459                                .send(NonExhaustive::new(PluginMsg::Error {
460                                    message: RString::from(error_msg),
461                                }))
462                        {
463                            tracing::error!(
464                                "Failed to send error message through plugin channel: {}",
465                                send_err
466                            );
467                        }
468                        return Err(e);
469                    }
470                }
471            }
472            Ok(Ok(PluginMsg::Terminate)) => return Ok(()),
473            Ok(Ok(other)) => {
474                return Err(PluginError::Execution(format!(
475                    "Expected Topology message, got: {:?}",
476                    other
477                )));
478            }
479            Ok(Err(_)) => {
480                return Err(PluginError::Execution(
481                    "Malformed message wrapper".to_string(),
482                ));
483            }
484            Err(e) => {
485                return Err(PluginError::Execution(format!(
486                    "Channel disconnected: {}",
487                    e
488                )));
489            }
490        }
491
492        // Wait for Terminate
493        match self.channels.input.receiver.recv().map(|m| m.into_enum()) {
494            Ok(Ok(PluginMsg::Terminate)) => Ok(()),
495            _ => Ok(()),
496        }
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::ffi::{PluginChannel, PluginChannels, PluginMetricsChannel, PluginMsg};
504    use abi_stable::external_types::crossbeam_channel;
505    use async_trait::async_trait;
506
507    fn make_channels() -> PluginChannels {
508        PluginChannels {
509            input: PluginChannel::new(crossbeam_channel::bounded(8)),
510            output: PluginChannel::new(crossbeam_channel::bounded(8)),
511            metrics: PluginMetricsChannel::new(crossbeam_channel::bounded(8)),
512        }
513    }
514
515    struct FailingPreprocessor {
516        error_msg: String,
517    }
518
519    #[async_trait]
520    impl PreprocessorPlugin for FailingPreprocessor {
521        async fn preprocess_topology(&self, _config: String) -> Result<String, PluginError> {
522            Err(PluginError::Execution(self.error_msg.clone()))
523        }
524    }
525
526    struct SuccessPreprocessor {
527        result: String,
528    }
529
530    #[async_trait]
531    impl PreprocessorPlugin for SuccessPreprocessor {
532        async fn preprocess_topology(&self, _config: String) -> Result<String, PluginError> {
533            Ok(self.result.clone())
534        }
535    }
536
537    #[tokio::test]
538    async fn preprocessor_start_sends_error_on_preprocess_failure() {
539        let channels = make_channels();
540        let error_msg = "transform 'foo' missing required field 'type'";
541        let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
542            error_msg: error_msg.to_string(),
543        });
544        let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
545
546        channels
547            .input
548            .sender
549            .send(NonExhaustive::new(PluginMsg::Topology {
550                config: RString::from("some_config"),
551            }))
552            .unwrap();
553
554        let result = dispatcher.start().await;
555        assert!(result.is_err());
556        assert!(
557            result.unwrap_err().to_string().contains(error_msg),
558            "start() should propagate the preprocessor error"
559        );
560
561        let output_msg = channels
562            .output
563            .receiver
564            .try_recv()
565            .expect("output channel should contain an Error message");
566        match output_msg.into_enum() {
567            Ok(PluginMsg::Error { message }) => {
568                assert_eq!(message.as_str(), error_msg);
569            }
570            other => panic!("expected PluginMsg::Error, got: {:?}", other),
571        }
572    }
573
574    #[tokio::test]
575    async fn preprocessor_start_returns_ok_on_terminate_before_topology() {
576        let channels = make_channels();
577        let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
578            error_msg: "should not be called".to_string(),
579        });
580        let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
581
582        channels
583            .input
584            .sender
585            .send(NonExhaustive::new(PluginMsg::Terminate))
586            .unwrap();
587
588        let result = dispatcher.start().await;
589        assert!(result.is_ok(), "Terminate before Topology should succeed");
590    }
591
592    #[tokio::test]
593    async fn preprocessor_start_errors_on_unexpected_message() {
594        let channels = make_channels();
595        let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(FailingPreprocessor {
596            error_msg: "should not be called".to_string(),
597        });
598        let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
599
600        channels
601            .input
602            .sender
603            .send(NonExhaustive::new(PluginMsg::Init))
604            .unwrap();
605
606        let result = dispatcher.start().await;
607        assert!(result.is_err());
608        assert!(
609            result
610                .unwrap_err()
611                .to_string()
612                .contains("Expected Topology message"),
613        );
614    }
615
616    #[tokio::test]
617    async fn preprocessor_start_sends_topology_response_on_success() {
618        let channels = make_channels();
619        let plugin: Arc<dyn PreprocessorPlugin> = Arc::new(SuccessPreprocessor {
620            result: "processed_config".to_string(),
621        });
622        let dispatcher = PreprocessorPluginDispatcher::new(channels.clone(), plugin);
623
624        channels
625            .input
626            .sender
627            .send(NonExhaustive::new(PluginMsg::Topology {
628                config: RString::from("input_config"),
629            }))
630            .unwrap();
631
632        channels
633            .input
634            .sender
635            .send(NonExhaustive::new(PluginMsg::Terminate))
636            .unwrap();
637
638        let result = dispatcher.start().await;
639        assert!(result.is_ok());
640
641        let output_msg = channels
642            .output
643            .receiver
644            .try_recv()
645            .expect("output channel should contain a Topology response");
646        match output_msg.into_enum() {
647            Ok(PluginMsg::Topology { config }) => {
648                assert_eq!(config.as_str(), "processed_config");
649            }
650            other => panic!("expected PluginMsg::Topology, got: {:?}", other),
651        }
652    }
653
654    // ------------------------------------------------------------------
655    // Source / Transform / Sink: Terminate-before-Init short-circuit.
656    //
657    // Validation/early-teardown paths cause `terminate_all_plugins` to send
658    // `Terminate` to plugins whose `Init` was never sent. The dispatcher
659    // contract is that `initialize()` is NOT called in that case, and that
660    // the plugin's `terminate()` IS called. Plugin authors rely on this to
661    // keep runtime I/O (network sockets, DB clients, etc.) inside
662    // `initialize()` without breaking hermetic validation environments.
663    // ------------------------------------------------------------------
664
665    use crate::api::{SinkPlugin, SourcePlugin, SupportsGracefulShutdown, TransformPlugin};
666    use crate::r#async::DirectTokioProxy;
667    use arrow::datatypes::{Schema, SchemaRef};
668    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
669
670    /// Records whether `initialize()` and `terminate()` were called. Used to
671    /// assert short-circuit behavior across all three streaming dispatchers.
672    #[derive(Default)]
673    struct LifecycleRecorder {
674        initialized: AtomicBool,
675        terminated: AtomicUsize,
676    }
677
678    impl LifecycleRecorder {
679        fn was_initialized(&self) -> bool {
680            self.initialized.load(Ordering::SeqCst)
681        }
682        fn terminate_count(&self) -> usize {
683            self.terminated.load(Ordering::SeqCst)
684        }
685    }
686
687    fn empty_schema() -> SchemaRef {
688        Arc::new(Schema::empty())
689    }
690
691    struct RecordingSource {
692        recorder: Arc<LifecycleRecorder>,
693        running: AtomicBool,
694    }
695
696    impl RecordingSource {
697        fn new(recorder: Arc<LifecycleRecorder>) -> Self {
698            Self {
699                recorder,
700                running: AtomicBool::new(true),
701            }
702        }
703    }
704
705    #[async_trait]
706    impl SupportsGracefulShutdown for RecordingSource {
707        fn is_running(&self) -> bool {
708            self.running.load(Ordering::SeqCst)
709        }
710        async fn terminate(&self) -> Result<(), PluginError> {
711            self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
712            self.running.store(false, Ordering::SeqCst);
713            Ok(())
714        }
715    }
716
717    #[async_trait]
718    impl SourcePlugin for RecordingSource {
719        async fn initialize(&self) -> Result<(), PluginError> {
720            self.recorder.initialized.store(true, Ordering::SeqCst);
721            Ok(())
722        }
723        fn output_schema(&self) -> Result<SchemaRef, PluginError> {
724            Ok(empty_schema())
725        }
726        async fn generate_batch(&self) -> Result<RecordBatch, PluginError> {
727            Ok(RecordBatch::new_empty(empty_schema()))
728        }
729        async fn process_checkpoint_marker(
730            &self,
731            _epoch: crate::api::CheckpointEpoch,
732        ) -> Result<(), PluginError> {
733            Ok(())
734        }
735        async fn process_checkpoint_finalizer(
736            &self,
737            _epoch: crate::api::CheckpointEpoch,
738        ) -> Result<(), PluginError> {
739            Ok(())
740        }
741    }
742
743    struct RecordingTransform {
744        recorder: Arc<LifecycleRecorder>,
745        running: AtomicBool,
746    }
747
748    impl RecordingTransform {
749        fn new(recorder: Arc<LifecycleRecorder>) -> Self {
750            Self {
751                recorder,
752                running: AtomicBool::new(true),
753            }
754        }
755    }
756
757    #[async_trait]
758    impl SupportsGracefulShutdown for RecordingTransform {
759        fn is_running(&self) -> bool {
760            self.running.load(Ordering::SeqCst)
761        }
762        async fn terminate(&self) -> Result<(), PluginError> {
763            self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
764            self.running.store(false, Ordering::SeqCst);
765            Ok(())
766        }
767    }
768
769    #[async_trait]
770    impl TransformPlugin for RecordingTransform {
771        async fn initialize(&self) -> Result<(), PluginError> {
772            self.recorder.initialized.store(true, Ordering::SeqCst);
773            Ok(())
774        }
775        fn output_schema(&self) -> Result<SchemaRef, PluginError> {
776            Ok(empty_schema())
777        }
778        async fn process_batch(&self, data: RecordBatch) -> Result<RecordBatch, PluginError> {
779            Ok(data)
780        }
781        async fn process_checkpoint_marker(
782            &self,
783            _epoch: crate::api::CheckpointEpoch,
784        ) -> Result<(), PluginError> {
785            Ok(())
786        }
787        async fn process_checkpoint_finalizer(
788            &self,
789            _epoch: crate::api::CheckpointEpoch,
790        ) -> Result<(), PluginError> {
791            Ok(())
792        }
793    }
794
795    struct RecordingSink {
796        recorder: Arc<LifecycleRecorder>,
797        running: AtomicBool,
798    }
799
800    impl RecordingSink {
801        fn new(recorder: Arc<LifecycleRecorder>) -> Self {
802            Self {
803                recorder,
804                running: AtomicBool::new(true),
805            }
806        }
807    }
808
809    #[async_trait]
810    impl SupportsGracefulShutdown for RecordingSink {
811        fn is_running(&self) -> bool {
812            self.running.load(Ordering::SeqCst)
813        }
814        async fn terminate(&self) -> Result<(), PluginError> {
815            self.recorder.terminated.fetch_add(1, Ordering::SeqCst);
816            self.running.store(false, Ordering::SeqCst);
817            Ok(())
818        }
819    }
820
821    #[async_trait]
822    impl SinkPlugin for RecordingSink {
823        async fn initialize(&self) -> Result<(), PluginError> {
824            self.recorder.initialized.store(true, Ordering::SeqCst);
825            Ok(())
826        }
827        async fn process_batch(&self, _data: RecordBatch) -> Result<(), PluginError> {
828            Ok(())
829        }
830        async fn process_checkpoint_marker(
831            &self,
832            _epoch: crate::api::CheckpointEpoch,
833        ) -> Result<(), PluginError> {
834            Ok(())
835        }
836        async fn process_checkpoint_finalizer(
837            &self,
838            _epoch: crate::api::CheckpointEpoch,
839        ) -> Result<(), PluginError> {
840            Ok(())
841        }
842    }
843
844    #[tokio::test]
845    async fn source_start_skips_initialize_on_terminate_before_init() {
846        let channels = make_channels();
847        let recorder = Arc::new(LifecycleRecorder::default());
848        let plugin: Arc<dyn SourcePlugin> = Arc::new(RecordingSource::new(recorder.clone()));
849        let dispatcher = SourcePluginDispatcher::new(channels.clone(), plugin);
850
851        channels
852            .input
853            .sender
854            .send(NonExhaustive::new(PluginMsg::Terminate))
855            .unwrap();
856
857        let runtime = DirectTokioProxy::new().into_async_runtime_obj();
858        let result = dispatcher.start(runtime).await;
859
860        assert!(result.is_ok(), "Terminate-before-Init should return Ok");
861        assert!(
862            !recorder.was_initialized(),
863            "initialize() must not run when host terminates first"
864        );
865        assert_eq!(
866            recorder.terminate_count(),
867            1,
868            "terminate() must be called exactly once on Terminate-before-Init"
869        );
870    }
871
872    #[tokio::test]
873    async fn transform_start_skips_initialize_on_terminate_before_init() {
874        let channels = make_channels();
875        let recorder = Arc::new(LifecycleRecorder::default());
876        let plugin: Arc<dyn TransformPlugin> = Arc::new(RecordingTransform::new(recorder.clone()));
877        let dispatcher = TransformPluginDispatcher::new(channels.clone(), plugin);
878
879        channels
880            .input
881            .sender
882            .send(NonExhaustive::new(PluginMsg::Terminate))
883            .unwrap();
884
885        let runtime = DirectTokioProxy::new().into_async_runtime_obj();
886        let result = dispatcher.start(runtime).await;
887
888        assert!(result.is_ok(), "Terminate-before-Init should return Ok");
889        assert!(
890            !recorder.was_initialized(),
891            "initialize() must not run when host terminates first"
892        );
893        assert_eq!(
894            recorder.terminate_count(),
895            1,
896            "terminate() must be called exactly once on Terminate-before-Init"
897        );
898    }
899
900    #[tokio::test]
901    async fn sink_start_skips_initialize_on_terminate_before_init() {
902        let channels = make_channels();
903        let recorder = Arc::new(LifecycleRecorder::default());
904        let plugin: Arc<dyn SinkPlugin> = Arc::new(RecordingSink::new(recorder.clone()));
905        let dispatcher = SinkPluginDispatcher::new(channels.clone(), plugin);
906
907        channels
908            .input
909            .sender
910            .send(NonExhaustive::new(PluginMsg::Terminate))
911            .unwrap();
912
913        let runtime = DirectTokioProxy::new().into_async_runtime_obj();
914        let result = dispatcher.start(runtime).await;
915
916        assert!(result.is_ok(), "Terminate-before-Init should return Ok");
917        assert!(
918            !recorder.was_initialized(),
919            "initialize() must not run when host terminates first"
920        );
921        assert_eq!(
922            recorder.terminate_count(),
923            1,
924            "terminate() must be called exactly once on Terminate-before-Init"
925        );
926    }
927}