kapot_executor/
flight_service.rs1use arrow::ipc::reader::StreamReader;
21use std::convert::TryFrom;
22use std::fs::File;
23use std::pin::Pin;
24
25use arrow::ipc::CompressionType;
26use arrow_flight::encode::FlightDataEncoderBuilder;
27use arrow_flight::error::FlightError;
28use kapot_core::error::KapotError;
29use kapot_core::serde::decode_protobuf;
30use kapot_core::serde::scheduler::Action as kapotAction;
31
32use arrow::ipc::writer::IpcWriteOptions;
33use arrow_flight::{
34 flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
35 FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
36 PollInfo, PutResult, SchemaResult, Ticket,
37};
38use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch};
39use futures::{Stream, StreamExt, TryStreamExt};
40use log::{debug, info};
41use std::io::{Read, Seek};
42use tokio::sync::mpsc::channel;
43use tokio::sync::mpsc::error::SendError;
44use tokio::{sync::mpsc::Sender, task};
45use tokio_stream::wrappers::ReceiverStream;
46use tonic::metadata::MetadataValue;
47use tonic::{Request, Response, Status, Streaming};
48use tracing::warn;
49
50#[derive(Clone)]
52pub struct KapotFlightService {}
53
54impl KapotFlightService {
55 pub fn new() -> Self {
56 Self {}
57 }
58}
59
60impl Default for KapotFlightService {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66type BoxedFlightStream<T> =
67 Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
68
69#[tonic::async_trait]
70impl FlightService for KapotFlightService {
71 type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
72 type DoExchangeStream = BoxedFlightStream<FlightData>;
73 type DoGetStream = BoxedFlightStream<FlightData>;
74 type DoPutStream = BoxedFlightStream<PutResult>;
75 type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
76 type ListActionsStream = BoxedFlightStream<ActionType>;
77 type ListFlightsStream = BoxedFlightStream<FlightInfo>;
78
79 async fn do_get(
80 &self,
81 request: Request<Ticket>,
82 ) -> Result<Response<Self::DoGetStream>, Status> {
83 let ticket = request.into_inner();
84
85 let action =
86 decode_protobuf(&ticket.ticket).map_err(|e| from_kapot_err(&e))?;
87
88 match &action {
89 kapotAction::FetchPartition { path, .. } => {
90 debug!("FetchPartition reading {}", path);
91 let file = File::open(path)
92 .map_err(|e| {
93 KapotError::General(format!(
94 "Failed to open partition file at {path}: {e:?}"
95 ))
96 })
97 .map(|file| std::io::BufReader::new(file))
98 .map_err(|e| from_kapot_err(&e))?;
99 let reader =
100 StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;
101
102 let (tx, rx) = channel(2);
103 let schema = reader.schema();
104 task::spawn_blocking(move || {
105 if let Err(e) = read_partition(reader, tx) {
106 warn!(error = %e, "error streaming shuffle partition");
107 }
108 });
109
110 let write_options: IpcWriteOptions = IpcWriteOptions::default()
111 .try_with_compression(Some(CompressionType::LZ4_FRAME))
112 .map_err(|e| from_arrow_err(&e))?;
113 let flight_data_stream = FlightDataEncoderBuilder::new()
114 .with_schema(schema)
115 .with_options(write_options)
116 .build(ReceiverStream::new(rx))
117 .map_err(|err| Status::from_error(Box::new(err)));
118
119 Ok(Response::new(
120 Box::pin(flight_data_stream) as Self::DoGetStream
121 ))
122 }
123 }
124 }
125
126 async fn get_schema(
127 &self,
128 _request: Request<FlightDescriptor>,
129 ) -> Result<Response<SchemaResult>, Status> {
130 Err(Status::unimplemented("get_schema"))
131 }
132
133 async fn get_flight_info(
134 &self,
135 _request: Request<FlightDescriptor>,
136 ) -> Result<Response<FlightInfo>, Status> {
137 Err(Status::unimplemented("get_flight_info"))
138 }
139
140 async fn handshake(
141 &self,
142 _request: Request<Streaming<HandshakeRequest>>,
143 ) -> Result<Response<Self::HandshakeStream>, Status> {
144 let token = uuid::Uuid::new_v4();
145 info!("do_handshake token={}", token);
146
147 let result = HandshakeResponse {
148 protocol_version: 0,
149 payload: token.as_bytes().to_vec().into(),
150 };
151 let result = Ok(result);
152 let output = futures::stream::iter(vec![result]);
153 let str = format!("Bearer {token}");
154 let mut resp: Response<
155 Pin<Box<dyn Stream<Item = Result<_, Status>> + Send + 'static>>,
156 > = Response::new(Box::pin(output));
157 let md = MetadataValue::try_from(str)
158 .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
159 resp.metadata_mut().insert("authorization", md);
160 Ok(resp)
161 }
162
163 async fn list_flights(
164 &self,
165 _request: Request<Criteria>,
166 ) -> Result<Response<Self::ListFlightsStream>, Status> {
167 Err(Status::unimplemented("list_flights"))
168 }
169
170 async fn do_put(
171 &self,
172 request: Request<Streaming<FlightData>>,
173 ) -> Result<Response<Self::DoPutStream>, Status> {
174 let mut request = request.into_inner();
175
176 while let Some(data) = request.next().await {
177 let _data = data?;
178 }
179
180 Err(Status::unimplemented("do_put"))
181 }
182
183 async fn do_action(
184 &self,
185 request: Request<Action>,
186 ) -> Result<Response<Self::DoActionStream>, Status> {
187 let action = request.into_inner();
188
189 let _action = decode_protobuf(&action.body).map_err(|e| from_kapot_err(&e))?;
190
191 Err(Status::unimplemented("do_action"))
192 }
193
194 async fn list_actions(
195 &self,
196 _request: Request<Empty>,
197 ) -> Result<Response<Self::ListActionsStream>, Status> {
198 Err(Status::unimplemented("list_actions"))
199 }
200
201 async fn do_exchange(
202 &self,
203 _request: Request<Streaming<FlightData>>,
204 ) -> Result<Response<Self::DoExchangeStream>, Status> {
205 Err(Status::unimplemented("do_exchange"))
206 }
207
208 async fn poll_flight_info(
209 &self,
210 _request: Request<FlightDescriptor>,
211 ) -> Result<Response<PollInfo>, Status> {
212 Err(Status::unimplemented("poll_flight_info"))
213 }
214}
215
216fn read_partition<T>(
217 reader: StreamReader<std::io::BufReader<T>>,
218 tx: Sender<Result<RecordBatch, FlightError>>,
219) -> Result<(), FlightError>
220where
221 T: Read + Seek,
222{
223 if tx.is_closed() {
224 return Err(FlightError::Tonic(Status::internal(
225 "Can't send a batch, channel is closed",
226 )));
227 }
228
229 for batch in reader {
230 tx.blocking_send(batch.map_err(|err| err.into()))
231 .map_err(|err| {
232 if let SendError(Err(err)) = err {
233 err
234 } else {
235 FlightError::Tonic(Status::internal(
236 "Can't send a batch, something went wrong",
237 ))
238 }
239 })?
240 }
241 Ok(())
242}
243
244fn from_arrow_err(e: &ArrowError) -> Status {
245 Status::internal(format!("ArrowError: {e:?}"))
246}
247
248fn from_kapot_err(e: &kapot_core::error::KapotError) -> Status {
249 Status::internal(format!("kapot Error: {e:?}"))
250}