1use std::fmt::{Debug, Display};
7use std::iter::Sum;
8use std::sync::Arc;
9use std::{
10 fmt::Formatter,
11 ops::{AddAssign, DivAssign},
12};
13
14use arrow_array::{
15 types::{Float16Type, Float32Type, Float64Type},
16 Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
17};
18use arrow_schema::{DataType, Field};
19use half::{bf16, f16};
20use num_traits::{AsPrimitive, Bounded, Float, FromPrimitive};
21
22use super::bfloat16::{BFloat16Array, BFloat16Type};
23use crate::bfloat16::is_bfloat16_field;
24use crate::Result;
25
26#[derive(Debug)]
31pub enum FloatType {
32 BFloat16,
33 Float16,
34 Float32,
35 Float64,
36}
37
38impl std::fmt::Display for FloatType {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::BFloat16 => write!(f, "bfloat16"),
42 Self::Float16 => write!(f, "float16"),
43 Self::Float32 => write!(f, "float32"),
44 Self::Float64 => write!(f, "float64"),
45 }
46 }
47}
48
49impl TryFrom<&DataType> for FloatType {
52 type Error = crate::ArrowError;
53
54 fn try_from(value: &DataType) -> Result<Self> {
55 match *value {
56 DataType::Float16 => Ok(Self::Float16),
57 DataType::Float32 => Ok(Self::Float32),
58 DataType::Float64 => Ok(Self::Float64),
59 _ => Err(crate::ArrowError::InvalidArgumentError(format!(
60 "{:?} is not a floating type",
61 value
62 ))),
63 }
64 }
65}
66
67impl TryFrom<&Field> for FloatType {
68 type Error = crate::ArrowError;
69
70 fn try_from(field: &Field) -> Result<Self> {
71 match field.data_type() {
72 DataType::FixedSizeBinary(2) if is_bfloat16_field(field) => Ok(Self::BFloat16),
73 _ => Self::try_from(field.data_type()),
74 }
75 }
76}
77
78pub trait ArrowFloatType: Debug {
83 type Native: FromPrimitive
84 + FloatToArrayType<ArrowType = Self>
85 + AsPrimitive<f32>
86 + Debug
87 + Display;
88
89 const FLOAT_TYPE: FloatType;
90 const MIN: Self::Native;
91 const MAX: Self::Native;
92
93 type ArrayType: FloatArray<Self>;
95
96 fn empty_array() -> Self::ArrayType {
98 <Self::ArrayType as FloatArray<Self>>::from_values(Vec::new())
99 }
100}
101
102pub trait FloatToArrayType:
108 Float
109 + Bounded
110 + Sum
111 + AddAssign<Self>
112 + AsPrimitive<f64>
113 + AsPrimitive<f32>
114 + DivAssign
115 + Send
116 + Sync
117 + Copy
118{
119 type ArrowType: ArrowFloatType<Native = Self>;
121}
122
123impl FloatToArrayType for bf16 {
124 type ArrowType = BFloat16Type;
125}
126
127impl FloatToArrayType for f16 {
128 type ArrowType = Float16Type;
129}
130
131impl FloatToArrayType for f32 {
132 type ArrowType = Float32Type;
133}
134
135impl FloatToArrayType for f64 {
136 type ArrowType = Float64Type;
137}
138
139impl ArrowFloatType for BFloat16Type {
140 type Native = bf16;
141
142 const FLOAT_TYPE: FloatType = FloatType::BFloat16;
143 const MIN: Self::Native = bf16::MIN;
144 const MAX: Self::Native = bf16::MAX;
145
146 type ArrayType = FixedSizeBinaryArray;
147}
148
149impl ArrowFloatType for Float16Type {
150 type Native = f16;
151
152 const FLOAT_TYPE: FloatType = FloatType::Float16;
153 const MIN: Self::Native = f16::MIN;
154 const MAX: Self::Native = f16::MAX;
155
156 type ArrayType = Float16Array;
157}
158
159impl ArrowFloatType for Float32Type {
160 type Native = f32;
161
162 const FLOAT_TYPE: FloatType = FloatType::Float32;
163 const MIN: Self::Native = f32::MIN;
164 const MAX: Self::Native = f32::MAX;
165
166 type ArrayType = Float32Array;
167}
168
169impl ArrowFloatType for Float64Type {
170 type Native = f64;
171
172 const FLOAT_TYPE: FloatType = FloatType::Float64;
173 const MIN: Self::Native = f64::MIN;
174 const MAX: Self::Native = f64::MAX;
175
176 type ArrayType = Float64Array;
177}
178
179pub trait FloatArray<T: ArrowFloatType + ?Sized>: Array + Clone + 'static {
184 type FloatType: ArrowFloatType;
185
186 fn as_slice(&self) -> &[T::Native];
188
189 fn from_values(values: Vec<T::Native>) -> Self;
191
192 fn from_iter_values(values: impl IntoIterator<Item = T::Native>) -> Self
194 where
195 Self: Sized,
196 {
197 Self::from_values(values.into_iter().collect())
198 }
199}
200
201impl FloatArray<Float16Type> for Float16Array {
202 type FloatType = Float16Type;
203
204 fn as_slice(&self) -> &[<Float16Type as ArrowFloatType>::Native] {
205 self.values()
206 }
207
208 fn from_values(values: Vec<<Float16Type as ArrowFloatType>::Native>) -> Self {
209 Self::from(values)
210 }
211}
212
213impl FloatArray<Float32Type> for Float32Array {
214 type FloatType = Float32Type;
215
216 fn as_slice(&self) -> &[<Float32Type as ArrowFloatType>::Native] {
217 self.values()
218 }
219
220 fn from_values(values: Vec<<Float32Type as ArrowFloatType>::Native>) -> Self {
221 Self::from(values)
222 }
223}
224
225impl FloatArray<Float64Type> for Float64Array {
226 type FloatType = Float64Type;
227
228 fn as_slice(&self) -> &[<Float64Type as ArrowFloatType>::Native] {
229 self.values()
230 }
231
232 fn from_values(values: Vec<<Float64Type as ArrowFloatType>::Native>) -> Self {
233 Self::from(values)
234 }
235}
236
237pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result<Arc<dyn Array>> {
242 match float_type {
243 FloatType::BFloat16 => Ok(Arc::new(
244 BFloat16Array::from_iter_values(input.values().iter().map(|v| bf16::from_f32(*v)))
245 .into_inner(),
246 )),
247 FloatType::Float16 => Ok(Arc::new(Float16Array::from_iter_values(
248 input.values().iter().map(|v| f16::from_f32(*v)),
249 ))),
250 FloatType::Float32 => Ok(Arc::new(input.clone())),
251 FloatType::Float64 => Ok(Arc::new(Float64Array::from_iter_values(
252 input.values().iter().map(|v| *v as f64),
253 ))),
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_coerce_float_vector_bfloat16() {
263 let input = Float32Array::from(vec![1.0f32, 2.0, 3.0]);
264 let array = coerce_float_vector(&input, FloatType::BFloat16).unwrap();
265
266 assert_eq!(array.data_type(), &DataType::FixedSizeBinary(2));
267
268 let fixed = array
269 .as_any()
270 .downcast_ref::<FixedSizeBinaryArray>()
271 .unwrap();
272 let expected: Vec<bf16> = input.values().iter().map(|v| bf16::from_f32(*v)).collect();
273 assert_eq!(fixed.as_slice(), expected.as_slice());
274 }
275}