use std::sync::Arc;
use arrow::{
array::{Array, RecordBatch},
datatypes::DataType,
};
use crate::{
core::DataChunkHandle,
vtab::arrow::{WritableVector, data_chunk_to_arrow, to_duckdb_logical_type, write_arrow_array_to_vector},
};
use super::{ScalarFunctionSignature, ScalarParams, VScalar};
pub enum ArrowScalarParams {
Exact(Vec<DataType>),
Variadic(DataType),
}
impl AsRef<[DataType]> for ArrowScalarParams {
fn as_ref(&self) -> &[DataType] {
match self {
Self::Exact(params) => params.as_ref(),
Self::Variadic(param) => std::slice::from_ref(param),
}
}
}
impl From<ArrowScalarParams> for ScalarParams {
fn from(params: ArrowScalarParams) -> Self {
match params {
ArrowScalarParams::Exact(params) => Self::Exact(
params
.into_iter()
.map(|v| to_duckdb_logical_type(&v).expect("type should be converted"))
.collect(),
),
ArrowScalarParams::Variadic(param) => {
Self::Variadic(to_duckdb_logical_type(¶m).expect("type should be converted"))
}
}
}
}
pub struct ArrowFunctionSignature {
pub parameters: Option<ArrowScalarParams>,
pub return_type: DataType,
}
impl ArrowFunctionSignature {
pub fn exact(params: Vec<DataType>, return_type: DataType) -> Self {
Self {
parameters: Some(ArrowScalarParams::Exact(params)),
return_type,
}
}
pub fn variadic(param: DataType, return_type: DataType) -> Self {
Self {
parameters: Some(ArrowScalarParams::Variadic(param)),
return_type,
}
}
}
pub trait VArrowScalar: Sized {
type State: Default + Sized + Send + Sync + 'static;
fn invoke(state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>>;
fn signatures() -> Vec<ArrowFunctionSignature>;
fn volatile() -> bool {
false
}
}
impl<T> VScalar for T
where
T: VArrowScalar,
{
type State = T::State;
unsafe fn invoke(
state: &Self::State,
input: &mut DataChunkHandle,
out: &mut dyn WritableVector,
) -> Result<(), Box<dyn std::error::Error>> {
let array = T::invoke(state, data_chunk_to_arrow(input)?)?;
write_arrow_array_to_vector(&array, out)
}
fn signatures() -> Vec<ScalarFunctionSignature> {
T::signatures()
.into_iter()
.map(|sig| ScalarFunctionSignature {
parameters: sig.parameters.map(Into::into),
return_type: to_duckdb_logical_type(&sig.return_type).expect("type should be converted"),
})
.collect()
}
}
#[cfg(test)]
mod test {
use std::{error::Error, sync::Arc};
use arrow::{
array::{Array, RecordBatch, StringArray},
datatypes::DataType,
};
use crate::{Connection, vscalar::arrow::ArrowFunctionSignature};
use super::VArrowScalar;
struct HelloScalarArrow {}
impl VArrowScalar for HelloScalarArrow {
type State = ();
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
let name = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let result = name.iter().map(|v| format!("Hello {}", v.unwrap())).collect::<Vec<_>>();
Ok(Arc::new(StringArray::from(result)))
}
fn signatures() -> Vec<ArrowFunctionSignature> {
vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], DataType::Utf8)]
}
}
#[derive(Debug)]
struct MockState {
info: String,
}
impl Default for MockState {
fn default() -> Self {
Self {
info: "some meta".to_string(),
}
}
}
impl Drop for MockState {
fn drop(&mut self) {
println!("dropped meta");
}
}
struct ArrowMultiplyScalar {}
impl VArrowScalar for ArrowMultiplyScalar {
type State = MockState;
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
let a = input
.column(0)
.as_any()
.downcast_ref::<::arrow::array::Float32Array>()
.unwrap();
let b = input
.column(1)
.as_any()
.downcast_ref::<::arrow::array::Float32Array>()
.unwrap();
let result = a
.iter()
.zip(b.iter())
.map(|(a, b)| a.unwrap() * b.unwrap())
.collect::<Vec<_>>();
Ok(Arc::new(::arrow::array::Float32Array::from(result)))
}
fn signatures() -> Vec<ArrowFunctionSignature> {
vec![ArrowFunctionSignature::exact(
vec![DataType::Float32, DataType::Float32],
DataType::Float32,
)]
}
}
struct ArrowOverloaded {}
impl VArrowScalar for ArrowOverloaded {
type State = MockState;
fn invoke(state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
assert_eq!("some meta", state.info);
let a = input.column(0);
let b = input.column(1);
let result = match a.data_type() {
DataType::Utf8 => {
let a = a
.as_any()
.downcast_ref::<::arrow::array::StringArray>()
.unwrap()
.iter()
.map(|v| v.unwrap().parse::<f32>().unwrap())
.collect::<Vec<_>>();
let b = b
.as_any()
.downcast_ref::<::arrow::array::Float32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect::<Vec<_>>();
a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::<Vec<_>>()
}
DataType::Float32 => {
let a = a
.as_any()
.downcast_ref::<::arrow::array::Float32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect::<Vec<_>>();
let b = b
.as_any()
.downcast_ref::<::arrow::array::Float32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect::<Vec<_>>();
a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::<Vec<_>>()
}
_ => panic!("unsupported type"),
};
Ok(Arc::new(::arrow::array::Float32Array::from(result)))
}
fn signatures() -> Vec<ArrowFunctionSignature> {
vec![
ArrowFunctionSignature::exact(vec![DataType::Utf8, DataType::Float32], DataType::Float32),
ArrowFunctionSignature::exact(vec![DataType::Float32, DataType::Float32], DataType::Float32),
]
}
}
#[test]
fn test_arrow_scalar() -> Result<(), Box<dyn Error>> {
let conn = Connection::open_in_memory()?;
conn.register_scalar_function::<HelloScalarArrow>("hello")?;
let batches = conn
.prepare("select hello('foo') as hello from range(10)")?
.query_arrow([])?
.collect::<Vec<_>>();
for batch in batches.iter() {
let array = batch.column(0);
let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap();
for i in 0..array.len() {
assert_eq!(array.value(i), format!("Hello foo"));
}
}
Ok(())
}
#[test]
fn test_arrow_scalar_multiply() -> Result<(), Box<dyn Error>> {
let conn = Connection::open_in_memory()?;
conn.register_scalar_function::<ArrowMultiplyScalar>("multiply_udf")?;
let batches = conn
.prepare("select multiply_udf(3.0, 2.0) as mult_result from range(10)")?
.query_arrow([])?
.collect::<Vec<_>>();
for batch in batches.iter() {
let array = batch.column(0);
let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap();
for i in 0..array.len() {
assert_eq!(array.value(i), 6.0);
}
}
Ok(())
}
#[test]
fn test_multiple_signatures_scalar() -> Result<(), Box<dyn Error>> {
let conn = Connection::open_in_memory()?;
conn.register_scalar_function::<ArrowOverloaded>("multi_sig_udf")?;
let batches = conn
.prepare("select multi_sig_udf('3', 5) as message from range(2)")?
.query_arrow([])?
.collect::<Vec<_>>();
for batch in batches.iter() {
let array = batch.column(0);
let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap();
for i in 0..array.len() {
assert_eq!(array.value(i), 15.0);
}
}
let batches = conn
.prepare("select multi_sig_udf(12, 10) as message from range(2)")?
.query_arrow([])?
.collect::<Vec<_>>();
for batch in batches.iter() {
let array = batch.column(0);
let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap();
for i in 0..array.len() {
assert_eq!(array.value(i), 120.0);
}
}
Ok(())
}
#[test]
fn test_split_function() -> Result<(), Box<dyn Error>> {
struct SplitFunction {}
impl VArrowScalar for SplitFunction {
type State = ();
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
let strings = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity(
strings.len(),
strings.len() * 10,
));
for s in strings.iter() {
let s = s.unwrap();
for split_value in s.split(' ').collect::<Vec<_>>() {
builder.values().append_value(split_value);
}
builder.append(true);
}
Ok(Arc::new(builder.finish()))
}
fn signatures() -> Vec<ArrowFunctionSignature> {
vec![ArrowFunctionSignature::exact(
vec![DataType::Utf8],
DataType::List(Arc::new(arrow::datatypes::Field::new("item", DataType::Utf8, true))),
)]
}
}
let conn = Connection::open_in_memory()?;
conn.register_scalar_function::<SplitFunction>("split_string")?;
let batches = conn
.prepare("select split_string('hello world') as result")?
.query_arrow([])?
.collect::<Vec<_>>();
let array = batches[0].column(0);
let list_array = array.as_any().downcast_ref::<arrow::array::ListArray>().unwrap();
let values = list_array.value(0);
let string_values = values.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_values.value(0), "hello");
assert_eq!(string_values.value(1), "world");
Ok(())
}
}