use std::fmt::Formatter;
use std::slice;
use arrow_array::{
    builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
    FixedSizeBinaryArray,
};
use arrow_buffer::MutableBuffer;
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field as ArrowField};
use half::bf16;
use crate::FloatArray;
pub const ARROW_EXT_NAME_KEY: &str = "ARROW:extension:name";
pub const ARROW_EXT_META_KEY: &str = "ARROW:extension:metadata";
pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
pub fn is_bfloat16_field(field: &ArrowField) -> bool {
    field.data_type() == &DataType::FixedSizeBinary(2)
        && field
            .metadata()
            .get(ARROW_EXT_NAME_KEY)
            .map(|name| name == BFLOAT16_EXT_NAME)
            .unwrap_or_default()
}
#[derive(Debug)]
pub struct BFloat16Type {}
#[derive(Clone)]
pub struct BFloat16Array {
    inner: FixedSizeBinaryArray,
}
impl std::fmt::Debug for BFloat16Array {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "BFloat16Array\n[\n")?;
        from_arrow::print_long_array(&self.inner, f, |array, i, f| {
            if array.is_null(i) {
                write!(f, "null")
            } else {
                let binary_values = array.value(i);
                let value =
                    bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
                write!(f, "{:?}", value)
            }
        })?;
        write!(f, "]")
    }
}
impl BFloat16Array {
    pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
        let values: Vec<bf16> = iter.into_iter().collect();
        values.into()
    }
    pub fn iter(&self) -> BFloat16Iter {
        BFloat16Iter::new(self)
    }
    pub fn value(&self, i: usize) -> bf16 {
        assert!(
            i < self.len(),
            "Trying to access an element at index {} from a BFloat16Array of length {}",
            i,
            self.len()
        );
        unsafe { self.value_unchecked(i) }
    }
    pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
        let binary_value = self.inner.value_unchecked(i);
        bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
    }
    pub fn into_inner(self) -> FixedSizeBinaryArray {
        self.inner
    }
}
impl<'a> ArrayAccessor for &'a BFloat16Array {
    type Item = bf16;
    fn value(&self, index: usize) -> Self::Item {
        BFloat16Array::value(self, index)
    }
    unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
        BFloat16Array::value_unchecked(self, index)
    }
}
impl Array for BFloat16Array {
    fn as_any(&self) -> &dyn std::any::Any {
        self.inner.as_any()
    }
    fn to_data(&self) -> arrow_data::ArrayData {
        self.inner.to_data()
    }
    fn into_data(self) -> arrow_data::ArrayData {
        self.inner.into_data()
    }
    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
        let inner_array: &dyn Array = &self.inner;
        inner_array.slice(offset, length)
    }
    fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
        self.inner.nulls()
    }
    fn data_type(&self) -> &DataType {
        self.inner.data_type()
    }
    fn len(&self) -> usize {
        self.inner.len()
    }
    fn is_empty(&self) -> bool {
        self.inner.is_empty()
    }
    fn offset(&self) -> usize {
        self.inner.offset()
    }
    fn get_array_memory_size(&self) -> usize {
        self.inner.get_array_memory_size()
    }
    fn get_buffer_memory_size(&self) -> usize {
        self.inner.get_buffer_memory_size()
    }
}
impl FromIterator<Option<bf16>> for BFloat16Array {
    fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
        let mut buffer = MutableBuffer::new(10);
        let mut nulls = BooleanBufferBuilder::new(10);
        let mut len = 0;
        for maybe_value in iter {
            if let Some(value) = maybe_value {
                let bytes = value.to_le_bytes();
                buffer.extend(bytes);
            } else {
                buffer.extend([0u8, 0u8]);
            }
            nulls.append(maybe_value.is_some());
            len += 1;
        }
        let null_buffer = nulls.finish();
        let num_valid = null_buffer.count_set_bits();
        let null_buffer = if num_valid == len {
            None
        } else {
            Some(null_buffer.into_inner())
        };
        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
            .len(len)
            .add_buffer(buffer.into())
            .null_bit_buffer(null_buffer);
        let array_data = unsafe { array_data.build_unchecked() };
        Self {
            inner: FixedSizeBinaryArray::from(array_data),
        }
    }
}
impl FromIterator<bf16> for BFloat16Array {
    fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
        Self::from_iter_values(iter)
    }
}
impl From<Vec<bf16>> for BFloat16Array {
    fn from(data: Vec<bf16>) -> Self {
        let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
        let bytes = data.iter().flat_map(|val| {
            let bytes = val.to_bits().to_le_bytes();
            bytes.to_vec()
        });
        buffer.extend(bytes);
        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
            .len(data.len())
            .add_buffer(buffer.into());
        let array_data = unsafe { array_data.build_unchecked() };
        Self {
            inner: FixedSizeBinaryArray::from(array_data),
        }
    }
}
impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
    type Error = ArrowError;
    fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
        if value.value_length() == 2 {
            Ok(Self { inner: value })
        } else {
            Err(ArrowError::InvalidArgumentError(
                "FixedSizeBinaryArray must have a value length of 2".to_string(),
            ))
        }
    }
}
impl PartialEq<Self> for BFloat16Array {
    fn eq(&self, other: &Self) -> bool {
        self.inner.eq(&other.inner)
    }
}
type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
mod from_arrow {
    use arrow_array::Array;
    pub(super) fn print_long_array<A, F>(
        array: &A,
        f: &mut std::fmt::Formatter,
        print_item: F,
    ) -> std::fmt::Result
    where
        A: Array,
        F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
    {
        let head = std::cmp::min(10, array.len());
        for i in 0..head {
            if array.is_null(i) {
                writeln!(f, "  null,")?;
            } else {
                write!(f, "  ")?;
                print_item(array, i, f)?;
                writeln!(f, ",")?;
            }
        }
        if array.len() > 10 {
            if array.len() > 20 {
                writeln!(f, "  ...{} elements...,", array.len() - 20)?;
            }
            let tail = std::cmp::max(head, array.len() - 10);
            for i in tail..array.len() {
                if array.is_null(i) {
                    writeln!(f, "  null,")?;
                } else {
                    write!(f, "  ")?;
                    print_item(array, i, f)?;
                    writeln!(f, ",")?;
                }
            }
        }
        Ok(())
    }
}
impl FloatArray<BFloat16Type> for BFloat16Array {
    type FloatType = BFloat16Type;
    fn as_slice(&self) -> &[bf16] {
        unsafe {
            slice::from_raw_parts(
                self.inner.value_data().as_ptr() as *const bf16,
                self.inner.value_data().len() / 2,
            )
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_basics() {
        let values: Vec<f32> = vec![1.0, 2.0, 3.0];
        let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
        let array = BFloat16Array::from_iter_values(values.clone());
        let array2 = BFloat16Array::from(values.clone());
        assert_eq!(array, array2);
        assert_eq!(array.len(), 3);
        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  2.0,\n  3.0,\n]";
        assert_eq!(expected_fmt, format!("{:?}", array));
        for (expected, value) in values.iter().zip(array.iter()) {
            assert_eq!(Some(*expected), value);
        }
        for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
            assert_eq!(Some(*expected), value);
        }
    }
    #[test]
    fn test_nulls() {
        let values: Vec<Option<bf16>> =
            vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
        let array = BFloat16Array::from_iter(values.clone());
        assert_eq!(array.len(), 3);
        assert_eq!(array.null_count(), 1);
        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  null,\n  3.0,\n]";
        assert_eq!(expected_fmt, format!("{:?}", array));
        for (expected, value) in values.iter().zip(array.iter()) {
            assert_eq!(*expected, value);
        }
    }
}