use futures::FutureExt;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use tasker_shared::messaging::client::MessageClient;
use tasker_shared::monitoring::ChannelMonitor;
use tasker_shared::system_context::SystemContext;
use tasker_shared::{TaskerError, TaskerResult};
use crate::actors::ActorRegistry;
use crate::health::caches::HealthStatusCaches;
use crate::orchestration::channels::{
ChannelFactory, OrchestrationCommandReceiver, OrchestrationCommandSender,
};
use crate::orchestration::commands::CommandProcessingService;
use crate::orchestration::commands::{
AtomicProcessingStats, CommandResponder, OrchestrationCommand,
};
#[derive(Debug)]
pub struct OrchestrationCommandProcessorActor {
context: Arc<SystemContext>,
actors: Arc<ActorRegistry>,
message_client: Arc<MessageClient>,
health_caches: HealthStatusCaches,
command_rx: Option<OrchestrationCommandReceiver>,
task_handle: Option<JoinHandle<()>>,
stats: Arc<AtomicProcessingStats>,
channel_monitor: ChannelMonitor,
}
impl OrchestrationCommandProcessorActor {
pub fn new(
context: Arc<SystemContext>,
actors: Arc<ActorRegistry>,
message_client: Arc<MessageClient>,
health_caches: HealthStatusCaches,
buffer_size: usize,
channel_monitor: ChannelMonitor,
) -> (Self, OrchestrationCommandSender) {
let (command_tx, command_rx) = ChannelFactory::orchestration_command_channel(buffer_size);
let stats = Arc::new(AtomicProcessingStats::default());
info!(
channel = %channel_monitor.channel_name(),
buffer_size = buffer_size,
"Creating OrchestrationCommandProcessorActor with channel monitoring"
);
let actor = Self {
context,
actors,
message_client,
health_caches,
command_rx: Some(command_rx),
task_handle: None,
stats,
channel_monitor,
};
(actor, command_tx)
}
pub async fn start(&mut self) -> TaskerResult<()> {
let context = self.context.clone();
let actors = self.actors.clone();
let stats = self.stats.clone();
let message_client = self.message_client.clone();
let health_caches = self.health_caches.clone();
let channel_monitor = self.channel_monitor.clone();
let mut command_rx = self.command_rx.take().ok_or_else(|| {
TaskerError::OrchestrationError("Processor already started".to_string())
})?;
let handle = tasker_shared::spawn_named!("orchestration_command_processor", async move {
let handler =
CommandHandler::new(context, actors, stats, message_client, health_caches);
while let Some(command) = command_rx.recv().await {
channel_monitor.record_receive();
if let Err(panic_payload) = AssertUnwindSafe(handler.process_command(command))
.catch_unwind()
.await
{
error!(
panic_message = %panic_message(&panic_payload),
"Command processor caught panic during command processing, continuing"
);
}
}
});
self.task_handle = Some(handle);
Ok(())
}
}
#[derive(Debug)]
struct CommandHandler {
service: CommandProcessingService,
stats: Arc<AtomicProcessingStats>,
}
impl CommandHandler {
fn new(
context: Arc<SystemContext>,
actors: Arc<ActorRegistry>,
stats: Arc<AtomicProcessingStats>,
message_client: Arc<MessageClient>,
health_caches: HealthStatusCaches,
) -> Self {
let service = CommandProcessingService::new(context, actors, message_client, health_caches);
Self { service, stats }
}
pub async fn process_command(&self, command: OrchestrationCommand) {
match command {
OrchestrationCommand::InitializeTask { request, resp } => {
self.execute_with_stats(
self.service.initialize_task(request),
|stats| &stats.task_requests_processed,
resp,
)
.await;
}
OrchestrationCommand::ProcessStepResult { result, resp } => {
self.execute_with_stats(
self.service.process_step_result(result),
|stats| &stats.step_results_processed,
resp,
)
.await;
}
OrchestrationCommand::FinalizeTask { task_uuid, resp } => {
self.execute_with_stats(
self.service.finalize_task(task_uuid),
|stats| &stats.tasks_finalized,
resp,
)
.await;
}
OrchestrationCommand::ProcessStepResultFromMessageEvent {
message_event,
resp,
} => {
self.execute_with_stats(
self.service.step_result_from_message_event(message_event),
|stats| &stats.step_results_processed,
resp,
)
.await;
}
OrchestrationCommand::InitializeTaskFromMessageEvent {
message_event,
resp,
} => {
self.execute_with_stats(
self.service
.task_initialize_from_message_event(message_event),
|stats| &stats.task_requests_processed,
resp,
)
.await;
}
OrchestrationCommand::FinalizeTaskFromMessageEvent {
message_event,
resp,
} => {
self.execute_with_stats(
self.service.task_finalize_from_message_event(message_event),
|stats| &stats.tasks_finalized,
resp,
)
.await;
}
OrchestrationCommand::ProcessStepResultFromMessage { message, resp } => {
let queue_name = message.queue_name();
debug!(
handle = ?message.handle,
queue = %queue_name,
"Starting ProcessStepResultFromMessage"
);
self.execute_with_stats(
self.service.step_result_from_message(message),
|stats| &stats.step_results_processed,
resp,
)
.await;
}
OrchestrationCommand::InitializeTaskFromMessage { message, resp } => {
self.execute_with_stats(
self.service.task_initialize_from_message(message),
|stats| &stats.task_requests_processed,
resp,
)
.await;
}
OrchestrationCommand::FinalizeTaskFromMessage { message, resp } => {
self.execute_with_stats(
self.service.task_finalize_from_message(message),
|stats| &stats.tasks_finalized,
resp,
)
.await;
}
OrchestrationCommand::GetProcessingStats { resp } => {
let stats_snapshot = self.stats.snapshot();
if resp.send(Ok(stats_snapshot)).is_err() {
error!("GetProcessingStats response channel closed - receiver dropped");
}
}
OrchestrationCommand::HealthCheck { resp } => {
let result = self.service.health_check().await;
if resp.send(result).is_err() {
error!("HealthCheck response channel closed - receiver dropped");
}
}
OrchestrationCommand::Shutdown { resp } => {
if resp.send(Ok(())).is_err() {
error!("Shutdown response channel closed - receiver dropped");
}
}
}
}
async fn execute_with_stats<T, Fut>(
&self,
handler: Fut,
stat_selector: impl FnOnce(&AtomicProcessingStats) -> &AtomicU64,
resp: CommandResponder<T>,
) where
Fut: Future<Output = TaskerResult<T>>,
T: std::fmt::Debug,
{
let result = handler.await;
let was_success = result.is_ok();
if was_success {
stat_selector(&self.stats).fetch_add(1, Ordering::Relaxed);
} else {
self.stats.processing_errors.fetch_add(1, Ordering::Relaxed);
}
if resp.send(result).is_err() {
debug!(
was_success = was_success,
"Command response channel closed - receiver dropped (fire-and-forget caller)"
);
}
}
}
fn panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"non-string panic payload".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn panic_message_extracts_str() {
let payload: Box<dyn std::any::Any + Send> = Box::new("test panic");
assert_eq!(panic_message(&payload), "test panic");
}
#[test]
fn panic_message_extracts_string() {
let payload: Box<dyn std::any::Any + Send> = Box::new("owned panic".to_string());
assert_eq!(panic_message(&payload), "owned panic");
}
#[test]
fn panic_message_handles_other_types() {
let payload: Box<dyn std::any::Any + Send> = Box::new(42i32);
assert_eq!(panic_message(&payload), "non-string panic payload");
}
}