use std::sync::Arc;
use arrow::compute::cast;
use arrow_array::{cast::AsArray, FixedSizeListArray};
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarValue> {
match value {
ScalarValue::Int8(val) => match ty {
DataType::Int8 => Some(value.clone()),
DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(i16::from(v)))),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => Some(value.clone()),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => Some(value.clone()),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Int64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => Some(value.clone()),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt8(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(v.into()))),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => Some(value.clone()),
DataType::UInt16 => val.map(|v| ScalarValue::UInt16(Some(u16::from(v)))),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => Some(value.clone()),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => Some(value.clone()),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => Some(value.clone()),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Float32(val) => match ty {
DataType::Float32 => Some(value.clone()),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Float64(val) => match ty {
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => Some(value.clone()),
_ => None,
},
ScalarValue::Utf8(val) => match ty {
DataType::Utf8 => Some(value.clone()),
DataType::LargeUtf8 => Some(ScalarValue::LargeUtf8(val.clone())),
_ => None,
},
ScalarValue::Boolean(_) => match ty {
DataType::Boolean => Some(value.clone()),
_ => None,
},
ScalarValue::Null => Some(value.clone()),
ScalarValue::List(values) if matches!(ty, DataType::FixedSizeList(_, _)) => {
let values = values.as_list::<i32>().values();
if let DataType::FixedSizeList(field, size) = ty {
let new_values = cast(values, field.data_type()).ok()?;
Some(ScalarValue::FixedSizeList(Arc::new(
FixedSizeListArray::new(field.clone(), *size, new_values, None),
)))
} else {
unreachable!()
}
}
ScalarValue::List(values) | ScalarValue::FixedSizeList(values) => {
let new_values = cast(values, ty).ok()?;
match ty {
DataType::List(_) => Some(ScalarValue::List(new_values)),
DataType::LargeList(_) => Some(ScalarValue::LargeList(new_values)),
DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(new_values)),
_ => None,
}
}
_ => None,
}
}