drasi_lib/reactions/common/
base.rs1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{
40 ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
41};
42use crate::context::ReactionRuntimeContext;
43use crate::reactions::QueryProvider;
44use crate::state_store::StateStoreProvider;
45
46#[derive(Debug, Clone)]
63pub struct ReactionBaseParams {
64 pub id: String,
66 pub queries: Vec<String>,
68 pub priority_queue_capacity: Option<usize>,
70 pub auto_start: bool,
72}
73
74impl ReactionBaseParams {
75 pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
77 Self {
78 id: id.into(),
79 queries,
80 priority_queue_capacity: None,
81 auto_start: true, }
83 }
84
85 pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
87 self.priority_queue_capacity = Some(capacity);
88 self
89 }
90
91 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
93 self.auto_start = auto_start;
94 self
95 }
96}
97
98pub struct ReactionBase {
100 pub id: String,
102 pub queries: Vec<String>,
104 pub auto_start: bool,
106 pub status: Arc<RwLock<ComponentStatus>>,
108 context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
110 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
112 query_provider: Arc<RwLock<Option<Arc<dyn QueryProvider>>>>,
114 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
116 pub priority_queue: PriorityQueue<QueryResult>,
118 pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
120 pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
122 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
124}
125
126impl ReactionBase {
127 pub fn new(params: ReactionBaseParams) -> Self {
132 Self {
133 priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
134 id: params.id,
135 queries: params.queries,
136 auto_start: params.auto_start,
137 status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
138 context: Arc::new(RwLock::new(None)), status_tx: Arc::new(RwLock::new(None)), query_provider: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), subscription_tasks: Arc::new(RwLock::new(Vec::new())),
143 processing_task: Arc::new(RwLock::new(None)),
144 shutdown_tx: Arc::new(RwLock::new(None)),
145 }
146 }
147
148 pub async fn initialize(&self, context: ReactionRuntimeContext) {
159 *self.context.write().await = Some(context.clone());
161
162 *self.status_tx.write().await = Some(context.status_tx.clone());
164 *self.query_provider.write().await = Some(context.query_provider.clone());
165
166 if let Some(state_store) = context.state_store.as_ref() {
167 *self.state_store.write().await = Some(state_store.clone());
168 }
169 }
170
171 pub async fn context(&self) -> Option<ReactionRuntimeContext> {
175 self.context.read().await.clone()
176 }
177
178 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
182 self.state_store.read().await.clone()
183 }
184
185 pub fn get_auto_start(&self) -> bool {
187 self.auto_start
188 }
189
190 pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
197 self.status_tx.clone()
198 }
199
200 pub fn clone_shared(&self) -> Self {
205 Self {
206 id: self.id.clone(),
207 queries: self.queries.clone(),
208 auto_start: self.auto_start,
209 status: self.status.clone(),
210 context: self.context.clone(),
211 status_tx: self.status_tx.clone(),
212 query_provider: self.query_provider.clone(),
213 state_store: self.state_store.clone(),
214 priority_queue: self.priority_queue.clone(),
215 subscription_tasks: self.subscription_tasks.clone(),
216 processing_task: self.processing_task.clone(),
217 shutdown_tx: self.shutdown_tx.clone(),
218 }
219 }
220
221 pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
228 let (tx, rx) = tokio::sync::oneshot::channel();
229 *self.shutdown_tx.write().await = Some(tx);
230 rx
231 }
232
233 pub fn get_id(&self) -> &str {
235 &self.id
236 }
237
238 pub fn get_queries(&self) -> &[String] {
240 &self.queries
241 }
242
243 pub async fn get_status(&self) -> ComponentStatus {
245 self.status.read().await.clone()
246 }
247
248 pub async fn send_component_event(
254 &self,
255 status: ComponentStatus,
256 message: Option<String>,
257 ) -> Result<()> {
258 let event = ComponentEvent {
259 component_id: self.id.clone(),
260 component_type: ComponentType::Reaction,
261 status,
262 timestamp: chrono::Utc::now(),
263 message,
264 };
265
266 if let Some(ref tx) = *self.status_tx.read().await {
267 if let Err(e) = tx.send(event).await {
268 error!("Failed to send component event: {e}");
269 }
270 }
271 Ok(())
273 }
274
275 pub async fn set_status_with_event(
277 &self,
278 status: ComponentStatus,
279 message: Option<String>,
280 ) -> Result<()> {
281 *self.status.write().await = status.clone();
282 self.send_component_event(status, message).await
283 }
284
285 pub async fn subscribe_to_queries(&self) -> Result<()> {
299 let query_provider = {
301 let qp_guard = self.query_provider.read().await;
302 qp_guard.as_ref().cloned().ok_or_else(|| {
303 anyhow::anyhow!(
304 "QueryProvider not injected - was reaction '{}' added to DrasiLib?",
305 self.id
306 )
307 })?
308 };
309
310 for query_id in &self.queries {
312 let query = query_provider.get_query_instance(query_id).await?;
314
315 let subscription_response = query
317 .subscribe(self.id.clone())
318 .await
319 .map_err(|e| anyhow::anyhow!(e))?;
320 let mut receiver = subscription_response.receiver;
321
322 let priority_queue = self.priority_queue.clone();
324 let query_id_clone = query_id.clone();
325 let reaction_id = self.id.clone();
326
327 let instance_id = self
329 .context()
330 .await
331 .map(|c| c.instance_id.clone())
332 .unwrap_or_default();
333
334 let query_config = query.get_config();
336 let dispatch_mode = query_config
337 .dispatch_mode
338 .unwrap_or(crate::channels::DispatchMode::Channel);
339 let use_blocking_enqueue =
340 matches!(dispatch_mode, crate::channels::DispatchMode::Channel);
341
342 let span = tracing::info_span!(
344 "reaction_forwarder",
345 instance_id = %instance_id,
346 component_id = %reaction_id,
347 component_type = "reaction"
348 );
349 let forwarder_task = tokio::spawn(
350 async move {
351 debug!(
352 "[{reaction_id}] Started result forwarder for query '{query_id_clone}' (dispatch_mode: {dispatch_mode:?}, blocking_enqueue: {use_blocking_enqueue})"
353 );
354
355 loop {
356 match receiver.recv().await {
357 Ok(query_result) => {
358 if use_blocking_enqueue {
360 priority_queue.enqueue_wait(query_result).await;
363 } else {
364 if !priority_queue.enqueue(query_result).await {
367 warn!(
368 "[{reaction_id}] Failed to enqueue result from query '{query_id_clone}' - priority queue at capacity (broadcast mode)"
369 );
370 }
371 }
372 }
373 Err(e) => {
374 let error_str = e.to_string();
376 if error_str.contains("lagged") {
377 warn!(
378 "[{reaction_id}] Receiver lagged for query '{query_id_clone}': {error_str}"
379 );
380 continue;
381 } else {
382 info!(
383 "[{reaction_id}] Receiver error for query '{query_id_clone}': {error_str}"
384 );
385 break;
386 }
387 }
388 }
389 }
390 }
391 .instrument(span),
392 );
393
394 self.subscription_tasks.write().await.push(forwarder_task);
396 }
397
398 Ok(())
399 }
400
401 pub async fn stop_common(&self) -> Result<()> {
409 info!("Stopping reaction: {}", self.id);
410
411 if let Some(tx) = self.shutdown_tx.write().await.take() {
413 let _ = tx.send(());
414 }
415
416 let mut subscription_tasks = self.subscription_tasks.write().await;
418 for task in subscription_tasks.drain(..) {
419 task.abort();
420 }
421 drop(subscription_tasks);
422
423 let mut processing_task = self.processing_task.write().await;
425 if let Some(task) = processing_task.take() {
426 match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
428 Ok(Ok(())) => {
429 debug!("[{}] Processing task completed gracefully", self.id);
430 }
431 Ok(Err(e)) => {
432 debug!("[{}] Processing task ended: {}", self.id, e);
434 }
435 Err(_) => {
436 warn!(
439 "[{}] Processing task did not respond to shutdown signal within timeout",
440 self.id
441 );
442 }
443 }
444 }
445 drop(processing_task);
446
447 let drained_events = self.priority_queue.drain().await;
449 if !drained_events.is_empty() {
450 info!(
451 "[{}] Drained {} pending events from priority queue",
452 self.id,
453 drained_events.len()
454 );
455 }
456
457 Ok(())
458 }
459
460 pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
462 *self.processing_task.write().await = Some(task);
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use std::sync::atomic::{AtomicBool, Ordering};
470 use std::time::Duration;
471 use tokio::sync::mpsc;
472
473 #[tokio::test]
474 async fn test_reaction_base_creation() {
475 let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
476 .with_priority_queue_capacity(5000);
477
478 let base = ReactionBase::new(params);
479 assert_eq!(base.id, "test-reaction");
480 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
481 }
482
483 #[tokio::test]
484 async fn test_status_transitions() {
485 use crate::context::ReactionRuntimeContext;
486 use crate::queries::Query;
487
488 struct MockQueryProvider;
490
491 #[async_trait::async_trait]
492 impl crate::reactions::QueryProvider for MockQueryProvider {
493 async fn get_query_instance(
494 &self,
495 _id: &str,
496 ) -> anyhow::Result<std::sync::Arc<dyn Query>> {
497 Err(anyhow::anyhow!("MockQueryProvider: query not found"))
498 }
499 }
500
501 let (status_tx, mut event_rx) = mpsc::channel(100);
502 let params = ReactionBaseParams::new("test-reaction", vec![]);
503
504 let base = ReactionBase::new(params);
505
506 let context = ReactionRuntimeContext::new(
508 "test-instance",
509 "test-reaction",
510 status_tx,
511 None,
512 std::sync::Arc::new(MockQueryProvider),
513 );
514 base.initialize(context).await;
515
516 base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
518 .await
519 .unwrap();
520
521 assert_eq!(base.get_status().await, ComponentStatus::Starting);
522
523 let event = event_rx.try_recv().unwrap();
525 assert_eq!(event.status, ComponentStatus::Starting);
526 assert_eq!(event.message, Some("Starting test".to_string()));
527 }
528
529 #[tokio::test]
530 async fn test_priority_queue_operations() {
531 let params =
532 ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
533
534 let base = ReactionBase::new(params);
535
536 let query_result = QueryResult::new(
538 "test-query".to_string(),
539 chrono::Utc::now(),
540 vec![],
541 Default::default(),
542 );
543
544 let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
546 assert!(enqueued);
547
548 let drained = base.priority_queue.drain().await;
550 assert_eq!(drained.len(), 1);
551 }
552
553 #[tokio::test]
554 async fn test_event_without_initialization() {
555 let params = ReactionBaseParams::new("test-reaction", vec![]);
557
558 let base = ReactionBase::new(params);
559
560 base.send_component_event(ComponentStatus::Starting, None)
562 .await
563 .unwrap();
564 }
565
566 #[tokio::test]
571 async fn test_create_shutdown_channel() {
572 let params = ReactionBaseParams::new("test-reaction", vec![]);
573 let base = ReactionBase::new(params);
574
575 assert!(base.shutdown_tx.read().await.is_none());
577
578 let rx = base.create_shutdown_channel().await;
580
581 assert!(base.shutdown_tx.read().await.is_some());
583
584 drop(rx);
586 }
587
588 #[tokio::test]
589 async fn test_shutdown_channel_signal() {
590 let params = ReactionBaseParams::new("test-reaction", vec![]);
591 let base = ReactionBase::new(params);
592
593 let mut rx = base.create_shutdown_channel().await;
594
595 if let Some(tx) = base.shutdown_tx.write().await.take() {
597 tx.send(()).unwrap();
598 }
599
600 let result = rx.try_recv();
602 assert!(result.is_ok());
603 }
604
605 #[tokio::test]
606 async fn test_shutdown_channel_replaced_on_second_create() {
607 let params = ReactionBaseParams::new("test-reaction", vec![]);
608 let base = ReactionBase::new(params);
609
610 let _rx1 = base.create_shutdown_channel().await;
612
613 let mut rx2 = base.create_shutdown_channel().await;
615
616 if let Some(tx) = base.shutdown_tx.write().await.take() {
618 tx.send(()).unwrap();
619 }
620
621 let result = rx2.try_recv();
623 assert!(result.is_ok());
624 }
625
626 #[tokio::test]
627 async fn test_stop_common_sends_shutdown_signal() {
628 let params = ReactionBaseParams::new("test-reaction", vec![]);
629 let base = ReactionBase::new(params);
630
631 let mut rx = base.create_shutdown_channel().await;
632
633 let shutdown_received = Arc::new(AtomicBool::new(false));
635 let shutdown_flag = shutdown_received.clone();
636
637 let task = tokio::spawn(async move {
638 tokio::select! {
639 _ = &mut rx => {
640 shutdown_flag.store(true, Ordering::SeqCst);
641 }
642 }
643 });
644
645 base.set_processing_task(task).await;
646
647 let _ = base.stop_common().await;
649
650 tokio::time::sleep(Duration::from_millis(50)).await;
652
653 assert!(
654 shutdown_received.load(Ordering::SeqCst),
655 "Processing task should have received shutdown signal"
656 );
657 }
658
659 #[tokio::test]
660 async fn test_graceful_shutdown_timing() {
661 let params = ReactionBaseParams::new("test-reaction", vec![]);
662 let base = ReactionBase::new(params);
663
664 let rx = base.create_shutdown_channel().await;
665
666 let task = tokio::spawn(async move {
668 let mut shutdown_rx = rx;
669 loop {
670 tokio::select! {
671 biased;
672 _ = &mut shutdown_rx => {
673 break;
674 }
675 _ = tokio::time::sleep(Duration::from_secs(10)) => {
676 }
678 }
679 }
680 });
681
682 base.set_processing_task(task).await;
683
684 let start = std::time::Instant::now();
686 let _ = base.stop_common().await;
687 let elapsed = start.elapsed();
688
689 assert!(
691 elapsed < Duration::from_millis(500),
692 "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
693 );
694 }
695
696 #[tokio::test]
697 async fn test_stop_common_without_shutdown_channel() {
698 let params = ReactionBaseParams::new("test-reaction", vec![]);
700 let base = ReactionBase::new(params);
701
702 let task = tokio::spawn(async {
704 tokio::time::sleep(Duration::from_millis(10)).await;
705 });
706
707 base.set_processing_task(task).await;
708
709 let result = base.stop_common().await;
711 assert!(result.is_ok());
712 }
713}