use anyhow::Result;
use log::{debug, error, info, warn};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::Instrument;
use crate::channels::priority_queue::PriorityQueue;
use crate::channels::{ComponentStatus, QueryResult};
use crate::component_graph::ComponentStatusHandle;
use crate::context::ReactionRuntimeContext;
use crate::identity::IdentityProvider;
use crate::state_store::StateStoreProvider;
#[derive(Debug, Clone)]
pub struct ReactionBaseParams {
pub id: String,
pub queries: Vec<String>,
pub priority_queue_capacity: Option<usize>,
pub auto_start: bool,
}
impl ReactionBaseParams {
pub fn new(id: impl Into<String>, queries: Vec<String>) -> Self {
Self {
id: id.into(),
queries,
priority_queue_capacity: None,
auto_start: true, }
}
pub fn with_priority_queue_capacity(mut self, capacity: usize) -> Self {
self.priority_queue_capacity = Some(capacity);
self
}
pub fn with_auto_start(mut self, auto_start: bool) -> Self {
self.auto_start = auto_start;
self
}
}
pub struct ReactionBase {
pub id: String,
pub queries: Vec<String>,
pub auto_start: bool,
status_handle: ComponentStatusHandle,
context: Arc<RwLock<Option<ReactionRuntimeContext>>>,
state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
pub priority_queue: PriorityQueue<QueryResult>,
pub subscription_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
pub processing_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
}
impl ReactionBase {
pub fn new(params: ReactionBaseParams) -> Self {
Self {
priority_queue: PriorityQueue::new(params.priority_queue_capacity.unwrap_or(10000)),
id: params.id.clone(),
queries: params.queries,
auto_start: params.auto_start,
status_handle: ComponentStatusHandle::new(¶ms.id),
context: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), subscription_tasks: Arc::new(RwLock::new(Vec::new())),
processing_task: Arc::new(RwLock::new(None)),
shutdown_tx: Arc::new(RwLock::new(None)),
identity_provider: Arc::new(RwLock::new(None)),
}
}
pub async fn initialize(&self, context: ReactionRuntimeContext) {
*self.context.write().await = Some(context.clone());
self.status_handle.wire(context.update_tx.clone()).await;
if let Some(state_store) = context.state_store.as_ref() {
*self.state_store.write().await = Some(state_store.clone());
}
if let Some(ip) = context.identity_provider.as_ref() {
let mut guard = self.identity_provider.write().await;
if guard.is_none() {
*guard = Some(ip.clone());
}
}
}
pub async fn context(&self) -> Option<ReactionRuntimeContext> {
self.context.read().await.clone()
}
pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
self.state_store.read().await.clone()
}
pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
self.identity_provider.read().await.clone()
}
pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
*self.identity_provider.write().await = Some(provider);
}
pub fn get_auto_start(&self) -> bool {
self.auto_start
}
pub fn clone_shared(&self) -> Self {
Self {
id: self.id.clone(),
queries: self.queries.clone(),
auto_start: self.auto_start,
status_handle: self.status_handle.clone(),
context: self.context.clone(),
state_store: self.state_store.clone(),
priority_queue: self.priority_queue.clone(),
subscription_tasks: self.subscription_tasks.clone(),
processing_task: self.processing_task.clone(),
shutdown_tx: self.shutdown_tx.clone(),
identity_provider: self.identity_provider.clone(),
}
}
pub async fn create_shutdown_channel(&self) -> tokio::sync::oneshot::Receiver<()> {
let (tx, rx) = tokio::sync::oneshot::channel();
*self.shutdown_tx.write().await = Some(tx);
rx
}
pub fn get_id(&self) -> &str {
&self.id
}
pub fn get_queries(&self) -> &[String] {
&self.queries
}
pub async fn get_status(&self) -> ComponentStatus {
self.status_handle.get_status().await
}
pub fn status_handle(&self) -> ComponentStatusHandle {
self.status_handle.clone()
}
pub async fn set_status(&self, status: ComponentStatus, message: Option<String>) {
self.status_handle.set_status(status, message).await;
}
pub async fn enqueue_query_result(&self, result: QueryResult) -> anyhow::Result<()> {
self.priority_queue.enqueue_wait(Arc::new(result)).await;
Ok(())
}
pub async fn stop_common(&self) -> Result<()> {
info!("Stopping reaction: {}", self.id);
if let Some(tx) = self.shutdown_tx.write().await.take() {
let _ = tx.send(());
}
let mut subscription_tasks = self.subscription_tasks.write().await;
for task in subscription_tasks.drain(..) {
task.abort();
}
drop(subscription_tasks);
let mut processing_task = self.processing_task.write().await;
if let Some(mut task) = processing_task.take() {
match tokio::time::timeout(std::time::Duration::from_secs(2), &mut task).await {
Ok(Ok(())) => {
debug!("[{}] Processing task completed gracefully", self.id);
}
Ok(Err(e)) => {
debug!("[{}] Processing task ended: {}", self.id, e);
}
Err(_) => {
warn!(
"[{}] Processing task did not respond to shutdown signal within timeout, aborting",
self.id
);
task.abort();
}
}
}
drop(processing_task);
let drained_events = self.priority_queue.drain().await;
if !drained_events.is_empty() {
info!(
"[{}] Drained {} pending events from priority queue",
self.id,
drained_events.len()
);
}
self.set_status(
ComponentStatus::Stopped,
Some(format!("Reaction '{}' stopped", self.id)),
)
.await;
info!("Reaction '{}' stopped", self.id);
Ok(())
}
pub async fn deprovision_common(&self) -> Result<()> {
info!("Deprovisioning reaction '{}'", self.id);
if let Some(store) = self.state_store().await {
let count = store.clear_store(&self.id).await.map_err(|e| {
anyhow::anyhow!(
"Failed to clear state store for reaction '{}': {}",
self.id,
e
)
})?;
info!(
"Cleared {} keys from state store for reaction '{}'",
count, self.id
);
}
Ok(())
}
pub async fn set_processing_task(&self, task: tokio::task::JoinHandle<()>) {
*self.processing_task.write().await = Some(task);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_reaction_base_creation() {
let params = ReactionBaseParams::new("test-reaction", vec!["query1".to_string()])
.with_priority_queue_capacity(5000);
let base = ReactionBase::new(params);
assert_eq!(base.id, "test-reaction");
assert_eq!(base.get_status().await, ComponentStatus::Stopped);
}
#[tokio::test]
async fn test_status_transitions() {
use crate::context::ReactionRuntimeContext;
let (graph, _rx) = crate::component_graph::ComponentGraph::new("test-instance");
let update_tx = graph.update_sender();
let graph = Arc::new(RwLock::new(graph));
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let context =
ReactionRuntimeContext::new("test-instance", "test-reaction", None, update_tx, None);
base.initialize(context).await;
base.set_status(ComponentStatus::Starting, Some("Starting test".to_string()))
.await;
assert_eq!(base.get_status().await, ComponentStatus::Starting);
let mut event_rx = graph.read().await.subscribe();
base.set_status(ComponentStatus::Running, Some("Running test".to_string()))
.await;
assert_eq!(base.get_status().await, ComponentStatus::Running);
}
#[tokio::test]
async fn test_priority_queue_operations() {
let params =
ReactionBaseParams::new("test-reaction", vec![]).with_priority_queue_capacity(10);
let base = ReactionBase::new(params);
let query_result = QueryResult::new(
"test-query".to_string(),
chrono::Utc::now(),
vec![],
Default::default(),
);
let enqueued = base.priority_queue.enqueue(Arc::new(query_result)).await;
assert!(enqueued);
let drained = base.priority_queue.drain().await;
assert_eq!(drained.len(), 1);
}
#[tokio::test]
async fn test_event_without_initialization() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
base.set_status(ComponentStatus::Starting, None).await;
}
#[tokio::test]
async fn test_create_shutdown_channel() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
assert!(base.shutdown_tx.read().await.is_none());
let rx = base.create_shutdown_channel().await;
assert!(base.shutdown_tx.read().await.is_some());
drop(rx);
}
#[tokio::test]
async fn test_shutdown_channel_signal() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let mut rx = base.create_shutdown_channel().await;
if let Some(tx) = base.shutdown_tx.write().await.take() {
tx.send(()).unwrap();
}
let result = rx.try_recv();
assert!(result.is_ok());
}
#[tokio::test]
async fn test_shutdown_channel_replaced_on_second_create() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let _rx1 = base.create_shutdown_channel().await;
let mut rx2 = base.create_shutdown_channel().await;
if let Some(tx) = base.shutdown_tx.write().await.take() {
tx.send(()).unwrap();
}
let result = rx2.try_recv();
assert!(result.is_ok());
}
#[tokio::test]
async fn test_stop_common_sends_shutdown_signal() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let mut rx = base.create_shutdown_channel().await;
let shutdown_received = Arc::new(AtomicBool::new(false));
let shutdown_flag = shutdown_received.clone();
let task = tokio::spawn(async move {
tokio::select! {
_ = &mut rx => {
shutdown_flag.store(true, Ordering::SeqCst);
}
}
});
base.set_processing_task(task).await;
let _ = base.stop_common().await;
assert!(
shutdown_received.load(Ordering::SeqCst),
"Processing task should have received shutdown signal"
);
}
#[tokio::test]
async fn test_graceful_shutdown_timing() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let rx = base.create_shutdown_channel().await;
let task = tokio::spawn(async move {
let mut shutdown_rx = rx;
loop {
tokio::select! {
biased;
_ = &mut shutdown_rx => {
break;
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {
}
}
}
});
base.set_processing_task(task).await;
let start = std::time::Instant::now();
let _ = base.stop_common().await;
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(500),
"Shutdown took {elapsed:?}, expected < 500ms. Task may not be responding to shutdown signal."
);
}
#[tokio::test]
async fn test_stop_common_without_shutdown_channel() {
let params = ReactionBaseParams::new("test-reaction", vec![]);
let base = ReactionBase::new(params);
let task = tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(10)).await;
});
base.set_processing_task(task).await;
let result = base.stop_common().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_get_id() {
let params = ReactionBaseParams::new("my-reaction-42", vec![]);
let base = ReactionBase::new(params);
assert_eq!(base.get_id(), "my-reaction-42");
}
#[tokio::test]
async fn test_get_queries() {
let queries = vec!["query-a".to_string(), "query-b".to_string(), "query-c".to_string()];
let params = ReactionBaseParams::new("r1", queries.clone());
let base = ReactionBase::new(params);
assert_eq!(base.get_queries(), &queries[..]);
}
#[tokio::test]
async fn test_get_queries_empty() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.get_queries().is_empty());
}
#[tokio::test]
async fn test_get_auto_start_default_true() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.get_auto_start());
}
#[tokio::test]
async fn test_get_auto_start_override_false() {
let params = ReactionBaseParams::new("r1", vec![]).with_auto_start(false);
let base = ReactionBase::new(params);
assert!(!base.get_auto_start());
}
#[tokio::test]
async fn test_context_none_before_initialize() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.context().await.is_none());
}
#[tokio::test]
async fn test_context_some_after_initialize() {
let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
let update_tx = graph.update_sender();
let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
base.initialize(context).await;
let ctx = base.context().await;
assert!(ctx.is_some());
assert_eq!(ctx.unwrap().reaction_id, "r1");
}
#[tokio::test]
async fn test_state_store_none_when_not_configured() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.state_store().await.is_none());
}
#[tokio::test]
async fn test_state_store_none_after_initialize_without_store() {
let (graph, _rx) = crate::component_graph::ComponentGraph::new("inst");
let update_tx = graph.update_sender();
let context = ReactionRuntimeContext::new("inst", "r1", None, update_tx, None);
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
base.initialize(context).await;
assert!(base.state_store().await.is_none());
}
#[tokio::test]
async fn test_identity_provider_none_by_default() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.identity_provider().await.is_none());
}
#[tokio::test]
async fn test_status_handle_returns_handle() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
let handle = base.status_handle();
assert_eq!(handle.get_status().await, ComponentStatus::Stopped);
handle.set_status(ComponentStatus::Running, None).await;
assert_eq!(base.get_status().await, ComponentStatus::Running);
}
#[tokio::test]
async fn test_deprovision_common_noop_without_state_store() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
let result = base.deprovision_common().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_set_processing_task_stores_handle() {
let params = ReactionBaseParams::new("r1", vec![]);
let base = ReactionBase::new(params);
assert!(base.processing_task.read().await.is_none());
let task = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(60)).await;
});
base.set_processing_task(task).await;
assert!(base.processing_task.read().await.is_some());
let task = base.processing_task.write().await.take();
if let Some(t) = task {
t.abort();
}
}
}