Skip to main content

drasi_lib/reactions/common/
base.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Base implementation for common reaction functionality.
16//!
17//! This module provides `ReactionBase` which encapsulates common patterns
18//! used across all reaction implementations:
19//! - Query subscription management
20//! - Priority queue handling
21//! - Task lifecycle management
22//! - Component status tracking
23//! - Event reporting
24//!
25//! # Plugin Architecture
26//!
27//! ReactionBase is designed to be used by reaction plugins. Each plugin:
28//! 1. Defines its own typed configuration struct
29//! 2. Creates a ReactionBase with ReactionBaseParams
30//! 3. Implements the Reaction trait delegating to ReactionBase methods
31
32use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{
40    ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
41};
42use crate::context::ReactionRuntimeContext;
43use crate::reactions::QueryProvider;
44use crate::state_store::StateStoreProvider;
45
46/// Parameters for creating a ReactionBase instance.
47///
48/// This struct contains only the information that ReactionBase needs to function.
49/// Plugin-specific configuration should remain in the plugin crate.
50///
51/// # Example
52///
53/// ```ignore
54/// use drasi_lib::reactions::common::base::{ReactionBase, ReactionBaseParams};
55///
56/// let params = ReactionBaseParams::new("my-reaction", vec!["query1".to_string()])
57///     .with_priority_queue_capacity(5000)
58///     .with_auto_start(true);
59///
60/// let base = ReactionBase::new(params);
61/// ```
62#[derive(Debug, Clone)]
63pub struct ReactionBaseParams {
64    /// Unique identifier for the reaction
65    pub id: String,
66    /// List of query IDs this reaction subscribes to
67    pub queries: Vec<String>,
68    /// Priority queue capacity - defaults to 10000
69    pub priority_queue_capacity: Option<usize>,
70    /// Whether this reaction should auto-start - defaults to true
71    pub auto_start: bool,
72}
73
74impl ReactionBaseParams {
75    /// Create new params with ID and queries, using defaults for everything else
76    pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
77        Self {
78            id: id.into(),
79            queries,
80            priority_queue_capacity: None,
81            auto_start: true, // Default to true like queries
82        }
83    }
84
85    /// Set the priority queue capacity
86    pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
87        self.priority_queue_capacity = Some(capacity);
88        self
89    }
90
91    /// Set whether this reaction should auto-start
92    pub fn with_auto_start(mut self, auto_start: bool) -> Self {
93        self.auto_start = auto_start;
94        self
95    }
96}
97
98/// Base implementation for common reaction functionality
99pub struct ReactionBase {
100    /// Reaction identifier
101    pub id: String,
102    /// List of query IDs to subscribe to
103    pub queries: Vec<String>,
104    /// Whether this reaction should auto-start
105    pub auto_start: bool,
106    /// Current component status
107    pub status: Arc<RwLock<ComponentStatus>>,
108    /// Runtime context (set by initialize())
109    context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
110    /// Channel for sending component status events (extracted from context for convenience)
111    status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
112    /// Query provider for accessing queries (extracted from context)
113    query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
114    /// State store provider (extracted from context for convenience)
115    state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
116    /// Priority queue for timestamp-ordered result processing
117    pub priority_queue: PriorityQueue<QueryResult>,
118    /// Handles to subscription forwarder tasks
119    pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
120    /// Handle to the main processing task
121    pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
122    /// Sender for shutdown signal to processing task
123    pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
124}
125
126impl ReactionBase {
127    /// Create a new ReactionBase with the given parameters
128    ///
129    /// Dependencies (event channel, query subscriber, state store) are not required during
130    /// construction - they will be provided via `initialize()` when the reaction is added to DrasiLib.
131    pub fn new(params: ReactionBaseParams) -> Self {
132        Self {
133            priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
134            id: params.id,
135            queries: params.queries,
136            auto_start: params.auto_start,
137            status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
138            context: Arc::new(RwLock::new(None)), // Set by initialize()
139            status_tx: Arc::new(RwLock::new(None)), // Extracted from context
140            query_provider: Arc::new(RwLock::new(None)), // Extracted from context
141            state_store: Arc::new(RwLock::new(None)), // Extracted from context
142            subscription_tasks: Arc::new(RwLock::new(Vec::new())),
143            processing_task: Arc::new(RwLock::new(None)),
144            shutdown_tx: Arc::new(RwLock::new(None)),
145        }
146    }
147
148    /// Initialize the reaction with runtime context.
149    ///
150    /// This method is called automatically by DrasiLib's `add_reaction()` method.
151    /// Plugin developers do not need to call this directly.
152    ///
153    /// The context provides access to:
154    /// - `reaction_id`: The reaction's unique identifier
155    /// - `status_tx`: Channel for reporting component status events
156    /// - `state_store`: Optional persistent state storage
157    /// - `query_provider`: Access to query instances for subscription
158    pub async fn initialize(&self, context: ReactionRuntimeContext) {
159        // Store context for later use
160        *self.context.write().await = Some(context.clone());
161
162        // Extract services for convenience
163        *self.status_tx.write().await = Some(context.status_tx.clone());
164        *self.query_provider.write().await = Some(context.query_provider.clone());
165
166        if let Some(state_store) = context.state_store.as_ref() {
167            *self.state_store.write().await = Some(state_store.clone());
168        }
169    }
170
171    /// Get the runtime context if initialized.
172    ///
173    /// Returns `None` if `initialize()` has not been called yet.
174    pub async fn context(&self) -> Option<ReactionRuntimeContext> {
175        self.context.read().await.clone()
176    }
177
178    /// Get the state store if configured.
179    ///
180    /// Returns `None` if no state store was provided in the context.
181    pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
182        self.state_store.read().await.clone()
183    }
184
185    /// Get whether this reaction should auto-start
186    pub fn get_auto_start(&self) -> bool {
187        self.auto_start
188    }
189
190    /// Get the status channel Arc for internal use by spawned tasks
191    ///
192    /// This returns the internal status_tx wrapped in Arc<RwLock<Option<...>>>
193    /// which allows background tasks to send component status events.
194    ///
195    /// Returns a clone of the Arc that can be moved into spawned tasks.
196    pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
197        self.status_tx.clone()
198    }
199
200    /// Clone the ReactionBase with shared Arc references
201    ///
202    /// This creates a new ReactionBase that shares the same underlying
203    /// data through Arc references. Useful for passing to spawned tasks.
204    pub fn clone_shared(&self) -> Self {
205        Self {
206            id: self.id.clone(),
207            queries: self.queries.clone(),
208            auto_start: self.auto_start,
209            status: self.status.clone(),
210            context: self.context.clone(),
211            status_tx: self.status_tx.clone(),
212            query_provider: self.query_provider.clone(),
213            state_store: self.state_store.clone(),
214            priority_queue: self.priority_queue.clone(),
215            subscription_tasks: self.subscription_tasks.clone(),
216            processing_task: self.processing_task.clone(),
217            shutdown_tx: self.shutdown_tx.clone(),
218        }
219    }
220
221    /// Create a shutdown channel and store the sender
222    ///
223    /// Returns the receiver which should be passed to the processing task.
224    /// The sender is stored internally and will be triggered by `stop_common()`.
225    ///
226    /// This should be called before spawning the processing task.
227    pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
228        let (tx, rx) = tokio::sync::oneshot::channel();
229        *self.shutdown_tx.write().await = Some(tx);
230        rx
231    }
232
233    /// Get the reaction ID
234    pub fn get_id(&self) -> &str {
235        &self.id
236    }
237
238    /// Get the query IDs
239    pub fn get_queries(&self) -> &[String] {
240        &self.queries
241    }
242
243    /// Get current status
244    pub async fn get_status(&self) -> ComponentStatus {
245        self.status.read().await.clone()
246    }
247
248    /// Send a component lifecycle event
249    ///
250    /// If the event channel has not been injected yet, this method silently
251    /// succeeds without sending anything. This allows reactions to be used
252    /// in a standalone fashion without DrasiLib if needed.
253    pub async fn send_component_event(
254        &self,
255        status: ComponentStatus,
256        message: Option<String>,
257    ) -> Result<()> {
258        let event = ComponentEvent {
259            component_id: self.id.clone(),
260            component_type: ComponentType::Reaction,
261            status,
262            timestamp: chrono::Utc::now(),
263            message,
264        };
265
266        if let Some(ref tx) = *self.status_tx.read().await {
267            if let Err(e) = tx.send(event).await {
268                error!("Failed to send component event: {e}");
269            }
270        }
271        // If status_tx is None, silently skip - initialization happens before start()
272        Ok(())
273    }
274
275    /// Transition to a new status and send event
276    pub async fn set_status_with_event(
277        &self,
278        status: ComponentStatus,
279        message: Option<String>,
280    ) -> Result<()> {
281        *self.status.write().await = status.clone();
282        self.send_component_event(status, message).await
283    }
284
285    /// Subscribe to all configured queries and spawn forwarder tasks
286    ///
287    /// This method handles the common pattern of:
288    /// 1. Getting query instances via the injected QueryProvider
289    /// 2. Subscribing to each configured query
290    /// 3. Spawning forwarder tasks to enqueue results to priority queue
291    ///
292    /// # Prerequisites
293    /// * `inject_query_provider()` must have been called (done automatically by DrasiLib)
294    ///
295    /// # Returns
296    /// * `Ok(())` if all subscriptions succeeded
297    /// * `Err(...)` if QueryProvider not injected or any subscription failed
298    pub async fn subscribe_to_queries(&self) -> Result<()> {
299        // Get the injected query provider (clone the Arc to release the lock)
300        let query_provider = {
301            let qp_guard = self.query_provider.read().await;
302            qp_guard.as_ref().cloned().ok_or_else(|| {
303                anyhow::anyhow!(
304                    "QueryProvider not injected - was reaction '{}' added to DrasiLib?",
305                    self.id
306                )
307            })?
308        };
309
310        // Subscribe to all configured queries and spawn forwarder tasks
311        for query_id in &self.queries {
312            // Get the query instance via QueryProvider
313            let query = query_provider.get_query_instance(query_id).await?;
314
315            // Subscribe to the query
316            let subscription_response = query
317                .subscribe(self.id.clone())
318                .await
319                .map_err(|e| anyhow::anyhow!(e))?;
320            let mut receiver = subscription_response.receiver;
321
322            // Clone necessary data for the forwarder task
323            let priority_queue = self.priority_queue.clone();
324            let query_id_clone = query_id.clone();
325            let reaction_id = self.id.clone();
326
327            // Get instance_id from context for log routing isolation
328            let instance_id = self
329                .context()
330                .await
331                .map(|c| c.instance_id.clone())
332                .unwrap_or_default();
333
334            // Get query dispatch mode to determine enqueue strategy
335            let query_config = query.get_config();
336            let dispatch_mode = query_config
337                .dispatch_mode
338                .unwrap_or(crate::channels::DispatchMode::Channel);
339            let use_blocking_enqueue =
340                matches!(dispatch_mode, crate::channels::DispatchMode::Channel);
341
342            // Spawn forwarder task with tracing span for proper log routing
343            let span = tracing::info_span!(
344                "reaction_forwarder",
345                instance_id = %instance_id,
346                component_id = %reaction_id,
347                component_type = "reaction"
348            );
349            let forwarder_task = tokio::spawn(
350                async move {
351                    debug!(
352                        "[{reaction_id}] Started result forwarder for query '{query_id_clone}' (dispatch_mode: {dispatch_mode:?}, blocking_enqueue: {use_blocking_enqueue})"
353                    );
354
355                    loop {
356                        match receiver.recv().await {
357                            Ok(query_result) => {
358                                // Use appropriate enqueue method based on dispatch mode
359                                if use_blocking_enqueue {
360                                    // Channel mode: Use blocking enqueue to prevent message loss
361                                    // This creates backpressure when the priority queue is full
362                                    priority_queue.enqueue_wait(query_result).await;
363                                } else {
364                                    // Broadcast mode: Use non-blocking enqueue to prevent deadlock
365                                    // Messages may be dropped when priority queue is full
366                                    if !priority_queue.enqueue(query_result).await {
367                                        warn!(
368                                            "[{reaction_id}] Failed to enqueue result from query '{query_id_clone}' - priority queue at capacity (broadcast mode)"
369                                        );
370                                    }
371                                }
372                            }
373                            Err(e) => {
374                                // Check if it's a lag error or closed channel
375                                let error_str = e.to_string();
376                                if error_str.contains("lagged") {
377                                    warn!(
378                                        "[{reaction_id}] Receiver lagged for query '{query_id_clone}': {error_str}"
379                                    );
380                                    continue;
381                                } else {
382                                    info!(
383                                        "[{reaction_id}] Receiver error for query '{query_id_clone}': {error_str}"
384                                    );
385                                    break;
386                                }
387                            }
388                        }
389                    }
390                }
391                .instrument(span),
392            );
393
394            // Store the forwarder task handle
395            self.subscription_tasks.write().await.push(forwarder_task);
396        }
397
398        Ok(())
399    }
400
401    /// Perform common cleanup operations
402    ///
403    /// This method handles:
404    /// 1. Sending shutdown signal to processing task (for graceful termination)
405    /// 2. Aborting all subscription forwarder tasks
406    /// 3. Waiting for or aborting the processing task
407    /// 4. Draining the priority queue
408    pub async fn stop_common(&self) -> Result<()> {
409        info!("Stopping reaction: {}", self.id);
410
411        // Send shutdown signal to processing task (if it's using tokio::select!)
412        if let Some(tx) = self.shutdown_tx.write().await.take() {
413            let _ = tx.send(());
414        }
415
416        // Abort all subscription forwarder tasks
417        let mut subscription_tasks = self.subscription_tasks.write().await;
418        for task in subscription_tasks.drain(..) {
419            task.abort();
420        }
421        drop(subscription_tasks);
422
423        // Wait for the processing task to complete (with timeout), or abort it
424        let mut processing_task = self.processing_task.write().await;
425        if let Some(task) = processing_task.take() {
426            // Give the task a short time to respond to the shutdown signal
427            match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
428                Ok(Ok(())) => {
429                    debug!("[{}] Processing task completed gracefully", self.id);
430                }
431                Ok(Err(e)) => {
432                    // Task was aborted or panicked
433                    debug!("[{}] Processing task ended: {}", self.id, e);
434                }
435                Err(_) => {
436                    // Timeout - task didn't respond to shutdown signal
437                    // This shouldn't happen if the task is using tokio::select! correctly
438                    warn!(
439                        "[{}] Processing task did not respond to shutdown signal within timeout",
440                        self.id
441                    );
442                }
443            }
444        }
445        drop(processing_task);
446
447        // Drain the priority queue
448        let drained_events = self.priority_queue.drain().await;
449        if !drained_events.is_empty() {
450            info!(
451                "[{}] Drained {} pending events from priority queue",
452                self.id,
453                drained_events.len()
454            );
455        }
456
457        Ok(())
458    }
459
460    /// Set the processing task handle
461    pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
462        *self.processing_task.write().await = Some(task);
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use std::sync::atomic::{AtomicBool, Ordering};
470    use std::time::Duration;
471    use tokio::sync::mpsc;
472
473    #[tokio::test]
474    async fn test_reaction_base_creation() {
475        let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
476            .with_priority_queue_capacity(5000);
477
478        let base = ReactionBase::new(params);
479        assert_eq!(base.id, "test-reaction");
480        assert_eq!(base.get_status().await, ComponentStatus::Stopped);
481    }
482
483    #[tokio::test]
484    async fn test_status_transitions() {
485        use crate::context::ReactionRuntimeContext;
486        use crate::queries::Query;
487
488        // Mock QueryProvider for testing
489        struct MockQueryProvider;
490
491        #[async_trait::async_trait]
492        impl crate::reactions::QueryProvider for MockQueryProvider {
493            async fn get_query_instance(
494                &self,
495                _id: &str,
496            ) -> anyhow::Result<std::sync::Arc<dyn Query>> {
497                Err(anyhow::anyhow!("MockQueryProvider: query not found"))
498            }
499        }
500
501        let (status_tx, mut event_rx) = mpsc::channel(100);
502        let params = ReactionBaseParams::new("test-reaction", vec![]);
503
504        let base = ReactionBase::new(params);
505
506        // Create context and initialize
507        let context = ReactionRuntimeContext::new(
508            "test-instance",
509            "test-reaction",
510            status_tx,
511            None,
512            std::sync::Arc::new(MockQueryProvider),
513        );
514        base.initialize(context).await;
515
516        // Test status transition
517        base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
518            .await
519            .unwrap();
520
521        assert_eq!(base.get_status().await, ComponentStatus::Starting);
522
523        // Check event was sent
524        let event = event_rx.try_recv().unwrap();
525        assert_eq!(event.status, ComponentStatus::Starting);
526        assert_eq!(event.message, Some("Starting test".to_string()));
527    }
528
529    #[tokio::test]
530    async fn test_priority_queue_operations() {
531        let params =
532            ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
533
534        let base = ReactionBase::new(params);
535
536        // Create a test query result
537        let query_result = QueryResult::new(
538            "test-query".to_string(),
539            chrono::Utc::now(),
540            vec![],
541            Default::default(),
542        );
543
544        // Enqueue result
545        let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
546        assert!(enqueued);
547
548        // Drain queue
549        let drained = base.priority_queue.drain().await;
550        assert_eq!(drained.len(), 1);
551    }
552
553    #[tokio::test]
554    async fn test_event_without_initialization() {
555        // Test that send_component_event works even without context initialization
556        let params = ReactionBaseParams::new("test-reaction", vec![]);
557
558        let base = ReactionBase::new(params);
559
560        // This should succeed without panicking (silently does nothing when status_tx is None)
561        base.send_component_event(ComponentStatus::Starting, None)
562            .await
563            .unwrap();
564    }
565
566    // =============================================================================
567    // Shutdown Channel Tests
568    // =============================================================================
569
570    #[tokio::test]
571    async fn test_create_shutdown_channel() {
572        let params = ReactionBaseParams::new("test-reaction", vec![]);
573        let base = ReactionBase::new(params);
574
575        // Initially no shutdown_tx
576        assert!(base.shutdown_tx.read().await.is_none());
577
578        // Create channel
579        let rx = base.create_shutdown_channel().await;
580
581        // Verify tx is stored
582        assert!(base.shutdown_tx.read().await.is_some());
583
584        // Verify receiver is valid (dropping it should not panic)
585        drop(rx);
586    }
587
588    #[tokio::test]
589    async fn test_shutdown_channel_signal() {
590        let params = ReactionBaseParams::new("test-reaction", vec![]);
591        let base = ReactionBase::new(params);
592
593        let mut rx = base.create_shutdown_channel().await;
594
595        // Send signal
596        if let Some(tx) = base.shutdown_tx.write().await.take() {
597            tx.send(()).unwrap();
598        }
599
600        // Verify signal received
601        let result = rx.try_recv();
602        assert!(result.is_ok());
603    }
604
605    #[tokio::test]
606    async fn test_shutdown_channel_replaced_on_second_create() {
607        let params = ReactionBaseParams::new("test-reaction", vec![]);
608        let base = ReactionBase::new(params);
609
610        // Create first channel
611        let _rx1 = base.create_shutdown_channel().await;
612
613        // Create second channel (should replace the first)
614        let mut rx2 = base.create_shutdown_channel().await;
615
616        // Send signal - should go to second channel
617        if let Some(tx) = base.shutdown_tx.write().await.take() {
618            tx.send(()).unwrap();
619        }
620
621        // Second receiver should get the signal
622        let result = rx2.try_recv();
623        assert!(result.is_ok());
624    }
625
626    #[tokio::test]
627    async fn test_stop_common_sends_shutdown_signal() {
628        let params = ReactionBaseParams::new("test-reaction", vec![]);
629        let base = ReactionBase::new(params);
630
631        let mut rx = base.create_shutdown_channel().await;
632
633        // Spawn a task that waits for shutdown
634        let shutdown_received = Arc::new(AtomicBool::new(false));
635        let shutdown_flag = shutdown_received.clone();
636
637        let task = tokio::spawn(async move {
638            tokio::select! {
639                _ = &mut rx => {
640                    shutdown_flag.store(true, Ordering::SeqCst);
641                }
642            }
643        });
644
645        base.set_processing_task(task).await;
646
647        // Call stop_common - should send shutdown signal
648        let _ = base.stop_common().await;
649
650        // Give task time to process
651        tokio::time::sleep(Duration::from_millis(50)).await;
652
653        assert!(
654            shutdown_received.load(Ordering::SeqCst),
655            "Processing task should have received shutdown signal"
656        );
657    }
658
659    #[tokio::test]
660    async fn test_graceful_shutdown_timing() {
661        let params = ReactionBaseParams::new("test-reaction", vec![]);
662        let base = ReactionBase::new(params);
663
664        let rx = base.create_shutdown_channel().await;
665
666        // Spawn task that uses select! pattern like real reactions
667        let task = tokio::spawn(async move {
668            let mut shutdown_rx = rx;
669            loop {
670                tokio::select! {
671                    biased;
672                    _ = &mut shutdown_rx => {
673                        break;
674                    }
675                    _ = tokio::time::sleep(Duration::from_secs(10)) => {
676                        // Simulates waiting on priority_queue.dequeue()
677                    }
678                }
679            }
680        });
681
682        base.set_processing_task(task).await;
683
684        // Measure shutdown time
685        let start = std::time::Instant::now();
686        let _ = base.stop_common().await;
687        let elapsed = start.elapsed();
688
689        // Should complete quickly (< 500ms), not hit 2s timeout
690        assert!(
691            elapsed < Duration::from_millis(500),
692            "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
693        );
694    }
695
696    #[tokio::test]
697    async fn test_stop_common_without_shutdown_channel() {
698        // Test that stop_common works even if no shutdown channel was created
699        let params = ReactionBaseParams::new("test-reaction", vec![]);
700        let base = ReactionBase::new(params);
701
702        // Don't create shutdown channel - just spawn a short-lived task
703        let task = tokio::spawn(async {
704            tokio::time::sleep(Duration::from_millis(10)).await;
705        });
706
707        base.set_processing_task(task).await;
708
709        // stop_common should still work
710        let result = base.stop_common().await;
711        assert!(result.is_ok());
712    }
713}