Skip to main content

lance_arrow/
bfloat16.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! bfloat16 support for Apache Arrow.
5
6use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{builder::BooleanBufferBuilder, Array, FixedSizeBinaryArray};
10use arrow_buffer::MutableBuffer;
11use arrow_data::ArrayData;
12use arrow_schema::{ArrowError, DataType, Field as ArrowField};
13use half::bf16;
14
15use crate::{FloatArray, ARROW_EXT_NAME_KEY};
16
17/// The name of the bfloat16 extension in Arrow metadata
18pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
19
20/// Check whether the given field is a bfloat16 field
21///
22/// A field is a bfloat16 field if it has a data type of `FixedSizeBinary(2)` and the metadata
23/// contains the bfloat16 extension name.
24pub fn is_bfloat16_field(field: &ArrowField) -> bool {
25    field.data_type() == &DataType::FixedSizeBinary(2)
26        && field
27            .metadata()
28            .get(ARROW_EXT_NAME_KEY)
29            .map(|name| name == BFLOAT16_EXT_NAME)
30            .unwrap_or_default()
31}
32
33/// The bfloat16 data type
34///
35/// This implements the [`ArrowFloatType`](crate::floats::ArrowFloatType) trait for bfloat16 values.
36#[derive(Debug)]
37pub struct BFloat16Type {}
38
39/// An array of bfloat16 values
40///
41/// Note that bfloat16 is not the same thing as fp16 which is supported natively by arrow-rs.
42#[derive(Clone)]
43pub struct BFloat16Array {
44    inner: FixedSizeBinaryArray,
45}
46
47impl std::fmt::Debug for BFloat16Array {
48    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49        write!(f, "BFloat16Array\n[\n")?;
50        from_arrow::print_long_array(&self.inner, f, |array, i, f| {
51            if array.is_null(i) {
52                write!(f, "null")
53            } else {
54                let binary_values = array.value(i);
55                let value =
56                    bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
57                write!(f, "{:?}", value)
58            }
59        })?;
60        write!(f, "]")
61    }
62}
63
64impl BFloat16Array {
65    pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
66        let values: Vec<bf16> = iter.into_iter().collect();
67        values.into()
68    }
69
70    pub fn len(&self) -> usize {
71        self.inner.len()
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.inner.is_empty()
76    }
77
78    pub fn is_null(&self, i: usize) -> bool {
79        self.inner.is_null(i)
80    }
81
82    pub fn null_count(&self) -> usize {
83        self.inner.null_count()
84    }
85
86    pub fn iter(&self) -> BFloat16Iter<'_> {
87        BFloat16Iter {
88            array: self,
89            index: 0,
90        }
91    }
92
93    pub fn value(&self, i: usize) -> bf16 {
94        assert!(
95            i < self.len(),
96            "Trying to access an element at index {} from a BFloat16Array of length {}",
97            i,
98            self.len()
99        );
100        // Safety:
101        // `i < self.len()
102        unsafe { self.value_unchecked(i) }
103    }
104
105    /// # Safety
106    /// Caller must ensure that `i < self.len()`
107    pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
108        let binary_value = self.inner.value_unchecked(i);
109        bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
110    }
111
112    pub fn into_inner(self) -> FixedSizeBinaryArray {
113        self.inner
114    }
115}
116
117impl FromIterator<Option<bf16>> for BFloat16Array {
118    fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
119        let mut buffer = MutableBuffer::new(10);
120        // No null buffer builder :(
121        let mut nulls = BooleanBufferBuilder::new(10);
122        let mut len = 0;
123
124        for maybe_value in iter {
125            if let Some(value) = maybe_value {
126                let bytes = value.to_le_bytes();
127                buffer.extend(bytes);
128            } else {
129                buffer.extend([0u8, 0u8]);
130            }
131            nulls.append(maybe_value.is_some());
132            len += 1;
133        }
134
135        let null_buffer = nulls.finish();
136        let num_valid = null_buffer.count_set_bits();
137        let null_buffer = if num_valid == len {
138            None
139        } else {
140            Some(null_buffer.into_inner())
141        };
142
143        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
144            .len(len)
145            .add_buffer(buffer.into())
146            .null_bit_buffer(null_buffer);
147        let array_data = unsafe { array_data.build_unchecked() };
148        Self {
149            inner: FixedSizeBinaryArray::from(array_data),
150        }
151    }
152}
153
154impl FromIterator<bf16> for BFloat16Array {
155    fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
156        Self::from_iter_values(iter)
157    }
158}
159
160impl From<Vec<bf16>> for BFloat16Array {
161    fn from(data: Vec<bf16>) -> Self {
162        let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
163
164        let bytes = data.iter().flat_map(|val| {
165            let bytes = val.to_bits().to_le_bytes();
166            bytes.to_vec()
167        });
168
169        buffer.extend(bytes);
170        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
171            .len(data.len())
172            .add_buffer(buffer.into());
173        let array_data = unsafe { array_data.build_unchecked() };
174        Self {
175            inner: FixedSizeBinaryArray::from(array_data),
176        }
177    }
178}
179
180impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
181    type Error = ArrowError;
182
183    fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
184        if value.value_length() == 2 {
185            Ok(Self { inner: value })
186        } else {
187            Err(ArrowError::InvalidArgumentError(
188                "FixedSizeBinaryArray must have a value length of 2".to_string(),
189            ))
190        }
191    }
192}
193
194impl PartialEq<Self> for BFloat16Array {
195    fn eq(&self, other: &Self) -> bool {
196        self.inner.eq(&other.inner)
197    }
198}
199
200pub struct BFloat16Iter<'a> {
201    array: &'a BFloat16Array,
202    index: usize,
203}
204
205impl<'a> Iterator for BFloat16Iter<'a> {
206    type Item = Option<bf16>;
207
208    fn next(&mut self) -> Option<Self::Item> {
209        if self.index >= self.array.len() {
210            return None;
211        }
212        let i = self.index;
213        self.index += 1;
214        if self.array.is_null(i) {
215            Some(None)
216        } else {
217            Some(Some(self.array.value(i)))
218        }
219    }
220}
221
222/// Methods that are lifted from arrow-rs temporarily until they are made public.
223mod from_arrow {
224    use arrow_array::Array;
225
226    /// Helper function for printing potentially long arrays.
227    pub(super) fn print_long_array<A, F>(
228        array: &A,
229        f: &mut std::fmt::Formatter,
230        print_item: F,
231    ) -> std::fmt::Result
232    where
233        A: Array,
234        F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
235    {
236        let head = std::cmp::min(10, array.len());
237
238        for i in 0..head {
239            if array.is_null(i) {
240                writeln!(f, "  null,")?;
241            } else {
242                write!(f, "  ")?;
243                print_item(array, i, f)?;
244                writeln!(f, ",")?;
245            }
246        }
247        if array.len() > 10 {
248            if array.len() > 20 {
249                writeln!(f, "  ...{} elements...,", array.len() - 20)?;
250            }
251
252            let tail = std::cmp::max(head, array.len() - 10);
253
254            for i in tail..array.len() {
255                if array.is_null(i) {
256                    writeln!(f, "  null,")?;
257                } else {
258                    write!(f, "  ")?;
259                    print_item(array, i, f)?;
260                    writeln!(f, ",")?;
261                }
262            }
263        }
264        Ok(())
265    }
266}
267
268impl FloatArray<BFloat16Type> for FixedSizeBinaryArray {
269    type FloatType = BFloat16Type;
270
271    fn as_slice(&self) -> &[bf16] {
272        assert_eq!(
273            self.value_length(),
274            2,
275            "BFloat16 arrays must use FixedSizeBinary(2) storage"
276        );
277        unsafe {
278            slice::from_raw_parts(
279                self.value_data().as_ptr() as *const bf16,
280                self.value_data().len() / 2,
281            )
282        }
283    }
284
285    fn from_values(values: Vec<bf16>) -> Self {
286        BFloat16Array::from(values).into_inner()
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_basics() {
296        let values: Vec<f32> = vec![1.0, 2.0, 3.0];
297        let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
298
299        let array = BFloat16Array::from_iter_values(values.clone());
300        let array2 = BFloat16Array::from(values.clone());
301        assert_eq!(array, array2);
302        assert_eq!(array.len(), 3);
303
304        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  2.0,\n  3.0,\n]";
305        assert_eq!(expected_fmt, format!("{:?}", array));
306
307        for (expected, value) in values.iter().zip(array.iter()) {
308            assert_eq!(Some(*expected), value);
309        }
310
311        for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
312            assert_eq!(Some(*expected), value);
313        }
314
315        let arrow_array = array.into_inner();
316        assert_eq!(arrow_array.as_slice(), values.as_slice());
317    }
318
319    #[test]
320    fn test_nulls() {
321        let values: Vec<Option<bf16>> =
322            vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
323        let array = BFloat16Array::from_iter(values.clone());
324        assert_eq!(array.len(), 3);
325        assert_eq!(array.null_count(), 1);
326
327        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  null,\n  3.0,\n]";
328        assert_eq!(expected_fmt, format!("{:?}", array));
329
330        for (expected, value) in values.iter().zip(array.iter()) {
331            assert_eq!(*expected, value);
332        }
333    }
334}