Skip to main content

sochdb_query/
candidate_gate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Unified Candidate Gate Interface (Task 4)
19//!
20//! This module defines the `AllowedSet` abstraction that every retrieval
21//! executor MUST accept. The gate guarantees:
22//!
23//! 1. **Never return a doc outside AllowedSet** - structural enforcement
24//! 2. **Apply constraints during generation** - no post-filtering
25//! 3. **Consistent semantics** across vector/BM25/hybrid/context
26//!
27//! ## The Contract
28//!
29//! Every executor receives an `AllowedSet` and must:
30//! - Check membership BEFORE including any candidate
31//! - Short-circuit if AllowedSet is empty (return empty results)
32//! - Report selectivity for query planning
33//!
34//! ## Representations
35//!
36//! `AllowedSet` supports multiple representations for efficiency:
37//!
38//! | Representation | Best For | Membership | Space |
39//! |----------------|----------|------------|-------|
40//! | Bitmap | Dense sets | O(1) | O(N/8) |
41//! | SortedVec | Sparse sets | O(log n) | O(n) |
42//! | HashSet | Random access | O(1) avg | O(n) |
43//! | All | No constraint | O(1) | O(1) |
44//!
45//! ## Selectivity
46//!
47//! Executors use selectivity `|S|/N` to choose execution strategy:
48//! - High selectivity (> 0.1): Standard search with filter
49//! - Low selectivity (< 0.01): Scan only allowed IDs
50//! - Very low (< 0.001): Consider alternative strategy
51
52use std::collections::HashSet;
53use std::fmt;
54use std::sync::Arc;
55
56// ============================================================================
57// AllowedSet - Core Abstraction
58// ============================================================================
59
60/// The unified gate for candidate filtering
61///
62/// Every executor MUST check `allowed_set.contains(doc_id)` before
63/// including any result. This is the structural enforcement of pushdown.
64#[derive(Clone)]
65pub enum AllowedSet {
66    /// All documents are allowed (no filter constraint)
67    All,
68    
69    /// Bitmap representation (efficient for dense sets)
70    Bitmap(Arc<AllowedBitmap>),
71    
72    /// Sorted vector (efficient for sparse sets with iteration)
73    SortedVec(Arc<Vec<u64>>),
74    
75    /// Hash set (efficient for random access)
76    HashSet(Arc<HashSet<u64>>),
77    
78    /// No documents allowed (empty result shortcut)
79    None,
80}
81
82impl AllowedSet {
83    /// Create an AllowedSet from a bitmap
84    pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
85        if bitmap.is_empty() {
86            Self::None
87        } else if bitmap.is_all() {
88            Self::All
89        } else {
90            Self::Bitmap(Arc::new(bitmap))
91        }
92    }
93    
94    /// Create an AllowedSet from a sorted vector of doc IDs
95    pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
96        if ids.is_empty() {
97            return Self::None;
98        }
99        ids.sort_unstable();
100        ids.dedup();
101        Self::SortedVec(Arc::new(ids))
102    }
103    
104    /// Create an AllowedSet from an iterator of doc IDs
105    pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
106        let set: HashSet<u64> = ids.into_iter().collect();
107        if set.is_empty() {
108            Self::None
109        } else {
110            Self::HashSet(Arc::new(set))
111        }
112    }
113    
114    /// Check if a document ID is allowed
115    ///
116    /// This is the core operation that executors MUST call.
117    #[inline]
118    pub fn contains(&self, doc_id: u64) -> bool {
119        match self {
120            Self::All => true,
121            Self::Bitmap(bm) => bm.contains(doc_id),
122            Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
123            Self::HashSet(set) => set.contains(&doc_id),
124            Self::None => false,
125        }
126    }
127    
128    /// Check if this set is empty (no allowed documents)
129    pub fn is_empty(&self) -> bool {
130        matches!(self, Self::None)
131    }
132    
133    /// Check if this set allows all documents
134    pub fn is_all(&self) -> bool {
135        matches!(self, Self::All)
136    }
137    
138    /// Get the cardinality (number of allowed documents)
139    ///
140    /// Returns None for All (unknown without universe size)
141    pub fn cardinality(&self) -> Option<usize> {
142        match self {
143            Self::All => None,
144            Self::Bitmap(bm) => Some(bm.count()),
145            Self::SortedVec(vec) => Some(vec.len()),
146            Self::HashSet(set) => Some(set.len()),
147            Self::None => Some(0),
148        }
149    }
150    
151    /// Compute selectivity against a universe of size N
152    ///
153    /// Returns |S| / N, the fraction of allowed documents
154    pub fn selectivity(&self, universe_size: usize) -> f64 {
155        if universe_size == 0 {
156            return 0.0;
157        }
158        match self {
159            Self::All => 1.0,
160            Self::None => 0.0,
161            other => {
162                other.cardinality()
163                    .map(|c| c as f64 / universe_size as f64)
164                    .unwrap_or(1.0)
165            }
166        }
167    }
168    
169    /// Intersect with another AllowedSet
170    pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
171        match (self, other) {
172            // Identity cases
173            (Self::All, x) | (x, Self::All) => x.clone(),
174            (Self::None, _) | (_, Self::None) => Self::None,
175            
176            // Both are sets - compute intersection
177            (Self::SortedVec(a), Self::SortedVec(b)) => {
178                let result = sorted_vec_intersect(a, b);
179                Self::from_sorted_vec(result)
180            }
181            (Self::HashSet(a), Self::HashSet(b)) => {
182                let result: HashSet<_> = a.intersection(b).copied().collect();
183                if result.is_empty() {
184                    Self::None
185                } else {
186                    Self::HashSet(Arc::new(result))
187                }
188            }
189            (Self::Bitmap(a), Self::Bitmap(b)) => {
190                let result = a.intersect(b);
191                Self::from_bitmap(result)
192            }
193            
194            // Mixed - convert to hash set
195            (a, b) => {
196                let set_a: HashSet<u64> = a.iter().collect();
197                let set_b: HashSet<u64> = b.iter().collect();
198                let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
199                if result.is_empty() {
200                    Self::None
201                } else {
202                    Self::HashSet(Arc::new(result))
203                }
204            }
205        }
206    }
207    
208    /// Union with another AllowedSet
209    pub fn union(&self, other: &AllowedSet) -> AllowedSet {
210        match (self, other) {
211            (Self::All, _) | (_, Self::All) => Self::All,
212            (Self::None, x) | (x, Self::None) => x.clone(),
213            
214            (Self::HashSet(a), Self::HashSet(b)) => {
215                let result: HashSet<_> = a.union(b).copied().collect();
216                Self::HashSet(Arc::new(result))
217            }
218            
219            // Mixed - convert to hash set
220            (a, b) => {
221                let mut result: HashSet<u64> = a.iter().collect();
222                result.extend(b.iter());
223                Self::HashSet(Arc::new(result))
224            }
225        }
226    }
227    
228    /// Iterate over allowed document IDs
229    ///
230    /// Note: For All, this returns an empty iterator (unknown universe)
231    pub fn iter(&self) -> AllowedSetIter<'_> {
232        match self {
233            Self::All => AllowedSetIter::Empty,
234            Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
235            Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
236            Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
237            Self::None => AllowedSetIter::Empty,
238        }
239    }
240    
241    /// Convert to a Vec (for small sets)
242    pub fn to_vec(&self) -> Vec<u64> {
243        self.iter().collect()
244    }
245}
246
247impl fmt::Debug for AllowedSet {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        match self {
250            Self::All => write!(f, "AllowedSet::All"),
251            Self::None => write!(f, "AllowedSet::None"),
252            Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
253            Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
254            Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
255        }
256    }
257}
258
259impl Default for AllowedSet {
260    fn default() -> Self {
261        Self::All
262    }
263}
264
265// Helper for sorted vec intersection
266fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
267    let mut result = Vec::with_capacity(a.len().min(b.len()));
268    let mut i = 0;
269    let mut j = 0;
270    
271    while i < a.len() && j < b.len() {
272        match a[i].cmp(&b[j]) {
273            std::cmp::Ordering::Less => i += 1,
274            std::cmp::Ordering::Greater => j += 1,
275            std::cmp::Ordering::Equal => {
276                result.push(a[i]);
277                i += 1;
278                j += 1;
279            }
280        }
281    }
282    
283    result
284}
285
286// ============================================================================
287// AllowedSet Iterator
288// ============================================================================
289
290/// Iterator over allowed document IDs
291pub enum AllowedSetIter<'a> {
292    Empty,
293    Bitmap(BitmapIter<'a>),
294    SortedVec(std::slice::Iter<'a, u64>),
295    HashSet(std::collections::hash_set::Iter<'a, u64>),
296}
297
298impl<'a> Iterator for AllowedSetIter<'a> {
299    type Item = u64;
300    
301    fn next(&mut self) -> Option<Self::Item> {
302        match self {
303            Self::Empty => None,
304            Self::Bitmap(iter) => iter.next(),
305            Self::SortedVec(iter) => iter.next().copied(),
306            Self::HashSet(iter) => iter.next().copied(),
307        }
308    }
309    
310    fn size_hint(&self) -> (usize, Option<usize>) {
311        match self {
312            Self::Empty => (0, Some(0)),
313            Self::Bitmap(iter) => iter.size_hint(),
314            Self::SortedVec(iter) => iter.size_hint(),
315            Self::HashSet(iter) => iter.size_hint(),
316        }
317    }
318}
319
320// ============================================================================
321// Bitmap Implementation
322// ============================================================================
323
324/// Simple bitmap for allowed document IDs
325///
326/// This is a basic implementation. For production, consider using
327/// the `roaring` crate for compressed bitmaps.
328pub struct AllowedBitmap {
329    /// Bits stored as u64 words
330    words: Vec<u64>,
331    /// Total number of set bits (cached)
332    count: usize,
333    /// Whether this represents "all" (complement mode)
334    all: bool,
335}
336
337impl AllowedBitmap {
338    /// Create a new empty bitmap
339    pub fn new() -> Self {
340        Self {
341            words: Vec::new(),
342            count: 0,
343            all: false,
344        }
345    }
346    
347    /// Create a bitmap with all bits set up to max_id
348    pub fn all(max_id: u64) -> Self {
349        let word_count = (max_id as usize / 64) + 1;
350        Self {
351            words: vec![u64::MAX; word_count],
352            count: max_id as usize + 1,
353            all: true,
354        }
355    }
356    
357    /// Create a bitmap from a set of IDs
358    pub fn from_ids(ids: &[u64]) -> Self {
359        if ids.is_empty() {
360            return Self::new();
361        }
362        
363        let max_id = *ids.iter().max().unwrap();
364        let word_count = (max_id as usize / 64) + 1;
365        let mut words = vec![0u64; word_count];
366        
367        for &id in ids {
368            let word_idx = id as usize / 64;
369            let bit_idx = id % 64;
370            words[word_idx] |= 1 << bit_idx;
371        }
372        
373        Self {
374            words,
375            count: ids.len(),
376            all: false,
377        }
378    }
379    
380    /// Set a bit
381    pub fn set(&mut self, id: u64) {
382        let word_idx = id as usize / 64;
383        let bit_idx = id % 64;
384        
385        // Extend if necessary
386        if word_idx >= self.words.len() {
387            self.words.resize(word_idx + 1, 0);
388        }
389        
390        let old = self.words[word_idx];
391        self.words[word_idx] |= 1 << bit_idx;
392        if old != self.words[word_idx] {
393            self.count += 1;
394        }
395    }
396    
397    /// Clear a bit
398    pub fn clear(&mut self, id: u64) {
399        let word_idx = id as usize / 64;
400        if word_idx >= self.words.len() {
401            return;
402        }
403        
404        let bit_idx = id % 64;
405        let old = self.words[word_idx];
406        self.words[word_idx] &= !(1 << bit_idx);
407        if old != self.words[word_idx] {
408            self.count -= 1;
409        }
410    }
411    
412    /// Check if a bit is set
413    #[inline]
414    pub fn contains(&self, id: u64) -> bool {
415        let word_idx = id as usize / 64;
416        if word_idx >= self.words.len() {
417            return false;
418        }
419        let bit_idx = id % 64;
420        (self.words[word_idx] & (1 << bit_idx)) != 0
421    }
422    
423    /// Get the count of set bits
424    pub fn count(&self) -> usize {
425        self.count
426    }
427    
428    /// Check if empty
429    pub fn is_empty(&self) -> bool {
430        self.count == 0
431    }
432    
433    /// Check if all bits are set
434    pub fn is_all(&self) -> bool {
435        self.all
436    }
437    
438    /// Intersect with another bitmap
439    pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
440        let min_len = self.words.len().min(other.words.len());
441        let mut words = Vec::with_capacity(min_len);
442        let mut count = 0;
443        
444        for i in 0..min_len {
445            let word = self.words[i] & other.words[i];
446            count += word.count_ones() as usize;
447            words.push(word);
448        }
449        
450        AllowedBitmap {
451            words,
452            count,
453            all: false,
454        }
455    }
456    
457    /// Union with another bitmap
458    pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
459        let max_len = self.words.len().max(other.words.len());
460        let mut words = Vec::with_capacity(max_len);
461        let mut count = 0;
462        
463        for i in 0..max_len {
464            let a = self.words.get(i).copied().unwrap_or(0);
465            let b = other.words.get(i).copied().unwrap_or(0);
466            let word = a | b;
467            count += word.count_ones() as usize;
468            words.push(word);
469        }
470        
471        AllowedBitmap {
472            words,
473            count,
474            all: false,
475        }
476    }
477    
478    /// Iterate over set bit positions
479    pub fn iter(&self) -> BitmapIter<'_> {
480        BitmapIter {
481            words: &self.words,
482            word_idx: 0,
483            bit_offset: 0,
484            remaining: self.count,
485        }
486    }
487}
488
489impl Default for AllowedBitmap {
490    fn default() -> Self {
491        Self::new()
492    }
493}
494
495/// Iterator over set bits in a bitmap
496pub struct BitmapIter<'a> {
497    words: &'a [u64],
498    word_idx: usize,
499    bit_offset: u64,
500    remaining: usize,
501}
502
503impl<'a> Iterator for BitmapIter<'a> {
504    type Item = u64;
505    
506    fn next(&mut self) -> Option<Self::Item> {
507        if self.remaining == 0 {
508            return None;
509        }
510        
511        while self.word_idx < self.words.len() {
512            let word = self.words[self.word_idx];
513            let masked = word >> self.bit_offset;
514            
515            if masked != 0 {
516                let trailing = masked.trailing_zeros() as u64;
517                let bit_pos = self.bit_offset + trailing;
518                self.bit_offset = bit_pos + 1;
519                
520                if self.bit_offset >= 64 {
521                    self.bit_offset = 0;
522                    self.word_idx += 1;
523                }
524                
525                self.remaining -= 1;
526                return Some(self.word_idx as u64 * 64 + bit_pos - (if self.bit_offset == 0 { 64 } else { 0 }) + (if bit_pos >= 64 { 0 } else { bit_pos }));
527            }
528            
529            self.word_idx += 1;
530            self.bit_offset = 0;
531        }
532        
533        None
534    }
535    
536    fn size_hint(&self) -> (usize, Option<usize>) {
537        (self.remaining, Some(self.remaining))
538    }
539}
540
541// Fix the iterator - simpler implementation
542impl<'a> BitmapIter<'a> {
543    #[allow(dead_code)]
544    fn new(words: &'a [u64], count: usize) -> Self {
545        Self {
546            words,
547            word_idx: 0,
548            bit_offset: 0,
549            remaining: count,
550        }
551    }
552}
553
554// Simple correct iterator implementation
555impl AllowedBitmap {
556    /// Iterate over set bit positions (simple implementation)
557    pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
558        self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
559            (0..64).filter_map(move |bit| {
560                if (word & (1 << bit)) != 0 {
561                    Some(word_idx as u64 * 64 + bit as u64)
562                } else {
563                    None
564                }
565            })
566        })
567    }
568}
569
570// ============================================================================
571// Candidate Gate Trait
572// ============================================================================
573
574/// The candidate gate trait that all executors must implement
575///
576/// This trait ensures every retrieval path respects the AllowedSet.
577pub trait CandidateGate {
578    /// The query type
579    type Query;
580    
581    /// The result type  
582    type Result;
583    
584    /// The error type
585    type Error;
586    
587    /// Execute with a mandatory allowed set
588    ///
589    /// # Contract
590    ///
591    /// - MUST NOT return any result with doc_id not in allowed_set
592    /// - SHOULD short-circuit if allowed_set is empty
593    /// - SHOULD use selectivity to choose execution strategy
594    fn execute_with_gate(
595        &self,
596        query: &Self::Query,
597        allowed_set: &AllowedSet,
598    ) -> Result<Self::Result, Self::Error>;
599    
600    /// Get the execution strategy for a given selectivity
601    fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
602        if selectivity >= 0.1 {
603            ExecutionStrategy::FilterDuringSearch
604        } else if selectivity >= 0.001 {
605            ExecutionStrategy::ScanAllowedIds
606        } else {
607            ExecutionStrategy::LinearScan
608        }
609    }
610}
611
612/// Execution strategy based on selectivity
613#[derive(Debug, Clone, Copy, PartialEq, Eq)]
614pub enum ExecutionStrategy {
615    /// Standard search with filter check during traversal
616    FilterDuringSearch,
617    
618    /// Iterate over allowed IDs and compute distances
619    ScanAllowedIds,
620    
621    /// Fall back to linear scan (very low selectivity)
622    LinearScan,
623    
624    /// Refuse to execute (too expensive)
625    Reject,
626}
627
628// ============================================================================
629// Tests
630// ============================================================================
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    
636    #[test]
637    fn test_allowed_set_contains() {
638        // All
639        let all = AllowedSet::All;
640        assert!(all.contains(0));
641        assert!(all.contains(1000000));
642        
643        // None
644        let none = AllowedSet::None;
645        assert!(!none.contains(0));
646        
647        // SortedVec
648        let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
649        assert!(vec.contains(1));
650        assert!(vec.contains(5));
651        assert!(!vec.contains(2));
652        assert!(!vec.contains(10));
653        
654        // HashSet
655        let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
656        assert!(set.contains(1));
657        assert!(set.contains(5));
658        assert!(!set.contains(2));
659    }
660    
661    #[test]
662    fn test_allowed_set_selectivity() {
663        let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
664        
665        assert_eq!(set.selectivity(100), 0.05);
666        assert_eq!(set.selectivity(10), 0.5);
667        
668        assert_eq!(AllowedSet::All.selectivity(100), 1.0);
669        assert_eq!(AllowedSet::None.selectivity(100), 0.0);
670    }
671    
672    #[test]
673    fn test_allowed_set_intersection() {
674        let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
675        let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
676        
677        let c = a.intersect(&b);
678        assert_eq!(c.cardinality(), Some(3));
679        assert!(c.contains(3));
680        assert!(c.contains(4));
681        assert!(c.contains(5));
682        assert!(!c.contains(1));
683        assert!(!c.contains(7));
684    }
685    
686    #[test]
687    fn test_bitmap_basic() {
688        let mut bm = AllowedBitmap::new();
689        bm.set(0);
690        bm.set(5);
691        bm.set(64);
692        bm.set(100);
693        
694        assert!(bm.contains(0));
695        assert!(bm.contains(5));
696        assert!(bm.contains(64));
697        assert!(bm.contains(100));
698        assert!(!bm.contains(1));
699        assert!(!bm.contains(63));
700        
701        assert_eq!(bm.count(), 4);
702    }
703    
704    #[test]
705    fn test_bitmap_from_ids() {
706        let ids = vec![1, 5, 10, 100, 1000];
707        let bm = AllowedBitmap::from_ids(&ids);
708        
709        for &id in &ids {
710            assert!(bm.contains(id));
711        }
712        assert!(!bm.contains(0));
713        assert!(!bm.contains(50));
714    }
715    
716    #[test]
717    fn test_bitmap_intersection() {
718        let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
719        let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
720        
721        let c = a.intersect(&b);
722        assert_eq!(c.count(), 3);
723        assert!(c.contains(3));
724        assert!(c.contains(4));
725        assert!(c.contains(5));
726    }
727    
728    #[test]
729    fn test_execution_strategy() {
730        struct DummyGate;
731        impl CandidateGate for DummyGate {
732            type Query = ();
733            type Result = ();
734            type Error = ();
735            fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
736                Ok(())
737            }
738        }
739        
740        let gate = DummyGate;
741        assert_eq!(gate.strategy_for_selectivity(0.5), ExecutionStrategy::FilterDuringSearch);
742        assert_eq!(gate.strategy_for_selectivity(0.01), ExecutionStrategy::ScanAllowedIds);
743        assert_eq!(gate.strategy_for_selectivity(0.0001), ExecutionStrategy::LinearScan);
744    }
745}