use crate::error::{DistributedError, Result};
use arrow::record_batch::RecordBatch;
use arrow_flight::{Action, HandshakeRequest, Ticket, flight_service_client::FlightServiceClient};
use bytes::Bytes;
use futures::StreamExt;
use std::time::Duration;
use tonic::transport::{Channel, Endpoint};
use tracing::{debug, info, warn};
pub struct FlightClient {
client: FlightServiceClient<Channel>,
address: String,
}
impl FlightClient {
pub async fn new(address: String) -> Result<Self> {
info!("Connecting to Flight server at {}", address);
let endpoint = Endpoint::from_shared(address.clone())
.map_err(|e| DistributedError::worker_connection(format!("Invalid endpoint: {}", e)))?
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(60))
.tcp_keepalive(Some(Duration::from_secs(30)))
.http2_keep_alive_interval(Duration::from_secs(30))
.keep_alive_timeout(Duration::from_secs(10));
let channel = endpoint.connect().await.map_err(|e| {
DistributedError::worker_connection(format!("Connection failed: {}", e))
})?;
let client = FlightServiceClient::new(channel);
Ok(Self { client, address })
}
pub async fn handshake(&mut self) -> Result<()> {
debug!("Performing handshake with {}", self.address);
let request = tonic::Request::new(futures::stream::once(async {
HandshakeRequest {
protocol_version: 0,
payload: Bytes::new(),
}
}));
let mut response_stream = self
.client
.handshake(request)
.await
.map_err(|e| DistributedError::flight_rpc(format!("Handshake failed: {}", e)))?
.into_inner();
while let Some(response) = response_stream.next().await {
let _handshake_response = response
.map_err(|e| DistributedError::flight_rpc(format!("Handshake error: {}", e)))?;
debug!("Handshake successful");
}
Ok(())
}
pub async fn get_data(&mut self, ticket: String) -> Result<Vec<RecordBatch>> {
info!("Fetching data for ticket: {}", ticket);
let ticket = Ticket {
ticket: Bytes::from(ticket),
};
let request = tonic::Request::new(ticket);
let mut stream = self
.client
.do_get(request)
.await
.map_err(|e| DistributedError::flight_rpc(format!("DoGet failed: {}", e)))?
.into_inner();
let mut flight_data_vec = Vec::new();
while let Some(data_result) = stream.next().await {
flight_data_vec.push(
data_result
.map_err(|e| DistributedError::flight_rpc(format!("Stream error: {}", e)))?,
);
}
let batches = arrow_flight::utils::flight_data_to_batches(&flight_data_vec)
.map_err(|e| DistributedError::arrow(format!("Failed to decode batches: {}", e)))?;
info!("Received {} batches", batches.len());
Ok(batches)
}
pub async fn put_data(&mut self, batches: Vec<RecordBatch>) -> Result<()> {
info!("Sending {} batches to server", batches.len());
if batches.is_empty() {
return Err(DistributedError::flight_rpc("No batches to send"));
}
let flight_data_vec =
arrow_flight::utils::batches_to_flight_data(batches[0].schema().as_ref(), batches)
.map_err(|e| DistributedError::arrow(format!("Failed to encode batches: {}", e)))?;
let request = tonic::Request::new(futures::stream::iter(flight_data_vec));
let mut response_stream = self
.client
.do_put(request)
.await
.map_err(|e| DistributedError::flight_rpc(format!("DoPut failed: {}", e)))?
.into_inner();
while let Some(result) = response_stream.next().await {
let _put_result =
result.map_err(|e| DistributedError::flight_rpc(format!("Put error: {}", e)))?;
}
info!("Data sent successfully");
Ok(())
}
pub async fn do_action(&mut self, action_type: String, body: Bytes) -> Result<Vec<Bytes>> {
debug!("Executing action: {}", action_type);
let action = Action {
r#type: action_type.clone(),
body,
};
let request = tonic::Request::new(action);
let mut stream = self
.client
.do_action(request)
.await
.map_err(|e| DistributedError::flight_rpc(format!("DoAction failed: {}", e)))?
.into_inner();
let mut results = Vec::new();
while let Some(result) = stream.next().await {
let action_result =
result.map_err(|e| DistributedError::flight_rpc(format!("Action error: {}", e)))?;
results.push(action_result.body);
}
debug!(
"Action {} completed with {} results",
action_type,
results.len()
);
Ok(results)
}
pub async fn list_tickets(&mut self) -> Result<Vec<String>> {
let results = self
.do_action("list_tickets".to_string(), Bytes::new())
.await?;
if results.is_empty() {
return Ok(Vec::new());
}
let tickets: Vec<String> = serde_json::from_slice(&results[0]).map_err(|e| {
DistributedError::flight_rpc(format!("Failed to deserialize tickets: {}", e))
})?;
Ok(tickets)
}
pub async fn remove_ticket(&mut self, ticket: String) -> Result<()> {
let body = Bytes::from(ticket.clone());
let _results = self.do_action("remove_ticket".to_string(), body).await?;
info!("Removed ticket: {}", ticket);
Ok(())
}
pub fn address(&self) -> &str {
&self.address
}
pub async fn health_check(&mut self) -> Result<bool> {
match self.handshake().await {
Ok(_) => Ok(true),
Err(e) => {
warn!("Health check failed: {}", e);
Ok(false)
}
}
}
}
pub struct FlightClientPool {
clients: Vec<FlightClient>,
max_size: usize,
}
impl FlightClientPool {
pub fn new(max_size: usize) -> Self {
Self {
clients: Vec::new(),
max_size,
}
}
pub async fn add_client(&mut self, address: String) -> Result<()> {
if self.clients.len() >= self.max_size {
return Err(DistributedError::worker_connection(
"Pool is at maximum capacity",
));
}
let client = FlightClient::new(address).await?;
self.clients.push(client);
Ok(())
}
pub fn get_client(&mut self) -> Result<&mut FlightClient> {
if self.clients.is_empty() {
return Err(DistributedError::worker_connection("No clients available"));
}
self.clients.rotate_left(1);
let idx = self.clients.len() - 1;
Ok(&mut self.clients[idx])
}
pub fn size(&self) -> usize {
self.clients.len()
}
pub async fn health_check_all(&mut self) -> Result<Vec<bool>> {
let mut results = Vec::new();
for client in &mut self.clients {
let is_healthy = client.health_check().await.unwrap_or(false);
results.push(is_healthy);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_pool() {
let pool = FlightClientPool::new(5);
assert_eq!(pool.size(), 0);
assert_eq!(pool.max_size, 5);
}
#[tokio::test]
async fn test_client_creation_fails_for_invalid_address() {
let result = FlightClient::new("invalid://address".to_string()).await;
assert!(result.is_err());
}
#[test]
fn test_pool_get_client_empty() {
let mut pool = FlightClientPool::new(5);
let result = pool.get_client();
assert!(result.is_err());
}
}