datafusion-dist-network-tonic 0.3.0

A Tonic network implementation for datafusion-dist
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, OnceLock},
    time::Duration,
};

use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use datafusion_dist::{
    DistError, DistResult, RecordBatchStream,
    cluster::NodeId,
    network::{DistNetwork, ScheduledTasks, StageInfo},
    planner::{StageId, TaskId},
    util::get_local_ip,
};
use datafusion_physical_plan::ExecutionPlan;
use datafusion_proto::{
    physical_plan::{AsExecutionPlan, ComposedPhysicalExtensionCodec, PhysicalExtensionCodec},
    protobuf::PhysicalPlanNode,
};
use futures::{StreamExt, TryStreamExt};
use tonic::transport::{Channel, Endpoint};
use uuid::Uuid;

use crate::{
    codec::DistPhysicalExtensionEncoder,
    protobuf::{self, SendTasksReq, StagePlan, dist_tonic_service_client::DistTonicServiceClient},
    serde::{
        parse_stage_id, parse_stage_info, serialize_stage_id, serialize_task_distribution,
        serialize_task_id,
    },
};

#[derive(Debug)]
pub struct DistTonicNetwork {
    pub port: u16,
    pub composed_extension_codec: Arc<dyn PhysicalExtensionCodec>,
}

impl DistTonicNetwork {
    pub fn new(port: u16, app_extension_codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
        let composed_extension_codec = Arc::new(ComposedPhysicalExtensionCodec::new(vec![
            app_extension_codec.clone(),
            Arc::new(DistPhysicalExtensionEncoder {
                app_extension_codec,
            }),
        ]));
        Self {
            port,
            composed_extension_codec,
        }
    }
}

impl DistTonicNetwork {
    fn serialize_scheduled_tasks(
        &self,
        scheduled_tasks: ScheduledTasks,
    ) -> DistResult<SendTasksReq> {
        let mut proto_stage_plans = Vec::new();
        for (stage_id, plan) in scheduled_tasks.stage_plans {
            proto_stage_plans.push(self.serialize_stage_plan(stage_id, plan)?);
        }
        let proto_task_ids = scheduled_tasks
            .task_ids
            .into_iter()
            .map(serialize_task_id)
            .collect::<Vec<_>>();

        Ok(SendTasksReq {
            stage_plans: proto_stage_plans,
            tasks: proto_task_ids,
            job_task_distribution: Some(serialize_task_distribution(
                &scheduled_tasks.job_task_distribution,
            )),
        })
    }

    fn serialize_stage_plan(
        &self,
        stage_id: StageId,
        plan: Arc<dyn ExecutionPlan>,
    ) -> DistResult<StagePlan> {
        let proto_stage_id = serialize_stage_id(stage_id);
        let mut plan_buf: Vec<u8> = vec![];
        let plan_proto =
            PhysicalPlanNode::try_from_physical_plan(plan, self.composed_extension_codec.as_ref())?;
        plan_proto.try_encode(&mut plan_buf)?;
        Ok(StagePlan {
            stage_id: Some(proto_stage_id),
            plan: plan_buf,
        })
    }
}

#[async_trait::async_trait]
impl DistNetwork for DistTonicNetwork {
    fn local_node(&self) -> NodeId {
        static LOCAL_IP: OnceLock<String> = OnceLock::new();
        NodeId {
            host: LOCAL_IP.get_or_init(get_local_ip).clone(),
            port: self.port,
        }
    }

    async fn send_tasks(&self, node_id: NodeId, scheduled_tasks: ScheduledTasks) -> DistResult<()> {
        let mut tonic_client = build_tonic_client(node_id).await?;
        let send_tasks_req = self.serialize_scheduled_tasks(scheduled_tasks)?;
        tonic_client
            .send_tasks(send_tasks_req)
            .await
            .map_err(|e| DistError::network(Box::new(e)))?;
        Ok(())
    }

    async fn execute_task(
        &self,
        node_id: NodeId,
        task_id: TaskId,
    ) -> DistResult<RecordBatchStream> {
        let mut tonic_client = build_tonic_client(node_id).await?;
        let response = tonic_client
            .execute_task(serialize_task_id(task_id))
            .await
            .map_err(|e| DistError::network(Box::new(e)))?;

        // Get the FlightData stream from the response
        let flight_data_stream = response.into_inner().map_err(FlightError::from);
        let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(flight_data_stream)
            .map_err(|e| DistError::network(Box::new(e)))
            .boxed();

        Ok(record_batch_stream)
    }

    async fn get_job_status(
        &self,
        node_id: NodeId,
        job_id: Option<Uuid>,
    ) -> DistResult<HashMap<StageId, StageInfo>> {
        let mut tonic_client = build_tonic_client(node_id).await?;

        let req = protobuf::GetJobStatusReq {
            job_id: job_id.map(|id| id.to_string()),
        };

        let resp = tonic_client
            .get_job_status(req)
            .await
            .map_err(|e| DistError::network(Box::new(e)))?
            .into_inner();

        let mut result = HashMap::new();
        for proto_stage_info in resp.stage_infos {
            let stage_id = parse_stage_id(
                proto_stage_info
                    .stage_id
                    .clone()
                    .ok_or_else(|| DistError::internal("Missing stage_id in StageInfo"))?,
            );
            let stage_info = parse_stage_info(proto_stage_info);
            result.insert(stage_id, stage_info);
        }

        Ok(result)
    }

    async fn cleanup_job(&self, node_id: NodeId, job_id: Uuid) -> DistResult<()> {
        let mut tonic_client = build_tonic_client(node_id).await?;

        let req = protobuf::CleanupJobReq {
            job_id: job_id.to_string(),
        };

        tonic_client
            .cleanup_job(req)
            .await
            .map_err(|e| DistError::network(Box::new(e)))?;

        Ok(())
    }
}

async fn build_tonic_channel(node_id: NodeId) -> DistResult<Channel> {
    let addr = format!("http://{}:{}", node_id.host, node_id.port);
    let endpoint = Endpoint::from_shared(addr)
        .map_err(|e| DistError::network(Box::new(e)))?
        .connect_timeout(Duration::from_secs(20));
    let channel = endpoint
        .connect()
        .await
        .map_err(|e| DistError::network(Box::new(e)))?;
    Ok(channel)
}

async fn build_tonic_client(node_id: NodeId) -> DistResult<DistTonicServiceClient<Channel>> {
    let channel = build_tonic_channel(node_id).await?;
    Ok(DistTonicServiceClient::new(channel)
        .max_encoding_message_size(usize::MAX)
        .max_decoding_message_size(usize::MAX))
}