1use arrow_array::builder::Int64Builder;
5use arrow_array::{Array, Int64Array};
6use arrow_schema::DataType;
7use deepsize::DeepSizeOf;
8use lance_io::encodings::plain::PlainDecoder;
9use lance_io::encodings::Decoder;
10use snafu::location;
11use std::collections::BTreeMap;
12use tokio::io::AsyncWriteExt;
13
14use lance_core::{Error, Result};
15use lance_io::traits::{Reader, Writer};
16
17#[derive(Clone, Debug, PartialEq, DeepSizeOf)]
18pub struct PageInfo {
19 pub position: usize,
20 pub length: usize,
21}
22
23impl PageInfo {
24 pub fn new(position: usize, length: usize) -> Self {
25 Self { position, length }
26 }
27}
28
29#[derive(Debug, Default, Clone, PartialEq, DeepSizeOf)]
32pub struct PageTable {
33 pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
35}
36
37impl PageTable {
38 pub async fn load(
55 reader: &dyn Reader,
56 position: usize,
57 min_field_id: i32,
58 max_field_id: i32,
59 num_batches: i32,
60 ) -> Result<Self> {
61 if max_field_id < min_field_id {
62 return Err(Error::Internal {
63 message: format!(
64 "max_field_id {} is less than min_field_id {}",
65 max_field_id, min_field_id
66 ),
67 location: location!(),
68 });
69 }
70
71 let field_ids = min_field_id..=max_field_id;
72 let num_columns = field_ids.clone().count();
73 let length = num_columns * num_batches as usize * 2;
74 let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
75 let raw_arr = decoder.decode().await?;
76 let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
77
78 let mut pages = BTreeMap::default();
79 for (field_pos, field_id) in field_ids.enumerate() {
80 pages.insert(field_id, BTreeMap::default());
81 for batch in 0..num_batches {
82 let idx = field_pos as i32 * num_batches + batch;
83 let batch_position = &arr.value((idx * 2) as usize);
84 let batch_length = &arr.value((idx * 2 + 1) as usize);
85 pages.get_mut(&field_id).unwrap().insert(
86 batch,
87 PageInfo {
88 position: *batch_position as usize,
89 length: *batch_length as usize,
90 },
91 );
92 }
93 }
94
95 Ok(Self { pages })
96 }
97
98 pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
108 if self.pages.is_empty() {
109 return Err(Error::InvalidInput {
110 source: "empty page table".into(),
111 location: location!(),
112 });
113 }
114
115 let observed_min = *self.pages.keys().min().unwrap();
116 if min_field_id > *self.pages.keys().min().unwrap() {
117 return Err(Error::invalid_input(
118 format!(
119 "field_id_offset {} is greater than the minimum field_id {}",
120 min_field_id, observed_min
121 ),
122 location!(),
123 ));
124 }
125 let max_field_id = *self.pages.keys().max().unwrap();
126 let field_ids = min_field_id..=max_field_id;
127
128 let pos = writer.tell().await?;
129 let num_batches = self
130 .pages
131 .values()
132 .flat_map(|c_map| c_map.keys().max())
133 .max()
134 .unwrap()
135 + 1;
136
137 let mut builder =
138 Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
139 for field_id in field_ids {
140 for batch in 0..num_batches {
141 if let Some(page_info) = self.get(field_id, batch) {
142 builder.append_value(page_info.position as i64);
143 builder.append_value(page_info.length as i64);
144 } else {
145 builder.append_slice(&[0, 0]);
146 }
147 }
148 }
149 let arr = builder.finish();
150 writer
151 .write_all(arr.into_data().buffers()[0].as_slice())
152 .await?;
153
154 Ok(pos)
155 }
156
157 pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
159 self.pages
160 .entry(field_id)
161 .or_default()
162 .insert(batch, page_info);
163 }
164
165 pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
166 self.pages
167 .get(&field_id)
168 .and_then(|c_map| c_map.get(&batch))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174
175 use super::*;
176 use lance_core::utils::tempfile::TempStdFile;
177 use pretty_assertions::assert_eq;
178
179 use lance_io::local::LocalObjectReader;
180
181 #[test]
182 fn test_set_page_info() {
183 let mut page_table = PageTable::default();
184 let page_info = PageInfo::new(1, 2);
185 page_table.set(10, 20, page_info.clone());
186
187 let actual = page_table.get(10, 20).unwrap();
188 assert_eq!(actual, &page_info);
189 }
190
191 #[tokio::test]
192 async fn test_roundtrip_page_info() {
193 let mut page_table = PageTable::default();
194 let page_info = PageInfo::new(1, 2);
195
196 page_table.set(10, 2, page_info.clone());
198 page_table.set(11, 1, page_info.clone());
199 page_table.set(13, 0, page_info.clone());
201 page_table.set(13, 1, page_info.clone());
202 page_table.set(13, 2, page_info.clone());
203 page_table.set(13, 3, page_info.clone());
204
205 let path = TempStdFile::default();
206
207 let starting_field_id = 9;
211
212 let mut writer = tokio::fs::File::create(&path).await.unwrap();
213 let pos = page_table
214 .write(&mut writer, starting_field_id)
215 .await
216 .unwrap();
217 writer.shutdown().await.unwrap();
218
219 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
220 .await
221 .unwrap();
222 let actual = PageTable::load(
223 reader.as_ref(),
224 pos,
225 starting_field_id, 13, 4, )
229 .await
230 .unwrap();
231
232 let mut expected = actual.clone();
234 let default_page_info = PageInfo::new(0, 0);
235 let expected_default_pages = [
236 (9, 0),
237 (9, 1),
238 (9, 2),
239 (9, 3),
240 (10, 0),
241 (10, 1),
242 (10, 3),
243 (11, 0),
244 (11, 2),
245 (11, 3),
246 (12, 0),
247 (12, 1),
248 (12, 2),
249 (12, 3),
250 ];
251 for (field_id, batch) in expected_default_pages.iter() {
252 expected.set(*field_id, *batch, default_page_info.clone());
253 }
254
255 assert_eq!(expected, actual);
256 }
257
258 #[tokio::test]
259 async fn test_error_handling() {
260 let mut page_table = PageTable::default();
261
262 let path = TempStdFile::default();
263
264 let mut writer = tokio::fs::File::create(&path).await.unwrap();
266 let res = page_table.write(&mut writer, 1).await;
267 assert!(res.is_err());
268 assert!(
269 matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
270 );
271
272 let page_info = PageInfo::new(1, 2);
273 page_table.set(0, 0, page_info.clone());
274
275 let mut writer = tokio::fs::File::create(&path).await.unwrap();
277 let res = page_table.write(&mut writer, 1).await;
278 assert!(res.is_err());
279 assert!(
280 matches!(res.unwrap_err(), Error::InvalidInput { source, .. }
281 if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
282 );
283
284 let mut writer = tokio::fs::File::create(&path).await.unwrap();
285 let res = page_table.write(&mut writer, 0).await.unwrap();
286
287 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
288 .await
289 .unwrap();
290
291 let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
293 assert!(res.is_err());
294 assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
295 if message.contains("max_field_id 0 is less than min_field_id 1")));
296 }
297}