drasi_lib/reactions/common/
base.rs1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{ComponentStatus, QueryResult};
40use crate::component_graph::ComponentStatusHandle;
41use crate::context::ReactionRuntimeContext;
42use crate::identity::IdentityProvider;
43use crate::state_store::StateStoreProvider;
44
45#[derive(Debug, Clone)]
62pub struct ReactionBaseParams {
63 pub id: String,
65 pub queries: Vec<String>,
67 pub priority_queue_capacity: Option<usize>,
69 pub auto_start: bool,
71}
72
73impl ReactionBaseParams {
74 pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
76 Self {
77 id: id.into(),
78 queries,
79 priority_queue_capacity: None,
80 auto_start: true, }
82 }
83
84 pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
86 self.priority_queue_capacity = Some(capacity);
87 self
88 }
89
90 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
92 self.auto_start = auto_start;
93 self
94 }
95}
96
97pub struct ReactionBase {
99 pub id: String,
101 pub queries: Vec<String>,
103 pub auto_start: bool,
105 status_handle: ComponentStatusHandle,
107 context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
109 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
111 pub priority_queue: PriorityQueue<QueryResult>,
113 pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
115 pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
117 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
119 identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
123 raw_config: Option<serde_json::Value>,
126}
127
128impl ReactionBase {
129 pub fn new(params: ReactionBaseParams) -> Self {
134 Self {
135 priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
136 id: params.id.clone(),
137 queries: params.queries,
138 auto_start: params.auto_start,
139 status_handle: ComponentStatusHandle::new(¶ms.id),
140 context: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), subscription_tasks: Arc::new(RwLock::new(Vec::new())),
143 processing_task: Arc::new(RwLock::new(None)),
144 shutdown_tx: Arc::new(RwLock::new(None)),
145 identity_provider: Arc::new(RwLock::new(None)),
146 raw_config: None,
147 }
148 }
149
150 pub async fn initialize(&self, context: ReactionRuntimeContext) {
160 *self.context.write().await = Some(context.clone());
162
163 self.status_handle.wire(context.update_tx.clone()).await;
165
166 if let Some(state_store) = context.state_store.as_ref() {
167 *self.state_store.write().await = Some(state_store.clone());
168 }
169
170 if let Some(ip) = context.identity_provider.as_ref() {
172 let mut guard = self.identity_provider.write().await;
173 if guard.is_none() {
174 *guard = Some(ip.clone());
175 }
176 }
177 }
178
179 pub async fn context(&self) -> Option<ReactionRuntimeContext> {
183 self.context.read().await.clone()
184 }
185
186 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
190 self.state_store.read().await.clone()
191 }
192
193 pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
199 self.identity_provider.read().await.clone()
200 }
201
202 pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
208 *self.identity_provider.write().await = Some(provider);
209 }
210
211 pub fn get_auto_start(&self) -> bool {
213 self.auto_start
214 }
215
216 pub fn set_raw_config(&mut self, config: serde_json::Value) {
218 self.raw_config = Some(config);
219 }
220
221 pub fn raw_config(&self) -> Option<&serde_json::Value> {
223 self.raw_config.as_ref()
224 }
225
226 pub fn properties_or_serialize<D: serde::Serialize>(
234 &self,
235 fallback_dto: &D,
236 ) -> std::collections::HashMap<String, serde_json::Value> {
237 if let Some(serde_json::Value::Object(map)) = self.raw_config.as_ref() {
238 return map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
239 }
240
241 match serde_json::to_value(fallback_dto) {
242 Ok(serde_json::Value::Object(map)) => map.into_iter().collect(),
243 _ => std::collections::HashMap::new(),
244 }
245 }
246
247 pub fn clone_shared(&self) -> Self {
252 Self {
253 id: self.id.clone(),
254 queries: self.queries.clone(),
255 auto_start: self.auto_start,
256 status_handle: self.status_handle.clone(),
257 context: self.context.clone(),
258 state_store: self.state_store.clone(),
259 priority_queue: self.priority_queue.clone(),
260 subscription_tasks: self.subscription_tasks.clone(),
261 processing_task: self.processing_task.clone(),
262 shutdown_tx: self.shutdown_tx.clone(),
263 identity_provider: self.identity_provider.clone(),
264 raw_config: self.raw_config.clone(),
265 }
266 }
267
268 pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
275 let (tx, rx) = tokio::sync::oneshot::channel();
276 *self.shutdown_tx.write().await = Some(tx);
277 rx
278 }
279
280 pub fn get_id(&self) -> &str {
282 &self.id
283 }
284
285 pub fn get_queries(&self) -> &[String] {
287 &self.queries
288 }
289
290 pub async fn get_status(&self) -> ComponentStatus {
292 self.status_handle.get_status().await
293 }
294
295 pub fn status_handle(&self) -> ComponentStatusHandle {
300 self.status_handle.clone()
301 }
302
303 pub async fn set_status(&self, status: ComponentStatus, message: Option<String>) {
307 self.status_handle.set_status(status, message).await;
308 }
309
310 pub async fn enqueue_query_result(&self, result: QueryResult) -> anyhow::Result<()> {
315 self.priority_queue.enqueue_wait(Arc::new(result)).await;
316 Ok(())
317 }
318
319 pub async fn stop_common(&self) -> Result<()> {
327 info!("Stopping reaction: {}", self.id);
328
329 if let Some(tx) = self.shutdown_tx.write().await.take() {
331 let _ = tx.send(());
332 }
333
334 let mut subscription_tasks = self.subscription_tasks.write().await;
336 for task in subscription_tasks.drain(..) {
337 task.abort();
338 }
339 drop(subscription_tasks);
340
341 let mut processing_task = self.processing_task.write().await;
343 if let Some(mut task) = processing_task.take() {
344 match tokio::time::timeout(std::time::Duration::from_secs(2), &mut task).await {
346 Ok(Ok(())) => {
347 debug!("[{}] Processing task completed gracefully", self.id);
348 }
349 Ok(Err(e)) => {
350 debug!("[{}] Processing task ended: {}", self.id, e);
352 }
353 Err(_) => {
354 warn!(
356 "[{}] Processing task did not respond to shutdown signal within timeout, aborting",
357 self.id
358 );
359 task.abort();
360 }
361 }
362 }
363 drop(processing_task);
364
365 let drained_events = self.priority_queue.drain().await;
367 if !drained_events.is_empty() {
368 info!(
369 "[{}] Drained {} pending events from priority queue",
370 self.id,
371 drained_events.len()
372 );
373 }
374
375 self.set_status(
376 ComponentStatus::Stopped,
377 Some(format!("Reaction '{}' stopped", self.id)),
378 )
379 .await;
380 info!("Reaction '{}' stopped", self.id);
381
382 Ok(())
383 }
384
385 pub async fn deprovision_common(&self) -> Result<()> {
391 info!("Deprovisioning reaction '{}'", self.id);
392 if let Some(store) = self.state_store().await {
393 let count = store.clear_store(&self.id).await.map_err(|e| {
394 anyhow::anyhow!(
395 "Failed to clear state store for reaction '{}': {}",
396 self.id,
397 e
398 )
399 })?;
400 info!(
401 "Cleared {} keys from state store for reaction '{}'",
402 count, self.id
403 );
404 }
405 Ok(())
406 }
407
408 pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
410 *self.processing_task.write().await = Some(task);
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use std::sync::atomic::{AtomicBool, Ordering};
418 use std::time::Duration;
419 use tokio::sync::mpsc;
420
421 #[tokio::test]
422 async fn test_reaction_base_creation() {
423 let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
424 .with_priority_queue_capacity(5000);
425
426 let base = ReactionBase::new(params);
427 assert_eq!(base.id, "test-reaction");
428 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
429 }
430
431 #[tokio::test]
432 async fn test_status_transitions() {
433 use crate::context::ReactionRuntimeContext;
434
435 let (graph, _rx) = crate::component_graph::ComponentGraph::new("test-instance");
436 let update_tx = graph.update_sender();
437 let graph = Arc::new(RwLock::new(graph));
438 let params = ReactionBaseParams::new("test-reaction", vec![]);
439
440 let base = ReactionBase::new(params);
441
442 let context =
444 ReactionRuntimeContext::new("test-instance", "test-reaction", None, update_tx, None);
445 base.initialize(context).await;
446
447 base.set_status(ComponentStatus::Starting, Some("Starting test".to_string()))
449 .await;
450
451 assert_eq!(base.get_status().await, ComponentStatus::Starting);
452
453 let mut event_rx = graph.read().await.subscribe();
455 base.set_status(ComponentStatus::Running, Some("Running test".to_string()))
457 .await;
458
459 assert_eq!(base.get_status().await, ComponentStatus::Running);
460 }
461
462 #[tokio::test]
463 async fn test_priority_queue_operations() {
464 let params =
465 ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
466
467 let base = ReactionBase::new(params);
468
469 let query_result = QueryResult::new(
471 "test-query".to_string(),
472 0,
473 chrono::Utc::now(),
474 vec![],
475 Default::default(),
476 );
477
478 let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
480 assert!(enqueued);
481
482 let drained = base.priority_queue.drain().await;
484 assert_eq!(drained.len(), 1);
485 }
486
487 #[tokio::test]
488 async fn test_event_without_initialization() {
489 let params = ReactionBaseParams::new("test-reaction", vec![]);
491
492 let base = ReactionBase::new(params);
493
494 base.set_status(ComponentStatus::Starting, None).await;
496 }
497
498 #[tokio::test]
503 async fn test_create_shutdown_channel() {
504 let params = ReactionBaseParams::new("test-reaction", vec![]);
505 let base = ReactionBase::new(params);
506
507 assert!(base.shutdown_tx.read().await.is_none());
509
510 let rx = base.create_shutdown_channel().await;
512
513 assert!(base.shutdown_tx.read().await.is_some());
515
516 drop(rx);
518 }
519
520 #[tokio::test]
521 async fn test_shutdown_channel_signal() {
522 let params = ReactionBaseParams::new("test-reaction", vec![]);
523 let base = ReactionBase::new(params);
524
525 let mut rx = base.create_shutdown_channel().await;
526
527 if let Some(tx) = base.shutdown_tx.write().await.take() {
529 tx.send(()).unwrap();
530 }
531
532 let result = rx.try_recv();
534 assert!(result.is_ok());
535 }
536
537 #[tokio::test]
538 async fn test_shutdown_channel_replaced_on_second_create() {
539 let params = ReactionBaseParams::new("test-reaction", vec![]);
540 let base = ReactionBase::new(params);
541
542 let _rx1 = base.create_shutdown_channel().await;
544
545 let mut rx2 = base.create_shutdown_channel().await;
547
548 if let Some(tx) = base.shutdown_tx.write().await.take() {
550 tx.send(()).unwrap();
551 }
552
553 let result = rx2.try_recv();
555 assert!(result.is_ok());
556 }
557
558 #[tokio::test]
559 async fn test_stop_common_sends_shutdown_signal() {
560 let params = ReactionBaseParams::new("test-reaction", vec![]);
561 let base = ReactionBase::new(params);
562
563 let mut rx = base.create_shutdown_channel().await;
564
565 let shutdown_received = Arc::new(AtomicBool::new(false));
567 let shutdown_flag = shutdown_received.clone();
568
569 let task = tokio::spawn(async move {
570 tokio::select! {
571 _ = &mut rx => {
572 shutdown_flag.store(true, Ordering::SeqCst);
573 }
574 }
575 });
576
577 base.set_processing_task(task).await;
578
579 let _ = base.stop_common().await;
581
582 assert!(
584 shutdown_received.load(Ordering::SeqCst),
585 "Processing task should have received shutdown signal"
586 );
587 }
588
589 #[tokio::test]
590 async fn test_graceful_shutdown_timing() {
591 let params = ReactionBaseParams::new("test-reaction", vec![]);
592 let base = ReactionBase::new(params);
593
594 let rx = base.create_shutdown_channel().await;
595
596 let task = tokio::spawn(async move {
598 let mut shutdown_rx = rx;
599 loop {
600 tokio::select! {
601 biased;
602 _ = &mut shutdown_rx => {
603 break;
604 }
605 _ = tokio::time::sleep(Duration::from_secs(10)) => {
606 }
608 }
609 }
610 });
611
612 base.set_processing_task(task).await;
613
614 let start = std::time::Instant::now();
616 let _ = base.stop_common().await;
617 let elapsed = start.elapsed();
618
619 assert!(
621 elapsed < Duration::from_millis(500),
622 "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
623 );
624 }
625
626 #[tokio::test]
627 async fn test_stop_common_without_shutdown_channel() {
628 let params = ReactionBaseParams::new("test-reaction", vec![]);
630 let base = ReactionBase::new(params);
631
632 let task = tokio::spawn(async {
634 tokio::time::sleep(Duration::from_millis(10)).await;
635 });
636
637 base.set_processing_task(task).await;
638
639 let result = base.stop_common().await;
641 assert!(result.is_ok());
642 }
643
644 #[tokio::test]
649 async fn test_get_id() {
650 let params = ReactionBaseParams::new("my-reaction-42", vec![]);
651 let base = ReactionBase::new(params);
652 assert_eq!(base.get_id(), "my-reaction-42");
653 }
654
655 #[tokio::test]
656 async fn test_get_queries() {
657 let queries = vec!["query-a".to_string(), "query-b".to_string(), "query-c".to_string()];
658 let params = ReactionBaseParams::new("r1", queries.clone());
659 let base = ReactionBase::new(params);
660 assert_eq!(base.get_queries(), &queries[..]);
661 }
662
663 #[tokio::test]
664 async fn test_get_queries_empty() {
665 let params = ReactionBaseParams::new("r1", vec![]);
666 let base = ReactionBase::new(params);
667 assert!(base.get_queries().is_empty());
668 }
669
670 #[tokio::test]
671 async fn test_get_auto_start_default_true() {
672 let params = ReactionBaseParams::new("r1", vec![]);
673 let base = ReactionBase::new(params);
674 assert!(base.get_auto_start());
675 }
676
677 #[tokio::test]
678 async fn test_get_auto_start_override_false() {
679 let params = ReactionBaseParams::new("r1", vec![]).with_auto_start(false);
680 let base = ReactionBase::new(params);
681 assert!(!base.get_auto_start());
682 }
683
684 #[tokio::test]
689 async fn test_context_none_before_initialize() {
690 let params = ReactionBaseParams::new("r1", vec![]);
691 let base = ReactionBase::new(params);
692 assert!(base.context().await.is_none());
693 }
694
695 #[tokio::test]
696 async fn test_context_some_after_initialize() {
697 let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
698 let update_tx = graph.update_sender();
699 let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
700
701 let params = ReactionBaseParams::new("r1", vec![]);
702 let base = ReactionBase::new(params);
703 base.initialize(context).await;
704
705 let ctx = base.context().await;
706 assert!(ctx.is_some());
707 assert_eq!(ctx.unwrap().reaction_id, "r1");
708 }
709
710 #[tokio::test]
711 async fn test_state_store_none_when_not_configured() {
712 let params = ReactionBaseParams::new("r1", vec![]);
713 let base = ReactionBase::new(params);
714 assert!(base.state_store().await.is_none());
715 }
716
717 #[tokio::test]
718 async fn test_state_store_none_after_initialize_without_store() {
719 let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
720 let update_tx = graph.update_sender();
721 let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
722
723 let params = ReactionBaseParams::new("r1", vec![]);
724 let base = ReactionBase::new(params);
725 base.initialize(context).await;
726
727 assert!(base.state_store().await.is_none());
728 }
729
730 #[tokio::test]
731 async fn test_identity_provider_none_by_default() {
732 let params = ReactionBaseParams::new("r1", vec![]);
733 let base = ReactionBase::new(params);
734 assert!(base.identity_provider().await.is_none());
735 }
736
737 #[tokio::test]
742 async fn test_status_handle_returns_handle() {
743 let params = ReactionBaseParams::new("r1", vec![]);
744 let base = ReactionBase::new(params);
745
746 let handle = base.status_handle();
747 assert_eq!(handle.get_status().await, ComponentStatus::Stopped);
749
750 handle.set_status(ComponentStatus::Running, None).await;
752 assert_eq!(base.get_status().await, ComponentStatus::Running);
753 }
754
755 #[tokio::test]
760 async fn test_deprovision_common_noop_without_state_store() {
761 let params = ReactionBaseParams::new("r1", vec![]);
762 let base = ReactionBase::new(params);
763 let result = base.deprovision_common().await;
765 assert!(result.is_ok());
766 }
767
768 #[tokio::test]
773 async fn test_set_processing_task_stores_handle() {
774 let params = ReactionBaseParams::new("r1", vec![]);
775 let base = ReactionBase::new(params);
776
777 assert!(base.processing_task.read().await.is_none());
779
780 let task = tokio::spawn(async {
781 tokio::time::sleep(Duration::from_secs(60)).await;
782 });
783
784 base.set_processing_task(task).await;
785
786 assert!(base.processing_task.read().await.is_some());
788
789 let task = base.processing_task.write().await.take();
791 if let Some(t) = task {
792 t.abort();
793 }
794 }
795}