drasi_lib/reactions/common/
base.rs1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36
37use crate::channels::priority_queue::PriorityQueue;
38use crate::channels::{
39 ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
40};
41use crate::context::ReactionRuntimeContext;
42use crate::reactions::QueryProvider;
43use crate::state_store::StateStoreProvider;
44
45#[derive(Debug, Clone)]
62pub struct ReactionBaseParams {
63 pub id: String,
65 pub queries: Vec<String>,
67 pub priority_queue_capacity: Option<usize>,
69 pub auto_start: bool,
71}
72
73impl ReactionBaseParams {
74 pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
76 Self {
77 id: id.into(),
78 queries,
79 priority_queue_capacity: None,
80 auto_start: true, }
82 }
83
84 pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
86 self.priority_queue_capacity = Some(capacity);
87 self
88 }
89
90 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
92 self.auto_start = auto_start;
93 self
94 }
95}
96
97pub struct ReactionBase {
99 pub id: String,
101 pub queries: Vec<String>,
103 pub auto_start: bool,
105 pub status: Arc<RwLock<ComponentStatus>>,
107 context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
109 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
111 query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
113 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
115 pub priority_queue: PriorityQueue<QueryResult>,
117 pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
119 pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
121 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
123}
124
125impl ReactionBase {
126 pub fn new(params: ReactionBaseParams) -> Self {
131 Self {
132 priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
133 id: params.id,
134 queries: params.queries,
135 auto_start: params.auto_start,
136 status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
137 context: Arc::new(RwLock::new(None)), status_tx: Arc::new(RwLock::new(None)), query_provider: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), subscription_tasks: Arc::new(RwLock::new(Vec::new())),
142 processing_task: Arc::new(RwLock::new(None)),
143 shutdown_tx: Arc::new(RwLock::new(None)),
144 }
145 }
146
147 pub async fn initialize(&self, context: ReactionRuntimeContext) {
158 *self.context.write().await = Some(context.clone());
160
161 *self.status_tx.write().await = Some(context.status_tx.clone());
163 *self.query_provider.write().await = Some(context.query_provider.clone());
164
165 if let Some(state_store) = context.state_store.as_ref() {
166 *self.state_store.write().await = Some(state_store.clone());
167 }
168 }
169
170 pub async fn context(&self) -> Option<ReactionRuntimeContext> {
174 self.context.read().await.clone()
175 }
176
177 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
181 self.state_store.read().await.clone()
182 }
183
184 pub fn get_auto_start(&self) -> bool {
186 self.auto_start
187 }
188
189 pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
196 self.status_tx.clone()
197 }
198
199 pub fn clone_shared(&self) -> Self {
204 Self {
205 id: self.id.clone(),
206 queries: self.queries.clone(),
207 auto_start: self.auto_start,
208 status: self.status.clone(),
209 context: self.context.clone(),
210 status_tx: self.status_tx.clone(),
211 query_provider: self.query_provider.clone(),
212 state_store: self.state_store.clone(),
213 priority_queue: self.priority_queue.clone(),
214 subscription_tasks: self.subscription_tasks.clone(),
215 processing_task: self.processing_task.clone(),
216 shutdown_tx: self.shutdown_tx.clone(),
217 }
218 }
219
220 pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
227 let (tx, rx) = tokio::sync::oneshot::channel();
228 *self.shutdown_tx.write().await = Some(tx);
229 rx
230 }
231
232 pub fn get_id(&self) -> &str {
234 &self.id
235 }
236
237 pub fn get_queries(&self) -> &[String] {
239 &self.queries
240 }
241
242 pub async fn get_status(&self) -> ComponentStatus {
244 self.status.read().await.clone()
245 }
246
247 pub async fn send_component_event(
253 &self,
254 status: ComponentStatus,
255 message: Option<String>,
256 ) -> Result<()> {
257 let event = ComponentEvent {
258 component_id: self.id.clone(),
259 component_type: ComponentType::Reaction,
260 status,
261 timestamp: chrono::Utc::now(),
262 message,
263 };
264
265 if let Some(ref tx) = *self.status_tx.read().await {
266 if let Err(e) = tx.send(event).await {
267 error!("Failed to send component event: {e}");
268 }
269 }
270 Ok(())
272 }
273
274 pub async fn set_status_with_event(
276 &self,
277 status: ComponentStatus,
278 message: Option<String>,
279 ) -> Result<()> {
280 *self.status.write().await = status.clone();
281 self.send_component_event(status, message).await
282 }
283
284 pub async fn subscribe_to_queries(&self) -> Result<()> {
298 let query_provider = {
300 let qp_guard = self.query_provider.read().await;
301 qp_guard.as_ref().cloned().ok_or_else(|| {
302 anyhow::anyhow!(
303 "QueryProvider not injected - was reaction '{}' added to DrasiLib?",
304 self.id
305 )
306 })?
307 };
308
309 for query_id in &self.queries {
311 let query = query_provider.get_query_instance(query_id).await?;
313
314 let subscription_response = query
316 .subscribe(self.id.clone())
317 .await
318 .map_err(|e| anyhow::anyhow!(e))?;
319 let mut receiver = subscription_response.receiver;
320
321 let priority_queue = self.priority_queue.clone();
323 let query_id_clone = query_id.clone();
324 let reaction_id = self.id.clone();
325
326 let query_config = query.get_config();
328 let dispatch_mode = query_config
329 .dispatch_mode
330 .unwrap_or(crate::channels::DispatchMode::Channel);
331 let use_blocking_enqueue =
332 matches!(dispatch_mode, crate::channels::DispatchMode::Channel);
333
334 let forwarder_task = tokio::spawn(async move {
336 debug!(
337 "[{reaction_id}] Started result forwarder for query '{query_id_clone}' (dispatch_mode: {dispatch_mode:?}, blocking_enqueue: {use_blocking_enqueue})"
338 );
339
340 loop {
341 match receiver.recv().await {
342 Ok(query_result) => {
343 if use_blocking_enqueue {
345 priority_queue.enqueue_wait(query_result).await;
348 } else {
349 if !priority_queue.enqueue(query_result).await {
352 warn!(
353 "[{reaction_id}] Failed to enqueue result from query '{query_id_clone}' - priority queue at capacity (broadcast mode)"
354 );
355 }
356 }
357 }
358 Err(e) => {
359 let error_str = e.to_string();
361 if error_str.contains("lagged") {
362 warn!(
363 "[{reaction_id}] Receiver lagged for query '{query_id_clone}': {error_str}"
364 );
365 continue;
366 } else {
367 info!(
368 "[{reaction_id}] Receiver error for query '{query_id_clone}': {error_str}"
369 );
370 break;
371 }
372 }
373 }
374 }
375 });
376
377 self.subscription_tasks.write().await.push(forwarder_task);
379 }
380
381 Ok(())
382 }
383
384 pub async fn stop_common(&self) -> Result<()> {
392 info!("Stopping reaction: {}", self.id);
393
394 if let Some(tx) = self.shutdown_tx.write().await.take() {
396 let _ = tx.send(());
397 }
398
399 let mut subscription_tasks = self.subscription_tasks.write().await;
401 for task in subscription_tasks.drain(..) {
402 task.abort();
403 }
404 drop(subscription_tasks);
405
406 let mut processing_task = self.processing_task.write().await;
408 if let Some(task) = processing_task.take() {
409 match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
411 Ok(Ok(())) => {
412 debug!("[{}] Processing task completed gracefully", self.id);
413 }
414 Ok(Err(e)) => {
415 debug!("[{}] Processing task ended: {}", self.id, e);
417 }
418 Err(_) => {
419 warn!(
422 "[{}] Processing task did not respond to shutdown signal within timeout",
423 self.id
424 );
425 }
426 }
427 }
428 drop(processing_task);
429
430 let drained_events = self.priority_queue.drain().await;
432 if !drained_events.is_empty() {
433 info!(
434 "[{}] Drained {} pending events from priority queue",
435 self.id,
436 drained_events.len()
437 );
438 }
439
440 Ok(())
441 }
442
443 pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
445 *self.processing_task.write().await = Some(task);
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use std::sync::atomic::{AtomicBool, Ordering};
453 use std::time::Duration;
454 use tokio::sync::mpsc;
455
456 #[tokio::test]
457 async fn test_reaction_base_creation() {
458 let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
459 .with_priority_queue_capacity(5000);
460
461 let base = ReactionBase::new(params);
462 assert_eq!(base.id, "test-reaction");
463 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
464 }
465
466 #[tokio::test]
467 async fn test_status_transitions() {
468 use crate::context::ReactionRuntimeContext;
469 use crate::queries::Query;
470
471 struct MockQueryProvider;
473
474 #[async_trait::async_trait]
475 impl crate::reactions::QueryProvider for MockQueryProvider {
476 async fn get_query_instance(
477 &self,
478 _id: &str,
479 ) -> anyhow::Result<std::sync::Arc<dyn Query>> {
480 Err(anyhow::anyhow!("MockQueryProvider: query not found"))
481 }
482 }
483
484 let (status_tx, mut event_rx) = mpsc::channel(100);
485 let params = ReactionBaseParams::new("test-reaction", vec![]);
486
487 let base = ReactionBase::new(params);
488
489 let context = ReactionRuntimeContext::new(
491 "test-reaction",
492 status_tx,
493 None,
494 std::sync::Arc::new(MockQueryProvider),
495 );
496 base.initialize(context).await;
497
498 base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
500 .await
501 .unwrap();
502
503 assert_eq!(base.get_status().await, ComponentStatus::Starting);
504
505 let event = event_rx.try_recv().unwrap();
507 assert_eq!(event.status, ComponentStatus::Starting);
508 assert_eq!(event.message, Some("Starting test".to_string()));
509 }
510
511 #[tokio::test]
512 async fn test_priority_queue_operations() {
513 let params =
514 ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
515
516 let base = ReactionBase::new(params);
517
518 let query_result = QueryResult::new(
520 "test-query".to_string(),
521 chrono::Utc::now(),
522 vec![],
523 Default::default(),
524 );
525
526 let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
528 assert!(enqueued);
529
530 let drained = base.priority_queue.drain().await;
532 assert_eq!(drained.len(), 1);
533 }
534
535 #[tokio::test]
536 async fn test_event_without_initialization() {
537 let params = ReactionBaseParams::new("test-reaction", vec![]);
539
540 let base = ReactionBase::new(params);
541
542 base.send_component_event(ComponentStatus::Starting, None)
544 .await
545 .unwrap();
546 }
547
548 #[tokio::test]
553 async fn test_create_shutdown_channel() {
554 let params = ReactionBaseParams::new("test-reaction", vec![]);
555 let base = ReactionBase::new(params);
556
557 assert!(base.shutdown_tx.read().await.is_none());
559
560 let rx = base.create_shutdown_channel().await;
562
563 assert!(base.shutdown_tx.read().await.is_some());
565
566 drop(rx);
568 }
569
570 #[tokio::test]
571 async fn test_shutdown_channel_signal() {
572 let params = ReactionBaseParams::new("test-reaction", vec![]);
573 let base = ReactionBase::new(params);
574
575 let mut rx = base.create_shutdown_channel().await;
576
577 if let Some(tx) = base.shutdown_tx.write().await.take() {
579 tx.send(()).unwrap();
580 }
581
582 let result = rx.try_recv();
584 assert!(result.is_ok());
585 }
586
587 #[tokio::test]
588 async fn test_shutdown_channel_replaced_on_second_create() {
589 let params = ReactionBaseParams::new("test-reaction", vec![]);
590 let base = ReactionBase::new(params);
591
592 let _rx1 = base.create_shutdown_channel().await;
594
595 let mut rx2 = base.create_shutdown_channel().await;
597
598 if let Some(tx) = base.shutdown_tx.write().await.take() {
600 tx.send(()).unwrap();
601 }
602
603 let result = rx2.try_recv();
605 assert!(result.is_ok());
606 }
607
608 #[tokio::test]
609 async fn test_stop_common_sends_shutdown_signal() {
610 let params = ReactionBaseParams::new("test-reaction", vec![]);
611 let base = ReactionBase::new(params);
612
613 let mut rx = base.create_shutdown_channel().await;
614
615 let shutdown_received = Arc::new(AtomicBool::new(false));
617 let shutdown_flag = shutdown_received.clone();
618
619 let task = tokio::spawn(async move {
620 tokio::select! {
621 _ = &mut rx => {
622 shutdown_flag.store(true, Ordering::SeqCst);
623 }
624 }
625 });
626
627 base.set_processing_task(task).await;
628
629 let _ = base.stop_common().await;
631
632 tokio::time::sleep(Duration::from_millis(50)).await;
634
635 assert!(
636 shutdown_received.load(Ordering::SeqCst),
637 "Processing task should have received shutdown signal"
638 );
639 }
640
641 #[tokio::test]
642 async fn test_graceful_shutdown_timing() {
643 let params = ReactionBaseParams::new("test-reaction", vec![]);
644 let base = ReactionBase::new(params);
645
646 let rx = base.create_shutdown_channel().await;
647
648 let task = tokio::spawn(async move {
650 let mut shutdown_rx = rx;
651 loop {
652 tokio::select! {
653 biased;
654 _ = &mut shutdown_rx => {
655 break;
656 }
657 _ = tokio::time::sleep(Duration::from_secs(10)) => {
658 }
660 }
661 }
662 });
663
664 base.set_processing_task(task).await;
665
666 let start = std::time::Instant::now();
668 let _ = base.stop_common().await;
669 let elapsed = start.elapsed();
670
671 assert!(
673 elapsed < Duration::from_millis(500),
674 "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
675 );
676 }
677
678 #[tokio::test]
679 async fn test_stop_common_without_shutdown_channel() {
680 let params = ReactionBaseParams::new("test-reaction", vec![]);
682 let base = ReactionBase::new(params);
683
684 let task = tokio::spawn(async {
686 tokio::time::sleep(Duration::from_millis(10)).await;
687 });
688
689 base.set_processing_task(task).await;
690
691 let result = base.stop_common().await;
693 assert!(result.is_ok());
694 }
695}