kapot_executor/
flight_service.rsuse arrow::ipc::reader::StreamReader;
use std::convert::TryFrom;
use std::fs::File;
use std::pin::Pin;
use arrow::ipc::CompressionType;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use kapot_core::error::KapotError;
use kapot_core::serde::decode_protobuf;
use kapot_core::serde::scheduler::Action as kapotAction;
use arrow::ipc::writer::IpcWriteOptions;
use arrow_flight::{
flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
PollInfo, PutResult, SchemaResult, Ticket,
};
use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch};
use futures::{Stream, StreamExt, TryStreamExt};
use log::{debug, info};
use std::io::{Read, Seek};
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::SendError;
use tokio::{sync::mpsc::Sender, task};
use tokio_stream::wrappers::ReceiverStream;
use tonic::metadata::MetadataValue;
use tonic::{Request, Response, Status, Streaming};
use tracing::warn;
#[derive(Clone)]
pub struct KapotFlightService {}
impl KapotFlightService {
pub fn new() -> Self {
Self {}
}
}
impl Default for KapotFlightService {
fn default() -> Self {
Self::new()
}
}
type BoxedFlightStream<T> =
Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
#[tonic::async_trait]
impl FlightService for KapotFlightService {
type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
type DoExchangeStream = BoxedFlightStream<FlightData>;
type DoGetStream = BoxedFlightStream<FlightData>;
type DoPutStream = BoxedFlightStream<PutResult>;
type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
type ListActionsStream = BoxedFlightStream<ActionType>;
type ListFlightsStream = BoxedFlightStream<FlightInfo>;
async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let ticket = request.into_inner();
let action =
decode_protobuf(&ticket.ticket).map_err(|e| from_kapot_err(&e))?;
match &action {
kapotAction::FetchPartition { path, .. } => {
debug!("FetchPartition reading {}", path);
let file = File::open(path)
.map_err(|e| {
KapotError::General(format!(
"Failed to open partition file at {path}: {e:?}"
))
})
.map(|file| std::io::BufReader::new(file))
.map_err(|e| from_kapot_err(&e))?;
let reader =
StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;
let (tx, rx) = channel(2);
let schema = reader.schema();
task::spawn_blocking(move || {
if let Err(e) = read_partition(reader, tx) {
warn!(error = %e, "error streaming shuffle partition");
}
});
let write_options: IpcWriteOptions = IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))
.map_err(|e| from_arrow_err(&e))?;
let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.with_options(write_options)
.build(ReceiverStream::new(rx))
.map_err(|err| Status::from_error(Box::new(err)));
Ok(Response::new(
Box::pin(flight_data_stream) as Self::DoGetStream
))
}
}
}
async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented("get_schema"))
}
async fn get_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented("get_flight_info"))
}
async fn handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
let token = uuid::Uuid::new_v4();
info!("do_handshake token={}", token);
let result = HandshakeResponse {
protocol_version: 0,
payload: token.as_bytes().to_vec().into(),
};
let result = Ok(result);
let output = futures::stream::iter(vec![result]);
let str = format!("Bearer {token}");
let mut resp: Response<
Pin<Box<dyn Stream<Item = Result<_, Status>> + Send + 'static>>,
> = Response::new(Box::pin(output));
let md = MetadataValue::try_from(str)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
resp.metadata_mut().insert("authorization", md);
Ok(resp)
}
async fn list_flights(
&self,
_request: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented("list_flights"))
}
async fn do_put(
&self,
request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let mut request = request.into_inner();
while let Some(data) = request.next().await {
let _data = data?;
}
Err(Status::unimplemented("do_put"))
}
async fn do_action(
&self,
request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
let action = request.into_inner();
let _action = decode_protobuf(&action.body).map_err(|e| from_kapot_err(&e))?;
Err(Status::unimplemented("do_action"))
}
async fn list_actions(
&self,
_request: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented("list_actions"))
}
async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented("do_exchange"))
}
async fn poll_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<PollInfo>, Status> {
Err(Status::unimplemented("poll_flight_info"))
}
}
fn read_partition<T>(
reader: StreamReader<std::io::BufReader<T>>,
tx: Sender<Result<RecordBatch, FlightError>>,
) -> Result<(), FlightError>
where
T: Read + Seek,
{
if tx.is_closed() {
return Err(FlightError::Tonic(Status::internal(
"Can't send a batch, channel is closed",
)));
}
for batch in reader {
tx.blocking_send(batch.map_err(|err| err.into()))
.map_err(|err| {
if let SendError(Err(err)) = err {
err
} else {
FlightError::Tonic(Status::internal(
"Can't send a batch, something went wrong",
))
}
})?
}
Ok(())
}
fn from_arrow_err(e: &ArrowError) -> Status {
Status::internal(format!("ArrowError: {e:?}"))
}
fn from_kapot_err(e: &kapot_core::error::KapotError) -> Status {
Status::internal(format!("kapot Error: {e:?}"))
}