1use arrow_array::builder::Int64Builder;
5use arrow_array::{Array, Int64Array};
6use arrow_schema::DataType;
7use deepsize::DeepSizeOf;
8use lance_io::encodings::Decoder;
9use lance_io::encodings::plain::PlainDecoder;
10use std::collections::BTreeMap;
11use tokio::io::AsyncWriteExt;
12
13use lance_core::{Error, Result};
14use lance_io::traits::{Reader, Writer};
15
16#[derive(Clone, Debug, PartialEq, DeepSizeOf)]
17pub struct PageInfo {
18 pub position: usize,
19 pub length: usize,
20}
21
22impl PageInfo {
23 pub fn new(position: usize, length: usize) -> Self {
24 Self { position, length }
25 }
26}
27
28#[derive(Debug, Default, Clone, PartialEq, DeepSizeOf)]
31pub struct PageTable {
32 pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
34}
35
36impl PageTable {
37 pub async fn load(
54 reader: &dyn Reader,
55 position: usize,
56 min_field_id: i32,
57 max_field_id: i32,
58 num_batches: i32,
59 ) -> Result<Self> {
60 if max_field_id < min_field_id {
61 return Err(Error::internal(format!(
62 "max_field_id {} is less than min_field_id {}",
63 max_field_id, min_field_id
64 )));
65 }
66
67 let field_ids = min_field_id..=max_field_id;
68 let num_columns = field_ids.clone().count();
69 let length = num_columns * num_batches as usize * 2;
70 let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
71 let raw_arr = decoder.decode().await?;
72 let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
73
74 let mut pages = BTreeMap::default();
75 for (field_pos, field_id) in field_ids.enumerate() {
76 pages.insert(field_id, BTreeMap::default());
77 for batch in 0..num_batches {
78 let idx = field_pos as i32 * num_batches + batch;
79 let batch_position = &arr.value((idx * 2) as usize);
80 let batch_length = &arr.value((idx * 2 + 1) as usize);
81 pages.get_mut(&field_id).unwrap().insert(
82 batch,
83 PageInfo {
84 position: *batch_position as usize,
85 length: *batch_length as usize,
86 },
87 );
88 }
89 }
90
91 Ok(Self { pages })
92 }
93
94 pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
104 if self.pages.is_empty() {
105 return Err(Error::invalid_input_source("empty page table".into()));
106 }
107
108 let observed_min = *self.pages.keys().min().unwrap();
109 if min_field_id > *self.pages.keys().min().unwrap() {
110 return Err(Error::invalid_input(format!(
111 "field_id_offset {} is greater than the minimum field_id {}",
112 min_field_id, observed_min
113 )));
114 }
115 let max_field_id = *self.pages.keys().max().unwrap();
116 let field_ids = min_field_id..=max_field_id;
117
118 let pos = writer.tell().await?;
119 let num_batches = self
120 .pages
121 .values()
122 .flat_map(|c_map| c_map.keys().max())
123 .max()
124 .unwrap()
125 + 1;
126
127 let mut builder =
128 Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
129 for field_id in field_ids {
130 for batch in 0..num_batches {
131 if let Some(page_info) = self.get(field_id, batch) {
132 builder.append_value(page_info.position as i64);
133 builder.append_value(page_info.length as i64);
134 } else {
135 builder.append_slice(&[0, 0]);
136 }
137 }
138 }
139 let arr = builder.finish();
140 writer
141 .write_all(arr.into_data().buffers()[0].as_slice())
142 .await?;
143
144 Ok(pos)
145 }
146
147 pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
149 self.pages
150 .entry(field_id)
151 .or_default()
152 .insert(batch, page_info);
153 }
154
155 pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
156 self.pages
157 .get(&field_id)
158 .and_then(|c_map| c_map.get(&batch))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164
165 use super::*;
166 use lance_core::utils::tempfile::TempStdFile;
167 use pretty_assertions::assert_eq;
168
169 use lance_io::local::LocalObjectReader;
170
171 #[test]
172 fn test_set_page_info() {
173 let mut page_table = PageTable::default();
174 let page_info = PageInfo::new(1, 2);
175 page_table.set(10, 20, page_info.clone());
176
177 let actual = page_table.get(10, 20).unwrap();
178 assert_eq!(actual, &page_info);
179 }
180
181 #[tokio::test]
182 async fn test_roundtrip_page_info() {
183 let mut page_table = PageTable::default();
184 let page_info = PageInfo::new(1, 2);
185
186 page_table.set(10, 2, page_info.clone());
188 page_table.set(11, 1, page_info.clone());
189 page_table.set(13, 0, page_info.clone());
191 page_table.set(13, 1, page_info.clone());
192 page_table.set(13, 2, page_info.clone());
193 page_table.set(13, 3, page_info.clone());
194
195 let path = TempStdFile::default();
196
197 let starting_field_id = 9;
201
202 let mut writer = tokio::fs::File::create(&path).await.unwrap();
203 let pos = page_table
204 .write(&mut writer, starting_field_id)
205 .await
206 .unwrap();
207 AsyncWriteExt::shutdown(&mut writer).await.unwrap();
208
209 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
210 .await
211 .unwrap();
212 let actual = PageTable::load(
213 reader.as_ref(),
214 pos,
215 starting_field_id, 13, 4, )
219 .await
220 .unwrap();
221
222 let mut expected = actual.clone();
224 let default_page_info = PageInfo::new(0, 0);
225 let expected_default_pages = [
226 (9, 0),
227 (9, 1),
228 (9, 2),
229 (9, 3),
230 (10, 0),
231 (10, 1),
232 (10, 3),
233 (11, 0),
234 (11, 2),
235 (11, 3),
236 (12, 0),
237 (12, 1),
238 (12, 2),
239 (12, 3),
240 ];
241 for (field_id, batch) in expected_default_pages.iter() {
242 expected.set(*field_id, *batch, default_page_info.clone());
243 }
244
245 assert_eq!(expected, actual);
246 }
247
248 #[tokio::test]
249 async fn test_error_handling() {
250 let mut page_table = PageTable::default();
251
252 let path = TempStdFile::default();
253
254 let mut writer = tokio::fs::File::create(&path).await.unwrap();
256 let res = page_table.write(&mut writer, 1).await;
257 assert!(res.is_err());
258 assert!(
259 matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
260 );
261
262 let page_info = PageInfo::new(1, 2);
263 page_table.set(0, 0, page_info.clone());
264
265 let mut writer = tokio::fs::File::create(&path).await.unwrap();
267 let res = page_table.write(&mut writer, 1).await;
268 assert!(res.is_err());
269 assert!(
270 matches!(res.unwrap_err(), Error::InvalidInput { source, .. }
271 if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
272 );
273
274 let mut writer = tokio::fs::File::create(&path).await.unwrap();
275 let res = page_table.write(&mut writer, 0).await.unwrap();
276
277 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
278 .await
279 .unwrap();
280
281 let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
283 assert!(res.is_err());
284 assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
285 if message.contains("max_field_id 0 is less than min_field_id 1")));
286 }
287}