use arrow_array::builder::Int64Builder;
use arrow_array::{Array, Int64Array};
use arrow_schema::DataType;
use lance_io::encodings::plain::PlainDecoder;
use lance_io::encodings::Decoder;
use snafu::{location, Location};
use std::collections::BTreeMap;
use tokio::io::AsyncWriteExt;
use lance_core::{Error, Result};
use lance_io::traits::{Reader, Writer};
#[derive(Clone, Debug, PartialEq)]
pub struct PageInfo {
pub position: usize,
pub length: usize,
}
impl PageInfo {
pub fn new(position: usize, length: usize) -> Self {
Self { position, length }
}
}
#[derive(Debug, Default, Clone, PartialEq)]
pub struct PageTable {
pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
}
impl PageTable {
pub async fn load<'a>(
reader: &dyn Reader,
position: usize,
num_columns: i32,
num_batches: i32,
field_id_offset: i32,
) -> Result<Self> {
let length = num_columns * num_batches * 2;
let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length as usize)?;
let raw_arr = decoder.decode().await?;
let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
let mut pages = BTreeMap::default();
for col in 0..num_columns {
let field_id = col + field_id_offset;
pages.insert(field_id, BTreeMap::default());
for batch in 0..num_batches {
let idx = col * num_batches + batch;
let batch_position = &arr.value((idx * 2) as usize);
let batch_length = &arr.value((idx * 2 + 1) as usize);
pages.get_mut(&field_id).unwrap().insert(
batch,
PageInfo {
position: *batch_position as usize,
length: *batch_length as usize,
},
);
}
}
Ok(Self { pages })
}
pub async fn write(&self, writer: &mut dyn Writer, field_id_offset: i32) -> Result<usize> {
if self.pages.is_empty() {
return Err(Error::InvalidInput {
source: "empty page table".into(),
location: location!(),
});
}
let pos = writer.tell().await?;
let num_columns = self.pages.keys().max().unwrap() + 1 - field_id_offset;
let num_batches = self
.pages
.values()
.flat_map(|c_map| c_map.keys().max())
.max()
.unwrap()
+ 1;
let mut builder = Int64Builder::with_capacity((num_columns * num_batches) as usize);
for col in 0..num_columns {
for batch in 0..num_batches {
let field_id = col + field_id_offset;
if let Some(page_info) = self.get(field_id, batch) {
builder.append_value(page_info.position as i64);
builder.append_value(page_info.length as i64);
} else {
builder.append_slice(&[0, 0]);
}
}
}
let arr = builder.finish();
writer
.write_all(arr.into_data().buffers()[0].as_slice())
.await?;
Ok(pos)
}
pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
self.pages
.entry(field_id)
.or_default()
.insert(batch, page_info);
}
pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
self.pages
.get(&field_id)
.and_then(|c_map| c_map.get(&batch))
}
}
#[cfg(test)]
mod tests {
use super::*;
use lance_io::local::LocalObjectReader;
#[test]
fn test_set_page_info() {
let mut page_table = PageTable::default();
let page_info = PageInfo::new(1, 2);
page_table.set(10, 20, page_info.clone());
let actual = page_table.get(10, 20).unwrap();
assert_eq!(actual, &page_info);
}
#[tokio::test]
async fn test_roundtrip_page_info() {
let mut page_table = PageTable::default();
let page_info = PageInfo::new(1, 2);
page_table.set(10, 2, page_info.clone());
page_table.set(11, 1, page_info.clone());
page_table.set(12, 0, page_info.clone());
page_table.set(12, 1, page_info.clone());
page_table.set(12, 2, page_info.clone());
page_table.set(12, 3, page_info.clone());
let test_dir = tempfile::tempdir().unwrap();
let path = test_dir.path().join("test");
let starting_field_id = 9;
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let pos = page_table
.write(&mut writer, starting_field_id)
.await
.unwrap();
writer.shutdown().await.unwrap();
let reader = LocalObjectReader::open_local_path(&path, 1024)
.await
.unwrap();
let actual = PageTable::load(
reader.as_ref(),
pos,
3, 4, starting_field_id, )
.await
.unwrap();
let mut expected = actual.clone();
let default_page_info = PageInfo::new(0, 0);
expected.set(9, 0, default_page_info.clone());
expected.set(9, 1, default_page_info.clone());
expected.set(9, 2, default_page_info.clone());
expected.set(9, 3, default_page_info.clone());
expected.set(10, 0, default_page_info.clone());
expected.set(10, 1, default_page_info.clone());
expected.set(10, 3, default_page_info.clone());
expected.set(11, 0, default_page_info.clone());
expected.set(11, 2, default_page_info.clone());
expected.set(11, 3, default_page_info);
assert_eq!(expected, actual);
}
}