#![doc = include_str!("../README.md")]
mod error;
pub use error::{Error, Result};
pub use arrow_flight;
use arrow_array::RecordBatch;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{Action, Criteria, FlightData, FlightDescriptor};
use arrow_schema::Schema;
use futures_util::{stream, Stream, StreamExt, TryStreamExt};
use tonic::transport::Channel;
#[derive(Debug)]
pub struct Client {
client: FlightServiceClient<Channel>,
protocol_version: u8,
}
impl Client {
pub async fn connect(addr: impl Into<String>) -> Result<Self> {
let conn = tonic::transport::Endpoint::new(addr.into())?
.connect()
.await?;
Self::new(FlightServiceClient::new(conn)).await
}
pub async fn new(mut client: FlightServiceClient<Channel>) -> Result<Self> {
let protocol_version = match client.do_action(Action::new("protocol_version", "")).await {
Err(_) => 1,
Ok(response) => *response
.into_inner()
.next()
.await
.ok_or_else(|| Error::Decode("no protocol version".into()))??
.body
.first()
.ok_or_else(|| Error::Decode("invalid protocol version".into()))?,
};
Ok(Self {
client,
protocol_version,
})
}
pub fn protocol_version(&self) -> u8 {
self.protocol_version
}
pub async fn get(&self, name: &str) -> Result<Function> {
let descriptor = FlightDescriptor::new_path(vec![name.into()]);
let response = self.client.clone().get_flight_info(descriptor).await?;
Function::from_flight_info(response.into_inner())
}
pub async fn list(&self) -> Result<Vec<Function>> {
let response = self
.client
.clone()
.list_flights(Criteria::default())
.await?;
let mut functions = vec![];
let mut response = response.into_inner();
while let Some(flight_info) = response.next().await {
let function = Function::from_flight_info(flight_info?)?;
functions.push(function);
}
Ok(functions)
}
pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
self.call_internal(name, input).await
}
async fn call_internal(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
let input = input.clone();
let mut output_stream = self.call_stream_internal(name, input).await?;
let mut batches = vec![];
while let Some(batch) = output_stream.next().await {
batches.push(batch?);
}
Ok(arrow_select::concat::concat_batches(
output_stream
.schema()
.ok_or_else(|| Error::Decode("no schema".into()))?,
batches.iter(),
)?)
}
pub async fn call_table_function(
&self,
name: &str,
input: &RecordBatch,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
let input = input.clone();
Ok(self
.call_stream_internal(name, input)
.await?
.map_err(|e| e.into()))
}
async fn call_stream_internal(
&self,
name: &str,
input: RecordBatch,
) -> Result<FlightRecordBatchStream> {
let descriptor = FlightDescriptor::new_path(vec![name.into()]);
let flight_data_stream = FlightDataEncoderBuilder::new()
.build(stream::once(async { Ok(input) }))
.map(move |res| FlightData {
flight_descriptor: Some(descriptor.clone()),
..res.unwrap()
});
let response = self.client.clone().do_exchange(flight_data_stream).await?;
let stream = response.into_inner();
Ok(FlightRecordBatchStream::new_from_flight_data(
stream.map_err(|e| e.into()),
))
}
}
#[derive(Debug)]
pub struct Function {
pub name: String,
pub args: Schema,
pub returns: Schema,
}
impl Function {
fn from_flight_info(info: arrow_flight::FlightInfo) -> Result<Self> {
let descriptor = info
.flight_descriptor
.as_ref()
.ok_or_else(|| Error::Decode("no descriptor in flight info".into()))?;
let name = descriptor
.path
.first()
.ok_or_else(|| Error::Decode("empty path in flight descriptor".into()))?
.clone();
let input_num = info.total_records as usize;
let schema = Schema::try_from(info)
.map_err(|e| Error::Decode(format!("failed to decode schema: {e}")))?;
if input_num > schema.fields.len() {
return Err(Error::Decode(format!("invalid input_number: {input_num}")));
}
let (input_fields, return_fields) = schema.fields.split_at(input_num);
Ok(Self {
name,
args: Schema::new(input_fields),
returns: Schema::new(return_fields),
})
}
}