use std::iter::Sum;
use std::{
fmt::Formatter,
ops::{AddAssign, DivAssign},
};
use arrow_array::{
types::{Float16Type, Float32Type, Float64Type},
Array, Float16Array, Float32Array, Float64Array,
};
use half::{bf16, f16};
use num_traits::{AsPrimitive, Bounded, Float, FromPrimitive};
use super::bfloat16::{BFloat16Array, BFloat16Type};
#[derive(Debug)]
pub enum FloatType {
BFloat16,
Float16,
Float32,
Float64,
}
impl std::fmt::Display for FloatType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::BFloat16 => write!(f, "bfloat16"),
Self::Float16 => write!(f, "float16"),
Self::Float32 => write!(f, "float32"),
Self::Float64 => write!(f, "float64"),
}
}
}
pub trait ArrowFloatType {
type Native: FromPrimitive + FloatToArrayType<ArrowType = Self>;
const FLOAT_TYPE: FloatType;
type ArrayType: FloatArray<Self>;
fn empty_array() -> Self::ArrayType {
Vec::<Self::Native>::new().into()
}
}
pub trait FloatToArrayType:
Float + Bounded + Sum + AddAssign<Self> + AsPrimitive<f64> + DivAssign + Send + Sync + Copy
{
type ArrowType: ArrowFloatType<Native = Self>;
}
impl FloatToArrayType for bf16 {
type ArrowType = BFloat16Type;
}
impl FloatToArrayType for f16 {
type ArrowType = Float16Type;
}
impl FloatToArrayType for f32 {
type ArrowType = Float32Type;
}
impl FloatToArrayType for f64 {
type ArrowType = Float64Type;
}
impl ArrowFloatType for BFloat16Type {
type Native = bf16;
const FLOAT_TYPE: FloatType = FloatType::BFloat16;
type ArrayType = BFloat16Array;
}
impl ArrowFloatType for Float16Type {
type Native = f16;
const FLOAT_TYPE: FloatType = FloatType::Float16;
type ArrayType = Float16Array;
}
impl ArrowFloatType for Float32Type {
type Native = f32;
const FLOAT_TYPE: FloatType = FloatType::Float32;
type ArrayType = Float32Array;
}
impl ArrowFloatType for Float64Type {
type Native = f64;
const FLOAT_TYPE: FloatType = FloatType::Float64;
type ArrayType = Float64Array;
}
pub trait FloatArray<T: ArrowFloatType + ?Sized>:
Array + Clone + From<Vec<T::Native>> + 'static
{
type FloatType: ArrowFloatType;
fn as_slice(&self) -> &[T::Native];
}
impl FloatArray<BFloat16Type> for BFloat16Array {
type FloatType = BFloat16Type;
fn as_slice(&self) -> &[<BFloat16Type as ArrowFloatType>::Native] {
todo!()
}
}
impl FloatArray<Float16Type> for Float16Array {
type FloatType = Float16Type;
fn as_slice(&self) -> &[<Float16Type as ArrowFloatType>::Native] {
self.values()
}
}
impl FloatArray<Float32Type> for Float32Array {
type FloatType = Float32Type;
fn as_slice(&self) -> &[<Float32Type as ArrowFloatType>::Native] {
self.values()
}
}
impl FloatArray<Float64Type> for Float64Array {
type FloatType = Float64Type;
fn as_slice(&self) -> &[<Float64Type as ArrowFloatType>::Native] {
self.values()
}
}