datafusion_functions/string/
concat_ws.rsuse arrow::array::{as_largestring_array, Array, StringArray};
use std::any::Any;
use std::sync::Arc;
use arrow::datatypes::DataType;
use crate::string::common::*;
use crate::string::concat::simplify_concat;
use crate::string::concat_ws;
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
#[derive(Debug)]
pub struct ConcatWsFunc {
    signature: Signature,
}
impl Default for ConcatWsFunc {
    fn default() -> Self {
        ConcatWsFunc::new()
    }
}
impl ConcatWsFunc {
    pub fn new() -> Self {
        use DataType::*;
        Self {
            signature: Signature::variadic(
                vec![Utf8View, Utf8, LargeUtf8],
                Volatility::Immutable,
            ),
        }
    }
}
impl ScalarUDFImpl for ConcatWsFunc {
    fn as_any(&self) -> &dyn Any {
        self
    }
    fn name(&self) -> &str {
        "concat_ws"
    }
    fn signature(&self) -> &Signature {
        &self.signature
    }
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        use DataType::*;
        Ok(Utf8)
    }
    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
        if args.len() < 2 {
            return exec_err!(
                "concat_ws was called with {} arguments. It requires at least 2.",
                args.len()
            );
        }
        let array_len = args
            .iter()
            .filter_map(|x| match x {
                ColumnarValue::Array(array) => Some(array.len()),
                _ => None,
            })
            .next();
        if array_len.is_none() {
            let sep = match &args[0] {
                ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
                | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s,
                ColumnarValue::Scalar(ScalarValue::Utf8(None))
                | ColumnarValue::Scalar(ScalarValue::Utf8View(None))
                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {
                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
                }
                _ => unreachable!(),
            };
            let mut result = String::new();
            let iter = &mut args[1..].iter();
            for arg in iter.by_ref() {
                match arg {
                    ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
                    | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
                        result.push_str(s);
                        break;
                    }
                    ColumnarValue::Scalar(ScalarValue::Utf8(None))
                    | ColumnarValue::Scalar(ScalarValue::Utf8View(None))
                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
                    _ => unreachable!(),
                }
            }
            for arg in iter.by_ref() {
                match arg {
                    ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
                    | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
                        result.push_str(sep);
                        result.push_str(s);
                    }
                    ColumnarValue::Scalar(ScalarValue::Utf8(None))
                    | ColumnarValue::Scalar(ScalarValue::Utf8View(None))
                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
                    _ => unreachable!(),
                }
            }
            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
        }
        let len = array_len.unwrap();
        let mut data_size = 0;
        let sep = match &args[0] {
            ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
                data_size += s.len() * len * (args.len() - 2); ColumnarValueRef::Scalar(s.as_bytes())
            }
            ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
                return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len))));
            }
            ColumnarValue::Array(array) => {
                let string_array = as_string_array(array)?;
                data_size += string_array.values().len() * (args.len() - 2); if array.is_nullable() {
                    ColumnarValueRef::NullableArray(string_array)
                } else {
                    ColumnarValueRef::NonNullableArray(string_array)
                }
            }
            _ => unreachable!(),
        };
        let mut columns = Vec::with_capacity(args.len() - 1);
        for arg in &args[1..] {
            match arg {
                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
                    if let Some(s) = maybe_value {
                        data_size += s.len() * len;
                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
                    }
                }
                ColumnarValue::Array(array) => {
                    match array.data_type() {
                        DataType::Utf8 => {
                            let string_array = as_string_array(array)?;
                            data_size += string_array.values().len();
                            let column = if array.is_nullable() {
                                ColumnarValueRef::NullableArray(string_array)
                            } else {
                                ColumnarValueRef::NonNullableArray(string_array)
                            };
                            columns.push(column);
                        },
                        DataType::LargeUtf8 => {
                            let string_array = as_largestring_array(array);
                            data_size += string_array.values().len();
                            let column = if array.is_nullable() {
                                ColumnarValueRef::NullableLargeStringArray(string_array)
                            } else {
                                ColumnarValueRef::NonNullableLargeStringArray(string_array)
                            };
                            columns.push(column);
                        },
                        DataType::Utf8View => {
                            let string_array = as_string_view_array(array)?;
                            data_size += string_array.data_buffers().iter().map(|buf| buf.len()).sum::<usize>();
                            let column = if array.is_nullable() {
                                ColumnarValueRef::NullableStringViewArray(string_array)
                            } else {
                                ColumnarValueRef::NonNullableStringViewArray(string_array)
                            };
                            columns.push(column);
                        },
                        other => {
                            return plan_err!("Input was {other} which is not a supported datatype for concat_ws function.")
                        }
                    };
                }
                _ => unreachable!(),
            }
        }
        let mut builder = StringArrayBuilder::with_capacity(len, data_size);
        for i in 0..len {
            if !sep.is_valid(i) {
                builder.append_offset();
                continue;
            }
            let mut iter = columns.iter();
            for column in iter.by_ref() {
                if column.is_valid(i) {
                    builder.write::<false>(column, i);
                    break;
                }
            }
            for column in iter {
                if column.is_valid(i) {
                    builder.write::<false>(&sep, i);
                    builder.write::<false>(column, i);
                }
            }
            builder.append_offset();
        }
        Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
    }
    fn simplify(
        &self,
        args: Vec<Expr>,
        _info: &dyn SimplifyInfo,
    ) -> Result<ExprSimplifyResult> {
        match &args[..] {
            [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals),
            _ => Ok(ExprSimplifyResult::Original(args)),
        }
    }
}
fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
    match delimiter {
        Expr::Literal(
            ScalarValue::Utf8(delimiter)
            | ScalarValue::LargeUtf8(delimiter)
            | ScalarValue::Utf8View(delimiter),
        ) => {
            match delimiter {
                Some(delimiter) if delimiter.is_empty() => simplify_concat(args.to_vec()),
                Some(delimiter) => {
                    let mut new_args = Vec::with_capacity(args.len());
                    new_args.push(lit(delimiter));
                    let mut contiguous_scalar = None;
                    for arg in args {
                        match arg {
                            Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
                            Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => {
                                match contiguous_scalar {
                                    None => contiguous_scalar = Some(v.to_string()),
                                    Some(mut pre) => {
                                        pre += delimiter;
                                        pre += v;
                                        contiguous_scalar = Some(pre)
                                    }
                                }
                            }
                            Expr::Literal(s) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."),
                            arg => {
                                if let Some(val) = contiguous_scalar {
                                    new_args.push(lit(val));
                                }
                                new_args.push(arg.clone());
                                contiguous_scalar = None;
                            }
                        }
                    }
                    if let Some(val) = contiguous_scalar {
                        new_args.push(lit(val));
                    }
                    Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
                        ScalarFunction {
                            func: concat_ws(),
                            args: new_args,
                        },
                    )))
                }
                None => Ok(ExprSimplifyResult::Simplified(Expr::Literal(
                    ScalarValue::Utf8(None),
                ))),
            }
        }
        Expr::Literal(d) => internal_err!(
            "The scalar {d} should be casted to string type during the type coercion."
        ),
        _ => {
            let mut args = args
                .iter()
                .filter(|&x| !is_null(x))
                .cloned()
                .collect::<Vec<Expr>>();
            args.insert(0, delimiter.clone());
            Ok(ExprSimplifyResult::Original(args))
        }
    }
}
fn is_null(expr: &Expr) -> bool {
    match expr {
        Expr::Literal(v) => v.is_null(),
        _ => false,
    }
}
#[cfg(test)]
mod tests {
    use std::sync::Arc;
    use arrow::array::{Array, ArrayRef, StringArray};
    use arrow::datatypes::DataType::Utf8;
    use crate::string::concat_ws::ConcatWsFunc;
    use datafusion_common::Result;
    use datafusion_common::ScalarValue;
    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
    use crate::utils::test::test_function;
    #[test]
    fn test_functions() -> Result<()> {
        test_function!(
            ConcatWsFunc::new(),
            &[
                ColumnarValue::Scalar(ScalarValue::from("|")),
                ColumnarValue::Scalar(ScalarValue::from("aa")),
                ColumnarValue::Scalar(ScalarValue::from("bb")),
                ColumnarValue::Scalar(ScalarValue::from("cc")),
            ],
            Ok(Some("aa|bb|cc")),
            &str,
            Utf8,
            StringArray
        );
        test_function!(
            ConcatWsFunc::new(),
            &[
                ColumnarValue::Scalar(ScalarValue::from("|")),
                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
            ],
            Ok(Some("")),
            &str,
            Utf8,
            StringArray
        );
        test_function!(
            ConcatWsFunc::new(),
            &[
                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
                ColumnarValue::Scalar(ScalarValue::from("aa")),
                ColumnarValue::Scalar(ScalarValue::from("bb")),
                ColumnarValue::Scalar(ScalarValue::from("cc")),
            ],
            Ok(None),
            &str,
            Utf8,
            StringArray
        );
        test_function!(
            ConcatWsFunc::new(),
            &[
                ColumnarValue::Scalar(ScalarValue::from("|")),
                ColumnarValue::Scalar(ScalarValue::from("aa")),
                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
                ColumnarValue::Scalar(ScalarValue::from("cc")),
            ],
            Ok(Some("aa|cc")),
            &str,
            Utf8,
            StringArray
        );
        Ok(())
    }
    #[test]
    fn concat_ws() -> Result<()> {
        let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
        let c1 =
            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
            Some("x"),
            None,
            Some("z"),
        ])));
        let args = &[c0, c1, c2];
        let result = ConcatWsFunc::new().invoke(args)?;
        let expected =
            Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
        match &result {
            ColumnarValue::Array(array) => {
                assert_eq!(&expected, array);
            }
            _ => panic!(),
        }
        let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
            Some(","),
            None,
            Some("+"),
        ])));
        let c1 =
            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
            Some("x"),
            Some("y"),
            Some("z"),
        ])));
        let args = &[c0, c1, c2];
        let result = ConcatWsFunc::new().invoke(args)?;
        let expected =
            Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
                as ArrayRef;
        match &result {
            ColumnarValue::Array(array) => {
                assert_eq!(&expected, array);
            }
            _ => panic!(),
        }
        Ok(())
    }
}