ora 0.12.7

Part of the Ora scheduler framework.
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
    time::{Duration, SystemTime},
};

use flume::{Receiver, Sender};
use tokio::{select, spawn, time::timeout};
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use uuid::Uuid;
use wgroup::WaitGuard;

use crate::{
    execution::ExecutionId,
    executor::implementation::{
        ExecutionContext, ExecutionFailedCb, ExecutorJobQueue, capabilities::send_capabilities,
        heartbeat::heartbeat_loop,
    },
    job::JobId,
    job_type::JobTypeId,
    proto::executors::v1::{
        ExecutionFailed, ExecutionReady, ExecutionSucceeded, executor_message::ExecutorMessageKind,
        server_message::ServerMessageKind,
    },
};

#[tracing::instrument(skip_all, fields(executor_id))]
pub(super) async fn executor_loop(
    incoming: Receiver<ServerMessageKind>,
    server: Sender<ExecutorMessageKind>,
    executor_name: Arc<str>,
    queues: Arc<[ExecutorJobQueue]>,
    server_cancellation_grace_period: Duration,
    on_execution_failed: Option<ExecutionFailedCb>,
    wg: WaitGuard,
) {
    let active_executions = ActiveExecutions::default();

    loop {
        let msg = select! {
                biased;

                _ = wg.waiting() => {
                    tracing::debug!("executor loop is shutting down");
                    if active_executions.is_empty() {
                        tracing::debug!("no active executions, shutting down immediately");
                        break;
                    }

                    active_executions.cancel_all();
                    break;
                }
                msg = incoming.recv_async() => {
                    match msg {
                        Ok(msg) => msg,
                        Err(_) => {
                            tracing::debug!("server channel closed, shutting down executor loop");
                            break;
                        }
                    }
            }
        };

        match msg {
            ServerMessageKind::Properties(executor_properties) => {
                tracing::Span::current().record("executor_id", &executor_properties.executor_id);

                let mut max_heartbeat_interval = Duration::try_from(
                    executor_properties
                        .max_heartbeat_interval
                        .unwrap_or_default(),
                )
                .unwrap_or_default();

                if max_heartbeat_interval.is_zero() {
                    max_heartbeat_interval = Duration::from_secs(1);
                    tracing::warn!("invalid heartbeat interval");
                }

                if let Err(error) = send_capabilities(&executor_name, &queues, &server).await {
                    tracing::error!(?error, "failed to send executor capabilities");
                    break;
                }

                spawn(heartbeat_loop(max_heartbeat_interval, server.clone()));

                tracing::info!("executor initialized");
            }
            ServerMessageKind::ExecutionReady(execution_ready) => {
                let Ok(execution_id) = execution_ready
                    .execution_id
                    .parse::<Uuid>()
                    .map(ExecutionId)
                else {
                    tracing::error!("invalid execution id");
                    _ = server
                        .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                            execution_id: execution_ready.execution_id,
                            timestamp: Some(SystemTime::now().into()),
                            failure_reason: "invalid execution ID".to_string(),
                        }))
                        .await;
                    continue;
                };

                spawn(
                    run_execution(
                        queues.clone(),
                        execution_ready,
                        server.clone(),
                        active_executions.add(execution_id),
                        server_cancellation_grace_period,
                        on_execution_failed.clone(),
                        wg.add_with("execution"),
                    )
                    .in_current_span(),
                );
            }
            ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
                let Ok(execution_id) = execution_cancelled
                    .execution_id
                    .parse::<Uuid>()
                    .map(ExecutionId)
                else {
                    tracing::error!("invalid execution id for cancellation");
                    continue;
                };

                active_executions.cancel(&execution_id);
            }
        }
    }
}

