1use rkyv::{
2 access_unchecked, api::high::to_bytes_with_alloc, deserialize, rancor::Error,
3 ser::allocator::Arena, util::AlignedVec,
4};
5use rkyv_derive::{Archive, Deserialize, Serialize};
6
7#[derive(Archive, Deserialize, Serialize, Debug)]
8pub struct Segment {
9 pub slope: f64,
10 pub intercept: f64,
11 pub start_key: u64,
12 pub end_key: u64,
13}
14
15#[derive(Archive, Deserialize, Serialize, Debug)]
16pub struct PGMIndex {
17 pub segments: Vec<Segment>,
18 pub top_level: Option<Vec<Segment>>,
19 pub epsilon: usize,
20}
21
22use thiserror::Error;
23
24#[derive(Debug, Error)]
25pub enum PGMIndexError {
26 #[error("Keys are not sorted")]
27 KeysNotSorted,
28}
29
30macro_rules! ensure {
31 ($cond:expr, $err:expr) => {
32 if !$cond {
33 return Err($err);
34 }
35 };
36}
37
38impl PGMIndex {
39 pub fn build(keys: &[u64], epsilon: usize) -> Result<Self, PGMIndexError> {
40 ensure!(
41 keys.windows(2).all(|w| w[0] <= w[1]),
42 PGMIndexError::KeysNotSorted
43 );
44 PGMIndex::build_unsafe(keys, epsilon)
45 }
46
47 pub fn build_unsafe(keys: &[u64], epsilon: usize) -> Result<Self, PGMIndexError> {
50 let segments = Self::build_segments(keys, epsilon);
51 let top_keys: Vec<u64> = segments.iter().map(|s| s.start_key).collect();
54
55 let top_level = if top_keys.len() > 2 {
56 Some(Self::build_segments(&top_keys, epsilon))
57 } else {
58 None
59 };
60
61 Ok(Self {
62 segments,
63 top_level,
64 epsilon,
65 })
66 }
67
68 fn build_segments(keys: &[u64], epsilon: usize) -> Vec<Segment> {
69 let epsilon = epsilon as f64;
70 let mut segments = Vec::new();
71
72 let mut start = 0;
73 let mut s_min = f64::NEG_INFINITY;
74 let mut s_max = f64::INFINITY;
75
76 for i in 1..keys.len() {
77 let x0 = keys[start] as f64;
78 let y0 = start as f64;
79 let xi = keys[i] as f64;
80 let yi = i as f64;
81
82 if (xi - x0).abs() < f64::EPSILON {
83 continue;
84 }
85
86 let new_s_min = ((yi - epsilon) - y0) / (xi - x0);
87 let new_s_max = ((yi + epsilon) - y0) / (xi - x0);
88 s_min = s_min.max(new_s_min);
89 s_max = s_max.min(new_s_max);
90
91 if s_min > s_max {
92 let x1 = keys[i - 1] as f64;
93 let y1 = (i - 1) as f64;
94 let slope = if (x1 - x0).abs() < f64::EPSILON {
95 0.0
96 } else {
97 (y1 - y0) / (x1 - x0)
98 };
99 let intercept = y0 - slope * x0;
100
101 segments.push(Segment {
102 slope,
103 intercept,
104 start_key: keys[start],
105 end_key: keys[i - 1],
106 });
107
108 start = i - 1;
109 s_min = f64::NEG_INFINITY;
110 s_max = f64::INFINITY;
111 }
112 }
113
114 let x0 = keys[start] as f64;
115 let x1 = keys[keys.len() - 1] as f64;
116 let y0 = start as f64;
117 let y1 = (keys.len() - 1) as f64;
118 let slope = if (x1 - x0).abs() < f64::EPSILON {
119 0.0
120 } else {
121 (y1 - y0) / (x1 - x0)
122 };
123 let intercept = y0 - slope * x0;
124
125 segments.push(Segment {
126 slope,
127 intercept,
128 start_key: keys[start],
129 end_key: keys[keys.len() - 1],
130 });
131
132 segments
133 }
134
135 pub fn search(&self, key: u64) -> (usize, usize) {
138 let seg_index = if let Some(top) = &self.top_level {
139 let i = match top.binary_search_by_key(&key, |seg| seg.end_key) {
140 Ok(i) | Err(i) => i.min(top.len() - 1),
141 };
142
143 let top_seg = &top[i];
144 let approx_index = (top_seg.slope * key as f64 + top_seg.intercept)
145 .max(0.0)
146 .round() as usize;
147 approx_index.min(self.segments.len() - 1)
148 } else {
149 match self.segments.binary_search_by_key(&key, |seg| seg.end_key) {
150 Ok(i) | Err(i) => i.min(self.segments.len() - 1),
151 }
152 };
153
154 let seg = &self.segments[seg_index];
155 let predicted = seg.slope * key as f64 + seg.intercept;
156 let pos = predicted.max(0.0).round() as isize;
157
158 let lo = pos.saturating_sub(self.epsilon as isize).max(0) as usize;
159 let hi = (pos + self.epsilon as isize)
160 .min(self.total_keys() as isize - 1)
161 .max(0) as usize;
162
163 (lo, hi)
164 }
165
166 pub fn to_bytes(&self) -> Result<AlignedVec, Error> {
167 let mut arena = Arena::new();
168 to_bytes_with_alloc::<_, Error>(self, arena.acquire())
169 }
170
171 pub fn as_archived(bytes: &[u8]) -> Result<&rkyv::Archived<PGMIndex>, Error> {
174 rkyv::access::<rkyv::Archived<PGMIndex>, Error>(bytes)
175 }
176
177 pub unsafe fn as_archived_unchecked(bytes: &[u8]) -> &rkyv::Archived<PGMIndex> {
180 unsafe { access_unchecked::<rkyv::Archived<PGMIndex>>(bytes) }
181 }
182
183 pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
185 let archived = rkyv::access::<rkyv::Archived<PGMIndex>, Error>(bytes)?;
186 deserialize::<PGMIndex, Error>(archived)
187 }
188
189 fn total_keys(&self) -> usize {
190 self.segments.last().map(|s| s.end_key).unwrap_or(0) as usize + 1
191 }
192}
193
194impl ArchivedPGMIndex {
195 pub fn search(&self, key: u64) -> (usize, usize) {
198 let segments: &[ArchivedSegment] = &self.segments;
199 let epsilon = self.epsilon.to_native() as isize;
200
201 let seg_index = if let Some(top) = self.top_level.as_ref().map(|v| &**v) {
203 let i = match top.binary_search_by_key(&key, |seg| seg.end_key.to_native()) {
204 Ok(i) | Err(i) => i.min(top.len() - 1),
205 };
206 let seg = &top[i];
207 let estimate = (seg.slope * key as f64 + seg.intercept).max(0.0).round() as usize;
208 estimate.min(segments.len().saturating_sub(1))
209 } else {
210 match segments.binary_search_by_key(&key, |seg| seg.end_key.to_native()) {
211 Ok(i) | Err(i) => i.min(segments.len().saturating_sub(1)),
212 }
213 };
214
215 let seg = &segments[seg_index];
216 let predicted = (seg.slope * key as f64 + seg.intercept).max(0.0).round() as isize;
217
218 let lo = predicted.saturating_sub(epsilon).max(0) as usize;
220 let hi = (predicted + epsilon)
221 .min(self.total_keys() as isize - 1)
222 .max(0) as usize;
223
224 (lo, hi)
225 }
226
227 fn total_keys(&self) -> usize {
228 self.segments
229 .last()
230 .map(|s| s.end_key.to_native())
231 .unwrap_or(0) as usize
232 + 1
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_build_and_search() {
242 let keys: Vec<u64> = (0..1000).step_by(5).collect();
243 let epsilon = 8;
244 let pgm = PGMIndex::build(&keys, epsilon).unwrap();
245
246 let key = 500;
247 let (lo, hi) = pgm.search(key);
248 assert!(
249 keys[lo..=hi].binary_search(&key).is_ok(),
250 "Key should be found within predicted range"
251 );
252
253 let key = 503;
254 let (lo, hi) = pgm.search(key);
255 assert!(
256 keys[lo..=hi].binary_search(&key).is_err(),
257 "Non-existent key should not be found, but range should still be valid"
258 );
259 }
260
261 #[test]
262 fn test_unsorted_input_fails() {
263 let unsorted_keys = vec![1, 3, 2, 4];
264 let result = PGMIndex::build(&unsorted_keys, 4);
265 assert!(matches!(result, Err(PGMIndexError::KeysNotSorted)));
266 }
267
268 #[test]
269 fn test_zero_copy_deserialization() {
270 let keys: Vec<u64> = (0..5000).step_by(10).collect();
271 let pgm = PGMIndex::build(&keys, 32).unwrap();
272 let bytes = pgm.to_bytes().expect("serialize failed");
273
274 let archived = PGMIndex::as_archived(&bytes).expect("zero-copy deserialize failed");
275 let key = 1000;
276 let (lo, hi) = archived.search(key);
277
278 assert!(
279 keys[lo..=hi].binary_search(&key).is_ok(),
280 "Key should be in range after zero-copy read"
281 );
282 }
283
284 #[test]
285 fn test_copy_deserialization() {
286 let keys: Vec<u64> = (0..10000).step_by(7).collect();
287 let pgm = PGMIndex::build(&keys, 64).unwrap();
288 let bytes = pgm.to_bytes().expect("serialize failed");
289
290 let restored = PGMIndex::from_bytes(&bytes).expect("full deserialize failed");
291 assert_eq!(restored.epsilon, pgm.epsilon);
292 assert_eq!(restored.segments.len(), pgm.segments.len());
293
294 let key = 9876;
295 let (lo, hi) = restored.search(key);
296 let found = keys[lo..=hi].binary_search(&key).ok();
297
298 if let Some(actual_index) = found {
299 assert_eq!(keys[lo + actual_index], key);
300 } else {
301 assert!(true, "Key not present in input set (as expected)");
302 }
303 }
304}