Skip to main content

drasi_lib/reactions/common/
base.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Base implementation for common reaction functionality.
16//!
17//! This module provides `ReactionBase` which encapsulates common patterns
18//! used across all reaction implementations:
19//! - Query subscription management
20//! - Priority queue handling
21//! - Task lifecycle management
22//! - Component status tracking
23//! - Event reporting
24//!
25//! # Plugin Architecture
26//!
27//! ReactionBase is designed to be used by reaction plugins. Each plugin:
28//! 1. Defines its own typed configuration struct
29//! 2. Creates a ReactionBase with ReactionBaseParams
30//! 3. Implements the Reaction trait delegating to ReactionBase methods
31
32use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{
40    ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
41};
42use crate::context::ReactionRuntimeContext;
43use crate::identity::IdentityProvider;
44use crate::state_store::StateStoreProvider;
45
46/// Parameters for creating a ReactionBase instance.
47///
48/// This struct contains only the information that ReactionBase needs to function.
49/// Plugin-specific configuration should remain in the plugin crate.
50///
51/// # Example
52///
53/// ```ignore
54/// use drasi_lib::reactions::common::base::{ReactionBase, ReactionBaseParams};
55///
56/// let params = ReactionBaseParams::new("my-reaction", vec!["query1".to_string()])
57///     .with_priority_queue_capacity(5000)
58///     .with_auto_start(true);
59///
60/// let base = ReactionBase::new(params);
61/// ```
62#[derive(Debug, Clone)]
63pub struct ReactionBaseParams {
64    /// Unique identifier for the reaction
65    pub id: String,
66    /// List of query IDs this reaction subscribes to
67    pub queries: Vec<String>,
68    /// Priority queue capacity - defaults to 10000
69    pub priority_queue_capacity: Option<usize>,
70    /// Whether this reaction should auto-start - defaults to true
71    pub auto_start: bool,
72}
73
74impl ReactionBaseParams {
75    /// Create new params with ID and queries, using defaults for everything else
76    pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
77        Self {
78            id: id.into(),
79            queries,
80            priority_queue_capacity: None,
81            auto_start: true, // Default to true like queries
82        }
83    }
84
85    /// Set the priority queue capacity
86    pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
87        self.priority_queue_capacity = Some(capacity);
88        self
89    }
90
91    /// Set whether this reaction should auto-start
92    pub fn with_auto_start(mut self, auto_start: bool) -> Self {
93        self.auto_start = auto_start;
94        self
95    }
96}
97
98/// Base implementation for common reaction functionality
99pub struct ReactionBase {
100    /// Reaction identifier
101    pub id: String,
102    /// List of query IDs to subscribe to
103    pub queries: Vec<String>,
104    /// Whether this reaction should auto-start
105    pub auto_start: bool,
106    /// Current component status
107    pub status: Arc<RwLock<ComponentStatus>>,
108    /// Runtime context (set by initialize())
109    context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
110    /// Channel for sending component status events (extracted from context for convenience)
111    status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
112    /// State store provider (extracted from context for convenience)
113    state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
114    /// Priority queue for timestamp-ordered result processing
115    pub priority_queue: PriorityQueue<QueryResult>,
116    /// Handles to subscription forwarder tasks
117    pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
118    /// Handle to the main processing task
119    pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
120    /// Sender for shutdown signal to processing task
121    pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
122    /// Optional identity provider for credential management.
123    /// Set either programmatically (via `set_identity_provider`) or automatically
124    /// from the runtime context during `initialize()`.
125    identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
126}
127
128impl ReactionBase {
129    /// Create a new ReactionBase with the given parameters
130    ///
131    /// Dependencies (event channel, query subscriber, state store) are not required during
132    /// construction - they will be provided via `initialize()` when the reaction is added to DrasiLib.
133    pub fn new(params: ReactionBaseParams) -> Self {
134        Self {
135            priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
136            id: params.id,
137            queries: params.queries,
138            auto_start: params.auto_start,
139            status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
140            context: Arc::new(RwLock::new(None)),
141            status_tx: Arc::new(RwLock::new(None)),
142            state_store: Arc::new(RwLock::new(None)),
143            subscription_tasks: Arc::new(RwLock::new(Vec::new())),
144            processing_task: Arc::new(RwLock::new(None)),
145            shutdown_tx: Arc::new(RwLock::new(None)),
146            identity_provider: Arc::new(RwLock::new(None)),
147        }
148    }
149
150    /// Initialize the reaction with runtime context.
151    ///
152    /// This method is called automatically by DrasiLib's `add_reaction()` method.
153    /// Plugin developers do not need to call this directly.
154    ///
155    /// The context provides access to:
156    /// - `reaction_id`: The reaction's unique identifier
157    /// - `status_tx`: Channel for reporting component status events
158    /// - `state_store`: Optional persistent state storage
159    pub async fn initialize(&self, context: ReactionRuntimeContext) {
160        // Store context for later use
161        *self.context.write().await = Some(context.clone());
162
163        // Extract services for convenience
164        *self.status_tx.write().await = Some(context.status_tx.clone());
165
166        if let Some(state_store) = context.state_store.as_ref() {
167            *self.state_store.write().await = Some(state_store.clone());
168        }
169
170        // Store identity provider from context if not already set programmatically
171        if let Some(ip) = context.identity_provider.as_ref() {
172            let mut guard = self.identity_provider.write().await;
173            if guard.is_none() {
174                *guard = Some(ip.clone());
175            }
176        }
177    }
178
179    /// Get the runtime context if initialized.
180    ///
181    /// Returns `None` if `initialize()` has not been called yet.
182    pub async fn context(&self) -> Option<ReactionRuntimeContext> {
183        self.context.read().await.clone()
184    }
185
186    /// Get the state store if configured.
187    ///
188    /// Returns `None` if no state store was provided in the context.
189    pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
190        self.state_store.read().await.clone()
191    }
192
193    /// Get the identity provider if set.
194    ///
195    /// Returns the identity provider set either programmatically via
196    /// `set_identity_provider()` or from the runtime context during `initialize()`.
197    /// Programmatically-set providers take precedence over context providers.
198    pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
199        self.identity_provider.read().await.clone()
200    }
201
202    /// Set the identity provider programmatically.
203    ///
204    /// This is typically called during reaction construction when the provider
205    /// is available from configuration (e.g., `with_identity_provider()` builder).
206    /// Providers set this way take precedence over context-injected providers.
207    pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
208        *self.identity_provider.write().await = Some(provider);
209    }
210
211    /// Get whether this reaction should auto-start
212    pub fn get_auto_start(&self) -> bool {
213        self.auto_start
214    }
215
216    /// Get the status channel Arc for internal use by spawned tasks
217    ///
218    /// This returns the internal status_tx wrapped in Arc<RwLock<Option<...>>>
219    /// which allows background tasks to send component status events.
220    ///
221    /// Returns a clone of the Arc that can be moved into spawned tasks.
222    pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
223        self.status_tx.clone()
224    }
225
226    /// Clone the ReactionBase with shared Arc references
227    ///
228    /// This creates a new ReactionBase that shares the same underlying
229    /// data through Arc references. Useful for passing to spawned tasks.
230    pub fn clone_shared(&self) -> Self {
231        Self {
232            id: self.id.clone(),
233            queries: self.queries.clone(),
234            auto_start: self.auto_start,
235            status: self.status.clone(),
236            context: self.context.clone(),
237            status_tx: self.status_tx.clone(),
238            state_store: self.state_store.clone(),
239            priority_queue: self.priority_queue.clone(),
240            subscription_tasks: self.subscription_tasks.clone(),
241            processing_task: self.processing_task.clone(),
242            shutdown_tx: self.shutdown_tx.clone(),
243            identity_provider: self.identity_provider.clone(),
244        }
245    }
246
247    /// Create a shutdown channel and store the sender
248    ///
249    /// Returns the receiver which should be passed to the processing task.
250    /// The sender is stored internally and will be triggered by `stop_common()`.
251    ///
252    /// This should be called before spawning the processing task.
253    pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
254        let (tx, rx) = tokio::sync::oneshot::channel();
255        *self.shutdown_tx.write().await = Some(tx);
256        rx
257    }
258
259    /// Get the reaction ID
260    pub fn get_id(&self) -> &str {
261        &self.id
262    }
263
264    /// Get the query IDs
265    pub fn get_queries(&self) -> &[String] {
266        &self.queries
267    }
268
269    /// Get current status
270    pub async fn get_status(&self) -> ComponentStatus {
271        self.status.read().await.clone()
272    }
273
274    /// Send a component lifecycle event
275    ///
276    /// If the event channel has not been injected yet, this method silently
277    /// succeeds without sending anything. This allows reactions to be used
278    /// in a standalone fashion without DrasiLib if needed.
279    pub async fn send_component_event(
280        &self,
281        status: ComponentStatus,
282        message: Option<String>,
283    ) -> Result<()> {
284        let event = ComponentEvent {
285            component_id: self.id.clone(),
286            component_type: ComponentType::Reaction,
287            status,
288            timestamp: chrono::Utc::now(),
289            message,
290        };
291
292        if let Some(ref tx) = *self.status_tx.read().await {
293            if let Err(e) = tx.send(event).await {
294                error!("Failed to send component event: {e}");
295            }
296        }
297        // If status_tx is None, silently skip - initialization happens before start()
298        Ok(())
299    }
300
301    /// Transition to a new status and send event
302    pub async fn set_status_with_event(
303        &self,
304        status: ComponentStatus,
305        message: Option<String>,
306    ) -> Result<()> {
307        *self.status.write().await = status.clone();
308        self.send_component_event(status, message).await
309    }
310
311    /// Enqueue a query result for processing.
312    ///
313    /// The host calls this to forward query results to the reaction's priority queue.
314    /// Results are processed in timestamp order by the reaction's processing task.
315    pub async fn enqueue_query_result(&self, result: QueryResult) -> anyhow::Result<()> {
316        self.priority_queue.enqueue_wait(Arc::new(result)).await;
317        Ok(())
318    }
319
320    /// Perform common cleanup operations
321    ///
322    /// This method handles:
323    /// 1. Sending shutdown signal to processing task (for graceful termination)
324    /// 2. Aborting all subscription forwarder tasks
325    /// 3. Waiting for or aborting the processing task
326    /// 4. Draining the priority queue
327    pub async fn stop_common(&self) -> Result<()> {
328        info!("Stopping reaction: {}", self.id);
329
330        // Send shutdown signal to processing task (if it's using tokio::select!)
331        if let Some(tx) = self.shutdown_tx.write().await.take() {
332            let _ = tx.send(());
333        }
334
335        // Abort all subscription forwarder tasks
336        let mut subscription_tasks = self.subscription_tasks.write().await;
337        for task in subscription_tasks.drain(..) {
338            task.abort();
339        }
340        drop(subscription_tasks);
341
342        // Wait for the processing task to complete (with timeout), or abort it
343        let mut processing_task = self.processing_task.write().await;
344        if let Some(task) = processing_task.take() {
345            // Give the task a short time to respond to the shutdown signal
346            match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
347                Ok(Ok(())) => {
348                    debug!("[{}] Processing task completed gracefully", self.id);
349                }
350                Ok(Err(e)) => {
351                    // Task was aborted or panicked
352                    debug!("[{}] Processing task ended: {}", self.id, e);
353                }
354                Err(_) => {
355                    // Timeout - task didn't respond to shutdown signal
356                    // This shouldn't happen if the task is using tokio::select! correctly
357                    warn!(
358                        "[{}] Processing task did not respond to shutdown signal within timeout",
359                        self.id
360                    );
361                }
362            }
363        }
364        drop(processing_task);
365
366        // Drain the priority queue
367        let drained_events = self.priority_queue.drain().await;
368        if !drained_events.is_empty() {
369            info!(
370                "[{}] Drained {} pending events from priority queue",
371                self.id,
372                drained_events.len()
373            );
374        }
375
376        *self.status.write().await = ComponentStatus::Stopped;
377        info!("Reaction '{}' stopped", self.id);
378
379        Ok(())
380    }
381
382    /// Clear the reaction's state store partition.
383    ///
384    /// This is called during deprovision to remove all persisted state
385    /// associated with this reaction. Reactions that override `deprovision()`
386    /// can call this to clean up their state store.
387    pub async fn deprovision_common(&self) -> Result<()> {
388        info!("Deprovisioning reaction '{}'", self.id);
389        if let Some(store) = self.state_store().await {
390            let count = store.clear_store(&self.id).await.map_err(|e| {
391                anyhow::anyhow!(
392                    "Failed to clear state store for reaction '{}': {}",
393                    self.id,
394                    e
395                )
396            })?;
397            info!(
398                "Cleared {} keys from state store for reaction '{}'",
399                count, self.id
400            );
401        }
402        Ok(())
403    }
404
405    /// Set the processing task handle
406    pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
407        *self.processing_task.write().await = Some(task);
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use std::sync::atomic::{AtomicBool, Ordering};
415    use std::time::Duration;
416    use tokio::sync::mpsc;
417
418    #[tokio::test]
419    async fn test_reaction_base_creation() {
420        let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
421            .with_priority_queue_capacity(5000);
422
423        let base = ReactionBase::new(params);
424        assert_eq!(base.id, "test-reaction");
425        assert_eq!(base.get_status().await, ComponentStatus::Stopped);
426    }
427
428    #[tokio::test]
429    async fn test_status_transitions() {
430        use crate::context::ReactionRuntimeContext;
431
432        let (status_tx, mut event_rx) = mpsc::channel(100);
433        let params = ReactionBaseParams::new("test-reaction", vec![]);
434
435        let base = ReactionBase::new(params);
436
437        // Create context and initialize
438        let context =
439            ReactionRuntimeContext::new("test-instance", "test-reaction", status_tx, None);
440        base.initialize(context).await;
441
442        // Test status transition
443        base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
444            .await
445            .unwrap();
446
447        assert_eq!(base.get_status().await, ComponentStatus::Starting);
448
449        // Check event was sent
450        let event = event_rx.try_recv().unwrap();
451        assert_eq!(event.status, ComponentStatus::Starting);
452        assert_eq!(event.message, Some("Starting test".to_string()));
453    }
454
455    #[tokio::test]
456    async fn test_priority_queue_operations() {
457        let params =
458            ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
459
460        let base = ReactionBase::new(params);
461
462        // Create a test query result
463        let query_result = QueryResult::new(
464            "test-query".to_string(),
465            chrono::Utc::now(),
466            vec![],
467            Default::default(),
468        );
469
470        // Enqueue result
471        let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
472        assert!(enqueued);
473
474        // Drain queue
475        let drained = base.priority_queue.drain().await;
476        assert_eq!(drained.len(), 1);
477    }
478
479    #[tokio::test]
480    async fn test_event_without_initialization() {
481        // Test that send_component_event works even without context initialization
482        let params = ReactionBaseParams::new("test-reaction", vec![]);
483
484        let base = ReactionBase::new(params);
485
486        // This should succeed without panicking (silently does nothing when status_tx is None)
487        base.send_component_event(ComponentStatus::Starting, None)
488            .await
489            .unwrap();
490    }
491
492    // =============================================================================
493    // Shutdown Channel Tests
494    // =============================================================================
495
496    #[tokio::test]
497    async fn test_create_shutdown_channel() {
498        let params = ReactionBaseParams::new("test-reaction", vec![]);
499        let base = ReactionBase::new(params);
500
501        // Initially no shutdown_tx
502        assert!(base.shutdown_tx.read().await.is_none());
503
504        // Create channel
505        let rx = base.create_shutdown_channel().await;
506
507        // Verify tx is stored
508        assert!(base.shutdown_tx.read().await.is_some());
509
510        // Verify receiver is valid (dropping it should not panic)
511        drop(rx);
512    }
513
514    #[tokio::test]
515    async fn test_shutdown_channel_signal() {
516        let params = ReactionBaseParams::new("test-reaction", vec![]);
517        let base = ReactionBase::new(params);
518
519        let mut rx = base.create_shutdown_channel().await;
520
521        // Send signal
522        if let Some(tx) = base.shutdown_tx.write().await.take() {
523            tx.send(()).unwrap();
524        }
525
526        // Verify signal received
527        let result = rx.try_recv();
528        assert!(result.is_ok());
529    }
530
531    #[tokio::test]
532    async fn test_shutdown_channel_replaced_on_second_create() {
533        let params = ReactionBaseParams::new("test-reaction", vec![]);
534        let base = ReactionBase::new(params);
535
536        // Create first channel
537        let _rx1 = base.create_shutdown_channel().await;
538
539        // Create second channel (should replace the first)
540        let mut rx2 = base.create_shutdown_channel().await;
541
542        // Send signal - should go to second channel
543        if let Some(tx) = base.shutdown_tx.write().await.take() {
544            tx.send(()).unwrap();
545        }
546
547        // Second receiver should get the signal
548        let result = rx2.try_recv();
549        assert!(result.is_ok());
550    }
551
552    #[tokio::test]
553    async fn test_stop_common_sends_shutdown_signal() {
554        let params = ReactionBaseParams::new("test-reaction", vec![]);
555        let base = ReactionBase::new(params);
556
557        let mut rx = base.create_shutdown_channel().await;
558
559        // Spawn a task that waits for shutdown
560        let shutdown_received = Arc::new(AtomicBool::new(false));
561        let shutdown_flag = shutdown_received.clone();
562
563        let task = tokio::spawn(async move {
564            tokio::select! {
565                _ = &mut rx => {
566                    shutdown_flag.store(true, Ordering::SeqCst);
567                }
568            }
569        });
570
571        base.set_processing_task(task).await;
572
573        // Call stop_common - should send shutdown signal
574        let _ = base.stop_common().await;
575
576        // Give task time to process
577        tokio::time::sleep(Duration::from_millis(50)).await;
578
579        assert!(
580            shutdown_received.load(Ordering::SeqCst),
581            "Processing task should have received shutdown signal"
582        );
583    }
584
585    #[tokio::test]
586    async fn test_graceful_shutdown_timing() {
587        let params = ReactionBaseParams::new("test-reaction", vec![]);
588        let base = ReactionBase::new(params);
589
590        let rx = base.create_shutdown_channel().await;
591
592        // Spawn task that uses select! pattern like real reactions
593        let task = tokio::spawn(async move {
594            let mut shutdown_rx = rx;
595            loop {
596                tokio::select! {
597                    biased;
598                    _ = &mut shutdown_rx => {
599                        break;
600                    }
601                    _ = tokio::time::sleep(Duration::from_secs(10)) => {
602                        // Simulates waiting on priority_queue.dequeue()
603                    }
604                }
605            }
606        });
607
608        base.set_processing_task(task).await;
609
610        // Measure shutdown time
611        let start = std::time::Instant::now();
612        let _ = base.stop_common().await;
613        let elapsed = start.elapsed();
614
615        // Should complete quickly (< 500ms), not hit 2s timeout
616        assert!(
617            elapsed < Duration::from_millis(500),
618            "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
619        );
620    }
621
622    #[tokio::test]
623    async fn test_stop_common_without_shutdown_channel() {
624        // Test that stop_common works even if no shutdown channel was created
625        let params = ReactionBaseParams::new("test-reaction", vec![]);
626        let base = ReactionBase::new(params);
627
628        // Don't create shutdown channel - just spawn a short-lived task
629        let task = tokio::spawn(async {
630            tokio::time::sleep(Duration::from_millis(10)).await;
631        });
632
633        base.set_processing_task(task).await;
634
635        // stop_common should still work
636        let result = base.stop_common().await;
637        assert!(result.is_ok());
638    }
639}