use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use tasker_shared::messaging::service::MessagingProvider;
use tasker_shared::{system_context::SystemContext, TaskerResult};
use crate::orchestration::channels::OrchestrationCommandSender;
use crate::orchestration::commands::OrchestrationCommand;
#[derive(Debug, Clone)]
pub struct OrchestrationPollerConfig {
pub enabled: bool,
pub polling_interval: Duration,
pub batch_size: u32,
pub age_threshold: Duration,
pub max_age: Duration,
pub monitored_queues: Vec<String>,
pub namespace: String,
pub visibility_timeout: Duration,
}
impl Default for OrchestrationPollerConfig {
fn default() -> Self {
Self {
enabled: true,
polling_interval: Duration::from_secs(30), batch_size: 50,
age_threshold: Duration::from_secs(5), max_age: Duration::from_secs(24 * 60 * 60), monitored_queues: vec![
"orchestration_step_results".to_string(),
"orchestration_task_requests".to_string(),
],
namespace: "orchestration".to_string(),
visibility_timeout: Duration::from_secs(30),
}
}
}
pub struct OrchestrationFallbackPoller {
poller_id: Uuid,
config: OrchestrationPollerConfig,
context: Arc<SystemContext>,
command_sender: OrchestrationCommandSender,
is_running: AtomicBool,
stats: OrchestrationPollerStats,
}
impl std::fmt::Debug for OrchestrationFallbackPoller {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OrchestrationFallbackPoller")
.field("poller_id", &self.poller_id)
.field("config", &self.config)
.field(
"provider",
&self.context.messaging_provider().provider_name(),
)
.field(
"is_running",
&self.is_running.load(std::sync::atomic::Ordering::Relaxed),
)
.finish()
}
}
#[derive(Debug, Default)]
pub struct OrchestrationPollerStats {
pub polling_cycles: AtomicU64,
pub messages_processed: AtomicU64,
pub step_results_processed: AtomicU64,
pub task_requests_processed: AtomicU64,
pub messages_skipped: AtomicU64,
pub polling_errors: AtomicU64,
pub last_poll_at: Arc<tokio::sync::Mutex<Option<Instant>>>,
pub started_at: Arc<tokio::sync::Mutex<Option<Instant>>>,
}
impl OrchestrationFallbackPoller {
pub async fn new(
config: OrchestrationPollerConfig,
context: Arc<SystemContext>,
command_sender: OrchestrationCommandSender,
) -> TaskerResult<Self> {
let poller_id = Uuid::new_v4();
info!(
poller_id = %poller_id,
namespace = %config.namespace,
monitored_queues = ?config.monitored_queues,
polling_interval = ?config.polling_interval,
provider = %context.messaging_provider().provider_name(),
"Creating OrchestrationFallbackPoller"
);
Ok(Self {
poller_id,
config,
context,
command_sender,
is_running: AtomicBool::new(false),
stats: OrchestrationPollerStats::default(),
})
}
pub async fn start(&self) -> TaskerResult<()> {
if !self.config.enabled {
info!(
poller_id = %self.poller_id,
"OrchestrationFallbackPoller disabled by configuration"
);
return Ok(());
}
info!(
poller_id = %self.poller_id,
"Starting OrchestrationFallbackPoller"
);
self.is_running.store(true, Ordering::Relaxed);
*self.stats.started_at.lock().await = Some(Instant::now());
self.start_polling_loop().await?;
info!(
poller_id = %self.poller_id,
"OrchestrationFallbackPoller started successfully"
);
Ok(())
}
pub async fn stop(&self) -> TaskerResult<()> {
info!(
poller_id = %self.poller_id,
"Stopping OrchestrationFallbackPoller"
);
self.is_running.store(false, Ordering::Relaxed);
info!(
poller_id = %self.poller_id,
"OrchestrationFallbackPoller stopped successfully"
);
Ok(())
}
pub fn is_healthy(&self) -> bool {
self.is_running.load(Ordering::Relaxed)
}
pub async fn stats(&self) -> OrchestrationPollerStats {
OrchestrationPollerStats {
polling_cycles: AtomicU64::new(self.stats.polling_cycles.load(Ordering::Relaxed)),
messages_processed: AtomicU64::new(
self.stats.messages_processed.load(Ordering::Relaxed),
),
step_results_processed: AtomicU64::new(
self.stats.step_results_processed.load(Ordering::Relaxed),
),
task_requests_processed: AtomicU64::new(
self.stats.task_requests_processed.load(Ordering::Relaxed),
),
messages_skipped: AtomicU64::new(self.stats.messages_skipped.load(Ordering::Relaxed)),
polling_errors: AtomicU64::new(self.stats.polling_errors.load(Ordering::Relaxed)),
last_poll_at: Arc::new(tokio::sync::Mutex::new(
*self.stats.last_poll_at.lock().await,
)),
started_at: Arc::new(tokio::sync::Mutex::new(*self.stats.started_at.lock().await)),
}
}
async fn start_polling_loop(&self) -> TaskerResult<()> {
let config = self.config.clone();
let command_sender = self.command_sender.clone();
let poller_id = self.poller_id;
let messaging_provider = self.context.messaging_provider().clone();
let stats = OrchestrationPollerStatsRef {
polling_cycles: Arc::new(AtomicU64::new(0)), messages_processed: Arc::new(AtomicU64::new(0)), step_results_processed: Arc::new(AtomicU64::new(0)), task_requests_processed: Arc::new(AtomicU64::new(0)), polling_errors: Arc::new(AtomicU64::new(0)), last_poll_at: self.stats.last_poll_at.clone(),
};
let queue_config = self.context.tasker_config.common.queues.clone();
let classifier = tasker_shared::config::QueueClassifier::from_queues_config(&queue_config);
tokio::spawn(async move {
info!(
poller_id = %poller_id,
interval_ms = %config.polling_interval.as_millis(),
"Starting fallback polling for orchestration coordination reliability"
);
let mut interval = tokio::time::interval(config.polling_interval);
let monitored_queues = vec![
classifier.step_results_queue_name().to_string(),
classifier.task_requests_queue_name().to_string(),
classifier.task_finalizations_queue_name().to_string(),
];
loop {
interval.tick().await;
stats.polling_cycles.fetch_add(1, Ordering::Relaxed);
*stats.last_poll_at.lock().await = Some(Instant::now());
for queue_name in &monitored_queues {
Self::poll_queue_for_messages(
&messaging_provider,
&command_sender,
queue_name,
&config,
poller_id,
&classifier,
&stats,
)
.await;
}
}
});
Ok(())
}
async fn poll_queue_for_messages(
messaging_provider: &MessagingProvider,
command_sender: &OrchestrationCommandSender,
queue_name: &str,
config: &OrchestrationPollerConfig,
poller_id: Uuid,
classifier: &tasker_shared::config::QueueClassifier,
stats: &OrchestrationPollerStatsRef,
) {
debug!(
poller_id = %poller_id,
queue = %queue_name,
"Performing fallback polling check"
);
let messages = match messaging_provider
.receive_messages::<serde_json::Value>(
queue_name,
config.batch_size as usize,
config.visibility_timeout,
)
.await
{
Ok(msgs) => msgs,
Err(e) => {
error!(
poller_id = %poller_id,
queue = %queue_name,
error = %e,
"Failed to read messages from fallback polling"
);
stats.polling_errors.fetch_add(1, Ordering::Relaxed);
return;
}
};
if messages.is_empty() {
debug!(
poller_id = %poller_id,
queue = %queue_name,
"No messages found in fallback polling"
);
return;
}
debug!(
poller_id = %poller_id,
queue = %queue_name,
count = messages.len(),
"Read messages from fallback polling"
);
let queue_type = classifier.classify(queue_name);
for queued_message in messages {
let command_result = match &queue_type {
tasker_shared::config::QueueType::StepResults => {
stats.step_results_processed.fetch_add(1, Ordering::Relaxed);
let (resp_tx, _resp_rx) = tokio::sync::oneshot::channel();
command_sender
.send(OrchestrationCommand::ProcessStepResultFromMessage {
message: queued_message,
resp: resp_tx,
})
.await
}
tasker_shared::config::QueueType::TaskRequests => {
stats
.task_requests_processed
.fetch_add(1, Ordering::Relaxed);
let (resp_tx, _resp_rx) = tokio::sync::oneshot::channel();
command_sender
.send(OrchestrationCommand::InitializeTaskFromMessage {
message: queued_message,
resp: resp_tx,
})
.await
}
tasker_shared::config::QueueType::TaskFinalizations => {
let (resp_tx, _resp_rx) = tokio::sync::oneshot::channel();
command_sender
.send(OrchestrationCommand::FinalizeTaskFromMessage {
message: queued_message,
resp: resp_tx,
})
.await
}
tasker_shared::config::QueueType::WorkerNamespace(namespace) => {
debug!(
poller_id = %poller_id,
queue = %queue_name,
namespace = %namespace,
"Worker namespace message received in orchestration fallback polling"
);
continue;
}
tasker_shared::config::QueueType::Unknown => {
warn!(
poller_id = %poller_id,
queue = %queue_name,
"Unknown queue type in fallback polling",
);
continue;
}
};
if let Err(e) = command_result {
warn!(
poller_id = %poller_id,
queue = %queue_name,
error = %e,
"Failed to send command from fallback polling"
);
stats.polling_errors.fetch_add(1, Ordering::Relaxed);
} else {
stats.messages_processed.fetch_add(1, Ordering::Relaxed);
}
}
}
}
struct OrchestrationPollerStatsRef {
polling_cycles: Arc<AtomicU64>,
messages_processed: Arc<AtomicU64>,
step_results_processed: Arc<AtomicU64>,
task_requests_processed: Arc<AtomicU64>,
polling_errors: Arc<AtomicU64>,
last_poll_at: Arc<tokio::sync::Mutex<Option<Instant>>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_orchestration_poller_config_default() {
let config = OrchestrationPollerConfig::default();
assert!(config.enabled);
assert_eq!(config.polling_interval, Duration::from_secs(30));
assert_eq!(config.batch_size, 50);
assert_eq!(config.age_threshold, Duration::from_secs(5));
assert_eq!(config.max_age, Duration::from_secs(24 * 60 * 60));
assert_eq!(config.monitored_queues.len(), 2);
assert!(config
.monitored_queues
.contains(&"orchestration_step_results".to_string()));
assert!(config
.monitored_queues
.contains(&"orchestration_task_requests".to_string()));
assert_eq!(config.namespace, "orchestration");
assert_eq!(config.visibility_timeout, Duration::from_secs(30));
}
#[test]
fn test_orchestration_poller_config_custom() {
let config = OrchestrationPollerConfig {
enabled: false,
polling_interval: Duration::from_secs(60),
batch_size: 100,
age_threshold: Duration::from_secs(10),
max_age: Duration::from_secs(3600),
monitored_queues: vec!["custom_queue".to_string()],
namespace: "custom".to_string(),
visibility_timeout: Duration::from_secs(60),
};
assert!(!config.enabled);
assert_eq!(config.polling_interval, Duration::from_secs(60));
assert_eq!(config.batch_size, 100);
assert_eq!(config.monitored_queues.len(), 1);
assert_eq!(config.namespace, "custom");
}
#[test]
fn test_orchestration_poller_config_clone() {
let config = OrchestrationPollerConfig::default();
let cloned = config.clone();
assert_eq!(cloned.enabled, config.enabled);
assert_eq!(cloned.polling_interval, config.polling_interval);
assert_eq!(cloned.batch_size, config.batch_size);
assert_eq!(cloned.monitored_queues, config.monitored_queues);
}
#[test]
fn test_orchestration_poller_config_debug() {
let config = OrchestrationPollerConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("OrchestrationPollerConfig"));
assert!(debug_str.contains("orchestration"));
}
#[test]
fn test_orchestration_poller_stats_default() {
let stats = OrchestrationPollerStats::default();
assert_eq!(stats.polling_cycles.load(Ordering::Relaxed), 0);
assert_eq!(stats.messages_processed.load(Ordering::Relaxed), 0);
assert_eq!(stats.step_results_processed.load(Ordering::Relaxed), 0);
assert_eq!(stats.task_requests_processed.load(Ordering::Relaxed), 0);
assert_eq!(stats.messages_skipped.load(Ordering::Relaxed), 0);
assert_eq!(stats.polling_errors.load(Ordering::Relaxed), 0);
}
#[test]
fn test_orchestration_poller_stats_increment() {
let stats = OrchestrationPollerStats::default();
stats.polling_cycles.fetch_add(3, Ordering::Relaxed);
stats.messages_processed.fetch_add(10, Ordering::Relaxed);
stats.step_results_processed.fetch_add(7, Ordering::Relaxed);
stats
.task_requests_processed
.fetch_add(3, Ordering::Relaxed);
stats.messages_skipped.fetch_add(2, Ordering::Relaxed);
stats.polling_errors.fetch_add(1, Ordering::Relaxed);
assert_eq!(stats.polling_cycles.load(Ordering::Relaxed), 3);
assert_eq!(stats.messages_processed.load(Ordering::Relaxed), 10);
assert_eq!(stats.step_results_processed.load(Ordering::Relaxed), 7);
assert_eq!(stats.task_requests_processed.load(Ordering::Relaxed), 3);
assert_eq!(stats.messages_skipped.load(Ordering::Relaxed), 2);
assert_eq!(stats.polling_errors.load(Ordering::Relaxed), 1);
}
#[test]
fn test_orchestration_poller_stats_debug() {
let stats = OrchestrationPollerStats::default();
let debug_str = format!("{:?}", stats);
assert!(debug_str.contains("OrchestrationPollerStats"));
}
}