hamelin_datafusion 0.7.5

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! map_from_entries UDF for DataFusion.
//!
//! Converts an array of key-value pairs (structs/tuples) into a map.
//! This is equivalent to Trino's `map_from_entries` function.

use std::any::Any;
use std::sync::{Arc, OnceLock};

use datafusion::arrow::array::{
    Array, ArrayRef, AsArray, GenericListArray, MapArray, OffsetSizeTrait, StructArray,
};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
    Volatility,
};

use crate::struct_expansion::map_data_type;

/// UDF that converts an array of key-value pairs into a map.
///
/// `map_from_entries(array<struct<K,V>>)` -> `map<K,V>`
///
/// Takes an array of structs where each struct has exactly two fields (key and value)
/// and returns a map from keys to values.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct MapFromEntriesUdf {
    signature: Signature,
}

impl Default for MapFromEntriesUdf {
    fn default() -> Self {
        Self::new()
    }
}

impl MapFromEntriesUdf {
    pub fn new() -> Self {
        Self {
            // Accept any list type - we'll validate it's a list of structs in invoke
            signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
        }
    }
}

impl ScalarUDFImpl for MapFromEntriesUdf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_map_from_entries"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
        let field = match &arg_types[0] {
            DataType::List(field) | DataType::LargeList(field) => field,
            _ => {
                return exec_err!(
                    "map_from_entries expects array type, got {:?}",
                    arg_types[0]
                )
            }
        };
        match field.data_type() {
            DataType::Struct(fields) if fields.len() == 2 => Ok(map_data_type(
                fields[0].data_type().clone(),
                fields[1].data_type().clone(),
            )),
            _ => exec_err!(
                "map_from_entries expects array of 2-field structs, got array of {:?}",
                field.data_type()
            ),
        }
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        let args = args.args;
        if args.len() != 1 {
            return exec_err!(
                "map_from_entries expects exactly 1 argument, got {}",
                args.len()
            );
        }

        match &args[0] {
            ColumnarValue::Scalar(scalar) => {
                let result = convert_scalar_to_map(scalar)?;
                Ok(ColumnarValue::Scalar(result))
            }
            ColumnarValue::Array(array) => {
                let result = convert_array_to_maps(array)?;
                Ok(ColumnarValue::Array(result))
            }
        }
    }
}

/// Convert a scalar list of pairs to a scalar map.
fn convert_scalar_to_map(scalar: &ScalarValue) -> Result<ScalarValue> {
    let entries = match scalar {
        ScalarValue::List(arr) => {
            if arr.is_empty() || arr.is_null(0) {
                return create_null_map_scalar_from_type(arr.value_type());
            }
            arr.value(0)
        }
        ScalarValue::LargeList(arr) => {
            if arr.is_empty() || arr.is_null(0) {
                return create_null_map_scalar_from_type(arr.value_type());
            }
            arr.value(0)
        }
        _ => {
            return exec_err!(
                "map_from_entries expects list type, got {:?}",
                scalar.data_type()
            )
        }
    };
    convert_struct_array_to_map_scalar(entries.as_struct())
}

/// Convert a struct array (key-value pairs) to a map scalar.
fn convert_struct_array_to_map_scalar(struct_array: &StructArray) -> Result<ScalarValue> {
    if struct_array.num_columns() != 2 {
        return exec_err!(
            "map_from_entries expects struct with 2 fields, got {}",
            struct_array.num_columns()
        );
    }

    let keys = struct_array.column(0);
    let values = struct_array.column(1);

    let key_type = keys.data_type().clone();
    let value_type = values.data_type().clone();

    if struct_array.is_empty() {
        // Return empty map (not null) — matches Trino's map_from_entries(ARRAY[])
        let entry_fields = Fields::from(vec![
            Field::new("key", key_type, false),
            Field::new("value", value_type, true),
        ]);
        let empty_entries = StructArray::try_new(
            entry_fields.clone(),
            vec![
                datafusion::arrow::array::new_empty_array(&entry_fields[0].data_type()),
                datafusion::arrow::array::new_empty_array(&entry_fields[1].data_type()),
            ],
            None,
        )
        .map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))?;
        let map_array = MapArray::new(
            Arc::new(Field::new("entries", DataType::Struct(entry_fields), false)),
            OffsetBuffer::from_lengths([0usize]),
            empty_entries,
            None,
            false,
        );
        return ScalarValue::try_from_array(&map_array, 0);
    }

    // Create new struct array with correct field names for map
    let new_struct = StructArray::new(
        Fields::from(vec![
            Field::new("key", key_type, false),
            Field::new("value", value_type, true),
        ]),
        vec![Arc::clone(keys), Arc::clone(values)],
        struct_array.nulls().cloned(),
    );

    // Create map array with single entry
    let map_array = MapArray::new(
        Arc::new(Field::new("entries", new_struct.data_type().clone(), false)),
        OffsetBuffer::from_lengths([struct_array.len()]),
        new_struct,
        None,
        false,
    );

    ScalarValue::try_from_array(&map_array, 0)
}

