1use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10 builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11 FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::{FloatArray, ARROW_EXT_NAME_KEY};
19
20pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
22
23pub fn is_bfloat16_field(field: &ArrowField) -> bool {
28 field.data_type() == &DataType::FixedSizeBinary(2)
29 && field
30 .metadata()
31 .get(ARROW_EXT_NAME_KEY)
32 .map(|name| name == BFLOAT16_EXT_NAME)
33 .unwrap_or_default()
34}
35
36#[derive(Debug)]
40pub struct BFloat16Type {}
41
42#[derive(Clone)]
48pub struct BFloat16Array {
49 inner: FixedSizeBinaryArray,
50}
51
52impl std::fmt::Debug for BFloat16Array {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "BFloat16Array\n[\n")?;
55 from_arrow::print_long_array(&self.inner, f, |array, i, f| {
56 if array.is_null(i) {
57 write!(f, "null")
58 } else {
59 let binary_values = array.value(i);
60 let value =
61 bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
62 write!(f, "{:?}", value)
63 }
64 })?;
65 write!(f, "]")
66 }
67}
68
69impl BFloat16Array {
70 pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
71 let values: Vec<bf16> = iter.into_iter().collect();
72 values.into()
73 }
74
75 pub fn iter(&self) -> BFloat16Iter<'_> {
76 BFloat16Iter::new(self)
77 }
78
79 pub fn value(&self, i: usize) -> bf16 {
80 assert!(
81 i < self.len(),
82 "Trying to access an element at index {} from a BFloat16Array of length {}",
83 i,
84 self.len()
85 );
86 unsafe { self.value_unchecked(i) }
89 }
90
91 pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
94 let binary_value = self.inner.value_unchecked(i);
95 bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
96 }
97
98 pub fn into_inner(self) -> FixedSizeBinaryArray {
99 self.inner
100 }
101}
102
103impl ArrayAccessor for &BFloat16Array {
104 type Item = bf16;
105
106 fn value(&self, index: usize) -> Self::Item {
107 BFloat16Array::value(self, index)
108 }
109
110 unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
111 BFloat16Array::value_unchecked(self, index)
112 }
113}
114
115impl Array for BFloat16Array {
116 fn as_any(&self) -> &dyn std::any::Any {
117 self.inner.as_any()
118 }
119
120 fn to_data(&self) -> arrow_data::ArrayData {
121 self.inner.to_data()
122 }
123
124 fn into_data(self) -> arrow_data::ArrayData {
125 self.inner.into_data()
126 }
127
128 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
129 let inner_array: &dyn Array = &self.inner;
130 inner_array.slice(offset, length)
131 }
132
133 fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
134 self.inner.nulls()
135 }
136
137 fn data_type(&self) -> &DataType {
138 self.inner.data_type()
139 }
140
141 fn len(&self) -> usize {
142 self.inner.len()
143 }
144
145 fn is_empty(&self) -> bool {
146 self.inner.is_empty()
147 }
148
149 fn offset(&self) -> usize {
150 self.inner.offset()
151 }
152
153 fn get_array_memory_size(&self) -> usize {
154 self.inner.get_array_memory_size()
155 }
156
157 fn get_buffer_memory_size(&self) -> usize {
158 self.inner.get_buffer_memory_size()
159 }
160}
161
162impl FromIterator<Option<bf16>> for BFloat16Array {
163 fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
164 let mut buffer = MutableBuffer::new(10);
165 let mut nulls = BooleanBufferBuilder::new(10);
167 let mut len = 0;
168
169 for maybe_value in iter {
170 if let Some(value) = maybe_value {
171 let bytes = value.to_le_bytes();
172 buffer.extend(bytes);
173 } else {
174 buffer.extend([0u8, 0u8]);
175 }
176 nulls.append(maybe_value.is_some());
177 len += 1;
178 }
179
180 let null_buffer = nulls.finish();
181 let num_valid = null_buffer.count_set_bits();
182 let null_buffer = if num_valid == len {
183 None
184 } else {
185 Some(null_buffer.into_inner())
186 };
187
188 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
189 .len(len)
190 .add_buffer(buffer.into())
191 .null_bit_buffer(null_buffer);
192 let array_data = unsafe { array_data.build_unchecked() };
193 Self {
194 inner: FixedSizeBinaryArray::from(array_data),
195 }
196 }
197}
198
199impl FromIterator<bf16> for BFloat16Array {
200 fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
201 Self::from_iter_values(iter)
202 }
203}
204
205impl From<Vec<bf16>> for BFloat16Array {
206 fn from(data: Vec<bf16>) -> Self {
207 let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
208
209 let bytes = data.iter().flat_map(|val| {
210 let bytes = val.to_bits().to_le_bytes();
211 bytes.to_vec()
212 });
213
214 buffer.extend(bytes);
215 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
216 .len(data.len())
217 .add_buffer(buffer.into());
218 let array_data = unsafe { array_data.build_unchecked() };
219 Self {
220 inner: FixedSizeBinaryArray::from(array_data),
221 }
222 }
223}
224
225impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
226 type Error = ArrowError;
227
228 fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
229 if value.value_length() == 2 {
230 Ok(Self { inner: value })
231 } else {
232 Err(ArrowError::InvalidArgumentError(
233 "FixedSizeBinaryArray must have a value length of 2".to_string(),
234 ))
235 }
236 }
237}
238
239impl PartialEq<Self> for BFloat16Array {
240 fn eq(&self, other: &Self) -> bool {
241 self.inner.eq(&other.inner)
242 }
243}
244
245type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
246
247mod from_arrow {
249 use arrow_array::Array;
250
251 pub(super) fn print_long_array<A, F>(
253 array: &A,
254 f: &mut std::fmt::Formatter,
255 print_item: F,
256 ) -> std::fmt::Result
257 where
258 A: Array,
259 F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
260 {
261 let head = std::cmp::min(10, array.len());
262
263 for i in 0..head {
264 if array.is_null(i) {
265 writeln!(f, " null,")?;
266 } else {
267 write!(f, " ")?;
268 print_item(array, i, f)?;
269 writeln!(f, ",")?;
270 }
271 }
272 if array.len() > 10 {
273 if array.len() > 20 {
274 writeln!(f, " ...{} elements...,", array.len() - 20)?;
275 }
276
277 let tail = std::cmp::max(head, array.len() - 10);
278
279 for i in tail..array.len() {
280 if array.is_null(i) {
281 writeln!(f, " null,")?;
282 } else {
283 write!(f, " ")?;
284 print_item(array, i, f)?;
285 writeln!(f, ",")?;
286 }
287 }
288 }
289 Ok(())
290 }
291}
292
293impl FloatArray<BFloat16Type> for BFloat16Array {
294 type FloatType = BFloat16Type;
295
296 fn as_slice(&self) -> &[bf16] {
297 unsafe {
298 slice::from_raw_parts(
299 self.inner.value_data().as_ptr() as *const bf16,
300 self.inner.value_data().len() / 2,
301 )
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_basics() {
312 let values: Vec<f32> = vec![1.0, 2.0, 3.0];
313 let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
314
315 let array = BFloat16Array::from_iter_values(values.clone());
316 let array2 = BFloat16Array::from(values.clone());
317 assert_eq!(array, array2);
318 assert_eq!(array.len(), 3);
319
320 let expected_fmt = "BFloat16Array\n[\n 1.0,\n 2.0,\n 3.0,\n]";
321 assert_eq!(expected_fmt, format!("{:?}", array));
322
323 for (expected, value) in values.iter().zip(array.iter()) {
324 assert_eq!(Some(*expected), value);
325 }
326
327 for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
328 assert_eq!(Some(*expected), value);
329 }
330 }
331
332 #[test]
333 fn test_nulls() {
334 let values: Vec<Option<bf16>> =
335 vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
336 let array = BFloat16Array::from_iter(values.clone());
337 assert_eq!(array.len(), 3);
338 assert_eq!(array.null_count(), 1);
339
340 let expected_fmt = "BFloat16Array\n[\n 1.0,\n null,\n 3.0,\n]";
341 assert_eq!(expected_fmt, format!("{:?}", array));
342
343 for (expected, value) in values.iter().zip(array.iter()) {
344 assert_eq!(*expected, value);
345 }
346 }
347}