arrow_udf_runtime/remote/
mod.rs1#![doc = include_str!("README.md")]
16
17mod error;
18
19pub use error::{Error, Result};
20
21pub use arrow_flight;
24
25use arrow_array::RecordBatch;
26use arrow_flight::decode::FlightRecordBatchStream;
27use arrow_flight::encode::FlightDataEncoderBuilder;
28use arrow_flight::flight_service_client::FlightServiceClient;
29use arrow_flight::{Action, Criteria, FlightData, FlightDescriptor};
30use arrow_schema::Schema;
31use futures_util::{stream, Stream, StreamExt, TryStreamExt};
32use tonic::transport::Channel;
33
34#[derive(Debug)]
36pub struct Client {
37 client: FlightServiceClient<Channel>,
38 protocol_version: u8,
39}
40
41impl Client {
42 pub async fn connect(addr: impl Into<String>) -> Result<Self> {
44 let conn = tonic::transport::Endpoint::new(addr.into())?
45 .connect()
46 .await?;
47 Self::new(FlightServiceClient::new(conn)).await
48 }
49
50 pub async fn new(mut client: FlightServiceClient<Channel>) -> Result<Self> {
52 let protocol_version = match client.do_action(Action::new("protocol_version", "")).await {
54 Err(_) => 1,
56 Ok(response) => *response
58 .into_inner()
59 .next()
60 .await
61 .ok_or_else(|| Error::Decode("no protocol version".into()))??
62 .body
63 .first()
64 .ok_or_else(|| Error::Decode("invalid protocol version".into()))?,
65 };
66
67 Ok(Self {
68 client,
69 protocol_version,
70 })
71 }
72
73 pub fn protocol_version(&self) -> u8 {
75 self.protocol_version
76 }
77
78 pub async fn get(&self, name: &str) -> Result<Function> {
80 let descriptor = FlightDescriptor::new_path(vec![name.into()]);
81 let response = self.client.clone().get_flight_info(descriptor).await?;
82 Function::from_flight_info(response.into_inner())
83 }
84
85 pub async fn list(&self) -> Result<Vec<Function>> {
87 let response = self
88 .client
89 .clone()
90 .list_flights(Criteria::default())
91 .await?;
92 let mut functions = vec![];
93 let mut response = response.into_inner();
94 while let Some(flight_info) = response.next().await {
95 let function = Function::from_flight_info(flight_info?)?;
96 functions.push(function);
97 }
98 Ok(functions)
99 }
100
101 pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
103 self.call_internal(name, input).await
104 }
105
106 async fn call_internal(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
107 let input = input.clone();
108 let mut output_stream = self.call_stream_internal(name, input).await?;
109 let mut batches = vec![];
110 while let Some(batch) = output_stream.next().await {
111 batches.push(batch?);
112 }
113 Ok(arrow_select::concat::concat_batches(
114 output_stream
115 .schema()
116 .ok_or_else(|| Error::Decode("no schema".into()))?,
117 batches.iter(),
118 )?)
119 }
120
121 pub async fn call_table_function(
123 &self,
124 name: &str,
125 input: &RecordBatch,
126 ) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
127 let input = input.clone();
128 Ok(self
129 .call_stream_internal(name, input)
130 .await?
131 .map_err(|e| e.into()))
132 }
133
134 async fn call_stream_internal(
135 &self,
136 name: &str,
137 input: RecordBatch,
138 ) -> Result<FlightRecordBatchStream> {
139 let descriptor = FlightDescriptor::new_path(vec![name.into()]);
140 let flight_data_stream = FlightDataEncoderBuilder::new()
141 .build(stream::once(async { Ok(input) }))
142 .map(move |res| FlightData {
143 flight_descriptor: Some(descriptor.clone()),
144 ..res.unwrap()
145 });
146
147 let response = self.client.clone().do_exchange(flight_data_stream).await?;
149
150 let stream = response.into_inner();
152 Ok(FlightRecordBatchStream::new_from_flight_data(
153 stream.map_err(|e| e.into()),
155 ))
156 }
157}
158
159#[derive(Debug)]
161pub struct Function {
162 pub name: String,
164 pub args: Schema,
166 pub returns: Schema,
168}
169
170impl Function {
171 fn from_flight_info(info: arrow_flight::FlightInfo) -> Result<Self> {
173 let descriptor = info
174 .flight_descriptor
175 .as_ref()
176 .ok_or_else(|| Error::Decode("no descriptor in flight info".into()))?;
177 let name = descriptor
178 .path
179 .first()
180 .ok_or_else(|| Error::Decode("empty path in flight descriptor".into()))?
181 .clone();
182 let input_num = info.total_records as usize;
183 let schema = Schema::try_from(info)
184 .map_err(|e| Error::Decode(format!("failed to decode schema: {e}")))?;
185 if input_num > schema.fields.len() {
186 return Err(Error::Decode(format!("invalid input_number: {input_num}")));
187 }
188 let (input_fields, return_fields) = schema.fields.split_at(input_num);
189 Ok(Self {
190 name,
191 args: Schema::new(input_fields),
192 returns: Schema::new(return_fields),
193 })
194 }
195}