Skip to main content

oxigdal_distributed/flight/
client.rs

1//! Arrow Flight client implementation for distributed data transfer.
2//!
3//! This module implements an Arrow Flight client for fetching and sending
4//! geospatial data using zero-copy transfers.
5
6use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{Action, HandshakeRequest, Ticket, flight_service_client::FlightServiceClient};
9use bytes::Bytes;
10use futures::StreamExt;
11use std::time::Duration;
12use tonic::transport::{Channel, Endpoint};
13use tracing::{debug, info, warn};
14
15/// Flight client for fetching geospatial data.
16pub struct FlightClient {
17    /// gRPC client.
18    client: FlightServiceClient<Channel>,
19    /// Server address.
20    address: String,
21}
22
23impl FlightClient {
24    /// Create a new Flight client.
25    pub async fn new(address: String) -> Result<Self> {
26        info!("Connecting to Flight server at {}", address);
27
28        let endpoint = Endpoint::from_shared(address.clone())
29            .map_err(|e| DistributedError::worker_connection(format!("Invalid endpoint: {}", e)))?
30            .connect_timeout(Duration::from_secs(10))
31            .timeout(Duration::from_secs(60))
32            .tcp_keepalive(Some(Duration::from_secs(30)))
33            .http2_keep_alive_interval(Duration::from_secs(30))
34            .keep_alive_timeout(Duration::from_secs(10));
35
36        let channel = endpoint.connect().await.map_err(|e| {
37            DistributedError::worker_connection(format!("Connection failed: {}", e))
38        })?;
39
40        let client = FlightServiceClient::new(channel);
41
42        Ok(Self { client, address })
43    }
44
45    /// Perform handshake with the server.
46    pub async fn handshake(&mut self) -> Result<()> {
47        debug!("Performing handshake with {}", self.address);
48
49        let request = tonic::Request::new(futures::stream::once(async {
50            HandshakeRequest {
51                protocol_version: 0,
52                payload: Bytes::new(),
53            }
54        }));
55
56        let mut response_stream = self
57            .client
58            .handshake(request)
59            .await
60            .map_err(|e| DistributedError::flight_rpc(format!("Handshake failed: {}", e)))?
61            .into_inner();
62
63        // Read handshake response
64        while let Some(response) = response_stream.next().await {
65            let _handshake_response = response
66                .map_err(|e| DistributedError::flight_rpc(format!("Handshake error: {}", e)))?;
67            debug!("Handshake successful");
68        }
69
70        Ok(())
71    }
72
73    /// Fetch data from the server using a ticket.
74    pub async fn get_data(&mut self, ticket: String) -> Result<Vec<RecordBatch>> {
75        info!("Fetching data for ticket: {}", ticket);
76
77        let ticket = Ticket {
78            ticket: Bytes::from(ticket),
79        };
80
81        let request = tonic::Request::new(ticket);
82
83        let mut stream = self
84            .client
85            .do_get(request)
86            .await
87            .map_err(|e| DistributedError::flight_rpc(format!("DoGet failed: {}", e)))?
88            .into_inner();
89
90        let mut flight_data_vec = Vec::new();
91
92        while let Some(data_result) = stream.next().await {
93            flight_data_vec.push(
94                data_result
95                    .map_err(|e| DistributedError::flight_rpc(format!("Stream error: {}", e)))?,
96            );
97        }
98
99        // Convert FlightData to RecordBatches
100        let batches = arrow_flight::utils::flight_data_to_batches(&flight_data_vec)
101            .map_err(|e| DistributedError::arrow(format!("Failed to decode batches: {}", e)))?;
102
103        info!("Received {} batches", batches.len());
104        Ok(batches)
105    }
106
107    /// Send data to the server.
108    pub async fn put_data(&mut self, batches: Vec<RecordBatch>) -> Result<()> {
109        info!("Sending {} batches to server", batches.len());
110
111        if batches.is_empty() {
112            return Err(DistributedError::flight_rpc("No batches to send"));
113        }
114
115        // Convert batches to FlightData
116        let flight_data_vec =
117            arrow_flight::utils::batches_to_flight_data(batches[0].schema().as_ref(), batches)
118                .map_err(|e| DistributedError::arrow(format!("Failed to encode batches: {}", e)))?;
119
120        let request = tonic::Request::new(futures::stream::iter(flight_data_vec));
121
122        let mut response_stream = self
123            .client
124            .do_put(request)
125            .await
126            .map_err(|e| DistributedError::flight_rpc(format!("DoPut failed: {}", e)))?
127            .into_inner();
128
129        // Read put results
130        while let Some(result) = response_stream.next().await {
131            let _put_result =
132                result.map_err(|e| DistributedError::flight_rpc(format!("Put error: {}", e)))?;
133        }
134
135        info!("Data sent successfully");
136        Ok(())
137    }
138
139    /// Execute an action on the server.
140    pub async fn do_action(&mut self, action_type: String, body: Bytes) -> Result<Vec<Bytes>> {
141        debug!("Executing action: {}", action_type);
142
143        let action = Action {
144            r#type: action_type.clone(),
145            body,
146        };
147
148        let request = tonic::Request::new(action);
149
150        let mut stream = self
151            .client
152            .do_action(request)
153            .await
154            .map_err(|e| DistributedError::flight_rpc(format!("DoAction failed: {}", e)))?
155            .into_inner();
156
157        let mut results = Vec::new();
158
159        while let Some(result) = stream.next().await {
160            let action_result =
161                result.map_err(|e| DistributedError::flight_rpc(format!("Action error: {}", e)))?;
162            results.push(action_result.body);
163        }
164
165        debug!(
166            "Action {} completed with {} results",
167            action_type,
168            results.len()
169        );
170        Ok(results)
171    }
172
173    /// List all available tickets.
174    pub async fn list_tickets(&mut self) -> Result<Vec<String>> {
175        let results = self
176            .do_action("list_tickets".to_string(), Bytes::new())
177            .await?;
178
179        if results.is_empty() {
180            return Ok(Vec::new());
181        }
182
183        let tickets: Vec<String> = serde_json::from_slice(&results[0]).map_err(|e| {
184            DistributedError::flight_rpc(format!("Failed to deserialize tickets: {}", e))
185        })?;
186
187        Ok(tickets)
188    }
189
190    /// Remove a ticket from the server.
191    pub async fn remove_ticket(&mut self, ticket: String) -> Result<()> {
192        let body = Bytes::from(ticket.clone());
193        let _results = self.do_action("remove_ticket".to_string(), body).await?;
194
195        info!("Removed ticket: {}", ticket);
196        Ok(())
197    }
198
199    /// Get the server address.
200    pub fn address(&self) -> &str {
201        &self.address
202    }
203
204    /// Check if the client is connected.
205    pub async fn health_check(&mut self) -> Result<bool> {
206        match self.handshake().await {
207            Ok(_) => Ok(true),
208            Err(e) => {
209                warn!("Health check failed: {}", e);
210                Ok(false)
211            }
212        }
213    }
214}
215
216/// Connection pool for managing multiple Flight clients.
217pub struct FlightClientPool {
218    /// Available clients.
219    clients: Vec<FlightClient>,
220    /// Maximum pool size.
221    max_size: usize,
222}
223
224impl FlightClientPool {
225    /// Create a new client pool.
226    pub fn new(max_size: usize) -> Self {
227        Self {
228            clients: Vec::new(),
229            max_size,
230        }
231    }
232
233    /// Add a client to the pool.
234    pub async fn add_client(&mut self, address: String) -> Result<()> {
235        if self.clients.len() >= self.max_size {
236            return Err(DistributedError::worker_connection(
237                "Pool is at maximum capacity",
238            ));
239        }
240
241        let client = FlightClient::new(address).await?;
242        self.clients.push(client);
243        Ok(())
244    }
245
246    /// Get a client from the pool (round-robin).
247    pub fn get_client(&mut self) -> Result<&mut FlightClient> {
248        if self.clients.is_empty() {
249            return Err(DistributedError::worker_connection("No clients available"));
250        }
251
252        // Simple round-robin: rotate the first client to the back
253        self.clients.rotate_left(1);
254        let idx = self.clients.len() - 1;
255        Ok(&mut self.clients[idx])
256    }
257
258    /// Get the number of clients in the pool.
259    pub fn size(&self) -> usize {
260        self.clients.len()
261    }
262
263    /// Check health of all clients.
264    pub async fn health_check_all(&mut self) -> Result<Vec<bool>> {
265        let mut results = Vec::new();
266
267        for client in &mut self.clients {
268            let is_healthy = client.health_check().await.unwrap_or(false);
269            results.push(is_healthy);
270        }
271
272        Ok(results)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_client_pool() {
282        let pool = FlightClientPool::new(5);
283        assert_eq!(pool.size(), 0);
284        assert_eq!(pool.max_size, 5);
285    }
286
287    #[tokio::test]
288    async fn test_client_creation_fails_for_invalid_address() {
289        let result = FlightClient::new("invalid://address".to_string()).await;
290        assert!(result.is_err());
291    }
292
293    #[test]
294    fn test_pool_get_client_empty() {
295        let mut pool = FlightClientPool::new(5);
296        let result = pool.get_client();
297        assert!(result.is_err());
298    }
299}