pgm_rs/
pgm.rs

1use rkyv::{
2    access_unchecked, api::high::to_bytes_with_alloc, deserialize, rancor::Error,
3    ser::allocator::Arena, util::AlignedVec,
4};
5use rkyv_derive::{Archive, Deserialize, Serialize};
6
7#[derive(Archive, Deserialize, Serialize, Debug)]
8pub struct Segment {
9    pub slope: f64,
10    pub intercept: f64,
11    pub start_key: u64,
12    pub end_key: u64,
13}
14
15#[derive(Archive, Deserialize, Serialize, Debug)]
16pub struct PGMIndex {
17    pub segments: Vec<Segment>,
18    pub top_level: Option<Vec<Segment>>,
19    pub epsilon: usize,
20}
21
22use thiserror::Error;
23
24#[derive(Debug, Error)]
25pub enum PGMIndexError {
26    #[error("Keys are not sorted")]
27    KeysNotSorted,
28}
29
30macro_rules! ensure {
31    ($cond:expr, $err:expr) => {
32        if !$cond {
33            return Err($err);
34        }
35    };
36}
37
38impl PGMIndex {
39    pub fn build(keys: &[u64], epsilon: usize) -> Result<Self, PGMIndexError> {
40        ensure!(
41            keys.windows(2).all(|w| w[0] <= w[1]),
42            PGMIndexError::KeysNotSorted
43        );
44        PGMIndex::build_unsafe(keys, epsilon)
45    }
46
47    /// Build the index without safety checks for invariants the algorithm
48    /// relies for the algorithm to be accurate.
49    pub fn build_unsafe(keys: &[u64], epsilon: usize) -> Result<Self, PGMIndexError> {
50        let segments = Self::build_segments(keys, epsilon);
51        // Top-level input: start keys of each segment
52
53        let top_keys: Vec<u64> = segments.iter().map(|s| s.start_key).collect();
54
55        let top_level = if top_keys.len() > 2 {
56            Some(Self::build_segments(&top_keys, epsilon))
57        } else {
58            None
59        };
60
61        Ok(Self {
62            segments,
63            top_level,
64            epsilon,
65        })
66    }
67
68    fn build_segments(keys: &[u64], epsilon: usize) -> Vec<Segment> {
69        let epsilon = epsilon as f64;
70        let mut segments = Vec::new();
71
72        let mut start = 0;
73        let mut s_min = f64::NEG_INFINITY;
74        let mut s_max = f64::INFINITY;
75
76        for i in 1..keys.len() {
77            let x0 = keys[start] as f64;
78            let y0 = start as f64;
79            let xi = keys[i] as f64;
80            let yi = i as f64;
81
82            if (xi - x0).abs() < f64::EPSILON {
83                continue;
84            }
85
86            let new_s_min = ((yi - epsilon) - y0) / (xi - x0);
87            let new_s_max = ((yi + epsilon) - y0) / (xi - x0);
88            s_min = s_min.max(new_s_min);
89            s_max = s_max.min(new_s_max);
90
91            if s_min > s_max {
92                let x1 = keys[i - 1] as f64;
93                let y1 = (i - 1) as f64;
94                let slope = if (x1 - x0).abs() < f64::EPSILON {
95                    0.0
96                } else {
97                    (y1 - y0) / (x1 - x0)
98                };
99                let intercept = y0 - slope * x0;
100
101                segments.push(Segment {
102                    slope,
103                    intercept,
104                    start_key: keys[start],
105                    end_key: keys[i - 1],
106                });
107
108                start = i - 1;
109                s_min = f64::NEG_INFINITY;
110                s_max = f64::INFINITY;
111            }
112        }
113
114        let x0 = keys[start] as f64;
115        let x1 = keys[keys.len() - 1] as f64;
116        let y0 = start as f64;
117        let y1 = (keys.len() - 1) as f64;
118        let slope = if (x1 - x0).abs() < f64::EPSILON {
119            0.0
120        } else {
121            (y1 - y0) / (x1 - x0)
122        };
123        let intercept = y0 - slope * x0;
124
125        segments.push(Segment {
126            slope,
127            intercept,
128            start_key: keys[start],
129            end_key: keys[keys.len() - 1],
130        });
131
132        segments
133    }
134
135    /// Returns the index range [lo, hi] where `key` may appear.
136    /// This range is guaranteed to contain the key "if" it exists.
137    pub fn search(&self, key: u64) -> (usize, usize) {
138        let seg_index = if let Some(top) = &self.top_level {
139            let i = match top.binary_search_by_key(&key, |seg| seg.end_key) {
140                Ok(i) | Err(i) => i.min(top.len() - 1),
141            };
142
143            let top_seg = &top[i];
144            let approx_index = (top_seg.slope * key as f64 + top_seg.intercept)
145                .max(0.0)
146                .round() as usize;
147            approx_index.min(self.segments.len() - 1)
148        } else {
149            match self.segments.binary_search_by_key(&key, |seg| seg.end_key) {
150                Ok(i) | Err(i) => i.min(self.segments.len() - 1),
151            }
152        };
153
154        let seg = &self.segments[seg_index];
155        let predicted = seg.slope * key as f64 + seg.intercept;
156        let pos = predicted.max(0.0).round() as isize;
157
158        let lo = pos.saturating_sub(self.epsilon as isize).max(0) as usize;
159        let hi = (pos + self.epsilon as isize)
160            .min(self.total_keys() as isize - 1)
161            .max(0) as usize;
162
163        (lo, hi)
164    }
165
166    pub fn to_bytes(&self) -> Result<AlignedVec, Error> {
167        let mut arena = Arena::new();
168        to_bytes_with_alloc::<_, Error>(self, arena.acquire())
169    }
170
171    /// Provides zero-copy access to the archived form.
172    /// Lifetime is tied to `bytes`.
173    pub fn as_archived(bytes: &[u8]) -> Result<&rkyv::Archived<PGMIndex>, Error> {
174        rkyv::access::<rkyv::Archived<PGMIndex>, Error>(bytes)
175    }
176
177    /// Unsafely access the archived index without bounds or validation.
178    /// Use only when buffer is known to be valid.
179    pub unsafe fn as_archived_unchecked(bytes: &[u8]) -> &rkyv::Archived<PGMIndex> {
180        unsafe { access_unchecked::<rkyv::Archived<PGMIndex>>(bytes) }
181    }
182
183    /// Deserialize from archived bytes back into a heap-owned PGMIndex.
184    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
185        let archived = rkyv::access::<rkyv::Archived<PGMIndex>, Error>(bytes)?;
186        deserialize::<PGMIndex, Error>(archived)
187    }
188
189    fn total_keys(&self) -> usize {
190        self.segments.last().map(|s| s.end_key).unwrap_or(0) as usize + 1
191    }
192}
193
194impl ArchivedPGMIndex {
195    /// Returns the index range [lo, hi] where `key` may appear.
196    /// This range is guaranteed to contain the key "if" it exists.
197    pub fn search(&self, key: u64) -> (usize, usize) {
198        let segments: &[ArchivedSegment] = &self.segments;
199        let epsilon = self.epsilon.to_native() as isize;
200
201        // Handle Archived<Option<Vec<T>>> as Option<&[T]>
202        let seg_index = if let Some(top) = self.top_level.as_ref().map(|v| &**v) {
203            let i = match top.binary_search_by_key(&key, |seg| seg.end_key.to_native()) {
204                Ok(i) | Err(i) => i.min(top.len() - 1),
205            };
206            let seg = &top[i];
207            let estimate = (seg.slope * key as f64 + seg.intercept).max(0.0).round() as usize;
208            estimate.min(segments.len().saturating_sub(1))
209        } else {
210            match segments.binary_search_by_key(&key, |seg| seg.end_key.to_native()) {
211                Ok(i) | Err(i) => i.min(segments.len().saturating_sub(1)),
212            }
213        };
214
215        let seg = &segments[seg_index];
216        let predicted = (seg.slope * key as f64 + seg.intercept).max(0.0).round() as isize;
217
218        // TODO: safely support conversion from little endian types to native
219        let lo = predicted.saturating_sub(epsilon).max(0) as usize;
220        let hi = (predicted + epsilon)
221            .min(self.total_keys() as isize - 1)
222            .max(0) as usize;
223
224        (lo, hi)
225    }
226
227    fn total_keys(&self) -> usize {
228        self.segments
229            .last()
230            .map(|s| s.end_key.to_native())
231            .unwrap_or(0) as usize
232            + 1
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_build_and_search() {
242        let keys: Vec<u64> = (0..1000).step_by(5).collect();
243        let epsilon = 8;
244        let pgm = PGMIndex::build(&keys, epsilon).unwrap();
245
246        let key = 500;
247        let (lo, hi) = pgm.search(key);
248        assert!(
249            keys[lo..=hi].binary_search(&key).is_ok(),
250            "Key should be found within predicted range"
251        );
252
253        let key = 503;
254        let (lo, hi) = pgm.search(key);
255        assert!(
256            keys[lo..=hi].binary_search(&key).is_err(),
257            "Non-existent key should not be found, but range should still be valid"
258        );
259    }
260
261    #[test]
262    fn test_unsorted_input_fails() {
263        let unsorted_keys = vec![1, 3, 2, 4];
264        let result = PGMIndex::build(&unsorted_keys, 4);
265        assert!(matches!(result, Err(PGMIndexError::KeysNotSorted)));
266    }
267
268    #[test]
269    fn test_zero_copy_deserialization() {
270        let keys: Vec<u64> = (0..5000).step_by(10).collect();
271        let pgm = PGMIndex::build(&keys, 32).unwrap();
272        let bytes = pgm.to_bytes().expect("serialize failed");
273
274        let archived = PGMIndex::as_archived(&bytes).expect("zero-copy deserialize failed");
275        let key = 1000;
276        let (lo, hi) = archived.search(key);
277
278        assert!(
279            keys[lo..=hi].binary_search(&key).is_ok(),
280            "Key should be in range after zero-copy read"
281        );
282    }
283
284    #[test]
285    fn test_copy_deserialization() {
286        let keys: Vec<u64> = (0..10000).step_by(7).collect();
287        let pgm = PGMIndex::build(&keys, 64).unwrap();
288        let bytes = pgm.to_bytes().expect("serialize failed");
289
290        let restored = PGMIndex::from_bytes(&bytes).expect("full deserialize failed");
291        assert_eq!(restored.epsilon, pgm.epsilon);
292        assert_eq!(restored.segments.len(), pgm.segments.len());
293
294        let key = 9876;
295        let (lo, hi) = restored.search(key);
296        let found = keys[lo..=hi].binary_search(&key).ok();
297
298        if let Some(actual_index) = found {
299            assert_eq!(keys[lo + actual_index], key);
300        } else {
301            assert!(true, "Key not present in input set (as expected)");
302        }
303    }
304}