use arrow_schema::DataType;
use datafusion_common::ScalarValue;
pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarValue> {
match value {
ScalarValue::Int8(val) => match ty {
DataType::Int8 => Some(value.clone()),
DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(i16::from(v)))),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => Some(value.clone()),
DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Int32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => Some(value.clone()),
DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Int64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => Some(value.clone()),
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => {
val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
}
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt8(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => Some(value.clone()),
DataType::UInt16 => val.map(|v| ScalarValue::UInt16(Some(u16::from(v)))),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt16(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => Some(value.clone()),
DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::UInt32(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => Some(value.clone()),
DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::UInt64(val) => match ty {
DataType::Int8 => {
val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
}
DataType::Int16 => {
val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
}
DataType::Int32 => {
val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
}
DataType::Int64 => {
val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
}
DataType::UInt8 => {
val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
}
DataType::UInt16 => {
val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
}
DataType::UInt32 => {
val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
}
DataType::UInt64 => Some(value.clone()),
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
_ => None,
},
ScalarValue::Float32(val) => match ty {
DataType::Float32 => Some(value.clone()),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
_ => None,
},
ScalarValue::Float64(val) => match ty {
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => Some(value.clone()),
_ => None,
},
ScalarValue::Utf8(val) => match ty {
DataType::Utf8 => Some(value.clone()),
DataType::LargeUtf8 => Some(ScalarValue::LargeUtf8(val.clone())),
_ => None,
},
ScalarValue::Boolean(_) => match ty {
DataType::Boolean => Some(value.clone()),
_ => None,
},
ScalarValue::Null => Some(value.clone()),
ScalarValue::List(vals, _) => {
if let DataType::FixedSizeList(_, size) = ty {
if let Some(vals) = vals {
if vals.len() as i32 != *size {
return None;
}
}
}
let (new_values, field) = match ty {
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _) => {
if let Some(vals) = vals {
let values = vals
.iter()
.map(|val| safe_coerce_scalar(val, field.data_type()))
.collect::<Option<Vec<_>>>();
(values, field)
} else {
(None, field)
}
}
_ => return None,
};
match ty {
DataType::List(_) => Some(ScalarValue::List(new_values, field.clone())),
DataType::FixedSizeList(_, size) => {
Some(ScalarValue::Fixedsizelist(new_values, field.clone(), *size))
}
_ => None,
}
}
_ => None,
}
}