1use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10 builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11 FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::{FloatArray, ARROW_EXT_NAME_KEY};
19
20pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
21
22pub fn is_bfloat16_field(field: &ArrowField) -> bool {
24 field.data_type() == &DataType::FixedSizeBinary(2)
25 && field
26 .metadata()
27 .get(ARROW_EXT_NAME_KEY)
28 .map(|name| name == BFLOAT16_EXT_NAME)
29 .unwrap_or_default()
30}
31
32#[derive(Debug)]
33pub struct BFloat16Type {}
34
35#[derive(Clone)]
36pub struct BFloat16Array {
37 inner: FixedSizeBinaryArray,
38}
39
40impl std::fmt::Debug for BFloat16Array {
41 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
42 write!(f, "BFloat16Array\n[\n")?;
43 from_arrow::print_long_array(&self.inner, f, |array, i, f| {
44 if array.is_null(i) {
45 write!(f, "null")
46 } else {
47 let binary_values = array.value(i);
48 let value =
49 bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
50 write!(f, "{:?}", value)
51 }
52 })?;
53 write!(f, "]")
54 }
55}
56
57impl BFloat16Array {
58 pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
59 let values: Vec<bf16> = iter.into_iter().collect();
60 values.into()
61 }
62
63 pub fn iter(&self) -> BFloat16Iter<'_> {
64 BFloat16Iter::new(self)
65 }
66
67 pub fn value(&self, i: usize) -> bf16 {
68 assert!(
69 i < self.len(),
70 "Trying to access an element at index {} from a BFloat16Array of length {}",
71 i,
72 self.len()
73 );
74 unsafe { self.value_unchecked(i) }
77 }
78
79 pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
82 let binary_value = self.inner.value_unchecked(i);
83 bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
84 }
85
86 pub fn into_inner(self) -> FixedSizeBinaryArray {
87 self.inner
88 }
89}
90
91impl ArrayAccessor for &BFloat16Array {
92 type Item = bf16;
93
94 fn value(&self, index: usize) -> Self::Item {
95 BFloat16Array::value(self, index)
96 }
97
98 unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
99 BFloat16Array::value_unchecked(self, index)
100 }
101}
102
103impl Array for BFloat16Array {
104 fn as_any(&self) -> &dyn std::any::Any {
105 self.inner.as_any()
106 }
107
108 fn to_data(&self) -> arrow_data::ArrayData {
109 self.inner.to_data()
110 }
111
112 fn into_data(self) -> arrow_data::ArrayData {
113 self.inner.into_data()
114 }
115
116 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
117 let inner_array: &dyn Array = &self.inner;
118 inner_array.slice(offset, length)
119 }
120
121 fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
122 self.inner.nulls()
123 }
124
125 fn data_type(&self) -> &DataType {
126 self.inner.data_type()
127 }
128
129 fn len(&self) -> usize {
130 self.inner.len()
131 }
132
133 fn is_empty(&self) -> bool {
134 self.inner.is_empty()
135 }
136
137 fn offset(&self) -> usize {
138 self.inner.offset()
139 }
140
141 fn get_array_memory_size(&self) -> usize {
142 self.inner.get_array_memory_size()
143 }
144
145 fn get_buffer_memory_size(&self) -> usize {
146 self.inner.get_buffer_memory_size()
147 }
148}
149
150impl FromIterator<Option<bf16>> for BFloat16Array {
151 fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
152 let mut buffer = MutableBuffer::new(10);
153 let mut nulls = BooleanBufferBuilder::new(10);
155 let mut len = 0;
156
157 for maybe_value in iter {
158 if let Some(value) = maybe_value {
159 let bytes = value.to_le_bytes();
160 buffer.extend(bytes);
161 } else {
162 buffer.extend([0u8, 0u8]);
163 }
164 nulls.append(maybe_value.is_some());
165 len += 1;
166 }
167
168 let null_buffer = nulls.finish();
169 let num_valid = null_buffer.count_set_bits();
170 let null_buffer = if num_valid == len {
171 None
172 } else {
173 Some(null_buffer.into_inner())
174 };
175
176 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
177 .len(len)
178 .add_buffer(buffer.into())
179 .null_bit_buffer(null_buffer);
180 let array_data = unsafe { array_data.build_unchecked() };
181 Self {
182 inner: FixedSizeBinaryArray::from(array_data),
183 }
184 }
185}
186
187impl FromIterator<bf16> for BFloat16Array {
188 fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
189 Self::from_iter_values(iter)
190 }
191}
192
193impl From<Vec<bf16>> for BFloat16Array {
194 fn from(data: Vec<bf16>) -> Self {
195 let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
196
197 let bytes = data.iter().flat_map(|val| {
198 let bytes = val.to_bits().to_le_bytes();
199 bytes.to_vec()
200 });
201
202 buffer.extend(bytes);
203 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
204 .len(data.len())
205 .add_buffer(buffer.into());
206 let array_data = unsafe { array_data.build_unchecked() };
207 Self {
208 inner: FixedSizeBinaryArray::from(array_data),
209 }
210 }
211}
212
213impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
214 type Error = ArrowError;
215
216 fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
217 if value.value_length() == 2 {
218 Ok(Self { inner: value })
219 } else {
220 Err(ArrowError::InvalidArgumentError(
221 "FixedSizeBinaryArray must have a value length of 2".to_string(),
222 ))
223 }
224 }
225}
226
227impl PartialEq<Self> for BFloat16Array {
228 fn eq(&self, other: &Self) -> bool {
229 self.inner.eq(&other.inner)
230 }
231}
232
233type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
234
235mod from_arrow {
237 use arrow_array::Array;
238
239 pub(super) fn print_long_array<A, F>(
241 array: &A,
242 f: &mut std::fmt::Formatter,
243 print_item: F,
244 ) -> std::fmt::Result
245 where
246 A: Array,
247 F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
248 {
249 let head = std::cmp::min(10, array.len());
250
251 for i in 0..head {
252 if array.is_null(i) {
253 writeln!(f, " null,")?;
254 } else {
255 write!(f, " ")?;
256 print_item(array, i, f)?;
257 writeln!(f, ",")?;
258 }
259 }
260 if array.len() > 10 {
261 if array.len() > 20 {
262 writeln!(f, " ...{} elements...,", array.len() - 20)?;
263 }
264
265 let tail = std::cmp::max(head, array.len() - 10);
266
267 for i in tail..array.len() {
268 if array.is_null(i) {
269 writeln!(f, " null,")?;
270 } else {
271 write!(f, " ")?;
272 print_item(array, i, f)?;
273 writeln!(f, ",")?;
274 }
275 }
276 }
277 Ok(())
278 }
279}
280
281impl FloatArray<BFloat16Type> for BFloat16Array {
282 type FloatType = BFloat16Type;
283
284 fn as_slice(&self) -> &[bf16] {
285 unsafe {
286 slice::from_raw_parts(
287 self.inner.value_data().as_ptr() as *const bf16,
288 self.inner.value_data().len() / 2,
289 )
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_basics() {
300 let values: Vec<f32> = vec![1.0, 2.0, 3.0];
301 let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
302
303 let array = BFloat16Array::from_iter_values(values.clone());
304 let array2 = BFloat16Array::from(values.clone());
305 assert_eq!(array, array2);
306 assert_eq!(array.len(), 3);
307
308 let expected_fmt = "BFloat16Array\n[\n 1.0,\n 2.0,\n 3.0,\n]";
309 assert_eq!(expected_fmt, format!("{:?}", array));
310
311 for (expected, value) in values.iter().zip(array.iter()) {
312 assert_eq!(Some(*expected), value);
313 }
314
315 for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
316 assert_eq!(Some(*expected), value);
317 }
318 }
319
320 #[test]
321 fn test_nulls() {
322 let values: Vec<Option<bf16>> =
323 vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
324 let array = BFloat16Array::from_iter(values.clone());
325 assert_eq!(array.len(), 3);
326 assert_eq!(array.null_count(), 1);
327
328 let expected_fmt = "BFloat16Array\n[\n 1.0,\n null,\n 3.0,\n]";
329 assert_eq!(expected_fmt, format!("{:?}", array));
330
331 for (expected, value) in values.iter().zip(array.iter()) {
332 assert_eq!(*expected, value);
333 }
334 }
335}