drasi_lib/reactions/common/
base.rs1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{ComponentStatus, QueryResult};
40use crate::component_graph::ComponentStatusHandle;
41use crate::context::ReactionRuntimeContext;
42use crate::identity::IdentityProvider;
43use crate::state_store::StateStoreProvider;
44
45#[derive(Debug, Clone)]
62pub struct ReactionBaseParams {
63 pub id: String,
65 pub queries: Vec<String>,
67 pub priority_queue_capacity: Option<usize>,
69 pub auto_start: bool,
71}
72
73impl ReactionBaseParams {
74 pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
76 Self {
77 id: id.into(),
78 queries,
79 priority_queue_capacity: None,
80 auto_start: true, }
82 }
83
84 pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
86 self.priority_queue_capacity = Some(capacity);
87 self
88 }
89
90 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
92 self.auto_start = auto_start;
93 self
94 }
95}
96
97pub struct ReactionBase {
99 pub id: String,
101 pub queries: Vec<String>,
103 pub auto_start: bool,
105 status_handle: ComponentStatusHandle,
107 context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
109 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
111 pub priority_queue: PriorityQueue<QueryResult>,
113 pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
115 pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
117 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
119 identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
123}
124
125impl ReactionBase {
126 pub fn new(params: ReactionBaseParams) -> Self {
131 Self {
132 priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
133 id: params.id.clone(),
134 queries: params.queries,
135 auto_start: params.auto_start,
136 status_handle: ComponentStatusHandle::new(¶ms.id),
137 context: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), subscription_tasks: Arc::new(RwLock::new(Vec::new())),
140 processing_task: Arc::new(RwLock::new(None)),
141 shutdown_tx: Arc::new(RwLock::new(None)),
142 identity_provider: Arc::new(RwLock::new(None)),
143 }
144 }
145
146 pub async fn initialize(&self, context: ReactionRuntimeContext) {
156 *self.context.write().await = Some(context.clone());
158
159 self.status_handle.wire(context.update_tx.clone()).await;
161
162 if let Some(state_store) = context.state_store.as_ref() {
163 *self.state_store.write().await = Some(state_store.clone());
164 }
165
166 if let Some(ip) = context.identity_provider.as_ref() {
168 let mut guard = self.identity_provider.write().await;
169 if guard.is_none() {
170 *guard = Some(ip.clone());
171 }
172 }
173 }
174
175 pub async fn context(&self) -> Option<ReactionRuntimeContext> {
179 self.context.read().await.clone()
180 }
181
182 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
186 self.state_store.read().await.clone()
187 }
188
189 pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
195 self.identity_provider.read().await.clone()
196 }
197
198 pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
204 *self.identity_provider.write().await = Some(provider);
205 }
206
207 pub fn get_auto_start(&self) -> bool {
209 self.auto_start
210 }
211
212 pub fn clone_shared(&self) -> Self {
217 Self {
218 id: self.id.clone(),
219 queries: self.queries.clone(),
220 auto_start: self.auto_start,
221 status_handle: self.status_handle.clone(),
222 context: self.context.clone(),
223 state_store: self.state_store.clone(),
224 priority_queue: self.priority_queue.clone(),
225 subscription_tasks: self.subscription_tasks.clone(),
226 processing_task: self.processing_task.clone(),
227 shutdown_tx: self.shutdown_tx.clone(),
228 identity_provider: self.identity_provider.clone(),
229 }
230 }
231
232 pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
239 let (tx, rx) = tokio::sync::oneshot::channel();
240 *self.shutdown_tx.write().await = Some(tx);
241 rx
242 }
243
244 pub fn get_id(&self) -> &str {
246 &self.id
247 }
248
249 pub fn get_queries(&self) -> &[String] {
251 &self.queries
252 }
253
254 pub async fn get_status(&self) -> ComponentStatus {
256 self.status_handle.get_status().await
257 }
258
259 pub fn status_handle(&self) -> ComponentStatusHandle {
264 self.status_handle.clone()
265 }
266
267 pub async fn set_status(&self, status: ComponentStatus, message: Option<String>) {
271 self.status_handle.set_status(status, message).await;
272 }
273
274 pub async fn enqueue_query_result(&self, result: QueryResult) -> anyhow::Result<()> {
279 self.priority_queue.enqueue_wait(Arc::new(result)).await;
280 Ok(())
281 }
282
283 pub async fn stop_common(&self) -> Result<()> {
291 info!("Stopping reaction: {}", self.id);
292
293 if let Some(tx) = self.shutdown_tx.write().await.take() {
295 let _ = tx.send(());
296 }
297
298 let mut subscription_tasks = self.subscription_tasks.write().await;
300 for task in subscription_tasks.drain(..) {
301 task.abort();
302 }
303 drop(subscription_tasks);
304
305 let mut processing_task = self.processing_task.write().await;
307 if let Some(mut task) = processing_task.take() {
308 match tokio::time::timeout(std::time::Duration::from_secs(2), &mut task).await {
310 Ok(Ok(())) => {
311 debug!("[{}] Processing task completed gracefully", self.id);
312 }
313 Ok(Err(e)) => {
314 debug!("[{}] Processing task ended: {}", self.id, e);
316 }
317 Err(_) => {
318 warn!(
320 "[{}] Processing task did not respond to shutdown signal within timeout, aborting",
321 self.id
322 );
323 task.abort();
324 }
325 }
326 }
327 drop(processing_task);
328
329 let drained_events = self.priority_queue.drain().await;
331 if !drained_events.is_empty() {
332 info!(
333 "[{}] Drained {} pending events from priority queue",
334 self.id,
335 drained_events.len()
336 );
337 }
338
339 self.set_status(
340 ComponentStatus::Stopped,
341 Some(format!("Reaction '{}' stopped", self.id)),
342 )
343 .await;
344 info!("Reaction '{}' stopped", self.id);
345
346 Ok(())
347 }
348
349 pub async fn deprovision_common(&self) -> Result<()> {
355 info!("Deprovisioning reaction '{}'", self.id);
356 if let Some(store) = self.state_store().await {
357 let count = store.clear_store(&self.id).await.map_err(|e| {
358 anyhow::anyhow!(
359 "Failed to clear state store for reaction '{}': {}",
360 self.id,
361 e
362 )
363 })?;
364 info!(
365 "Cleared {} keys from state store for reaction '{}'",
366 count, self.id
367 );
368 }
369 Ok(())
370 }
371
372 pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
374 *self.processing_task.write().await = Some(task);
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use std::sync::atomic::{AtomicBool, Ordering};
382 use std::time::Duration;
383 use tokio::sync::mpsc;
384
385 #[tokio::test]
386 async fn test_reaction_base_creation() {
387 let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
388 .with_priority_queue_capacity(5000);
389
390 let base = ReactionBase::new(params);
391 assert_eq!(base.id, "test-reaction");
392 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
393 }
394
395 #[tokio::test]
396 async fn test_status_transitions() {
397 use crate::context::ReactionRuntimeContext;
398
399 let (graph, _rx) = crate::component_graph::ComponentGraph::new("test-instance");
400 let update_tx = graph.update_sender();
401 let graph = Arc::new(RwLock::new(graph));
402 let params = ReactionBaseParams::new("test-reaction", vec![]);
403
404 let base = ReactionBase::new(params);
405
406 let context =
408 ReactionRuntimeContext::new("test-instance", "test-reaction", None, update_tx, None);
409 base.initialize(context).await;
410
411 base.set_status(ComponentStatus::Starting, Some("Starting test".to_string()))
413 .await;
414
415 assert_eq!(base.get_status().await, ComponentStatus::Starting);
416
417 let mut event_rx = graph.read().await.subscribe();
419 base.set_status(ComponentStatus::Running, Some("Running test".to_string()))
421 .await;
422
423 assert_eq!(base.get_status().await, ComponentStatus::Running);
424 }
425
426 #[tokio::test]
427 async fn test_priority_queue_operations() {
428 let params =
429 ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
430
431 let base = ReactionBase::new(params);
432
433 let query_result = QueryResult::new(
435 "test-query".to_string(),
436 chrono::Utc::now(),
437 vec![],
438 Default::default(),
439 );
440
441 let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
443 assert!(enqueued);
444
445 let drained = base.priority_queue.drain().await;
447 assert_eq!(drained.len(), 1);
448 }
449
450 #[tokio::test]
451 async fn test_event_without_initialization() {
452 let params = ReactionBaseParams::new("test-reaction", vec![]);
454
455 let base = ReactionBase::new(params);
456
457 base.set_status(ComponentStatus::Starting, None).await;
459 }
460
461 #[tokio::test]
466 async fn test_create_shutdown_channel() {
467 let params = ReactionBaseParams::new("test-reaction", vec![]);
468 let base = ReactionBase::new(params);
469
470 assert!(base.shutdown_tx.read().await.is_none());
472
473 let rx = base.create_shutdown_channel().await;
475
476 assert!(base.shutdown_tx.read().await.is_some());
478
479 drop(rx);
481 }
482
483 #[tokio::test]
484 async fn test_shutdown_channel_signal() {
485 let params = ReactionBaseParams::new("test-reaction", vec![]);
486 let base = ReactionBase::new(params);
487
488 let mut rx = base.create_shutdown_channel().await;
489
490 if let Some(tx) = base.shutdown_tx.write().await.take() {
492 tx.send(()).unwrap();
493 }
494
495 let result = rx.try_recv();
497 assert!(result.is_ok());
498 }
499
500 #[tokio::test]
501 async fn test_shutdown_channel_replaced_on_second_create() {
502 let params = ReactionBaseParams::new("test-reaction", vec![]);
503 let base = ReactionBase::new(params);
504
505 let _rx1 = base.create_shutdown_channel().await;
507
508 let mut rx2 = base.create_shutdown_channel().await;
510
511 if let Some(tx) = base.shutdown_tx.write().await.take() {
513 tx.send(()).unwrap();
514 }
515
516 let result = rx2.try_recv();
518 assert!(result.is_ok());
519 }
520
521 #[tokio::test]
522 async fn test_stop_common_sends_shutdown_signal() {
523 let params = ReactionBaseParams::new("test-reaction", vec![]);
524 let base = ReactionBase::new(params);
525
526 let mut rx = base.create_shutdown_channel().await;
527
528 let shutdown_received = Arc::new(AtomicBool::new(false));
530 let shutdown_flag = shutdown_received.clone();
531
532 let task = tokio::spawn(async move {
533 tokio::select! {
534 _ = &mut rx => {
535 shutdown_flag.store(true, Ordering::SeqCst);
536 }
537 }
538 });
539
540 base.set_processing_task(task).await;
541
542 let _ = base.stop_common().await;
544
545 assert!(
547 shutdown_received.load(Ordering::SeqCst),
548 "Processing task should have received shutdown signal"
549 );
550 }
551
552 #[tokio::test]
553 async fn test_graceful_shutdown_timing() {
554 let params = ReactionBaseParams::new("test-reaction", vec![]);
555 let base = ReactionBase::new(params);
556
557 let rx = base.create_shutdown_channel().await;
558
559 let task = tokio::spawn(async move {
561 let mut shutdown_rx = rx;
562 loop {
563 tokio::select! {
564 biased;
565 _ = &mut shutdown_rx => {
566 break;
567 }
568 _ = tokio::time::sleep(Duration::from_secs(10)) => {
569 }
571 }
572 }
573 });
574
575 base.set_processing_task(task).await;
576
577 let start = std::time::Instant::now();
579 let _ = base.stop_common().await;
580 let elapsed = start.elapsed();
581
582 assert!(
584 elapsed < Duration::from_millis(500),
585 "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
586 );
587 }
588
589 #[tokio::test]
590 async fn test_stop_common_without_shutdown_channel() {
591 let params = ReactionBaseParams::new("test-reaction", vec![]);
593 let base = ReactionBase::new(params);
594
595 let task = tokio::spawn(async {
597 tokio::time::sleep(Duration::from_millis(10)).await;
598 });
599
600 base.set_processing_task(task).await;
601
602 let result = base.stop_common().await;
604 assert!(result.is_ok());
605 }
606
607 #[tokio::test]
612 async fn test_get_id() {
613 let params = ReactionBaseParams::new("my-reaction-42", vec![]);
614 let base = ReactionBase::new(params);
615 assert_eq!(base.get_id(), "my-reaction-42");
616 }
617
618 #[tokio::test]
619 async fn test_get_queries() {
620 let queries = vec!["query-a".to_string(), "query-b".to_string(), "query-c".to_string()];
621 let params = ReactionBaseParams::new("r1", queries.clone());
622 let base = ReactionBase::new(params);
623 assert_eq!(base.get_queries(), &queries[..]);
624 }
625
626 #[tokio::test]
627 async fn test_get_queries_empty() {
628 let params = ReactionBaseParams::new("r1", vec![]);
629 let base = ReactionBase::new(params);
630 assert!(base.get_queries().is_empty());
631 }
632
633 #[tokio::test]
634 async fn test_get_auto_start_default_true() {
635 let params = ReactionBaseParams::new("r1", vec![]);
636 let base = ReactionBase::new(params);
637 assert!(base.get_auto_start());
638 }
639
640 #[tokio::test]
641 async fn test_get_auto_start_override_false() {
642 let params = ReactionBaseParams::new("r1", vec![]).with_auto_start(false);
643 let base = ReactionBase::new(params);
644 assert!(!base.get_auto_start());
645 }
646
647 #[tokio::test]
652 async fn test_context_none_before_initialize() {
653 let params = ReactionBaseParams::new("r1", vec![]);
654 let base = ReactionBase::new(params);
655 assert!(base.context().await.is_none());
656 }
657
658 #[tokio::test]
659 async fn test_context_some_after_initialize() {
660 let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
661 let update_tx = graph.update_sender();
662 let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
663
664 let params = ReactionBaseParams::new("r1", vec![]);
665 let base = ReactionBase::new(params);
666 base.initialize(context).await;
667
668 let ctx = base.context().await;
669 assert!(ctx.is_some());
670 assert_eq!(ctx.unwrap().reaction_id, "r1");
671 }
672
673 #[tokio::test]
674 async fn test_state_store_none_when_not_configured() {
675 let params = ReactionBaseParams::new("r1", vec![]);
676 let base = ReactionBase::new(params);
677 assert!(base.state_store().await.is_none());
678 }
679
680 #[tokio::test]
681 async fn test_state_store_none_after_initialize_without_store() {
682 let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
683 let update_tx = graph.update_sender();
684 let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
685
686 let params = ReactionBaseParams::new("r1", vec![]);
687 let base = ReactionBase::new(params);
688 base.initialize(context).await;
689
690 assert!(base.state_store().await.is_none());
691 }
692
693 #[tokio::test]
694 async fn test_identity_provider_none_by_default() {
695 let params = ReactionBaseParams::new("r1", vec![]);
696 let base = ReactionBase::new(params);
697 assert!(base.identity_provider().await.is_none());
698 }
699
700 #[tokio::test]
705 async fn test_status_handle_returns_handle() {
706 let params = ReactionBaseParams::new("r1", vec![]);
707 let base = ReactionBase::new(params);
708
709 let handle = base.status_handle();
710 assert_eq!(handle.get_status().await, ComponentStatus::Stopped);
712
713 handle.set_status(ComponentStatus::Running, None).await;
715 assert_eq!(base.get_status().await, ComponentStatus::Running);
716 }
717
718 #[tokio::test]
723 async fn test_deprovision_common_noop_without_state_store() {
724 let params = ReactionBaseParams::new("r1", vec![]);
725 let base = ReactionBase::new(params);
726 let result = base.deprovision_common().await;
728 assert!(result.is_ok());
729 }
730
731 #[tokio::test]
736 async fn test_set_processing_task_stores_handle() {
737 let params = ReactionBaseParams::new("r1", vec![]);
738 let base = ReactionBase::new(params);
739
740 assert!(base.processing_task.read().await.is_none());
742
743 let task = tokio::spawn(async {
744 tokio::time::sleep(Duration::from_secs(60)).await;
745 });
746
747 base.set_processing_task(task).await;
748
749 assert!(base.processing_task.read().await.is_some());
751
752 let task = base.processing_task.write().await.take();
754 if let Some(t) = task {
755 t.abort();
756 }
757 }
758}