oxigdal_distributed/flight/
server.rs1use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{
9 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
10 HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
11 flight_service_server::{FlightService, FlightServiceServer},
12};
13use bytes::Bytes;
14use futures::{Stream, StreamExt, stream};
15use std::collections::HashMap;
16use std::pin::Pin;
17use std::sync::{Arc, RwLock};
18use tonic::{Request, Response, Streaming};
19use tracing::{debug, info};
20
21pub struct FlightServer {
23 data_store: Arc<RwLock<HashMap<String, Arc<RecordBatch>>>>,
25 auth_tokens: Arc<RwLock<HashMap<String, String>>>,
27 enable_auth: bool,
29}
30
31impl FlightServer {
32 pub fn new() -> Self {
34 Self {
35 data_store: Arc::new(RwLock::new(HashMap::new())),
36 auth_tokens: Arc::new(RwLock::new(HashMap::new())),
37 enable_auth: false,
38 }
39 }
40
41 pub fn with_auth(mut self) -> Self {
43 self.enable_auth = true;
44 self
45 }
46
47 pub fn store_data(&self, ticket: String, data: Arc<RecordBatch>) -> Result<()> {
49 let mut store = self
50 .data_store
51 .write()
52 .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
53
54 store.insert(ticket, data);
55 Ok(())
56 }
57
58 pub fn get_data(&self, ticket: &str) -> Result<Option<Arc<RecordBatch>>> {
60 let store = self
61 .data_store
62 .read()
63 .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
64
65 Ok(store.get(ticket).cloned())
66 }
67
68 pub fn remove_data(&self, ticket: &str) -> Result<Option<Arc<RecordBatch>>> {
70 let mut store = self
71 .data_store
72 .write()
73 .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
74
75 Ok(store.remove(ticket))
76 }
77
78 pub fn list_tickets(&self) -> Result<Vec<String>> {
80 let store = self
81 .data_store
82 .read()
83 .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
84
85 Ok(store.keys().cloned().collect())
86 }
87
88 pub fn add_auth_token(&self, token: String, user: String) -> Result<()> {
90 let mut tokens = self
91 .auth_tokens
92 .write()
93 .map_err(|_| DistributedError::authentication("Failed to acquire auth tokens lock"))?;
94
95 tokens.insert(token, user);
96 Ok(())
97 }
98
99 pub fn into_service(self) -> FlightServiceServer<Self> {
101 FlightServiceServer::new(self)
102 }
103}
104
105impl Default for FlightServer {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[tonic::async_trait]
112impl FlightService for FlightServer {
113 type HandshakeStream =
114 Pin<Box<dyn Stream<Item = std::result::Result<HandshakeResponse, tonic::Status>> + Send>>;
115 type ListFlightsStream =
116 Pin<Box<dyn Stream<Item = std::result::Result<FlightInfo, tonic::Status>> + Send>>;
117 type DoGetStream =
118 Pin<Box<dyn Stream<Item = std::result::Result<FlightData, tonic::Status>> + Send>>;
119 type DoPutStream =
120 Pin<Box<dyn Stream<Item = std::result::Result<PutResult, tonic::Status>> + Send>>;
121 type DoActionStream = Pin<
122 Box<dyn Stream<Item = std::result::Result<arrow_flight::Result, tonic::Status>> + Send>,
123 >;
124 type ListActionsStream =
125 Pin<Box<dyn Stream<Item = std::result::Result<ActionType, tonic::Status>> + Send>>;
126 type DoExchangeStream =
127 Pin<Box<dyn Stream<Item = std::result::Result<FlightData, tonic::Status>> + Send>>;
128
129 async fn handshake(
130 &self,
131 _request: Request<Streaming<HandshakeRequest>>,
132 ) -> std::result::Result<Response<Self::HandshakeStream>, tonic::Status> {
133 debug!("Handshake request received");
134
135 let response = HandshakeResponse {
137 protocol_version: 0,
138 payload: Bytes::new(),
139 };
140
141 let stream = stream::once(async { Ok(response) });
142 Ok(Response::new(Box::pin(stream)))
143 }
144
145 async fn list_flights(
146 &self,
147 _request: Request<Criteria>,
148 ) -> std::result::Result<Response<Self::ListFlightsStream>, tonic::Status> {
149 debug!("List flights request received");
150
151 let stream = stream::empty();
153 Ok(Response::new(Box::pin(stream)))
154 }
155
156 async fn get_flight_info(
157 &self,
158 request: Request<FlightDescriptor>,
159 ) -> std::result::Result<Response<FlightInfo>, tonic::Status> {
160 let descriptor = request.into_inner();
161 debug!("Get flight info request: {:?}", descriptor);
162
163 Err(tonic::Status::unimplemented(
164 "get_flight_info not implemented",
165 ))
166 }
167
168 async fn get_schema(
169 &self,
170 _request: Request<FlightDescriptor>,
171 ) -> std::result::Result<Response<SchemaResult>, tonic::Status> {
172 debug!("Get schema request received");
173
174 Err(tonic::Status::unimplemented("get_schema not implemented"))
175 }
176
177 async fn do_get(
178 &self,
179 request: Request<Ticket>,
180 ) -> std::result::Result<Response<Self::DoGetStream>, tonic::Status> {
181 let ticket = request.into_inner();
182 let ticket_str = String::from_utf8(ticket.ticket.to_vec())
183 .map_err(|e| tonic::Status::invalid_argument(format!("Invalid ticket: {}", e)))?;
184
185 info!("DoGet request for ticket: {}", ticket_str);
186
187 let data = self
189 .get_data(&ticket_str)
190 .map_err(|e| tonic::Status::internal(e.to_string()))?
191 .ok_or_else(|| tonic::Status::not_found(format!("Ticket not found: {}", ticket_str)))?;
192
193 let flight_data_vec = arrow_flight::utils::batches_to_flight_data(
195 data.schema().as_ref(),
196 vec![(*data).clone()],
197 )
198 .map_err(|e| tonic::Status::internal(format!("Failed to encode batches: {}", e)))?
199 .into_iter()
200 .map(Ok)
201 .collect::<Vec<_>>();
202
203 let stream = stream::iter(flight_data_vec);
204 Ok(Response::new(Box::pin(stream)))
205 }
206
207 async fn do_put(
208 &self,
209 request: Request<Streaming<FlightData>>,
210 ) -> std::result::Result<Response<Self::DoPutStream>, tonic::Status> {
211 debug!("DoPut request received");
212
213 let mut stream = request.into_inner();
214 let mut flight_data_vec = Vec::new();
215
216 while let Some(data_result) = stream.next().await {
218 flight_data_vec.push(data_result?);
219 }
220
221 let batches = arrow_flight::utils::flight_data_to_batches(&flight_data_vec)
223 .map_err(|e| tonic::Status::internal(format!("Failed to decode batches: {}", e)))?;
224
225 info!("DoPut received {} batches", batches.len());
226
227 for (i, batch) in batches.into_iter().enumerate() {
229 let ticket = format!("uploaded_{}", i);
230 self.store_data(ticket, Arc::new(batch))
231 .map_err(|e| tonic::Status::internal(e.to_string()))?;
232 }
233
234 let result = PutResult {
236 app_metadata: Bytes::new(),
237 };
238
239 let stream = stream::once(async { Ok(result) });
240 Ok(Response::new(Box::pin(stream)))
241 }
242
243 async fn do_action(
244 &self,
245 request: Request<Action>,
246 ) -> std::result::Result<Response<Self::DoActionStream>, tonic::Status> {
247 let action = request.into_inner();
248 info!("DoAction request: {}", action.r#type);
249
250 match action.r#type.as_str() {
251 "list_tickets" => {
252 let tickets = self
253 .list_tickets()
254 .map_err(|e| tonic::Status::internal(e.to_string()))?;
255
256 let result = arrow_flight::Result {
257 body: serde_json::to_vec(&tickets)
258 .map_err(|e| {
259 tonic::Status::internal(format!("Serialization error: {}", e))
260 })?
261 .into(),
262 };
263
264 let stream = stream::once(async { Ok(result) });
265 Ok(Response::new(Box::pin(stream)))
266 }
267 "remove_ticket" => {
268 let ticket = String::from_utf8(action.body.to_vec()).map_err(|e| {
269 tonic::Status::invalid_argument(format!("Invalid ticket: {}", e))
270 })?;
271
272 self.remove_data(&ticket)
273 .map_err(|e| tonic::Status::internal(e.to_string()))?;
274
275 let result = arrow_flight::Result {
276 body: Bytes::from("removed"),
277 };
278
279 let stream = stream::once(async { Ok(result) });
280 Ok(Response::new(Box::pin(stream)))
281 }
282 _ => Err(tonic::Status::unimplemented(format!(
283 "Action not implemented: {}",
284 action.r#type
285 ))),
286 }
287 }
288
289 async fn list_actions(
290 &self,
291 _request: Request<Empty>,
292 ) -> std::result::Result<Response<Self::ListActionsStream>, tonic::Status> {
293 debug!("List actions request received");
294
295 let actions = vec![
296 ActionType {
297 r#type: "list_tickets".to_string(),
298 description: "List all available tickets".to_string(),
299 },
300 ActionType {
301 r#type: "remove_ticket".to_string(),
302 description: "Remove a ticket from the server".to_string(),
303 },
304 ];
305
306 let stream = stream::iter(actions.into_iter().map(Ok));
307 Ok(Response::new(Box::pin(stream)))
308 }
309
310 async fn do_exchange(
311 &self,
312 _request: Request<Streaming<FlightData>>,
313 ) -> std::result::Result<Response<Self::DoExchangeStream>, tonic::Status> {
314 debug!("DoExchange request received");
315
316 Err(tonic::Status::unimplemented("do_exchange not implemented"))
317 }
318
319 async fn poll_flight_info(
320 &self,
321 request: Request<FlightDescriptor>,
322 ) -> std::result::Result<Response<arrow_flight::PollInfo>, tonic::Status> {
323 let _descriptor = request.into_inner();
324 debug!("Poll flight info request received");
325
326 Err(tonic::Status::unimplemented(
327 "poll_flight_info not implemented",
328 ))
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use arrow::array::Int32Array;
336 use arrow::datatypes::{DataType, Field, Schema};
337
338 fn create_test_batch() -> std::result::Result<Arc<RecordBatch>, Box<dyn std::error::Error>> {
339 let schema = Arc::new(Schema::new(vec![Field::new(
340 "value",
341 DataType::Int32,
342 false,
343 )]));
344
345 let array = Int32Array::from(vec![1, 2, 3, 4, 5]);
346
347 Ok(Arc::new(RecordBatch::try_new(
348 schema,
349 vec![Arc::new(array)],
350 )?))
351 }
352
353 #[test]
354 fn test_server_creation() {
355 let server = FlightServer::new();
356 assert!(!server.enable_auth);
357 }
358
359 #[test]
360 fn test_store_and_retrieve_data() -> std::result::Result<(), Box<dyn std::error::Error>> {
361 let server = FlightServer::new();
362 let batch = create_test_batch()?;
363
364 server.store_data("test_ticket".to_string(), batch.clone())?;
365
366 let retrieved = server
367 .get_data("test_ticket")?
368 .ok_or_else(|| Box::<dyn std::error::Error>::from("should exist"))?;
369
370 assert_eq!(retrieved.num_rows(), batch.num_rows());
371 Ok(())
372 }
373
374 #[test]
375 fn test_remove_data() -> std::result::Result<(), Box<dyn std::error::Error>> {
376 let server = FlightServer::new();
377 let batch = create_test_batch()?;
378
379 server.store_data("test_ticket".to_string(), batch)?;
380
381 let removed = server
382 .remove_data("test_ticket")?
383 .ok_or_else(|| Box::<dyn std::error::Error>::from("should exist"))?;
384
385 assert_eq!(removed.num_rows(), 5);
386
387 let retrieved = server.get_data("test_ticket")?;
388 assert!(retrieved.is_none());
389 Ok(())
390 }
391
392 #[test]
393 fn test_list_tickets() -> std::result::Result<(), Box<dyn std::error::Error>> {
394 let server = FlightServer::new();
395
396 server.store_data("ticket1".to_string(), create_test_batch()?)?;
397 server.store_data("ticket2".to_string(), create_test_batch()?)?;
398
399 let tickets = server.list_tickets()?;
400 assert_eq!(tickets.len(), 2);
401 assert!(tickets.contains(&"ticket1".to_string()));
402 assert!(tickets.contains(&"ticket2".to_string()));
403 Ok(())
404 }
405
406 #[test]
407 fn test_authentication() -> std::result::Result<(), Box<dyn std::error::Error>> {
408 let server = FlightServer::new().with_auth();
409 assert!(server.enable_auth);
410
411 server.add_auth_token("token123".to_string(), "user1".to_string())?;
412
413 assert!(
415 server
416 .auth_tokens
417 .read()
418 .map_err(|e| Box::<dyn std::error::Error>::from(format!("lock poisoned: {}", e)))?
419 .contains_key("token123")
420 );
421 assert!(
422 !server
423 .auth_tokens
424 .read()
425 .map_err(|e| Box::<dyn std::error::Error>::from(format!("lock poisoned: {}", e)))?
426 .contains_key("invalid")
427 );
428 Ok(())
429 }
430}