lance_arrow/
bfloat16.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! bfloat16 support for Apache Arrow.
5
6use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10    builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11    FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::{FloatArray, ARROW_EXT_NAME_KEY};
19
20pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
21
22/// Check whether the given field is a bfloat16 field.
23pub fn is_bfloat16_field(field: &ArrowField) -> bool {
24    field.data_type() == &DataType::FixedSizeBinary(2)
25        && field
26            .metadata()
27            .get(ARROW_EXT_NAME_KEY)
28            .map(|name| name == BFLOAT16_EXT_NAME)
29            .unwrap_or_default()
30}
31
32#[derive(Debug)]
33pub struct BFloat16Type {}
34
35#[derive(Clone)]
36pub struct BFloat16Array {
37    inner: FixedSizeBinaryArray,
38}
39
40impl std::fmt::Debug for BFloat16Array {
41    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
42        write!(f, "BFloat16Array\n[\n")?;
43        from_arrow::print_long_array(&self.inner, f, |array, i, f| {
44            if array.is_null(i) {
45                write!(f, "null")
46            } else {
47                let binary_values = array.value(i);
48                let value =
49                    bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
50                write!(f, "{:?}", value)
51            }
52        })?;
53        write!(f, "]")
54    }
55}
56
57impl BFloat16Array {
58    pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
59        let values: Vec<bf16> = iter.into_iter().collect();
60        values.into()
61    }
62
63    pub fn iter(&self) -> BFloat16Iter<'_> {
64        BFloat16Iter::new(self)
65    }
66
67    pub fn value(&self, i: usize) -> bf16 {
68        assert!(
69            i < self.len(),
70            "Trying to access an element at index {} from a BFloat16Array of length {}",
71            i,
72            self.len()
73        );
74        // Safety:
75        // `i < self.len()
76        unsafe { self.value_unchecked(i) }
77    }
78
79    /// # Safety
80    /// Caller must ensure that `i < self.len()`
81    pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
82        let binary_value = self.inner.value_unchecked(i);
83        bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
84    }
85
86    pub fn into_inner(self) -> FixedSizeBinaryArray {
87        self.inner
88    }
89}
90
91impl ArrayAccessor for &BFloat16Array {
92    type Item = bf16;
93
94    fn value(&self, index: usize) -> Self::Item {
95        BFloat16Array::value(self, index)
96    }
97
98    unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
99        BFloat16Array::value_unchecked(self, index)
100    }
101}
102
103impl Array for BFloat16Array {
104    fn as_any(&self) -> &dyn std::any::Any {
105        self.inner.as_any()
106    }
107
108    fn to_data(&self) -> arrow_data::ArrayData {
109        self.inner.to_data()
110    }
111
112    fn into_data(self) -> arrow_data::ArrayData {
113        self.inner.into_data()
114    }
115
116    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
117        let inner_array: &dyn Array = &self.inner;
118        inner_array.slice(offset, length)
119    }
120
121    fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
122        self.inner.nulls()
123    }
124
125    fn data_type(&self) -> &DataType {
126        self.inner.data_type()
127    }
128
129    fn len(&self) -> usize {
130        self.inner.len()
131    }
132
133    fn is_empty(&self) -> bool {
134        self.inner.is_empty()
135    }
136
137    fn offset(&self) -> usize {
138        self.inner.offset()
139    }
140
141    fn get_array_memory_size(&self) -> usize {
142        self.inner.get_array_memory_size()
143    }
144
145    fn get_buffer_memory_size(&self) -> usize {
146        self.inner.get_buffer_memory_size()
147    }
148}
149
150impl FromIterator<Option<bf16>> for BFloat16Array {
151    fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
152        let mut buffer = MutableBuffer::new(10);
153        // No null buffer builder :(
154        let mut nulls = BooleanBufferBuilder::new(10);
155        let mut len = 0;
156
157        for maybe_value in iter {
158            if let Some(value) = maybe_value {
159                let bytes = value.to_le_bytes();
160                buffer.extend(bytes);
161            } else {
162                buffer.extend([0u8, 0u8]);
163            }
164            nulls.append(maybe_value.is_some());
165            len += 1;
166        }
167
168        let null_buffer = nulls.finish();
169        let num_valid = null_buffer.count_set_bits();
170        let null_buffer = if num_valid == len {
171            None
172        } else {
173            Some(null_buffer.into_inner())
174        };
175
176        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
177            .len(len)
178            .add_buffer(buffer.into())
179            .null_bit_buffer(null_buffer);
180        let array_data = unsafe { array_data.build_unchecked() };
181        Self {
182            inner: FixedSizeBinaryArray::from(array_data),
183        }
184    }
185}
186
187impl FromIterator<bf16> for BFloat16Array {
188    fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
189        Self::from_iter_values(iter)
190    }
191}
192
193impl From<Vec<bf16>> for BFloat16Array {
194    fn from(data: Vec<bf16>) -> Self {
195        let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
196
197        let bytes = data.iter().flat_map(|val| {
198            let bytes = val.to_bits().to_le_bytes();
199            bytes.to_vec()
200        });
201
202        buffer.extend(bytes);
203        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
204            .len(data.len())
205            .add_buffer(buffer.into());
206        let array_data = unsafe { array_data.build_unchecked() };
207        Self {
208            inner: FixedSizeBinaryArray::from(array_data),
209        }
210    }
211}
212
213impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
214    type Error = ArrowError;
215
216    fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
217        if value.value_length() == 2 {
218            Ok(Self { inner: value })
219        } else {
220            Err(ArrowError::InvalidArgumentError(
221                "FixedSizeBinaryArray must have a value length of 2".to_string(),
222            ))
223        }
224    }
225}
226
227impl PartialEq<Self> for BFloat16Array {
228    fn eq(&self, other: &Self) -> bool {
229        self.inner.eq(&other.inner)
230    }
231}
232
233type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
234
235/// Methods that are lifted from arrow-rs temporarily until they are made public.
236mod from_arrow {
237    use arrow_array::Array;
238
239    /// Helper function for printing potentially long arrays.
240    pub(super) fn print_long_array<A, F>(
241        array: &A,
242        f: &mut std::fmt::Formatter,
243        print_item: F,
244    ) -> std::fmt::Result
245    where
246        A: Array,
247        F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
248    {
249        let head = std::cmp::min(10, array.len());
250
251        for i in 0..head {
252            if array.is_null(i) {
253                writeln!(f, "  null,")?;
254            } else {
255                write!(f, "  ")?;
256                print_item(array, i, f)?;
257                writeln!(f, ",")?;
258            }
259        }
260        if array.len() > 10 {
261            if array.len() > 20 {
262                writeln!(f, "  ...{} elements...,", array.len() - 20)?;
263            }
264
265            let tail = std::cmp::max(head, array.len() - 10);
266
267            for i in tail..array.len() {
268                if array.is_null(i) {
269                    writeln!(f, "  null,")?;
270                } else {
271                    write!(f, "  ")?;
272                    print_item(array, i, f)?;
273                    writeln!(f, ",")?;
274                }
275            }
276        }
277        Ok(())
278    }
279}
280
281impl FloatArray<BFloat16Type> for BFloat16Array {
282    type FloatType = BFloat16Type;
283
284    fn as_slice(&self) -> &[bf16] {
285        unsafe {
286            slice::from_raw_parts(
287                self.inner.value_data().as_ptr() as *const bf16,
288                self.inner.value_data().len() / 2,
289            )
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_basics() {
300        let values: Vec<f32> = vec![1.0, 2.0, 3.0];
301        let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
302
303        let array = BFloat16Array::from_iter_values(values.clone());
304        let array2 = BFloat16Array::from(values.clone());
305        assert_eq!(array, array2);
306        assert_eq!(array.len(), 3);
307
308        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  2.0,\n  3.0,\n]";
309        assert_eq!(expected_fmt, format!("{:?}", array));
310
311        for (expected, value) in values.iter().zip(array.iter()) {
312            assert_eq!(Some(*expected), value);
313        }
314
315        for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
316            assert_eq!(Some(*expected), value);
317        }
318    }
319
320    #[test]
321    fn test_nulls() {
322        let values: Vec<Option<bf16>> =
323            vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
324        let array = BFloat16Array::from_iter(values.clone());
325        assert_eq!(array.len(), 3);
326        assert_eq!(array.null_count(), 1);
327
328        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  null,\n  3.0,\n]";
329        assert_eq!(expected_fmt, format!("{:?}", array));
330
331        for (expected, value) in values.iter().zip(array.iter()) {
332            assert_eq!(*expected, value);
333        }
334    }
335}