Skip to main content

drasi_lib/reactions/common/
base.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Base implementation for common reaction functionality.
16//!
17//! This module provides `ReactionBase` which encapsulates common patterns
18//! used across all reaction implementations:
19//! - Query subscription management
20//! - Priority queue handling
21//! - Task lifecycle management
22//! - Component status tracking
23//! - Event reporting
24//!
25//! # Plugin Architecture
26//!
27//! ReactionBase is designed to be used by reaction plugins. Each plugin:
28//! 1. Defines its own typed configuration struct
29//! 2. Creates a ReactionBase with ReactionBaseParams
30//! 3. Implements the Reaction trait delegating to ReactionBase methods
31
32use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36
37use crate::channels::priority_queue::PriorityQueue;
38use crate::channels::{
39    ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
40};
41use crate::context::ReactionRuntimeContext;
42use crate::reactions::QueryProvider;
43use crate::state_store::StateStoreProvider;
44
45/// Parameters for creating a ReactionBase instance.
46///
47/// This struct contains only the information that ReactionBase needs to function.
48/// Plugin-specific configuration should remain in the plugin crate.
49///
50/// # Example
51///
52/// ```ignore
53/// use drasi_lib::reactions::common::base::{ReactionBase, ReactionBaseParams};
54///
55/// let params = ReactionBaseParams::new("my-reaction", vec!["query1".to_string()])
56///     .with_priority_queue_capacity(5000)
57///     .with_auto_start(true);
58///
59/// let base = ReactionBase::new(params);
60/// ```
61#[derive(Debug, Clone)]
62pub struct ReactionBaseParams {
63    /// Unique identifier for the reaction
64    pub id: String,
65    /// List of query IDs this reaction subscribes to
66    pub queries: Vec<String>,
67    /// Priority queue capacity - defaults to 10000
68    pub priority_queue_capacity: Option<usize>,
69    /// Whether this reaction should auto-start - defaults to true
70    pub auto_start: bool,
71}
72
73impl ReactionBaseParams {
74    /// Create new params with ID and queries, using defaults for everything else
75    pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
76        Self {
77            id: id.into(),
78            queries,
79            priority_queue_capacity: None,
80            auto_start: true, // Default to true like queries
81        }
82    }
83
84    /// Set the priority queue capacity
85    pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
86        self.priority_queue_capacity = Some(capacity);
87        self
88    }
89
90    /// Set whether this reaction should auto-start
91    pub fn with_auto_start(mut self, auto_start: bool) -> Self {
92        self.auto_start = auto_start;
93        self
94    }
95}
96
97/// Base implementation for common reaction functionality
98pub struct ReactionBase {
99    /// Reaction identifier
100    pub id: String,
101    /// List of query IDs to subscribe to
102    pub queries: Vec<String>,
103    /// Whether this reaction should auto-start
104    pub auto_start: bool,
105    /// Current component status
106    pub status: Arc<RwLock<ComponentStatus>>,
107    /// Runtime context (set by initialize())
108    context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
109    /// Channel for sending component status events (extracted from context for convenience)
110    status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
111    /// Query provider for accessing queries (extracted from context)
112    query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
113    /// State store provider (extracted from context for convenience)
114    state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
115    /// Priority queue for timestamp-ordered result processing
116    pub priority_queue: PriorityQueue<QueryResult>,
117    /// Handles to subscription forwarder tasks
118    pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
119    /// Handle to the main processing task
120    pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
121    /// Sender for shutdown signal to processing task
122    pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
123}
124
125impl ReactionBase {
126    /// Create a new ReactionBase with the given parameters
127    ///
128    /// Dependencies (event channel, query subscriber, state store) are not required during
129    /// construction - they will be provided via `initialize()` when the reaction is added to DrasiLib.
130    pub fn new(params: ReactionBaseParams) -> Self {
131        Self {
132            priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
133            id: params.id,
134            queries: params.queries,
135            auto_start: params.auto_start,
136            status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
137            context: Arc::new(RwLock::new(None)), // Set by initialize()
138            status_tx: Arc::new(RwLock::new(None)), // Extracted from context
139            query_provider: Arc::new(RwLock::new(None)), // Extracted from context
140            state_store: Arc::new(RwLock::new(None)), // Extracted from context
141            subscription_tasks: Arc::new(RwLock::new(Vec::new())),
142            processing_task: Arc::new(RwLock::new(None)),
143            shutdown_tx: Arc::new(RwLock::new(None)),
144        }
145    }
146
147    /// Initialize the reaction with runtime context.
148    ///
149    /// This method is called automatically by DrasiLib's `add_reaction()` method.
150    /// Plugin developers do not need to call this directly.
151    ///
152    /// The context provides access to:
153    /// - `reaction_id`: The reaction's unique identifier
154    /// - `status_tx`: Channel for reporting component status events
155    /// - `state_store`: Optional persistent state storage
156    /// - `query_provider`: Access to query instances for subscription
157    pub async fn initialize(&self, context: ReactionRuntimeContext) {
158        // Store context for later use
159        *self.context.write().await = Some(context.clone());
160
161        // Extract services for convenience
162        *self.status_tx.write().await = Some(context.status_tx.clone());
163        *self.query_provider.write().await = Some(context.query_provider.clone());
164
165        if let Some(state_store) = context.state_store.as_ref() {
166            *self.state_store.write().await = Some(state_store.clone());
167        }
168    }
169
170    /// Get the runtime context if initialized.
171    ///
172    /// Returns `None` if `initialize()` has not been called yet.
173    pub async fn context(&self) -> Option<ReactionRuntimeContext> {
174        self.context.read().await.clone()
175    }
176
177    /// Get the state store if configured.
178    ///
179    /// Returns `None` if no state store was provided in the context.
180    pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
181        self.state_store.read().await.clone()
182    }
183
184    /// Get whether this reaction should auto-start
185    pub fn get_auto_start(&self) -> bool {
186        self.auto_start
187    }
188
189    /// Get the status channel Arc for internal use by spawned tasks
190    ///
191    /// This returns the internal status_tx wrapped in Arc<RwLock<Option<...>>>
192    /// which allows background tasks to send component status events.
193    ///
194    /// Returns a clone of the Arc that can be moved into spawned tasks.
195    pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
196        self.status_tx.clone()
197    }
198
199    /// Clone the ReactionBase with shared Arc references
200    ///
201    /// This creates a new ReactionBase that shares the same underlying
202    /// data through Arc references. Useful for passing to spawned tasks.
203    pub fn clone_shared(&self) -> Self {
204        Self {
205            id: self.id.clone(),
206            queries: self.queries.clone(),
207            auto_start: self.auto_start,
208            status: self.status.clone(),
209            context: self.context.clone(),
210            status_tx: self.status_tx.clone(),
211            query_provider: self.query_provider.clone(),
212            state_store: self.state_store.clone(),
213            priority_queue: self.priority_queue.clone(),
214            subscription_tasks: self.subscription_tasks.clone(),
215            processing_task: self.processing_task.clone(),
216            shutdown_tx: self.shutdown_tx.clone(),
217        }
218    }
219
220    /// Create a shutdown channel and store the sender
221    ///
222    /// Returns the receiver which should be passed to the processing task.
223    /// The sender is stored internally and will be triggered by `stop_common()`.
224    ///
225    /// This should be called before spawning the processing task.
226    pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
227        let (tx, rx) = tokio::sync::oneshot::channel();
228        *self.shutdown_tx.write().await = Some(tx);
229        rx
230    }
231
232    /// Get the reaction ID
233    pub fn get_id(&self) -> &str {
234        &self.id
235    }
236
237    /// Get the query IDs
238    pub fn get_queries(&self) -> &[String] {
239        &self.queries
240    }
241
242    /// Get current status
243    pub async fn get_status(&self) -> ComponentStatus {
244        self.status.read().await.clone()
245    }
246
247    /// Send a component lifecycle event
248    ///
249    /// If the event channel has not been injected yet, this method silently
250    /// succeeds without sending anything. This allows reactions to be used
251    /// in a standalone fashion without DrasiLib if needed.
252    pub async fn send_component_event(
253        &self,
254        status: ComponentStatus,
255        message: Option<String>,
256    ) -> Result<()> {
257        let event = ComponentEvent {
258            component_id: self.id.clone(),
259            component_type: ComponentType::Reaction,
260            status,
261            timestamp: chrono::Utc::now(),
262            message,
263        };
264
265        if let Some(ref tx) = *self.status_tx.read().await {
266            if let Err(e) = tx.send(event).await {
267                error!("Failed to send component event: {e}");
268            }
269        }
270        // If status_tx is None, silently skip - initialization happens before start()
271        Ok(())
272    }
273
274    /// Transition to a new status and send event
275    pub async fn set_status_with_event(
276        &self,
277        status: ComponentStatus,
278        message: Option<String>,
279    ) -> Result<()> {
280        *self.status.write().await = status.clone();
281        self.send_component_event(status, message).await
282    }
283
284    /// Subscribe to all configured queries and spawn forwarder tasks
285    ///
286    /// This method handles the common pattern of:
287    /// 1. Getting query instances via the injected QueryProvider
288    /// 2. Subscribing to each configured query
289    /// 3. Spawning forwarder tasks to enqueue results to priority queue
290    ///
291    /// # Prerequisites
292    /// * `inject_query_provider()` must have been called (done automatically by DrasiLib)
293    ///
294    /// # Returns
295    /// * `Ok(())` if all subscriptions succeeded
296    /// * `Err(...)` if QueryProvider not injected or any subscription failed
297    pub async fn subscribe_to_queries(&self) -> Result<()> {
298        // Get the injected query provider (clone the Arc to release the lock)
299        let query_provider = {
300            let qp_guard = self.query_provider.read().await;
301            qp_guard.as_ref().cloned().ok_or_else(|| {
302                anyhow::anyhow!(
303                    "QueryProvider not injected - was reaction '{}' added to DrasiLib?",
304                    self.id
305                )
306            })?
307        };
308
309        // Subscribe to all configured queries and spawn forwarder tasks
310        for query_id in &self.queries {
311            // Get the query instance via QueryProvider
312            let query = query_provider.get_query_instance(query_id).await?;
313
314            // Subscribe to the query
315            let subscription_response = query
316                .subscribe(self.id.clone())
317                .await
318                .map_err(|e| anyhow::anyhow!(e))?;
319            let mut receiver = subscription_response.receiver;
320
321            // Clone necessary data for the forwarder task
322            let priority_queue = self.priority_queue.clone();
323            let query_id_clone = query_id.clone();
324            let reaction_id = self.id.clone();
325
326            // Get query dispatch mode to determine enqueue strategy
327            let query_config = query.get_config();
328            let dispatch_mode = query_config
329                .dispatch_mode
330                .unwrap_or(crate::channels::DispatchMode::Channel);
331            let use_blocking_enqueue =
332                matches!(dispatch_mode, crate::channels::DispatchMode::Channel);
333
334            // Spawn forwarder task to read from receiver and enqueue to priority queue
335            let forwarder_task = tokio::spawn(async move {
336                debug!(
337                    "[{reaction_id}] Started result forwarder for query '{query_id_clone}' (dispatch_mode: {dispatch_mode:?}, blocking_enqueue: {use_blocking_enqueue})"
338                );
339
340                loop {
341                    match receiver.recv().await {
342                        Ok(query_result) => {
343                            // Use appropriate enqueue method based on dispatch mode
344                            if use_blocking_enqueue {
345                                // Channel mode: Use blocking enqueue to prevent message loss
346                                // This creates backpressure when the priority queue is full
347                                priority_queue.enqueue_wait(query_result).await;
348                            } else {
349                                // Broadcast mode: Use non-blocking enqueue to prevent deadlock
350                                // Messages may be dropped when priority queue is full
351                                if !priority_queue.enqueue(query_result).await {
352                                    warn!(
353                                        "[{reaction_id}] Failed to enqueue result from query '{query_id_clone}' - priority queue at capacity (broadcast mode)"
354                                    );
355                                }
356                            }
357                        }
358                        Err(e) => {
359                            // Check if it's a lag error or closed channel
360                            let error_str = e.to_string();
361                            if error_str.contains("lagged") {
362                                warn!(
363                                    "[{reaction_id}] Receiver lagged for query '{query_id_clone}': {error_str}"
364                                );
365                                continue;
366                            } else {
367                                info!(
368                                    "[{reaction_id}] Receiver error for query '{query_id_clone}': {error_str}"
369                                );
370                                break;
371                            }
372                        }
373                    }
374                }
375            });
376
377            // Store the forwarder task handle
378            self.subscription_tasks.write().await.push(forwarder_task);
379        }
380
381        Ok(())
382    }
383
384    /// Perform common cleanup operations
385    ///
386    /// This method handles:
387    /// 1. Sending shutdown signal to processing task (for graceful termination)
388    /// 2. Aborting all subscription forwarder tasks
389    /// 3. Waiting for or aborting the processing task
390    /// 4. Draining the priority queue
391    pub async fn stop_common(&self) -> Result<()> {
392        info!("Stopping reaction: {}", self.id);
393
394        // Send shutdown signal to processing task (if it's using tokio::select!)
395        if let Some(tx) = self.shutdown_tx.write().await.take() {
396            let _ = tx.send(());
397        }
398
399        // Abort all subscription forwarder tasks
400        let mut subscription_tasks = self.subscription_tasks.write().await;
401        for task in subscription_tasks.drain(..) {
402            task.abort();
403        }
404        drop(subscription_tasks);
405
406        // Wait for the processing task to complete (with timeout), or abort it
407        let mut processing_task = self.processing_task.write().await;
408        if let Some(task) = processing_task.take() {
409            // Give the task a short time to respond to the shutdown signal
410            match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
411                Ok(Ok(())) => {
412                    debug!("[{}] Processing task completed gracefully", self.id);
413                }
414                Ok(Err(e)) => {
415                    // Task was aborted or panicked
416                    debug!("[{}] Processing task ended: {}", self.id, e);
417                }
418                Err(_) => {
419                    // Timeout - task didn't respond to shutdown signal
420                    // This shouldn't happen if the task is using tokio::select! correctly
421                    warn!(
422                        "[{}] Processing task did not respond to shutdown signal within timeout",
423                        self.id
424                    );
425                }
426            }
427        }
428        drop(processing_task);
429
430        // Drain the priority queue
431        let drained_events = self.priority_queue.drain().await;
432        if !drained_events.is_empty() {
433            info!(
434                "[{}] Drained {} pending events from priority queue",
435                self.id,
436                drained_events.len()
437            );
438        }
439
440        Ok(())
441    }
442
443    /// Set the processing task handle
444    pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
445        *self.processing_task.write().await = Some(task);
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use std::sync::atomic::{AtomicBool, Ordering};
453    use std::time::Duration;
454    use tokio::sync::mpsc;
455
456    #[tokio::test]
457    async fn test_reaction_base_creation() {
458        let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
459            .with_priority_queue_capacity(5000);
460
461        let base = ReactionBase::new(params);
462        assert_eq!(base.id, "test-reaction");
463        assert_eq!(base.get_status().await, ComponentStatus::Stopped);
464    }
465
466    #[tokio::test]
467    async fn test_status_transitions() {
468        use crate::context::ReactionRuntimeContext;
469        use crate::queries::Query;
470
471        // Mock QueryProvider for testing
472        struct MockQueryProvider;
473
474        #[async_trait::async_trait]
475        impl crate::reactions::QueryProvider for MockQueryProvider {
476            async fn get_query_instance(
477                &self,
478                _id: &str,
479            ) -> anyhow::Result<std::sync::Arc<dyn Query>> {
480                Err(anyhow::anyhow!("MockQueryProvider: query not found"))
481            }
482        }
483
484        let (status_tx, mut event_rx) = mpsc::channel(100);
485        let params = ReactionBaseParams::new("test-reaction", vec![]);
486
487        let base = ReactionBase::new(params);
488
489        // Create context and initialize
490        let context = ReactionRuntimeContext::new(
491            "test-reaction",
492            status_tx,
493            None,
494            std::sync::Arc::new(MockQueryProvider),
495        );
496        base.initialize(context).await;
497
498        // Test status transition
499        base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
500            .await
501            .unwrap();
502
503        assert_eq!(base.get_status().await, ComponentStatus::Starting);
504
505        // Check event was sent
506        let event = event_rx.try_recv().unwrap();
507        assert_eq!(event.status, ComponentStatus::Starting);
508        assert_eq!(event.message, Some("Starting test".to_string()));
509    }
510
511    #[tokio::test]
512    async fn test_priority_queue_operations() {
513        let params =
514            ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
515
516        let base = ReactionBase::new(params);
517
518        // Create a test query result
519        let query_result = QueryResult::new(
520            "test-query".to_string(),
521            chrono::Utc::now(),
522            vec![],
523            Default::default(),
524        );
525
526        // Enqueue result
527        let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
528        assert!(enqueued);
529
530        // Drain queue
531        let drained = base.priority_queue.drain().await;
532        assert_eq!(drained.len(), 1);
533    }
534
535    #[tokio::test]
536    async fn test_event_without_initialization() {
537        // Test that send_component_event works even without context initialization
538        let params = ReactionBaseParams::new("test-reaction", vec![]);
539
540        let base = ReactionBase::new(params);
541
542        // This should succeed without panicking (silently does nothing when status_tx is None)
543        base.send_component_event(ComponentStatus::Starting, None)
544            .await
545            .unwrap();
546    }
547
548    // =============================================================================
549    // Shutdown Channel Tests
550    // =============================================================================
551
552    #[tokio::test]
553    async fn test_create_shutdown_channel() {
554        let params = ReactionBaseParams::new("test-reaction", vec![]);
555        let base = ReactionBase::new(params);
556
557        // Initially no shutdown_tx
558        assert!(base.shutdown_tx.read().await.is_none());
559
560        // Create channel
561        let rx = base.create_shutdown_channel().await;
562
563        // Verify tx is stored
564        assert!(base.shutdown_tx.read().await.is_some());
565
566        // Verify receiver is valid (dropping it should not panic)
567        drop(rx);
568    }
569
570    #[tokio::test]
571    async fn test_shutdown_channel_signal() {
572        let params = ReactionBaseParams::new("test-reaction", vec![]);
573        let base = ReactionBase::new(params);
574
575        let mut rx = base.create_shutdown_channel().await;
576
577        // Send signal
578        if let Some(tx) = base.shutdown_tx.write().await.take() {
579            tx.send(()).unwrap();
580        }
581
582        // Verify signal received
583        let result = rx.try_recv();
584        assert!(result.is_ok());
585    }
586
587    #[tokio::test]
588    async fn test_shutdown_channel_replaced_on_second_create() {
589        let params = ReactionBaseParams::new("test-reaction", vec![]);
590        let base = ReactionBase::new(params);
591
592        // Create first channel
593        let _rx1 = base.create_shutdown_channel().await;
594
595        // Create second channel (should replace the first)
596        let mut rx2 = base.create_shutdown_channel().await;
597
598        // Send signal - should go to second channel
599        if let Some(tx) = base.shutdown_tx.write().await.take() {
600            tx.send(()).unwrap();
601        }
602
603        // Second receiver should get the signal
604        let result = rx2.try_recv();
605        assert!(result.is_ok());
606    }
607
608    #[tokio::test]
609    async fn test_stop_common_sends_shutdown_signal() {
610        let params = ReactionBaseParams::new("test-reaction", vec![]);
611        let base = ReactionBase::new(params);
612
613        let mut rx = base.create_shutdown_channel().await;
614
615        // Spawn a task that waits for shutdown
616        let shutdown_received = Arc::new(AtomicBool::new(false));
617        let shutdown_flag = shutdown_received.clone();
618
619        let task = tokio::spawn(async move {
620            tokio::select! {
621                _ = &mut rx => {
622                    shutdown_flag.store(true, Ordering::SeqCst);
623                }
624            }
625        });
626
627        base.set_processing_task(task).await;
628
629        // Call stop_common - should send shutdown signal
630        let _ = base.stop_common().await;
631
632        // Give task time to process
633        tokio::time::sleep(Duration::from_millis(50)).await;
634
635        assert!(
636            shutdown_received.load(Ordering::SeqCst),
637            "Processing task should have received shutdown signal"
638        );
639    }
640
641    #[tokio::test]
642    async fn test_graceful_shutdown_timing() {
643        let params = ReactionBaseParams::new("test-reaction", vec![]);
644        let base = ReactionBase::new(params);
645
646        let rx = base.create_shutdown_channel().await;
647
648        // Spawn task that uses select! pattern like real reactions
649        let task = tokio::spawn(async move {
650            let mut shutdown_rx = rx;
651            loop {
652                tokio::select! {
653                    biased;
654                    _ = &mut shutdown_rx => {
655                        break;
656                    }
657                    _ = tokio::time::sleep(Duration::from_secs(10)) => {
658                        // Simulates waiting on priority_queue.dequeue()
659                    }
660                }
661            }
662        });
663
664        base.set_processing_task(task).await;
665
666        // Measure shutdown time
667        let start = std::time::Instant::now();
668        let _ = base.stop_common().await;
669        let elapsed = start.elapsed();
670
671        // Should complete quickly (< 500ms), not hit 2s timeout
672        assert!(
673            elapsed < Duration::from_millis(500),
674            "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
675        );
676    }
677
678    #[tokio::test]
679    async fn test_stop_common_without_shutdown_channel() {
680        // Test that stop_common works even if no shutdown channel was created
681        let params = ReactionBaseParams::new("test-reaction", vec![]);
682        let base = ReactionBase::new(params);
683
684        // Don't create shutdown channel - just spawn a short-lived task
685        let task = tokio::spawn(async {
686            tokio::time::sleep(Duration::from_millis(10)).await;
687        });
688
689        base.set_processing_task(task).await;
690
691        // stop_common should still work
692        let result = base.stop_common().await;
693        assert!(result.is_ok());
694    }
695}