lance_arrow/
bfloat16.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! bfloat16 support for Apache Arrow.
5
6use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10    builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11    FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::{FloatArray, ARROW_EXT_NAME_KEY};
19
20/// The name of the bfloat16 extension in Arrow metadata
21pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
22
23/// Check whether the given field is a bfloat16 field
24///
25/// A field is a bfloat16 field if it has a data type of `FixedSizeBinary(2)` and the metadata
26/// contains the bfloat16 extension name.
27pub fn is_bfloat16_field(field: &ArrowField) -> bool {
28    field.data_type() == &DataType::FixedSizeBinary(2)
29        && field
30            .metadata()
31            .get(ARROW_EXT_NAME_KEY)
32            .map(|name| name == BFLOAT16_EXT_NAME)
33            .unwrap_or_default()
34}
35
36/// The bfloat16 data type
37///
38/// This implements the [`ArrowFloatType`](crate::floats::ArrowFloatType) trait for bfloat16 values.
39#[derive(Debug)]
40pub struct BFloat16Type {}
41
42/// An array of bfloat16 values
43///
44/// This implements the [`Array`](arrow_array::Array) trait for bfloat16 values.  Note that
45/// bfloat16 is not the same thing as fp16 which is supported natively
46/// by arrow-rs.
47#[derive(Clone)]
48pub struct BFloat16Array {
49    inner: FixedSizeBinaryArray,
50}
51
52impl std::fmt::Debug for BFloat16Array {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        write!(f, "BFloat16Array\n[\n")?;
55        from_arrow::print_long_array(&self.inner, f, |array, i, f| {
56            if array.is_null(i) {
57                write!(f, "null")
58            } else {
59                let binary_values = array.value(i);
60                let value =
61                    bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
62                write!(f, "{:?}", value)
63            }
64        })?;
65        write!(f, "]")
66    }
67}
68
69impl BFloat16Array {
70    pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
71        let values: Vec<bf16> = iter.into_iter().collect();
72        values.into()
73    }
74
75    pub fn iter(&self) -> BFloat16Iter<'_> {
76        BFloat16Iter::new(self)
77    }
78
79    pub fn value(&self, i: usize) -> bf16 {
80        assert!(
81            i < self.len(),
82            "Trying to access an element at index {} from a BFloat16Array of length {}",
83            i,
84            self.len()
85        );
86        // Safety:
87        // `i < self.len()
88        unsafe { self.value_unchecked(i) }
89    }
90
91    /// # Safety
92    /// Caller must ensure that `i < self.len()`
93    pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
94        let binary_value = self.inner.value_unchecked(i);
95        bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
96    }
97
98    pub fn into_inner(self) -> FixedSizeBinaryArray {
99        self.inner
100    }
101}
102
103impl ArrayAccessor for &BFloat16Array {
104    type Item = bf16;
105
106    fn value(&self, index: usize) -> Self::Item {
107        BFloat16Array::value(self, index)
108    }
109
110    unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
111        BFloat16Array::value_unchecked(self, index)
112    }
113}
114
115impl Array for BFloat16Array {
116    fn as_any(&self) -> &dyn std::any::Any {
117        self.inner.as_any()
118    }
119
120    fn to_data(&self) -> arrow_data::ArrayData {
121        self.inner.to_data()
122    }
123
124    fn into_data(self) -> arrow_data::ArrayData {
125        self.inner.into_data()
126    }
127
128    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
129        let inner_array: &dyn Array = &self.inner;
130        inner_array.slice(offset, length)
131    }
132
133    fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
134        self.inner.nulls()
135    }
136
137    fn data_type(&self) -> &DataType {
138        self.inner.data_type()
139    }
140
141    fn len(&self) -> usize {
142        self.inner.len()
143    }
144
145    fn is_empty(&self) -> bool {
146        self.inner.is_empty()
147    }
148
149    fn offset(&self) -> usize {
150        self.inner.offset()
151    }
152
153    fn get_array_memory_size(&self) -> usize {
154        self.inner.get_array_memory_size()
155    }
156
157    fn get_buffer_memory_size(&self) -> usize {
158        self.inner.get_buffer_memory_size()
159    }
160}
161
162impl FromIterator<Option<bf16>> for BFloat16Array {
163    fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
164        let mut buffer = MutableBuffer::new(10);
165        // No null buffer builder :(
166        let mut nulls = BooleanBufferBuilder::new(10);
167        let mut len = 0;
168
169        for maybe_value in iter {
170            if let Some(value) = maybe_value {
171                let bytes = value.to_le_bytes();
172                buffer.extend(bytes);
173            } else {
174                buffer.extend([0u8, 0u8]);
175            }
176            nulls.append(maybe_value.is_some());
177            len += 1;
178        }
179
180        let null_buffer = nulls.finish();
181        let num_valid = null_buffer.count_set_bits();
182        let null_buffer = if num_valid == len {
183            None
184        } else {
185            Some(null_buffer.into_inner())
186        };
187
188        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
189            .len(len)
190            .add_buffer(buffer.into())
191            .null_bit_buffer(null_buffer);
192        let array_data = unsafe { array_data.build_unchecked() };
193        Self {
194            inner: FixedSizeBinaryArray::from(array_data),
195        }
196    }
197}
198
199impl FromIterator<bf16> for BFloat16Array {
200    fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
201        Self::from_iter_values(iter)
202    }
203}
204
205impl From<Vec<bf16>> for BFloat16Array {
206    fn from(data: Vec<bf16>) -> Self {
207        let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
208
209        let bytes = data.iter().flat_map(|val| {
210            let bytes = val.to_bits().to_le_bytes();
211            bytes.to_vec()
212        });
213
214        buffer.extend(bytes);
215        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
216            .len(data.len())
217            .add_buffer(buffer.into());
218        let array_data = unsafe { array_data.build_unchecked() };
219        Self {
220            inner: FixedSizeBinaryArray::from(array_data),
221        }
222    }
223}
224
225impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
226    type Error = ArrowError;
227
228    fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
229        if value.value_length() == 2 {
230            Ok(Self { inner: value })
231        } else {
232            Err(ArrowError::InvalidArgumentError(
233                "FixedSizeBinaryArray must have a value length of 2".to_string(),
234            ))
235        }
236    }
237}
238
239impl PartialEq<Self> for BFloat16Array {
240    fn eq(&self, other: &Self) -> bool {
241        self.inner.eq(&other.inner)
242    }
243}
244
245type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
246
247/// Methods that are lifted from arrow-rs temporarily until they are made public.
248mod from_arrow {
249    use arrow_array::Array;
250
251    /// Helper function for printing potentially long arrays.
252    pub(super) fn print_long_array<A, F>(
253        array: &A,
254        f: &mut std::fmt::Formatter,
255        print_item: F,
256    ) -> std::fmt::Result
257    where
258        A: Array,
259        F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
260    {
261        let head = std::cmp::min(10, array.len());
262
263        for i in 0..head {
264            if array.is_null(i) {
265                writeln!(f, "  null,")?;
266            } else {
267                write!(f, "  ")?;
268                print_item(array, i, f)?;
269                writeln!(f, ",")?;
270            }
271        }
272        if array.len() > 10 {
273            if array.len() > 20 {
274                writeln!(f, "  ...{} elements...,", array.len() - 20)?;
275            }
276
277            let tail = std::cmp::max(head, array.len() - 10);
278
279            for i in tail..array.len() {
280                if array.is_null(i) {
281                    writeln!(f, "  null,")?;
282                } else {
283                    write!(f, "  ")?;
284                    print_item(array, i, f)?;
285                    writeln!(f, ",")?;
286                }
287            }
288        }
289        Ok(())
290    }
291}
292
293impl FloatArray<BFloat16Type> for BFloat16Array {
294    type FloatType = BFloat16Type;
295
296    fn as_slice(&self) -> &[bf16] {
297        unsafe {
298            slice::from_raw_parts(
299                self.inner.value_data().as_ptr() as *const bf16,
300                self.inner.value_data().len() / 2,
301            )
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_basics() {
312        let values: Vec<f32> = vec![1.0, 2.0, 3.0];
313        let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
314
315        let array = BFloat16Array::from_iter_values(values.clone());
316        let array2 = BFloat16Array::from(values.clone());
317        assert_eq!(array, array2);
318        assert_eq!(array.len(), 3);
319
320        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  2.0,\n  3.0,\n]";
321        assert_eq!(expected_fmt, format!("{:?}", array));
322
323        for (expected, value) in values.iter().zip(array.iter()) {
324            assert_eq!(Some(*expected), value);
325        }
326
327        for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
328            assert_eq!(Some(*expected), value);
329        }
330    }
331
332    #[test]
333    fn test_nulls() {
334        let values: Vec<Option<bf16>> =
335            vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
336        let array = BFloat16Array::from_iter(values.clone());
337        assert_eq!(array.len(), 3);
338        assert_eq!(array.null_count(), 1);
339
340        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  null,\n  3.0,\n]";
341        assert_eq!(expected_fmt, format!("{:?}", array));
342
343        for (expected, value) in values.iter().zip(array.iter()) {
344            assert_eq!(*expected, value);
345        }
346    }
347}