use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
use arrow::record_batch::RecordBatch;
use arrow_flight::{
flight_service_client::FlightServiceClient, FlightData, FlightDescriptor, Ticket,
};
use futures::StreamExt;
use std::io::Cursor;
use tonic::transport::Channel;
use crate::core::error::{Error, Result};
use crate::dataframe::DataFrame;
use super::conversion::{dataframe_to_record_batch, record_batch_to_dataframe};
pub struct PandRsFlightClient {
endpoint: String,
inner: Option<FlightServiceClient<Channel>>,
}
impl PandRsFlightClient {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
inner: None,
}
}
pub async fn connect(&mut self) -> Result<()> {
let channel = Channel::from_shared(self.endpoint.clone())
.map_err(|e| Error::InvalidOperation(format!("Invalid endpoint URI: {e}")))?
.connect()
.await
.map_err(|e| Error::InvalidOperation(format!("gRPC connect failed: {e}")))?;
self.inner = Some(FlightServiceClient::new(channel));
Ok(())
}
fn client(&mut self) -> Result<&mut FlightServiceClient<Channel>> {
self.inner
.as_mut()
.ok_or_else(|| Error::InvalidOperation("Not connected – call connect() first".into()))
}
pub async fn list_datasets(&mut self) -> Result<Vec<String>> {
let request = arrow_flight::Criteria {
expression: bytes::Bytes::new(),
};
let mut stream = self
.client()?
.list_flights(request)
.await
.map_err(|e| Error::InvalidOperation(format!("ListFlights RPC failed: {e}")))?
.into_inner();
let mut names = Vec::new();
while let Some(info) = stream
.next()
.await
.transpose()
.map_err(|e| Error::InvalidOperation(format!("ListFlights stream error: {e}")))?
{
if let Some(descriptor) = info.flight_descriptor {
if let Some(first) = descriptor.path.into_iter().next() {
names.push(first);
}
}
}
Ok(names)
}
pub async fn get_dataframe(&mut self, name: &str) -> Result<DataFrame> {
let ticket = Ticket {
ticket: bytes::Bytes::from(name.as_bytes().to_vec()),
};
let mut stream = self
.client()?
.do_get(ticket)
.await
.map_err(|e| Error::InvalidOperation(format!("DoGet RPC failed: {e}")))?
.into_inner();
let mut flight_data_msgs: Vec<FlightData> = Vec::new();
while let Some(msg) = stream
.next()
.await
.transpose()
.map_err(|e| Error::InvalidOperation(format!("DoGet stream error: {e}")))?
{
flight_data_msgs.push(msg);
}
if flight_data_msgs.is_empty() {
return Err(Error::InvalidOperation(format!(
"Server returned no data for dataset '{name}'"
)));
}
let batch = flight_data_to_record_batch(&flight_data_msgs)?;
record_batch_to_dataframe(&batch)
}
pub async fn put_dataframe(&mut self, name: &str, df: &DataFrame) -> Result<()> {
let batch = dataframe_to_record_batch(df)?;
let flight_data_msgs = record_batch_to_flight_data_with_descriptor(name, &batch)?;
let request_stream = futures::stream::iter(flight_data_msgs.into_iter());
self.client()?
.do_put(request_stream)
.await
.map_err(|e| Error::InvalidOperation(format!("DoPut RPC failed: {e}")))?;
Ok(())
}
}
fn record_batch_to_flight_data_with_descriptor(
name: &str,
batch: &RecordBatch,
) -> Result<Vec<FlightData>> {
use arrow::ipc::writer::IpcWriteOptions;
let mut buf = Vec::new();
let options = IpcWriteOptions::default();
{
let mut writer = StreamWriter::try_new_with_options(&mut buf, batch.schema_ref(), options)
.map_err(|e| Error::InvalidOperation(format!("IPC writer init failed: {e}")))?;
writer
.write(batch)
.map_err(|e| Error::InvalidOperation(format!("IPC write failed: {e}")))?;
writer
.finish()
.map_err(|e| Error::InvalidOperation(format!("IPC finish failed: {e}")))?;
}
let descriptor = FlightDescriptor::new_path(vec![name.to_string()]);
let flight_data = FlightData {
flight_descriptor: Some(descriptor),
data_header: bytes::Bytes::new(),
app_metadata: bytes::Bytes::new(),
data_body: bytes::Bytes::from(buf),
};
Ok(vec![flight_data])
}
fn flight_data_to_record_batch(msgs: &[FlightData]) -> Result<RecordBatch> {
let mut combined: Vec<u8> = Vec::new();
for msg in msgs {
combined.extend_from_slice(&msg.data_body);
}
let cursor = Cursor::new(combined);
let mut reader = StreamReader::try_new(cursor, None)
.map_err(|e| Error::InvalidOperation(format!("IPC reader init failed: {e}")))?;
let batch = reader
.next()
.ok_or_else(|| Error::InvalidOperation("No RecordBatch in IPC stream".into()))?
.map_err(|e| Error::InvalidOperation(format!("IPC read failed: {e}")))?;
Ok(batch)
}