1use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{builder::BooleanBufferBuilder, Array, FixedSizeBinaryArray};
10use arrow_buffer::MutableBuffer;
11use arrow_data::ArrayData;
12use arrow_schema::{ArrowError, DataType, Field as ArrowField};
13use half::bf16;
14
15use crate::{FloatArray, ARROW_EXT_NAME_KEY};
16
17pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
19
20pub fn is_bfloat16_field(field: &ArrowField) -> bool {
25 field.data_type() == &DataType::FixedSizeBinary(2)
26 && field
27 .metadata()
28 .get(ARROW_EXT_NAME_KEY)
29 .map(|name| name == BFLOAT16_EXT_NAME)
30 .unwrap_or_default()
31}
32
33#[derive(Debug)]
37pub struct BFloat16Type {}
38
39#[derive(Clone)]
43pub struct BFloat16Array {
44 inner: FixedSizeBinaryArray,
45}
46
47impl std::fmt::Debug for BFloat16Array {
48 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49 write!(f, "BFloat16Array\n[\n")?;
50 from_arrow::print_long_array(&self.inner, f, |array, i, f| {
51 if array.is_null(i) {
52 write!(f, "null")
53 } else {
54 let binary_values = array.value(i);
55 let value =
56 bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
57 write!(f, "{:?}", value)
58 }
59 })?;
60 write!(f, "]")
61 }
62}
63
64impl BFloat16Array {
65 pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
66 let values: Vec<bf16> = iter.into_iter().collect();
67 values.into()
68 }
69
70 pub fn len(&self) -> usize {
71 self.inner.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.inner.is_empty()
76 }
77
78 pub fn is_null(&self, i: usize) -> bool {
79 self.inner.is_null(i)
80 }
81
82 pub fn null_count(&self) -> usize {
83 self.inner.null_count()
84 }
85
86 pub fn iter(&self) -> BFloat16Iter<'_> {
87 BFloat16Iter {
88 array: self,
89 index: 0,
90 }
91 }
92
93 pub fn value(&self, i: usize) -> bf16 {
94 assert!(
95 i < self.len(),
96 "Trying to access an element at index {} from a BFloat16Array of length {}",
97 i,
98 self.len()
99 );
100 unsafe { self.value_unchecked(i) }
103 }
104
105 pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
108 let binary_value = self.inner.value_unchecked(i);
109 bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
110 }
111
112 pub fn into_inner(self) -> FixedSizeBinaryArray {
113 self.inner
114 }
115}
116
117impl FromIterator<Option<bf16>> for BFloat16Array {
118 fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
119 let mut buffer = MutableBuffer::new(10);
120 let mut nulls = BooleanBufferBuilder::new(10);
122 let mut len = 0;
123
124 for maybe_value in iter {
125 if let Some(value) = maybe_value {
126 let bytes = value.to_le_bytes();
127 buffer.extend(bytes);
128 } else {
129 buffer.extend([0u8, 0u8]);
130 }
131 nulls.append(maybe_value.is_some());
132 len += 1;
133 }
134
135 let null_buffer = nulls.finish();
136 let num_valid = null_buffer.count_set_bits();
137 let null_buffer = if num_valid == len {
138 None
139 } else {
140 Some(null_buffer.into_inner())
141 };
142
143 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
144 .len(len)
145 .add_buffer(buffer.into())
146 .null_bit_buffer(null_buffer);
147 let array_data = unsafe { array_data.build_unchecked() };
148 Self {
149 inner: FixedSizeBinaryArray::from(array_data),
150 }
151 }
152}
153
154impl FromIterator<bf16> for BFloat16Array {
155 fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
156 Self::from_iter_values(iter)
157 }
158}
159
160impl From<Vec<bf16>> for BFloat16Array {
161 fn from(data: Vec<bf16>) -> Self {
162 let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
163
164 let bytes = data.iter().flat_map(|val| {
165 let bytes = val.to_bits().to_le_bytes();
166 bytes.to_vec()
167 });
168
169 buffer.extend(bytes);
170 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
171 .len(data.len())
172 .add_buffer(buffer.into());
173 let array_data = unsafe { array_data.build_unchecked() };
174 Self {
175 inner: FixedSizeBinaryArray::from(array_data),
176 }
177 }
178}
179
180impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
181 type Error = ArrowError;
182
183 fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
184 if value.value_length() == 2 {
185 Ok(Self { inner: value })
186 } else {
187 Err(ArrowError::InvalidArgumentError(
188 "FixedSizeBinaryArray must have a value length of 2".to_string(),
189 ))
190 }
191 }
192}
193
194impl PartialEq<Self> for BFloat16Array {
195 fn eq(&self, other: &Self) -> bool {
196 self.inner.eq(&other.inner)
197 }
198}
199
200pub struct BFloat16Iter<'a> {
201 array: &'a BFloat16Array,
202 index: usize,
203}
204
205impl<'a> Iterator for BFloat16Iter<'a> {
206 type Item = Option<bf16>;
207
208 fn next(&mut self) -> Option<Self::Item> {
209 if self.index >= self.array.len() {
210 return None;
211 }
212 let i = self.index;
213 self.index += 1;
214 if self.array.is_null(i) {
215 Some(None)
216 } else {
217 Some(Some(self.array.value(i)))
218 }
219 }
220}
221
222mod from_arrow {
224 use arrow_array::Array;
225
226 pub(super) fn print_long_array<A, F>(
228 array: &A,
229 f: &mut std::fmt::Formatter,
230 print_item: F,
231 ) -> std::fmt::Result
232 where
233 A: Array,
234 F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
235 {
236 let head = std::cmp::min(10, array.len());
237
238 for i in 0..head {
239 if array.is_null(i) {
240 writeln!(f, " null,")?;
241 } else {
242 write!(f, " ")?;
243 print_item(array, i, f)?;
244 writeln!(f, ",")?;
245 }
246 }
247 if array.len() > 10 {
248 if array.len() > 20 {
249 writeln!(f, " ...{} elements...,", array.len() - 20)?;
250 }
251
252 let tail = std::cmp::max(head, array.len() - 10);
253
254 for i in tail..array.len() {
255 if array.is_null(i) {
256 writeln!(f, " null,")?;
257 } else {
258 write!(f, " ")?;
259 print_item(array, i, f)?;
260 writeln!(f, ",")?;
261 }
262 }
263 }
264 Ok(())
265 }
266}
267
268impl FloatArray<BFloat16Type> for FixedSizeBinaryArray {
269 type FloatType = BFloat16Type;
270
271 fn as_slice(&self) -> &[bf16] {
272 assert_eq!(
273 self.value_length(),
274 2,
275 "BFloat16 arrays must use FixedSizeBinary(2) storage"
276 );
277 unsafe {
278 slice::from_raw_parts(
279 self.value_data().as_ptr() as *const bf16,
280 self.value_data().len() / 2,
281 )
282 }
283 }
284
285 fn from_values(values: Vec<bf16>) -> Self {
286 BFloat16Array::from(values).into_inner()
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_basics() {
296 let values: Vec<f32> = vec![1.0, 2.0, 3.0];
297 let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
298
299 let array = BFloat16Array::from_iter_values(values.clone());
300 let array2 = BFloat16Array::from(values.clone());
301 assert_eq!(array, array2);
302 assert_eq!(array.len(), 3);
303
304 let expected_fmt = "BFloat16Array\n[\n 1.0,\n 2.0,\n 3.0,\n]";
305 assert_eq!(expected_fmt, format!("{:?}", array));
306
307 for (expected, value) in values.iter().zip(array.iter()) {
308 assert_eq!(Some(*expected), value);
309 }
310
311 for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
312 assert_eq!(Some(*expected), value);
313 }
314
315 let arrow_array = array.into_inner();
316 assert_eq!(arrow_array.as_slice(), values.as_slice());
317 }
318
319 #[test]
320 fn test_nulls() {
321 let values: Vec<Option<bf16>> =
322 vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
323 let array = BFloat16Array::from_iter(values.clone());
324 assert_eq!(array.len(), 3);
325 assert_eq!(array.null_count(), 1);
326
327 let expected_fmt = "BFloat16Array\n[\n 1.0,\n null,\n 3.0,\n]";
328 assert_eq!(expected_fmt, format!("{:?}", array));
329
330 for (expected, value) in values.iter().zip(array.iter()) {
331 assert_eq!(*expected, value);
332 }
333 }
334}