drasi_lib/reactions/common/
base.rs1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::channels::priority_queue::PriorityQueue;
39use crate::channels::{
40 ComponentEvent, ComponentEventSender, ComponentStatus, ComponentType, QueryResult,
41};
42use crate::context::ReactionRuntimeContext;
43use crate::identity::IdentityProvider;
44use crate::state_store::StateStoreProvider;
45
46#[derive(Debug, Clone)]
63pub struct ReactionBaseParams {
64 pub id: String,
66 pub queries: Vec<String>,
68 pub priority_queue_capacity: Option<usize>,
70 pub auto_start: bool,
72}
73
74impl ReactionBaseParams {
75 pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
77 Self {
78 id: id.into(),
79 queries,
80 priority_queue_capacity: None,
81 auto_start: true, }
83 }
84
85 pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
87 self.priority_queue_capacity = Some(capacity);
88 self
89 }
90
91 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
93 self.auto_start = auto_start;
94 self
95 }
96}
97
98pub struct ReactionBase {
100 pub id: String,
102 pub queries: Vec<String>,
104 pub auto_start: bool,
106 pub status: Arc<RwLock<ComponentStatus>>,
108 context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
110 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
112 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
114 pub priority_queue: PriorityQueue<QueryResult>,
116 pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
118 pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
120 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
122 identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
126}
127
128impl ReactionBase {
129 pub fn new(params: ReactionBaseParams) -> Self {
134 Self {
135 priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
136 id: params.id,
137 queries: params.queries,
138 auto_start: params.auto_start,
139 status: Arc::new(RwLock::new(ComponentStatus::Stopped)),
140 context: Arc::new(RwLock::new(None)),
141 status_tx: Arc::new(RwLock::new(None)),
142 state_store: Arc::new(RwLock::new(None)),
143 subscription_tasks: Arc::new(RwLock::new(Vec::new())),
144 processing_task: Arc::new(RwLock::new(None)),
145 shutdown_tx: Arc::new(RwLock::new(None)),
146 identity_provider: Arc::new(RwLock::new(None)),
147 }
148 }
149
150 pub async fn initialize(&self, context: ReactionRuntimeContext) {
160 *self.context.write().await = Some(context.clone());
162
163 *self.status_tx.write().await = Some(context.status_tx.clone());
165
166 if let Some(state_store) = context.state_store.as_ref() {
167 *self.state_store.write().await = Some(state_store.clone());
168 }
169
170 if let Some(ip) = context.identity_provider.as_ref() {
172 let mut guard = self.identity_provider.write().await;
173 if guard.is_none() {
174 *guard = Some(ip.clone());
175 }
176 }
177 }
178
179 pub async fn context(&self) -> Option<ReactionRuntimeContext> {
183 self.context.read().await.clone()
184 }
185
186 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
190 self.state_store.read().await.clone()
191 }
192
193 pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
199 self.identity_provider.read().await.clone()
200 }
201
202 pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
208 *self.identity_provider.write().await = Some(provider);
209 }
210
211 pub fn get_auto_start(&self) -> bool {
213 self.auto_start
214 }
215
216 pub fn status_tx(&self) -> Arc<RwLock<Option<ComponentEventSender>>> {
223 self.status_tx.clone()
224 }
225
226 pub fn clone_shared(&self) -> Self {
231 Self {
232 id: self.id.clone(),
233 queries: self.queries.clone(),
234 auto_start: self.auto_start,
235 status: self.status.clone(),
236 context: self.context.clone(),
237 status_tx: self.status_tx.clone(),
238 state_store: self.state_store.clone(),
239 priority_queue: self.priority_queue.clone(),
240 subscription_tasks: self.subscription_tasks.clone(),
241 processing_task: self.processing_task.clone(),
242 shutdown_tx: self.shutdown_tx.clone(),
243 identity_provider: self.identity_provider.clone(),
244 }
245 }
246
247 pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
254 let (tx, rx) = tokio::sync::oneshot::channel();
255 *self.shutdown_tx.write().await = Some(tx);
256 rx
257 }
258
259 pub fn get_id(&self) -> &str {
261 &self.id
262 }
263
264 pub fn get_queries(&self) -> &[String] {
266 &self.queries
267 }
268
269 pub async fn get_status(&self) -> ComponentStatus {
271 self.status.read().await.clone()
272 }
273
274 pub async fn send_component_event(
280 &self,
281 status: ComponentStatus,
282 message: Option<String>,
283 ) -> Result<()> {
284 let event = ComponentEvent {
285 component_id: self.id.clone(),
286 component_type: ComponentType::Reaction,
287 status,
288 timestamp: chrono::Utc::now(),
289 message,
290 };
291
292 if let Some(ref tx) = *self.status_tx.read().await {
293 if let Err(e) = tx.send(event).await {
294 error!("Failed to send component event: {e}");
295 }
296 }
297 Ok(())
299 }
300
301 pub async fn set_status_with_event(
303 &self,
304 status: ComponentStatus,
305 message: Option<String>,
306 ) -> Result<()> {
307 *self.status.write().await = status.clone();
308 self.send_component_event(status, message).await
309 }
310
311 pub async fn enqueue_query_result(&self, result: QueryResult) -> anyhow::Result<()> {
316 self.priority_queue.enqueue_wait(Arc::new(result)).await;
317 Ok(())
318 }
319
320 pub async fn stop_common(&self) -> Result<()> {
328 info!("Stopping reaction: {}", self.id);
329
330 if let Some(tx) = self.shutdown_tx.write().await.take() {
332 let _ = tx.send(());
333 }
334
335 let mut subscription_tasks = self.subscription_tasks.write().await;
337 for task in subscription_tasks.drain(..) {
338 task.abort();
339 }
340 drop(subscription_tasks);
341
342 let mut processing_task = self.processing_task.write().await;
344 if let Some(task) = processing_task.take() {
345 match tokio::time::timeout(std::time::Duration::from_secs(2), task).await {
347 Ok(Ok(())) => {
348 debug!("[{}] Processing task completed gracefully", self.id);
349 }
350 Ok(Err(e)) => {
351 debug!("[{}] Processing task ended: {}", self.id, e);
353 }
354 Err(_) => {
355 warn!(
358 "[{}] Processing task did not respond to shutdown signal within timeout",
359 self.id
360 );
361 }
362 }
363 }
364 drop(processing_task);
365
366 let drained_events = self.priority_queue.drain().await;
368 if !drained_events.is_empty() {
369 info!(
370 "[{}] Drained {} pending events from priority queue",
371 self.id,
372 drained_events.len()
373 );
374 }
375
376 *self.status.write().await = ComponentStatus::Stopped;
377 info!("Reaction '{}' stopped", self.id);
378
379 Ok(())
380 }
381
382 pub async fn deprovision_common(&self) -> Result<()> {
388 info!("Deprovisioning reaction '{}'", self.id);
389 if let Some(store) = self.state_store().await {
390 let count = store.clear_store(&self.id).await.map_err(|e| {
391 anyhow::anyhow!(
392 "Failed to clear state store for reaction '{}': {}",
393 self.id,
394 e
395 )
396 })?;
397 info!(
398 "Cleared {} keys from state store for reaction '{}'",
399 count, self.id
400 );
401 }
402 Ok(())
403 }
404
405 pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
407 *self.processing_task.write().await = Some(task);
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use std::sync::atomic::{AtomicBool, Ordering};
415 use std::time::Duration;
416 use tokio::sync::mpsc;
417
418 #[tokio::test]
419 async fn test_reaction_base_creation() {
420 let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
421 .with_priority_queue_capacity(5000);
422
423 let base = ReactionBase::new(params);
424 assert_eq!(base.id, "test-reaction");
425 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
426 }
427
428 #[tokio::test]
429 async fn test_status_transitions() {
430 use crate::context::ReactionRuntimeContext;
431
432 let (status_tx, mut event_rx) = mpsc::channel(100);
433 let params = ReactionBaseParams::new("test-reaction", vec![]);
434
435 let base = ReactionBase::new(params);
436
437 let context =
439 ReactionRuntimeContext::new("test-instance", "test-reaction", status_tx, None);
440 base.initialize(context).await;
441
442 base.set_status_with_event(ComponentStatus::Starting, Some("Starting test".to_string()))
444 .await
445 .unwrap();
446
447 assert_eq!(base.get_status().await, ComponentStatus::Starting);
448
449 let event = event_rx.try_recv().unwrap();
451 assert_eq!(event.status, ComponentStatus::Starting);
452 assert_eq!(event.message, Some("Starting test".to_string()));
453 }
454
455 #[tokio::test]
456 async fn test_priority_queue_operations() {
457 let params =
458 ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
459
460 let base = ReactionBase::new(params);
461
462 let query_result = QueryResult::new(
464 "test-query".to_string(),
465 chrono::Utc::now(),
466 vec![],
467 Default::default(),
468 );
469
470 let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
472 assert!(enqueued);
473
474 let drained = base.priority_queue.drain().await;
476 assert_eq!(drained.len(), 1);
477 }
478
479 #[tokio::test]
480 async fn test_event_without_initialization() {
481 let params = ReactionBaseParams::new("test-reaction", vec![]);
483
484 let base = ReactionBase::new(params);
485
486 base.send_component_event(ComponentStatus::Starting, None)
488 .await
489 .unwrap();
490 }
491
492 #[tokio::test]
497 async fn test_create_shutdown_channel() {
498 let params = ReactionBaseParams::new("test-reaction", vec![]);
499 let base = ReactionBase::new(params);
500
501 assert!(base.shutdown_tx.read().await.is_none());
503
504 let rx = base.create_shutdown_channel().await;
506
507 assert!(base.shutdown_tx.read().await.is_some());
509
510 drop(rx);
512 }
513
514 #[tokio::test]
515 async fn test_shutdown_channel_signal() {
516 let params = ReactionBaseParams::new("test-reaction", vec![]);
517 let base = ReactionBase::new(params);
518
519 let mut rx = base.create_shutdown_channel().await;
520
521 if let Some(tx) = base.shutdown_tx.write().await.take() {
523 tx.send(()).unwrap();
524 }
525
526 let result = rx.try_recv();
528 assert!(result.is_ok());
529 }
530
531 #[tokio::test]
532 async fn test_shutdown_channel_replaced_on_second_create() {
533 let params = ReactionBaseParams::new("test-reaction", vec![]);
534 let base = ReactionBase::new(params);
535
536 let _rx1 = base.create_shutdown_channel().await;
538
539 let mut rx2 = base.create_shutdown_channel().await;
541
542 if let Some(tx) = base.shutdown_tx.write().await.take() {
544 tx.send(()).unwrap();
545 }
546
547 let result = rx2.try_recv();
549 assert!(result.is_ok());
550 }
551
552 #[tokio::test]
553 async fn test_stop_common_sends_shutdown_signal() {
554 let params = ReactionBaseParams::new("test-reaction", vec![]);
555 let base = ReactionBase::new(params);
556
557 let mut rx = base.create_shutdown_channel().await;
558
559 let shutdown_received = Arc::new(AtomicBool::new(false));
561 let shutdown_flag = shutdown_received.clone();
562
563 let task = tokio::spawn(async move {
564 tokio::select! {
565 _ = &mut rx => {
566 shutdown_flag.store(true, Ordering::SeqCst);
567 }
568 }
569 });
570
571 base.set_processing_task(task).await;
572
573 let _ = base.stop_common().await;
575
576 tokio::time::sleep(Duration::from_millis(50)).await;
578
579 assert!(
580 shutdown_received.load(Ordering::SeqCst),
581 "Processing task should have received shutdown signal"
582 );
583 }
584
585 #[tokio::test]
586 async fn test_graceful_shutdown_timing() {
587 let params = ReactionBaseParams::new("test-reaction", vec![]);
588 let base = ReactionBase::new(params);
589
590 let rx = base.create_shutdown_channel().await;
591
592 let task = tokio::spawn(async move {
594 let mut shutdown_rx = rx;
595 loop {
596 tokio::select! {
597 biased;
598 _ = &mut shutdown_rx => {
599 break;
600 }
601 _ = tokio::time::sleep(Duration::from_secs(10)) => {
602 }
604 }
605 }
606 });
607
608 base.set_processing_task(task).await;
609
610 let start = std::time::Instant::now();
612 let _ = base.stop_common().await;
613 let elapsed = start.elapsed();
614
615 assert!(
617 elapsed < Duration::from_millis(500),
618 "Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
619 );
620 }
621
622 #[tokio::test]
623 async fn test_stop_common_without_shutdown_channel() {
624 let params = ReactionBaseParams::new("test-reaction", vec![]);
626 let base = ReactionBase::new(params);
627
628 let task = tokio::spawn(async {
630 tokio::time::sleep(Duration::from_millis(10)).await;
631 });
632
633 base.set_processing_task(task).await;
634
635 let result = base.stop_common().await;
637 assert!(result.is_ok());
638 }
639}