#[tracing::instrument(skip_all, fields(execution_id, job_id, job_type_id))]
async fn run_execution(
    queues: Arc<[ExecutorJobQueue]>,
    ready_execution: ExecutionReady,
    server: Sender<ExecutorMessageKind>,
    active_execution: ActiveExecutionGuard,
    server_cancellation_grace_period: Duration,
    on_execution_failed: Option<ExecutionFailedCb>,
    _wg: WaitGuard,
) {
    tracing::Span::current().record("execution_id", &ready_execution.execution_id);
    tracing::Span::current().record("job_id", &ready_execution.job_id);
    tracing::Span::current().record("job_type_id", &ready_execution.job_type_id);

    let Ok(execution_id) = ready_execution
        .execution_id
        .parse::<Uuid>()
        .map(ExecutionId)
    else {
        tracing::warn!("invalid execution id");
        _ = server
            .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                execution_id: ready_execution.execution_id,
                timestamp: Some(SystemTime::now().into()),
                failure_reason: "invalid execution ID".to_string(),
            }))
            .await;
        return;
    };

    let Ok(job_id) = ready_execution.job_id.parse().map(JobId) else {
        tracing::warn!("invalid job id");
        _ = server
            .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                execution_id: ready_execution.execution_id,
                timestamp: Some(SystemTime::now().into()),
                failure_reason: "invalid job ID".to_string(),
            }))
            .await;
        return;
    };

    let Ok(job_type_id) = JobTypeId::new(ready_execution.job_type_id) else {
        tracing::warn!("invalid job type id");
        _ = server
            .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                execution_id: ready_execution.execution_id,
                timestamp: Some(SystemTime::now().into()),
                failure_reason: "invalid job type ID".to_string(),
            }))
            .await;
        return;
    };

    let Some(target_execution_time) = ready_execution
        .target_execution_time
        .and_then(|t| t.try_into().ok())
    else {
        tracing::warn!("invalid target execution time");
        _ = server
            .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                execution_id: ready_execution.execution_id,
                timestamp: Some(SystemTime::now().into()),
                failure_reason: "invalid target execution time".to_string(),
            }))
            .await;
        return;
    };

    let Some(queue) = queues.iter().find(|q| q.job_type_id == job_type_id) else {
        tracing::warn!("job type not supported");
        _ = server
            .send_async(ExecutorMessageKind::ExecutionFailed(ExecutionFailed {
                execution_id: ready_execution.execution_id,
                timestamp: Some(SystemTime::now().into()),
                failure_reason: "job type not supported".to_string(),
            }))
            .await;
        return;
    };

    let ctx = ExecutionContext {
        execution_id,
        job_id,
        job_type_id,
        target_execution_time,
        attempt_number: ready_execution.attempt_number,
        cancellation_token: active_execution.cancellation_token.clone(),
    };

    let mut handler_fut = (queue.handler)(ctx.clone(), ready_execution.input_payload_json);

    tokio::select! {
        handler_result = &mut handler_fut => {
            if active_execution.cancellation_token.is_cancelled() {
                return;
            }

            match handler_result {
                Ok(output_payload_json) => {
                    _ = server
                        .send_async(ExecutorMessageKind::ExecutionSucceeded(
                            ExecutionSucceeded {
                                execution_id: ready_execution.execution_id,
                                timestamp: Some(SystemTime::now().into()),
                                output_payload_json,
                            },
                        ))
                        .await;
                }
                Err(error) => {
                    let failure_reason = format!("{error:?}");

                    if let Some(callback) = on_execution_failed {
                        callback(ctx, &failure_reason);
                    }

                    _ = server
                        .send_async(ExecutorMessageKind::ExecutionFailed(
                            ExecutionFailed {
                                execution_id: ready_execution.execution_id,
                                timestamp: Some(SystemTime::now().into()),
                                failure_reason,
                            },
                        ))
                        .await;
                }
            }
        }
        _ = active_execution.cancellation_token.cancelled() => {
            if timeout(server_cancellation_grace_period, handler_fut).await.is_err() {
                tracing::debug!("dropping cancelled execution");
            }
        },
    }
}

#[derive(Default)]
struct ActiveExecutions {
    executions: Arc<Mutex<HashMap<ExecutionId, CancellationToken>>>,
}

impl ActiveExecutions {
    fn add(&self, execution_id: ExecutionId) -> ActiveExecutionGuard {
        let cancellation_token = CancellationToken::new();
        let mut executions = self.executions.lock().unwrap();
        executions.insert(execution_id, cancellation_token.clone());
        ActiveExecutionGuard {
            executions: self.executions.clone(),
            this_execution_id: execution_id,
            cancellation_token,
        }
    }

    fn cancel(&self, execution_id: &ExecutionId) {
        let executions = self.executions.lock().unwrap();
        if let Some(token) = executions.get(execution_id) {
            token.cancel();
        }
    }

    fn is_empty(&self) -> bool {
        let executions = self.executions.lock().unwrap();
        executions.is_empty()
    }

    fn cancel_all(&self) {
        let executions = self.executions.lock().unwrap();
        for token in executions.values() {
            token.cancel();
        }
    }
}

struct ActiveExecutionGuard {
    executions: Arc<Mutex<HashMap<ExecutionId, CancellationToken>>>,
    this_execution_id: ExecutionId,
    cancellation_token: CancellationToken,
}

impl Drop for ActiveExecutionGuard {
    fn drop(&mut self) {
        let mut executions = self.executions.lock().unwrap();
        executions.remove(&self.this_execution_id);
    }
}