Skip to main content

lance_arrow/
floats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Floats Array
5
6use std::fmt::{Debug, Display};
7use std::iter::Sum;
8use std::sync::Arc;
9use std::{
10    fmt::Formatter,
11    ops::{AddAssign, DivAssign},
12};
13
14use arrow_array::{
15    types::{Float16Type, Float32Type, Float64Type},
16    Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
17};
18use arrow_schema::{DataType, Field};
19use half::{bf16, f16};
20use num_traits::{AsPrimitive, Bounded, Float, FromPrimitive};
21
22use super::bfloat16::{BFloat16Array, BFloat16Type};
23use crate::bfloat16::is_bfloat16_field;
24use crate::Result;
25
26/// Float data type.
27///
28/// This helps differentiate between the different float types,
29/// because bf16 is not officially supported [DataType] in arrow-rs.
30#[derive(Debug)]
31pub enum FloatType {
32    BFloat16,
33    Float16,
34    Float32,
35    Float64,
36}
37
38impl std::fmt::Display for FloatType {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::BFloat16 => write!(f, "bfloat16"),
42            Self::Float16 => write!(f, "float16"),
43            Self::Float32 => write!(f, "float32"),
44            Self::Float64 => write!(f, "float64"),
45        }
46    }
47}
48
49/// Try to convert a [DataType] to a [FloatType]. To support bfloat16, always
50/// prefer using the `TryFrom<&Field>` implementation.
51impl TryFrom<&DataType> for FloatType {
52    type Error = crate::ArrowError;
53
54    fn try_from(value: &DataType) -> Result<Self> {
55        match *value {
56            DataType::Float16 => Ok(Self::Float16),
57            DataType::Float32 => Ok(Self::Float32),
58            DataType::Float64 => Ok(Self::Float64),
59            _ => Err(crate::ArrowError::InvalidArgumentError(format!(
60                "{:?} is not a floating type",
61                value
62            ))),
63        }
64    }
65}
66
67impl TryFrom<&Field> for FloatType {
68    type Error = crate::ArrowError;
69
70    fn try_from(field: &Field) -> Result<Self> {
71        match field.data_type() {
72            DataType::FixedSizeBinary(2) if is_bfloat16_field(field) => Ok(Self::BFloat16),
73            _ => Self::try_from(field.data_type()),
74        }
75    }
76}
77
78/// Trait for float types used in Lance indexes
79///
80/// This mimics the utilities provided by [`arrow_array::ArrowPrimitiveType`]
81/// but applies to all float types (including bfloat16)
82pub trait ArrowFloatType: Debug {
83    type Native: FromPrimitive
84        + FloatToArrayType<ArrowType = Self>
85        + AsPrimitive<f32>
86        + Debug
87        + Display;
88
89    const FLOAT_TYPE: FloatType;
90    const MIN: Self::Native;
91    const MAX: Self::Native;
92
93    /// Arrow Float Array Type.
94    type ArrayType: FloatArray<Self>;
95
96    /// Returns empty array of this type.
97    fn empty_array() -> Self::ArrayType {
98        <Self::ArrayType as FloatArray<Self>>::from_values(Vec::new())
99    }
100}
101
102/// Trait to be implemented by native types that have a corresponding [`ArrowFloatType`]
103/// implementation.
104///
105/// This helps define what operations are supported by native floats and also helps convert
106/// from a native type back to the corresponding Arrow float type.
107pub trait FloatToArrayType:
108    Float
109    + Bounded
110    + Sum
111    + AddAssign<Self>
112    + AsPrimitive<f64>
113    + AsPrimitive<f32>
114    + DivAssign
115    + Send
116    + Sync
117    + Copy
118{
119    /// The corresponding [`ArrowFloatType`] implementation for this native type
120    type ArrowType: ArrowFloatType<Native = Self>;
121}
122
123impl FloatToArrayType for bf16 {
124    type ArrowType = BFloat16Type;
125}
126
127impl FloatToArrayType for f16 {
128    type ArrowType = Float16Type;
129}
130
131impl FloatToArrayType for f32 {
132    type ArrowType = Float32Type;
133}
134
135impl FloatToArrayType for f64 {
136    type ArrowType = Float64Type;
137}
138
139impl ArrowFloatType for BFloat16Type {
140    type Native = bf16;
141
142    const FLOAT_TYPE: FloatType = FloatType::BFloat16;
143    const MIN: Self::Native = bf16::MIN;
144    const MAX: Self::Native = bf16::MAX;
145
146    type ArrayType = FixedSizeBinaryArray;
147}
148
149impl ArrowFloatType for Float16Type {
150    type Native = f16;
151
152    const FLOAT_TYPE: FloatType = FloatType::Float16;
153    const MIN: Self::Native = f16::MIN;
154    const MAX: Self::Native = f16::MAX;
155
156    type ArrayType = Float16Array;
157}
158
159impl ArrowFloatType for Float32Type {
160    type Native = f32;
161
162    const FLOAT_TYPE: FloatType = FloatType::Float32;
163    const MIN: Self::Native = f32::MIN;
164    const MAX: Self::Native = f32::MAX;
165
166    type ArrayType = Float32Array;
167}
168
169impl ArrowFloatType for Float64Type {
170    type Native = f64;
171
172    const FLOAT_TYPE: FloatType = FloatType::Float64;
173    const MIN: Self::Native = f64::MIN;
174    const MAX: Self::Native = f64::MAX;
175
176    type ArrayType = Float64Array;
177}
178
179/// [FloatArray] is a trait that is implemented by all float type arrays
180///
181/// This is similar to [`arrow_array::PrimitiveArray`] but applies to all float types (including bfloat16)
182/// and is implemented as a trait and not a struct
183pub trait FloatArray<T: ArrowFloatType + ?Sized>: Array + Clone + 'static {
184    type FloatType: ArrowFloatType;
185
186    /// Returns a reference to the underlying data as a slice.
187    fn as_slice(&self) -> &[T::Native];
188
189    /// Construct an array from a vector of values.
190    fn from_values(values: Vec<T::Native>) -> Self;
191
192    /// Construct an array from an iterator of values.
193    fn from_iter_values(values: impl IntoIterator<Item = T::Native>) -> Self
194    where
195        Self: Sized,
196    {
197        Self::from_values(values.into_iter().collect())
198    }
199}
200
201impl FloatArray<Float16Type> for Float16Array {
202    type FloatType = Float16Type;
203
204    fn as_slice(&self) -> &[<Float16Type as ArrowFloatType>::Native] {
205        self.values()
206    }
207
208    fn from_values(values: Vec<<Float16Type as ArrowFloatType>::Native>) -> Self {
209        Self::from(values)
210    }
211}
212
213impl FloatArray<Float32Type> for Float32Array {
214    type FloatType = Float32Type;
215
216    fn as_slice(&self) -> &[<Float32Type as ArrowFloatType>::Native] {
217        self.values()
218    }
219
220    fn from_values(values: Vec<<Float32Type as ArrowFloatType>::Native>) -> Self {
221        Self::from(values)
222    }
223}
224
225impl FloatArray<Float64Type> for Float64Array {
226    type FloatType = Float64Type;
227
228    fn as_slice(&self) -> &[<Float64Type as ArrowFloatType>::Native] {
229        self.values()
230    }
231
232    fn from_values(values: Vec<<Float64Type as ArrowFloatType>::Native>) -> Self {
233        Self::from(values)
234    }
235}
236
237/// Convert a float32 array to another float array
238///
239/// This is used during queries as query vectors are always provided as float32 arrays
240/// and need to be converted to the appropriate float type for the index.
241pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result<Arc<dyn Array>> {
242    match float_type {
243        FloatType::BFloat16 => Ok(Arc::new(
244            BFloat16Array::from_iter_values(input.values().iter().map(|v| bf16::from_f32(*v)))
245                .into_inner(),
246        )),
247        FloatType::Float16 => Ok(Arc::new(Float16Array::from_iter_values(
248            input.values().iter().map(|v| f16::from_f32(*v)),
249        ))),
250        FloatType::Float32 => Ok(Arc::new(input.clone())),
251        FloatType::Float64 => Ok(Arc::new(Float64Array::from_iter_values(
252            input.values().iter().map(|v| *v as f64),
253        ))),
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_coerce_float_vector_bfloat16() {
263        let input = Float32Array::from(vec![1.0f32, 2.0, 3.0]);
264        let array = coerce_float_vector(&input, FloatType::BFloat16).unwrap();
265
266        assert_eq!(array.data_type(), &DataType::FixedSizeBinary(2));
267
268        let fixed = array
269            .as_any()
270            .downcast_ref::<FixedSizeBinaryArray>()
271            .unwrap();
272        let expected: Vec<bf16> = input.values().iter().map(|v| bf16::from_f32(*v)).collect();
273        assert_eq!(fixed.as_slice(), expected.as_slice());
274    }
275}