Skip to main content

oxigdal_distributed/flight/
server.rs

1//! Arrow Flight server implementation for distributed data transfer.
2//!
3//! This module implements an Arrow Flight server that streams geospatial data
4//! between nodes using zero-copy transfers.
5
6use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{
9    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
10    HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
11    flight_service_server::{FlightService, FlightServiceServer},
12};
13use bytes::Bytes;
14use futures::{Stream, StreamExt, stream};
15use std::collections::HashMap;
16use std::pin::Pin;
17use std::sync::{Arc, RwLock};
18use tonic::{Request, Response, Streaming};
19use tracing::{debug, info};
20
21/// Flight server for serving geospatial data.
22pub struct FlightServer {
23    /// Stored data partitions (ticket -> RecordBatch).
24    data_store: Arc<RwLock<HashMap<String, Arc<RecordBatch>>>>,
25    /// Authentication tokens.
26    auth_tokens: Arc<RwLock<HashMap<String, String>>>,
27    /// Enable authentication.
28    enable_auth: bool,
29}
30
31impl FlightServer {
32    /// Create a new Flight server.
33    pub fn new() -> Self {
34        Self {
35            data_store: Arc::new(RwLock::new(HashMap::new())),
36            auth_tokens: Arc::new(RwLock::new(HashMap::new())),
37            enable_auth: false,
38        }
39    }
40
41    /// Enable authentication.
42    pub fn with_auth(mut self) -> Self {
43        self.enable_auth = true;
44        self
45    }
46
47    /// Store data with a ticket.
48    pub fn store_data(&self, ticket: String, data: Arc<RecordBatch>) -> Result<()> {
49        let mut store = self
50            .data_store
51            .write()
52            .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
53
54        store.insert(ticket, data);
55        Ok(())
56    }
57
58    /// Retrieve data by ticket.
59    pub fn get_data(&self, ticket: &str) -> Result<Option<Arc<RecordBatch>>> {
60        let store = self
61            .data_store
62            .read()
63            .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
64
65        Ok(store.get(ticket).cloned())
66    }
67
68    /// Remove data by ticket.
69    pub fn remove_data(&self, ticket: &str) -> Result<Option<Arc<RecordBatch>>> {
70        let mut store = self
71            .data_store
72            .write()
73            .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
74
75        Ok(store.remove(ticket))
76    }
77
78    /// List all available tickets.
79    pub fn list_tickets(&self) -> Result<Vec<String>> {
80        let store = self
81            .data_store
82            .read()
83            .map_err(|_| DistributedError::flight_rpc("Failed to acquire data store lock"))?;
84
85        Ok(store.keys().cloned().collect())
86    }
87
88    /// Add authentication token.
89    pub fn add_auth_token(&self, token: String, user: String) -> Result<()> {
90        let mut tokens = self
91            .auth_tokens
92            .write()
93            .map_err(|_| DistributedError::authentication("Failed to acquire auth tokens lock"))?;
94
95        tokens.insert(token, user);
96        Ok(())
97    }
98
99    /// Convert to tonic service.
100    pub fn into_service(self) -> FlightServiceServer<Self> {
101        FlightServiceServer::new(self)
102    }
103}
104
105impl Default for FlightServer {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[tonic::async_trait]
112impl FlightService for FlightServer {
113    type HandshakeStream =
114        Pin<Box<dyn Stream<Item = std::result::Result<HandshakeResponse, tonic::Status>> + Send>>;
115    type ListFlightsStream =
116        Pin<Box<dyn Stream<Item = std::result::Result<FlightInfo, tonic::Status>> + Send>>;
117    type DoGetStream =
118        Pin<Box<dyn Stream<Item = std::result::Result<FlightData, tonic::Status>> + Send>>;
119    type DoPutStream =
120        Pin<Box<dyn Stream<Item = std::result::Result<PutResult, tonic::Status>> + Send>>;
121    type DoActionStream = Pin<
122        Box<dyn Stream<Item = std::result::Result<arrow_flight::Result, tonic::Status>> + Send>,
123    >;
124    type ListActionsStream =
125        Pin<Box<dyn Stream<Item = std::result::Result<ActionType, tonic::Status>> + Send>>;
126    type DoExchangeStream =
127        Pin<Box<dyn Stream<Item = std::result::Result<FlightData, tonic::Status>> + Send>>;
128
129    async fn handshake(
130        &self,
131        _request: Request<Streaming<HandshakeRequest>>,
132    ) -> std::result::Result<Response<Self::HandshakeStream>, tonic::Status> {
133        debug!("Handshake request received");
134
135        // Simple handshake - just acknowledge
136        let response = HandshakeResponse {
137            protocol_version: 0,
138            payload: Bytes::new(),
139        };
140
141        let stream = stream::once(async { Ok(response) });
142        Ok(Response::new(Box::pin(stream)))
143    }
144
145    async fn list_flights(
146        &self,
147        _request: Request<Criteria>,
148    ) -> std::result::Result<Response<Self::ListFlightsStream>, tonic::Status> {
149        debug!("List flights request received");
150
151        // Return empty stream - we don't support flight listing yet
152        let stream = stream::empty();
153        Ok(Response::new(Box::pin(stream)))
154    }
155
156    async fn get_flight_info(
157        &self,
158        request: Request<FlightDescriptor>,
159    ) -> std::result::Result<Response<FlightInfo>, tonic::Status> {
160        let descriptor = request.into_inner();
161        debug!("Get flight info request: {:?}", descriptor);
162
163        Err(tonic::Status::unimplemented(
164            "get_flight_info not implemented",
165        ))
166    }
167
168    async fn get_schema(
169        &self,
170        _request: Request<FlightDescriptor>,
171    ) -> std::result::Result<Response<SchemaResult>, tonic::Status> {
172        debug!("Get schema request received");
173
174        Err(tonic::Status::unimplemented("get_schema not implemented"))
175    }
176
177    async fn do_get(
178        &self,
179        request: Request<Ticket>,
180    ) -> std::result::Result<Response<Self::DoGetStream>, tonic::Status> {
181        let ticket = request.into_inner();
182        let ticket_str = String::from_utf8(ticket.ticket.to_vec())
183            .map_err(|e| tonic::Status::invalid_argument(format!("Invalid ticket: {}", e)))?;
184
185        info!("DoGet request for ticket: {}", ticket_str);
186
187        // Retrieve data
188        let data = self
189            .get_data(&ticket_str)
190            .map_err(|e| tonic::Status::internal(e.to_string()))?
191            .ok_or_else(|| tonic::Status::not_found(format!("Ticket not found: {}", ticket_str)))?;
192
193        // Convert RecordBatch to FlightData stream
194        let flight_data_vec = arrow_flight::utils::batches_to_flight_data(
195            data.schema().as_ref(),
196            vec![(*data).clone()],
197        )
198        .map_err(|e| tonic::Status::internal(format!("Failed to encode batches: {}", e)))?
199        .into_iter()
200        .map(Ok)
201        .collect::<Vec<_>>();
202
203        let stream = stream::iter(flight_data_vec);
204        Ok(Response::new(Box::pin(stream)))
205    }
206
207    async fn do_put(
208        &self,
209        request: Request<Streaming<FlightData>>,
210    ) -> std::result::Result<Response<Self::DoPutStream>, tonic::Status> {
211        debug!("DoPut request received");
212
213        let mut stream = request.into_inner();
214        let mut flight_data_vec = Vec::new();
215
216        // Collect all FlightData messages
217        while let Some(data_result) = stream.next().await {
218            flight_data_vec.push(data_result?);
219        }
220
221        // Convert FlightData to RecordBatches
222        let batches = arrow_flight::utils::flight_data_to_batches(&flight_data_vec)
223            .map_err(|e| tonic::Status::internal(format!("Failed to decode batches: {}", e)))?;
224
225        info!("DoPut received {} batches", batches.len());
226
227        // Store batches (using a generated ticket)
228        for (i, batch) in batches.into_iter().enumerate() {
229            let ticket = format!("uploaded_{}", i);
230            self.store_data(ticket, Arc::new(batch))
231                .map_err(|e| tonic::Status::internal(e.to_string()))?;
232        }
233
234        // Return success
235        let result = PutResult {
236            app_metadata: Bytes::new(),
237        };
238
239        let stream = stream::once(async { Ok(result) });
240        Ok(Response::new(Box::pin(stream)))
241    }
242
243    async fn do_action(
244        &self,
245        request: Request<Action>,
246    ) -> std::result::Result<Response<Self::DoActionStream>, tonic::Status> {
247        let action = request.into_inner();
248        info!("DoAction request: {}", action.r#type);
249
250        match action.r#type.as_str() {
251            "list_tickets" => {
252                let tickets = self
253                    .list_tickets()
254                    .map_err(|e| tonic::Status::internal(e.to_string()))?;
255
256                let result = arrow_flight::Result {
257                    body: serde_json::to_vec(&tickets)
258                        .map_err(|e| {
259                            tonic::Status::internal(format!("Serialization error: {}", e))
260                        })?
261                        .into(),
262                };
263
264                let stream = stream::once(async { Ok(result) });
265                Ok(Response::new(Box::pin(stream)))
266            }
267            "remove_ticket" => {
268                let ticket = String::from_utf8(action.body.to_vec()).map_err(|e| {
269                    tonic::Status::invalid_argument(format!("Invalid ticket: {}", e))
270                })?;
271
272                self.remove_data(&ticket)
273                    .map_err(|e| tonic::Status::internal(e.to_string()))?;
274
275                let result = arrow_flight::Result {
276                    body: Bytes::from("removed"),
277                };
278
279                let stream = stream::once(async { Ok(result) });
280                Ok(Response::new(Box::pin(stream)))
281            }
282            _ => Err(tonic::Status::unimplemented(format!(
283                "Action not implemented: {}",
284                action.r#type
285            ))),
286        }
287    }
288
289    async fn list_actions(
290        &self,
291        _request: Request<Empty>,
292    ) -> std::result::Result<Response<Self::ListActionsStream>, tonic::Status> {
293        debug!("List actions request received");
294
295        let actions = vec![
296            ActionType {
297                r#type: "list_tickets".to_string(),
298                description: "List all available tickets".to_string(),
299            },
300            ActionType {
301                r#type: "remove_ticket".to_string(),
302                description: "Remove a ticket from the server".to_string(),
303            },
304        ];
305
306        let stream = stream::iter(actions.into_iter().map(Ok));
307        Ok(Response::new(Box::pin(stream)))
308    }
309
310    async fn do_exchange(
311        &self,
312        _request: Request<Streaming<FlightData>>,
313    ) -> std::result::Result<Response<Self::DoExchangeStream>, tonic::Status> {
314        debug!("DoExchange request received");
315
316        Err(tonic::Status::unimplemented("do_exchange not implemented"))
317    }
318
319    async fn poll_flight_info(
320        &self,
321        request: Request<FlightDescriptor>,
322    ) -> std::result::Result<Response<arrow_flight::PollInfo>, tonic::Status> {
323        let _descriptor = request.into_inner();
324        debug!("Poll flight info request received");
325
326        Err(tonic::Status::unimplemented(
327            "poll_flight_info not implemented",
328        ))
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use arrow::array::Int32Array;
336    use arrow::datatypes::{DataType, Field, Schema};
337
338    fn create_test_batch() -> std::result::Result<Arc<RecordBatch>, Box<dyn std::error::Error>> {
339        let schema = Arc::new(Schema::new(vec![Field::new(
340            "value",
341            DataType::Int32,
342            false,
343        )]));
344
345        let array = Int32Array::from(vec![1, 2, 3, 4, 5]);
346
347        Ok(Arc::new(RecordBatch::try_new(
348            schema,
349            vec![Arc::new(array)],
350        )?))
351    }
352
353    #[test]
354    fn test_server_creation() {
355        let server = FlightServer::new();
356        assert!(!server.enable_auth);
357    }
358
359    #[test]
360    fn test_store_and_retrieve_data() -> std::result::Result<(), Box<dyn std::error::Error>> {
361        let server = FlightServer::new();
362        let batch = create_test_batch()?;
363
364        server.store_data("test_ticket".to_string(), batch.clone())?;
365
366        let retrieved = server
367            .get_data("test_ticket")?
368            .ok_or_else(|| Box::<dyn std::error::Error>::from("should exist"))?;
369
370        assert_eq!(retrieved.num_rows(), batch.num_rows());
371        Ok(())
372    }
373
374    #[test]
375    fn test_remove_data() -> std::result::Result<(), Box<dyn std::error::Error>> {
376        let server = FlightServer::new();
377        let batch = create_test_batch()?;
378
379        server.store_data("test_ticket".to_string(), batch)?;
380
381        let removed = server
382            .remove_data("test_ticket")?
383            .ok_or_else(|| Box::<dyn std::error::Error>::from("should exist"))?;
384
385        assert_eq!(removed.num_rows(), 5);
386
387        let retrieved = server.get_data("test_ticket")?;
388        assert!(retrieved.is_none());
389        Ok(())
390    }
391
392    #[test]
393    fn test_list_tickets() -> std::result::Result<(), Box<dyn std::error::Error>> {
394        let server = FlightServer::new();
395
396        server.store_data("ticket1".to_string(), create_test_batch()?)?;
397        server.store_data("ticket2".to_string(), create_test_batch()?)?;
398
399        let tickets = server.list_tickets()?;
400        assert_eq!(tickets.len(), 2);
401        assert!(tickets.contains(&"ticket1".to_string()));
402        assert!(tickets.contains(&"ticket2".to_string()));
403        Ok(())
404    }
405
406    #[test]
407    fn test_authentication() -> std::result::Result<(), Box<dyn std::error::Error>> {
408        let server = FlightServer::new().with_auth();
409        assert!(server.enable_auth);
410
411        server.add_auth_token("token123".to_string(), "user1".to_string())?;
412
413        // Verify token exists via auth_tokens (verify_token method not exposed)
414        assert!(
415            server
416                .auth_tokens
417                .read()
418                .map_err(|e| Box::<dyn std::error::Error>::from(format!("lock poisoned: {}", e)))?
419                .contains_key("token123")
420        );
421        assert!(
422            !server
423                .auth_tokens
424                .read()
425                .map_err(|e| Box::<dyn std::error::Error>::from(format!("lock poisoned: {}", e)))?
426                .contains_key("invalid")
427        );
428        Ok(())
429    }
430}