kapot_executor/
flight_service.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Implementation of the Apache Arrow Flight protocol that wraps an executor.
19
20use arrow::ipc::reader::StreamReader;
21use std::convert::TryFrom;
22use std::fs::File;
23use std::pin::Pin;
24
25use arrow::ipc::CompressionType;
26use arrow_flight::encode::FlightDataEncoderBuilder;
27use arrow_flight::error::FlightError;
28use kapot_core::error::KapotError;
29use kapot_core::serde::decode_protobuf;
30use kapot_core::serde::scheduler::Action as kapotAction;
31
32use arrow::ipc::writer::IpcWriteOptions;
33use arrow_flight::{
34    flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
35    FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
36    PollInfo, PutResult, SchemaResult, Ticket,
37};
38use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch};
39use futures::{Stream, StreamExt, TryStreamExt};
40use log::{debug, info};
41use std::io::{Read, Seek};
42use tokio::sync::mpsc::channel;
43use tokio::sync::mpsc::error::SendError;
44use tokio::{sync::mpsc::Sender, task};
45use tokio_stream::wrappers::ReceiverStream;
46use tonic::metadata::MetadataValue;
47use tonic::{Request, Response, Status, Streaming};
48use tracing::warn;
49
50/// Service implementing the Apache Arrow Flight Protocol
51#[derive(Clone)]
52pub struct KapotFlightService {}
53
54impl KapotFlightService {
55    pub fn new() -> Self {
56        Self {}
57    }
58}
59
60impl Default for KapotFlightService {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66type BoxedFlightStream<T> =
67    Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
68
69#[tonic::async_trait]
70impl FlightService for KapotFlightService {
71    type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
72    type DoExchangeStream = BoxedFlightStream<FlightData>;
73    type DoGetStream = BoxedFlightStream<FlightData>;
74    type DoPutStream = BoxedFlightStream<PutResult>;
75    type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
76    type ListActionsStream = BoxedFlightStream<ActionType>;
77    type ListFlightsStream = BoxedFlightStream<FlightInfo>;
78
79    async fn do_get(
80        &self,
81        request: Request<Ticket>,
82    ) -> Result<Response<Self::DoGetStream>, Status> {
83        let ticket = request.into_inner();
84
85        let action =
86            decode_protobuf(&ticket.ticket).map_err(|e| from_kapot_err(&e))?;
87
88        match &action {
89            kapotAction::FetchPartition { path, .. } => {
90                debug!("FetchPartition reading {}", path);
91                let file = File::open(path)
92                    .map_err(|e| {
93                        KapotError::General(format!(
94                            "Failed to open partition file at {path}: {e:?}"
95                        ))
96                    })
97                    .map(|file| std::io::BufReader::new(file))
98                    .map_err(|e| from_kapot_err(&e))?;
99                let reader =
100                    StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;
101
102                let (tx, rx) = channel(2);
103                let schema = reader.schema();
104                task::spawn_blocking(move || {
105                    if let Err(e) = read_partition(reader, tx) {
106                        warn!(error = %e, "error streaming shuffle partition");
107                    }
108                });
109
110                let write_options: IpcWriteOptions = IpcWriteOptions::default()
111                    .try_with_compression(Some(CompressionType::LZ4_FRAME))
112                    .map_err(|e| from_arrow_err(&e))?;
113                let flight_data_stream = FlightDataEncoderBuilder::new()
114                    .with_schema(schema)
115                    .with_options(write_options)
116                    .build(ReceiverStream::new(rx))
117                    .map_err(|err| Status::from_error(Box::new(err)));
118
119                Ok(Response::new(
120                    Box::pin(flight_data_stream) as Self::DoGetStream
121                ))
122            }
123        }
124    }
125
126    async fn get_schema(
127        &self,
128        _request: Request<FlightDescriptor>,
129    ) -> Result<Response<SchemaResult>, Status> {
130        Err(Status::unimplemented("get_schema"))
131    }
132
133    async fn get_flight_info(
134        &self,
135        _request: Request<FlightDescriptor>,
136    ) -> Result<Response<FlightInfo>, Status> {
137        Err(Status::unimplemented("get_flight_info"))
138    }
139
140    async fn handshake(
141        &self,
142        _request: Request<Streaming<HandshakeRequest>>,
143    ) -> Result<Response<Self::HandshakeStream>, Status> {
144        let token = uuid::Uuid::new_v4();
145        info!("do_handshake token={}", token);
146
147        let result = HandshakeResponse {
148            protocol_version: 0,
149            payload: token.as_bytes().to_vec().into(),
150        };
151        let result = Ok(result);
152        let output = futures::stream::iter(vec![result]);
153        let str = format!("Bearer {token}");
154        let mut resp: Response<
155            Pin<Box<dyn Stream<Item = Result<_, Status>> + Send + 'static>>,
156        > = Response::new(Box::pin(output));
157        let md = MetadataValue::try_from(str)
158            .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
159        resp.metadata_mut().insert("authorization", md);
160        Ok(resp)
161    }
162
163    async fn list_flights(
164        &self,
165        _request: Request<Criteria>,
166    ) -> Result<Response<Self::ListFlightsStream>, Status> {
167        Err(Status::unimplemented("list_flights"))
168    }
169
170    async fn do_put(
171        &self,
172        request: Request<Streaming<FlightData>>,
173    ) -> Result<Response<Self::DoPutStream>, Status> {
174        let mut request = request.into_inner();
175
176        while let Some(data) = request.next().await {
177            let _data = data?;
178        }
179
180        Err(Status::unimplemented("do_put"))
181    }
182
183    async fn do_action(
184        &self,
185        request: Request<Action>,
186    ) -> Result<Response<Self::DoActionStream>, Status> {
187        let action = request.into_inner();
188
189        let _action = decode_protobuf(&action.body).map_err(|e| from_kapot_err(&e))?;
190
191        Err(Status::unimplemented("do_action"))
192    }
193
194    async fn list_actions(
195        &self,
196        _request: Request<Empty>,
197    ) -> Result<Response<Self::ListActionsStream>, Status> {
198        Err(Status::unimplemented("list_actions"))
199    }
200
201    async fn do_exchange(
202        &self,
203        _request: Request<Streaming<FlightData>>,
204    ) -> Result<Response<Self::DoExchangeStream>, Status> {
205        Err(Status::unimplemented("do_exchange"))
206    }
207
208    async fn poll_flight_info(
209        &self,
210        _request: Request<FlightDescriptor>,
211    ) -> Result<Response<PollInfo>, Status> {
212        Err(Status::unimplemented("poll_flight_info"))
213    }
214}
215
216fn read_partition<T>(
217    reader: StreamReader<std::io::BufReader<T>>,
218    tx: Sender<Result<RecordBatch, FlightError>>,
219) -> Result<(), FlightError>
220where
221    T: Read + Seek,
222{
223    if tx.is_closed() {
224        return Err(FlightError::Tonic(Status::internal(
225            "Can't send a batch, channel is closed",
226        )));
227    }
228
229    for batch in reader {
230        tx.blocking_send(batch.map_err(|err| err.into()))
231            .map_err(|err| {
232                if let SendError(Err(err)) = err {
233                    err
234                } else {
235                    FlightError::Tonic(Status::internal(
236                        "Can't send a batch, something went wrong",
237                    ))
238                }
239            })?
240    }
241    Ok(())
242}
243
244fn from_arrow_err(e: &ArrowError) -> Status {
245    Status::internal(format!("ArrowError: {e:?}"))
246}
247
248fn from_kapot_err(e: &kapot_core::error::KapotError) -> Status {
249    Status::internal(format!("kapot Error: {e:?}"))
250}