use std::sync::Arc;
use arrow::compute::cast;
use arrow_array::{cast::AsArray, ArrayRef};
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarValue> {
match value {
ScalarValue::Int8(val) => match ty {
DataType::Int8 => Some(value.clone()),
DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(i16::from(v)))),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => Some(value.clone()),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => Some(value.clone()),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Int64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => Some(value.clone()),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt8(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(v.into()))),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => Some(value.clone()),
DataType::UInt16 => val.map(|v| ScalarValue::UInt16(Some(u16::from(v)))),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => Some(value.clone()),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => Some(value.clone()),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => Some(value.clone()),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Float32(val) => match ty {
DataType::Float32 => Some(value.clone()),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Float64(val) => match ty {
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => Some(value.clone()),
_ => None,
},
ScalarValue::Utf8(val) => match ty {
DataType::Utf8 => Some(value.clone()),
DataType::LargeUtf8 => Some(ScalarValue::LargeUtf8(val.clone())),
_ => None,
},
ScalarValue::Boolean(_) => match ty {
DataType::Boolean => Some(value.clone()),
_ => None,
},
ScalarValue::Null => Some(value.clone()),
ScalarValue::List(values) => {
let values = values.clone() as ArrayRef;
let new_values = cast(&values, ty).ok()?;
match ty {
DataType::List(_) => {
Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
}
DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
new_values.as_list().clone(),
))),
DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
new_values.as_fixed_size_list().clone(),
))),
_ => None,
}
}
ScalarValue::LargeList(values) => {
let values = values.clone() as ArrayRef;
let new_values = cast(&values, ty).ok()?;
match ty {
DataType::List(_) => {
Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
}
DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
new_values.as_list().clone(),
))),
DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
new_values.as_fixed_size_list().clone(),
))),
_ => None,
}
}
ScalarValue::FixedSizeList(values) => {
let values = values.clone() as ArrayRef;
let new_values = cast(&values, ty).ok()?;
match ty {
DataType::List(_) => {
Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
}
DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
new_values.as_list().clone(),
))),
DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
new_values.as_fixed_size_list().clone(),
))),
_ => None,
}
}
_ => None,
}
}