oxigdal_distributed/flight/
client.rs1use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{Action, HandshakeRequest, Ticket, flight_service_client::FlightServiceClient};
9use bytes::Bytes;
10use futures::StreamExt;
11use std::time::Duration;
12use tonic::transport::{Channel, Endpoint};
13use tracing::{debug, info, warn};
14
15pub struct FlightClient {
17 client: FlightServiceClient<Channel>,
19 address: String,
21}
22
23impl FlightClient {
24 pub async fn new(address: String) -> Result<Self> {
26 info!("Connecting to Flight server at {}", address);
27
28 let endpoint = Endpoint::from_shared(address.clone())
29 .map_err(|e| DistributedError::worker_connection(format!("Invalid endpoint: {}", e)))?
30 .connect_timeout(Duration::from_secs(10))
31 .timeout(Duration::from_secs(60))
32 .tcp_keepalive(Some(Duration::from_secs(30)))
33 .http2_keep_alive_interval(Duration::from_secs(30))
34 .keep_alive_timeout(Duration::from_secs(10));
35
36 let channel = endpoint.connect().await.map_err(|e| {
37 DistributedError::worker_connection(format!("Connection failed: {}", e))
38 })?;
39
40 let client = FlightServiceClient::new(channel);
41
42 Ok(Self { client, address })
43 }
44
45 pub async fn handshake(&mut self) -> Result<()> {
47 debug!("Performing handshake with {}", self.address);
48
49 let request = tonic::Request::new(futures::stream::once(async {
50 HandshakeRequest {
51 protocol_version: 0,
52 payload: Bytes::new(),
53 }
54 }));
55
56 let mut response_stream = self
57 .client
58 .handshake(request)
59 .await
60 .map_err(|e| DistributedError::flight_rpc(format!("Handshake failed: {}", e)))?
61 .into_inner();
62
63 while let Some(response) = response_stream.next().await {
65 let _handshake_response = response
66 .map_err(|e| DistributedError::flight_rpc(format!("Handshake error: {}", e)))?;
67 debug!("Handshake successful");
68 }
69
70 Ok(())
71 }
72
73 pub async fn get_data(&mut self, ticket: String) -> Result<Vec<RecordBatch>> {
75 info!("Fetching data for ticket: {}", ticket);
76
77 let ticket = Ticket {
78 ticket: Bytes::from(ticket),
79 };
80
81 let request = tonic::Request::new(ticket);
82
83 let mut stream = self
84 .client
85 .do_get(request)
86 .await
87 .map_err(|e| DistributedError::flight_rpc(format!("DoGet failed: {}", e)))?
88 .into_inner();
89
90 let mut flight_data_vec = Vec::new();
91
92 while let Some(data_result) = stream.next().await {
93 flight_data_vec.push(
94 data_result
95 .map_err(|e| DistributedError::flight_rpc(format!("Stream error: {}", e)))?,
96 );
97 }
98
99 let batches = arrow_flight::utils::flight_data_to_batches(&flight_data_vec)
101 .map_err(|e| DistributedError::arrow(format!("Failed to decode batches: {}", e)))?;
102
103 info!("Received {} batches", batches.len());
104 Ok(batches)
105 }
106
107 pub async fn put_data(&mut self, batches: Vec<RecordBatch>) -> Result<()> {
109 info!("Sending {} batches to server", batches.len());
110
111 if batches.is_empty() {
112 return Err(DistributedError::flight_rpc("No batches to send"));
113 }
114
115 let flight_data_vec =
117 arrow_flight::utils::batches_to_flight_data(batches[0].schema().as_ref(), batches)
118 .map_err(|e| DistributedError::arrow(format!("Failed to encode batches: {}", e)))?;
119
120 let request = tonic::Request::new(futures::stream::iter(flight_data_vec));
121
122 let mut response_stream = self
123 .client
124 .do_put(request)
125 .await
126 .map_err(|e| DistributedError::flight_rpc(format!("DoPut failed: {}", e)))?
127 .into_inner();
128
129 while let Some(result) = response_stream.next().await {
131 let _put_result =
132 result.map_err(|e| DistributedError::flight_rpc(format!("Put error: {}", e)))?;
133 }
134
135 info!("Data sent successfully");
136 Ok(())
137 }
138
139 pub async fn do_action(&mut self, action_type: String, body: Bytes) -> Result<Vec<Bytes>> {
141 debug!("Executing action: {}", action_type);
142
143 let action = Action {
144 r#type: action_type.clone(),
145 body,
146 };
147
148 let request = tonic::Request::new(action);
149
150 let mut stream = self
151 .client
152 .do_action(request)
153 .await
154 .map_err(|e| DistributedError::flight_rpc(format!("DoAction failed: {}", e)))?
155 .into_inner();
156
157 let mut results = Vec::new();
158
159 while let Some(result) = stream.next().await {
160 let action_result =
161 result.map_err(|e| DistributedError::flight_rpc(format!("Action error: {}", e)))?;
162 results.push(action_result.body);
163 }
164
165 debug!(
166 "Action {} completed with {} results",
167 action_type,
168 results.len()
169 );
170 Ok(results)
171 }
172
173 pub async fn list_tickets(&mut self) -> Result<Vec<String>> {
175 let results = self
176 .do_action("list_tickets".to_string(), Bytes::new())
177 .await?;
178
179 if results.is_empty() {
180 return Ok(Vec::new());
181 }
182
183 let tickets: Vec<String> = serde_json::from_slice(&results[0]).map_err(|e| {
184 DistributedError::flight_rpc(format!("Failed to deserialize tickets: {}", e))
185 })?;
186
187 Ok(tickets)
188 }
189
190 pub async fn remove_ticket(&mut self, ticket: String) -> Result<()> {
192 let body = Bytes::from(ticket.clone());
193 let _results = self.do_action("remove_ticket".to_string(), body).await?;
194
195 info!("Removed ticket: {}", ticket);
196 Ok(())
197 }
198
199 pub fn address(&self) -> &str {
201 &self.address
202 }
203
204 pub async fn health_check(&mut self) -> Result<bool> {
206 match self.handshake().await {
207 Ok(_) => Ok(true),
208 Err(e) => {
209 warn!("Health check failed: {}", e);
210 Ok(false)
211 }
212 }
213 }
214}
215
216pub struct FlightClientPool {
218 clients: Vec<FlightClient>,
220 max_size: usize,
222}
223
224impl FlightClientPool {
225 pub fn new(max_size: usize) -> Self {
227 Self {
228 clients: Vec::new(),
229 max_size,
230 }
231 }
232
233 pub async fn add_client(&mut self, address: String) -> Result<()> {
235 if self.clients.len() >= self.max_size {
236 return Err(DistributedError::worker_connection(
237 "Pool is at maximum capacity",
238 ));
239 }
240
241 let client = FlightClient::new(address).await?;
242 self.clients.push(client);
243 Ok(())
244 }
245
246 pub fn get_client(&mut self) -> Result<&mut FlightClient> {
248 if self.clients.is_empty() {
249 return Err(DistributedError::worker_connection("No clients available"));
250 }
251
252 self.clients.rotate_left(1);
254 let idx = self.clients.len() - 1;
255 Ok(&mut self.clients[idx])
256 }
257
258 pub fn size(&self) -> usize {
260 self.clients.len()
261 }
262
263 pub async fn health_check_all(&mut self) -> Result<Vec<bool>> {
265 let mut results = Vec::new();
266
267 for client in &mut self.clients {
268 let is_healthy = client.health_check().await.unwrap_or(false);
269 results.push(is_healthy);
270 }
271
272 Ok(results)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_client_pool() {
282 let pool = FlightClientPool::new(5);
283 assert_eq!(pool.size(), 0);
284 assert_eq!(pool.max_size, 5);
285 }
286
287 #[tokio::test]
288 async fn test_client_creation_fails_for_invalid_address() {
289 let result = FlightClient::new("invalid://address".to_string()).await;
290 assert!(result.is_err());
291 }
292
293 #[test]
294 fn test_pool_get_client_empty() {
295 let mut pool = FlightClientPool::new(5);
296 let result = pool.get_client();
297 assert!(result.is_err());
298 }
299}