arrow_udf_runtime/remote/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17mod error;
18
19pub use error::{Error, Result};
20
21/// Re-export `arrow_flight` so downstream crates can use it without declaring it on their own,
22/// avoiding version conflicts.
23pub use arrow_flight;
24
25use arrow_array::RecordBatch;
26use arrow_flight::decode::FlightRecordBatchStream;
27use arrow_flight::encode::FlightDataEncoderBuilder;
28use arrow_flight::flight_service_client::FlightServiceClient;
29use arrow_flight::{Action, Criteria, FlightData, FlightDescriptor};
30use arrow_schema::Schema;
31use futures_util::{stream, Stream, StreamExt, TryStreamExt};
32use tonic::transport::Channel;
33
34/// Client for a remote Arrow UDF service.
35#[derive(Debug)]
36pub struct Client {
37    client: FlightServiceClient<Channel>,
38    protocol_version: u8,
39}
40
41impl Client {
42    /// Connect to a UDF service.
43    pub async fn connect(addr: impl Into<String>) -> Result<Self> {
44        let conn = tonic::transport::Endpoint::new(addr.into())?
45            .connect()
46            .await?;
47        Self::new(FlightServiceClient::new(conn)).await
48    }
49
50    /// Create a new client.
51    pub async fn new(mut client: FlightServiceClient<Channel>) -> Result<Self> {
52        // get protocol version in server
53        let protocol_version = match client.do_action(Action::new("protocol_version", "")).await {
54            // if `do_action` is not implemented, assume protocol version is 1
55            Err(_) => 1,
56            // >= 2
57            Ok(response) => *response
58                .into_inner()
59                .next()
60                .await
61                .ok_or_else(|| Error::Decode("no protocol version".into()))??
62                .body
63                .first()
64                .ok_or_else(|| Error::Decode("invalid protocol version".into()))?,
65        };
66
67        Ok(Self {
68            client,
69            protocol_version,
70        })
71    }
72
73    /// Get protocol version.
74    pub fn protocol_version(&self) -> u8 {
75        self.protocol_version
76    }
77
78    /// Get function schema.
79    pub async fn get(&self, name: &str) -> Result<Function> {
80        let descriptor = FlightDescriptor::new_path(vec![name.into()]);
81        let response = self.client.clone().get_flight_info(descriptor).await?;
82        Function::from_flight_info(response.into_inner())
83    }
84
85    /// List all available functions.
86    pub async fn list(&self) -> Result<Vec<Function>> {
87        let response = self
88            .client
89            .clone()
90            .list_flights(Criteria::default())
91            .await?;
92        let mut functions = vec![];
93        let mut response = response.into_inner();
94        while let Some(flight_info) = response.next().await {
95            let function = Function::from_flight_info(flight_info?)?;
96            functions.push(function);
97        }
98        Ok(functions)
99    }
100
101    /// Call a function.
102    pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
103        self.call_internal(name, input).await
104    }
105
106    async fn call_internal(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
107        let input = input.clone();
108        let mut output_stream = self.call_stream_internal(name, input).await?;
109        let mut batches = vec![];
110        while let Some(batch) = output_stream.next().await {
111            batches.push(batch?);
112        }
113        Ok(arrow_select::concat::concat_batches(
114            output_stream
115                .schema()
116                .ok_or_else(|| Error::Decode("no schema".into()))?,
117            batches.iter(),
118        )?)
119    }
120
121    /// Call a table function.
122    pub async fn call_table_function(
123        &self,
124        name: &str,
125        input: &RecordBatch,
126    ) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
127        let input = input.clone();
128        Ok(self
129            .call_stream_internal(name, input)
130            .await?
131            .map_err(|e| e.into()))
132    }
133
134    async fn call_stream_internal(
135        &self,
136        name: &str,
137        input: RecordBatch,
138    ) -> Result<FlightRecordBatchStream> {
139        let descriptor = FlightDescriptor::new_path(vec![name.into()]);
140        let flight_data_stream = FlightDataEncoderBuilder::new()
141            .build(stream::once(async { Ok(input) }))
142            .map(move |res| FlightData {
143                flight_descriptor: Some(descriptor.clone()),
144                ..res.unwrap()
145            });
146
147        // call `do_exchange` on Flight server
148        let response = self.client.clone().do_exchange(flight_data_stream).await?;
149
150        // decode response
151        let stream = response.into_inner();
152        Ok(FlightRecordBatchStream::new_from_flight_data(
153            // convert tonic::Status to FlightError
154            stream.map_err(|e| e.into()),
155        ))
156    }
157}
158
159/// Function signature.
160#[derive(Debug)]
161pub struct Function {
162    /// Function name.
163    pub name: String,
164    /// The schema of function arguments.
165    pub args: Schema,
166    /// The schema of function return values.
167    pub returns: Schema,
168}
169
170impl Function {
171    /// Create a function from a `FlightInfo`.
172    fn from_flight_info(info: arrow_flight::FlightInfo) -> Result<Self> {
173        let descriptor = info
174            .flight_descriptor
175            .as_ref()
176            .ok_or_else(|| Error::Decode("no descriptor in flight info".into()))?;
177        let name = descriptor
178            .path
179            .first()
180            .ok_or_else(|| Error::Decode("empty path in flight descriptor".into()))?
181            .clone();
182        let input_num = info.total_records as usize;
183        let schema = Schema::try_from(info)
184            .map_err(|e| Error::Decode(format!("failed to decode schema: {e}")))?;
185        if input_num > schema.fields.len() {
186            return Err(Error::Decode(format!("invalid input_number: {input_num}")));
187        }
188        let (input_fields, return_fields) = schema.fields.split_at(input_num);
189        Ok(Self {
190            name,
191            args: Schema::new(input_fields),
192            returns: Schema::new(return_fields),
193        })
194    }
195}