datafusion-dist 0.1.0

A distributed streaming execution library for Apache DataFusion
Documentation
use std::{collections::HashMap, sync::Arc};

use log::{debug, error};
use tokio::sync::{
    Mutex,
    mpsc::{Receiver, Sender},
};
use uuid::Uuid;

use crate::{
    DistResult,
    cluster::{DistCluster, NodeId},
    config::DistConfig,
    network::{DistNetwork, StageInfo},
    planner::StageId,
    runtime::StageState,
};

#[derive(Debug, Clone)]
pub enum Event {
    CheckJobCompleted(Uuid),
    CleanupJob(Uuid),
    ReceivedStage0Tasks(Vec<StageId>),
}

pub fn start_event_handler(mut handler: EventHandler) {
    tokio::spawn(async move {
        handler.start().await;
    });
}

pub struct EventHandler {
    pub local_node: NodeId,
    pub config: Arc<DistConfig>,
    pub cluster: Arc<dyn DistCluster>,
    pub network: Arc<dyn DistNetwork>,
    pub local_stages: Arc<Mutex<HashMap<StageId, Arc<StageState>>>>,
    pub sender: Sender<Event>,
    pub receiver: Receiver<Event>,
}

impl EventHandler {
    pub async fn start(&mut self) {
        while let Some(event) = self.receiver.recv().await {
            debug!("Received event: {event:?}");
            match event {
                Event::CheckJobCompleted(job_id) => {
                    self.handle_check_job_completed(job_id).await;
                }
                Event::CleanupJob(job_id) => {
                    self.handle_cleanup_job(job_id).await;
                }
                Event::ReceivedStage0Tasks(stage0_ids) => {
                    self.handle_received_stage0_tasks(stage0_ids).await;
                }
            }
        }
    }

    async fn handle_check_job_completed(&mut self, job_id: Uuid) {
        match check_job_completed(&self.cluster, &self.network, &self.local_stages, job_id).await {
            Ok(Some(true)) => {
                debug!("Job {job_id} completed, remove it from cluster");

                if let Err(e) = self.sender.send(Event::CleanupJob(job_id)).await {
                    error!("Failed to send cleanup job event for job {job_id}: {e}");
                }
            }
            Ok(_) => {}
            Err(err) => {
                error!("Failed to check job {job_id} completed: {err}");
            }
        }
    }

    async fn handle_cleanup_job(&mut self, job_id: Uuid) {
        if let Err(e) = cleanup_job(
            &self.local_node,
            &self.cluster,
            &self.network,
            &self.local_stages,
            job_id,
        )
        .await
        {
            error!("Failed to cleanup job {job_id}: {e}");
        }
    }

    async fn handle_received_stage0_tasks(&self, stage0_ids: Vec<StageId>) {
        let stage0_task_poll_timeout = self.config.stage0_task_poll_timeout;
        let local_stages = self.local_stages.clone();
        let sender = self.sender.clone();
        tokio::spawn(async move {
            tokio::time::sleep(stage0_task_poll_timeout).await;
            let stages_guard = local_stages.lock().await;
            for stage_id in stage0_ids {
                if let Some(stage) = stages_guard.get(&stage_id)
                    && stage.never_executed().await
                {
                    debug!("Found stage0 {stage_id} never polled until timeout");

                    if let Err(e) = sender.send(Event::CleanupJob(stage_id.job_id)).await {
                        error!(
                            "Failed to send CleanupJob event for job {}: {e}",
                            stage_id.job_id
                        );
                    }
                    break;
                }
            }
        });
    }
}

pub async fn check_job_completed(
    cluster: &Arc<dyn DistCluster>,
    network: &Arc<dyn DistNetwork>,
    local_stages: &Arc<Mutex<HashMap<StageId, Arc<StageState>>>>,
    job_id: Uuid,
) -> DistResult<Option<bool>> {
    // First, get local status
    let mut combined_status = local_job_status(local_stages, Some(job_id)).await;

    // Then, get status from all other alive nodes
    let node_states = cluster.alive_nodes().await?;

    let local_node_id = network.local_node();

    let mut handles = Vec::new();
    for node_id in node_states.keys() {
        if *node_id != local_node_id {
            let network = network.clone();
            let node_id = node_id.clone();
            let handle =
                tokio::spawn(async move { network.get_job_status(node_id, Some(job_id)).await });
            handles.push(handle);
        }
    }

    for handle in handles {
        let remote_status = handle.await??;
        for (stage_id, remote_stage_info) in remote_status {
            combined_status
                .entry(stage_id)
                .and_modify(|existing| {
                    existing
                        .assigned_partitions
                        .extend(&remote_stage_info.assigned_partitions);
                    existing
                        .task_set_infos
                        .extend(remote_stage_info.task_set_infos.clone());
                })
                .or_insert(remote_stage_info);
        }
    }

    let stage0 = StageId { job_id, stage: 0 };

    let Some(stage0_info) = combined_status.get(&stage0) else {
        return Ok(None);
    };

    // Check if all assigned partitions are completed
    for partition in &stage0_info.assigned_partitions {
        let is_completed = stage0_info
            .task_set_infos
            .iter()
            .any(|ts| ts.dropped_partitions.contains_key(partition));
        if !is_completed {
            return Ok(Some(false));
        }
    }

    Ok(Some(true))
}

pub async fn local_job_status(
    stages: &Arc<Mutex<HashMap<StageId, Arc<StageState>>>>,
    job_id: Option<Uuid>,
) -> HashMap<StageId, StageInfo> {
    let guard = stages.lock().await;

    let mut result = HashMap::new();
    for (stage_id, stage_state) in guard.iter() {
        if job_id.is_none() || stage_id.job_id == job_id.unwrap() {
            let stage_info = StageInfo::from_stage_state(stage_state).await;
            result.insert(*stage_id, stage_info);
        }
    }

    result
}

pub async fn cleanup_job(
    local_node: &NodeId,
    cluster: &Arc<dyn DistCluster>,
    network: &Arc<dyn DistNetwork>,
    local_stages: &Arc<Mutex<HashMap<StageId, Arc<StageState>>>>,
    job_id: Uuid,
) -> DistResult<()> {
    let alive_nodes = cluster.alive_nodes().await?;

    for node_id in alive_nodes.keys() {
        if node_id == local_node {
            let mut guard = local_stages.lock().await;
            guard.retain(|stage_id, _| stage_id.job_id != job_id);
            drop(guard);
        } else {
            // Send cleanup request to remote node
            network.cleanup_job(node_id.clone(), job_id).await?
        }
    }
    Ok(())
}