/// Create a null map scalar, preserving key/value type information.
fn create_null_map_scalar_from_type(element_type: DataType) -> Result<ScalarValue> {
    let (key_type, value_type) = match element_type {
        DataType::Struct(fields) if fields.len() == 2 => {
            (fields[0].data_type().clone(), fields[1].data_type().clone())
        }
        _ => (DataType::Utf8, DataType::Utf8), // fallback
    };

    let entries_field = Field::new(
        "entries",
        DataType::Struct(Fields::from(vec![
            Field::new("key", key_type, false),
            Field::new("value", value_type, true),
        ])),
        false,
    );
    let map_type = DataType::Map(Arc::new(entries_field), false);
    ScalarValue::try_new_null(&map_type)
}

/// Convert an array of lists (each containing struct pairs) to an array of maps.
fn convert_array_to_maps(array: &ArrayRef) -> Result<ArrayRef> {
    match array.data_type() {
        DataType::List(_) => convert_list_to_maps(array.as_list::<i32>()),
        DataType::LargeList(_) => convert_list_to_maps(array.as_list::<i64>()),
        dt => exec_err!("map_from_entries expects array type, got {:?}", dt),
    }
}

fn convert_list_to_maps<O: OffsetSizeTrait>(list_array: &GenericListArray<O>) -> Result<ArrayRef> {
    let struct_type = match list_array.value_type() {
        DataType::Struct(fields) if fields.len() == 2 => fields,
        other => {
            return exec_err!(
                "map_from_entries expects array of 2-field structs, got array of {:?}",
                other
            )
        }
    };

    let key_type = struct_type[0].data_type().clone();
    let value_type = struct_type[1].data_type().clone();

    // Collect all entries from all lists and track offsets
    let mut all_keys: Vec<ScalarValue> = Vec::new();
    let mut all_values: Vec<ScalarValue> = Vec::new();
    let mut offsets: Vec<i32> = vec![0];
    let mut nulls: Vec<bool> = Vec::new();

    for i in 0..list_array.len() {
        if list_array.is_null(i) {
            nulls.push(false); // null bitmap: false = null
            offsets.push(offsets.last().copied().unwrap_or(0));
            continue;
        }

        nulls.push(true); // not null

        let entries = list_array.value(i);
        let struct_array = entries.as_struct();

        let keys = struct_array.column(0);
        let values = struct_array.column(1);

        for j in 0..struct_array.len() {
            all_keys.push(ScalarValue::try_from_array(keys, j)?);
            all_values.push(ScalarValue::try_from_array(values, j)?);
        }

        offsets.push(all_keys.len() as i32);
    }

    // Build the combined struct array
    let keys_array = if all_keys.is_empty() {
        datafusion::arrow::array::new_empty_array(&key_type)
    } else {
        ScalarValue::iter_to_array(all_keys.into_iter())?
    };

    let values_array = if all_values.is_empty() {
        datafusion::arrow::array::new_empty_array(&value_type)
    } else {
        ScalarValue::iter_to_array(all_values.into_iter())?
    };

    let struct_array = StructArray::new(
        Fields::from(vec![
            Field::new("key", key_type.clone(), false),
            Field::new("value", value_type.clone(), true),
        ]),
        vec![keys_array, values_array],
        None,
    );

    // Create null buffer
    let null_buffer = if nulls.iter().all(|&b| b) {
        None
    } else {
        Some(datafusion::arrow::buffer::NullBuffer::from(nulls))
    };

    // Create map array
    let map_array = MapArray::new(
        Arc::new(Field::new(
            "entries",
            struct_array.data_type().clone(),
            false,
        )),
        OffsetBuffer::new(offsets.into()),
        struct_array,
        null_buffer,
        false,
    );

    Ok(Arc::new(map_array))
}

/// Get the map_from_entries UDF.
static MAP_FROM_ENTRIES_UDF: OnceLock<ScalarUDF> = OnceLock::new();

pub fn map_from_entries_udf() -> ScalarUDF {
    MAP_FROM_ENTRIES_UDF
        .get_or_init(|| ScalarUDF::new_from_impl(MapFromEntriesUdf::new()))
        .clone()
}