use arrow_array::builder::Int64Builder;
use arrow_array::{Array, Int64Array};
use arrow_schema::DataType;
use deepsize::DeepSizeOf;
use lance_io::encodings::plain::PlainDecoder;
use lance_io::encodings::Decoder;
use snafu::{location, Location};
use std::collections::BTreeMap;
use tokio::io::AsyncWriteExt;
use lance_core::{Error, Result};
use lance_io::traits::{Reader, Writer};
#[derive(Clone, Debug, PartialEq, DeepSizeOf)]
pub struct PageInfo {
    pub position: usize,
    pub length: usize,
}
impl PageInfo {
    pub fn new(position: usize, length: usize) -> Self {
        Self { position, length }
    }
}
#[derive(Debug, Default, Clone, PartialEq, DeepSizeOf)]
pub struct PageTable {
    pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
}
impl PageTable {
    pub async fn load<'a>(
        reader: &dyn Reader,
        position: usize,
        min_field_id: i32,
        max_field_id: i32,
        num_batches: i32,
    ) -> Result<Self> {
        if max_field_id < min_field_id {
            return Err(Error::Internal {
                message: format!(
                    "max_field_id {} is less than min_field_id {}",
                    max_field_id, min_field_id
                ),
                location: location!(),
            });
        }
        let field_ids = min_field_id..=max_field_id;
        let num_columns = field_ids.clone().count();
        let length = num_columns * num_batches as usize * 2;
        let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
        let raw_arr = decoder.decode().await?;
        let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
        let mut pages = BTreeMap::default();
        for (field_pos, field_id) in field_ids.enumerate() {
            pages.insert(field_id, BTreeMap::default());
            for batch in 0..num_batches {
                let idx = field_pos as i32 * num_batches + batch;
                let batch_position = &arr.value((idx * 2) as usize);
                let batch_length = &arr.value((idx * 2 + 1) as usize);
                pages.get_mut(&field_id).unwrap().insert(
                    batch,
                    PageInfo {
                        position: *batch_position as usize,
                        length: *batch_length as usize,
                    },
                );
            }
        }
        Ok(Self { pages })
    }
    pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
        if self.pages.is_empty() {
            return Err(Error::InvalidInput {
                source: "empty page table".into(),
                location: location!(),
            });
        }
        let observed_min = *self.pages.keys().min().unwrap();
        if min_field_id > *self.pages.keys().min().unwrap() {
            return Err(Error::invalid_input(
                format!(
                    "field_id_offset {} is greater than the minimum field_id {}",
                    min_field_id, observed_min
                ),
                location!(),
            ));
        }
        let max_field_id = *self.pages.keys().max().unwrap();
        let field_ids = min_field_id..=max_field_id;
        let pos = writer.tell().await?;
        let num_batches = self
            .pages
            .values()
            .flat_map(|c_map| c_map.keys().max())
            .max()
            .unwrap()
            + 1;
        let mut builder =
            Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
        for field_id in field_ids {
            for batch in 0..num_batches {
                if let Some(page_info) = self.get(field_id, batch) {
                    builder.append_value(page_info.position as i64);
                    builder.append_value(page_info.length as i64);
                } else {
                    builder.append_slice(&[0, 0]);
                }
            }
        }
        let arr = builder.finish();
        writer
            .write_all(arr.into_data().buffers()[0].as_slice())
            .await?;
        Ok(pos)
    }
    pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
        self.pages
            .entry(field_id)
            .or_default()
            .insert(batch, page_info);
    }
    pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
        self.pages
            .get(&field_id)
            .and_then(|c_map| c_map.get(&batch))
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use pretty_assertions::assert_eq;
    use lance_io::local::LocalObjectReader;
    #[test]
    fn test_set_page_info() {
        let mut page_table = PageTable::default();
        let page_info = PageInfo::new(1, 2);
        page_table.set(10, 20, page_info.clone());
        let actual = page_table.get(10, 20).unwrap();
        assert_eq!(actual, &page_info);
    }
    #[tokio::test]
    async fn test_roundtrip_page_info() {
        let mut page_table = PageTable::default();
        let page_info = PageInfo::new(1, 2);
        page_table.set(10, 2, page_info.clone());
        page_table.set(11, 1, page_info.clone());
        page_table.set(13, 0, page_info.clone());
        page_table.set(13, 1, page_info.clone());
        page_table.set(13, 2, page_info.clone());
        page_table.set(13, 3, page_info.clone());
        let test_dir = tempfile::tempdir().unwrap();
        let path = test_dir.path().join("test");
        let starting_field_id = 9;
        let mut writer = tokio::fs::File::create(&path).await.unwrap();
        let pos = page_table
            .write(&mut writer, starting_field_id)
            .await
            .unwrap();
        writer.shutdown().await.unwrap();
        let reader = LocalObjectReader::open_local_path(&path, 1024, None)
            .await
            .unwrap();
        let actual = PageTable::load(
            reader.as_ref(),
            pos,
            starting_field_id, 13,                4,                 )
        .await
        .unwrap();
        let mut expected = actual.clone();
        let default_page_info = PageInfo::new(0, 0);
        let expected_default_pages = [
            (9, 0),
            (9, 1),
            (9, 2),
            (9, 3),
            (10, 0),
            (10, 1),
            (10, 3),
            (11, 0),
            (11, 2),
            (11, 3),
            (12, 0),
            (12, 1),
            (12, 2),
            (12, 3),
        ];
        for (field_id, batch) in expected_default_pages.iter() {
            expected.set(*field_id, *batch, default_page_info.clone());
        }
        assert_eq!(expected, actual);
    }
    #[tokio::test]
    async fn test_error_handling() {
        let mut page_table = PageTable::default();
        let test_dir = tempfile::tempdir().unwrap();
        let path = test_dir.path().join("test");
        let mut writer = tokio::fs::File::create(&path).await.unwrap();
        let res = page_table.write(&mut writer, 1).await;
        assert!(res.is_err());
        assert!(
            matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
        );
        let page_info = PageInfo::new(1, 2);
        page_table.set(0, 0, page_info.clone());
        let mut writer = tokio::fs::File::create(&path).await.unwrap();
        let res = page_table.write(&mut writer, 1).await;
        assert!(res.is_err());
        assert!(
            matches!(res.unwrap_err(), Error::InvalidInput { source, .. } 
                if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
        );
        let mut writer = tokio::fs::File::create(&path).await.unwrap();
        let res = page_table.write(&mut writer, 0).await.unwrap();
        let reader = LocalObjectReader::open_local_path(&path, 1024, None)
            .await
            .unwrap();
        let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
        assert!(res.is_err());
        assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
                if message.contains("max_field_id 0 is less than min_field_id 1")));
    }
}