Skip to main content

selene_graph/
text_search.rs

1//! Exact BM25 full-text search over graph node properties.
2//!
3//! This module is the full-text correctness oracle: it scans the current graph
4//! snapshot, tokenizes string properties, computes BM25 document statistics for
5//! the requested `(label, property)` surface, and returns a deterministic
6//! top-`k` ranking. Future maintained or postings-backed text indexes should
7//! use this path as their ordering and recall reference.
8
9use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::collections::{BTreeSet, BinaryHeap};
12use std::time::Duration;
13
14use roaring::RoaringBitmap;
15use selene_core::{CancellationCause, CancellationChecker, DbString, NodeId, Value};
16use smallvec::SmallVec;
17
18use crate::error::{GraphError, GraphResult};
19use crate::graph::SeleneGraph;
20use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
21use crate::shared::SharedGraph;
22use crate::store::RowIndex;
23
24pub(crate) const TEXT_SEARCH_CANCEL_STRIDE: usize = 1024;
25#[cfg(not(test))]
26const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
27#[cfg(test)]
28const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 4;
29
30#[cfg(not(test))]
31const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
32#[cfg(test)]
33const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
34const BM25_K1: f64 = 1.2;
35const BM25_B: f64 = 0.75;
36
37type TermCounts = SmallVec<[u32; 4]>;
38
39/// One BM25-ranked node hit.
40#[derive(Clone, Debug, PartialEq)]
41pub struct TextSearchHit {
42    /// Matched node id.
43    pub node_id: NodeId,
44    /// Higher-is-better BM25 score.
45    pub score: f64,
46}
47
48/// Error returned by checked text-search APIs.
49#[derive(Debug, thiserror::Error)]
50pub enum TextSearchError {
51    /// Graph storage or consistency failure.
52    #[error(transparent)]
53    Graph(#[from] GraphError),
54    /// Caller requested cooperative cancellation.
55    #[error("text search cancelled")]
56    Cancelled,
57    /// Statement deadline elapsed.
58    #[error("text search timed out after {elapsed:?}")]
59    Timeout {
60        /// Wall-clock duration since the deadline elapsed.
61        elapsed: Duration,
62    },
63    /// Deterministic node-scan budget was exceeded.
64    #[error("text search node scan budget exceeded ({scanned} > {limit})")]
65    NodeScanBudgetExceeded {
66        /// Maximum allowed scanned nodes.
67        limit: usize,
68        /// Observed scanned nodes after the batch that crossed the limit.
69        scanned: usize,
70    },
71}
72
73impl TextSearchError {
74    fn into_graph_error(self) -> GraphError {
75        match self {
76            Self::Graph(error) => error,
77            Self::Cancelled | Self::Timeout { .. } | Self::NodeScanBudgetExceeded { .. } => {
78                GraphError::Inconsistent {
79                    reason: format!("disabled text-search checker returned {self}"),
80                }
81            }
82        }
83    }
84}
85
86impl From<CancellationCause> for TextSearchError {
87    fn from(cause: CancellationCause) -> Self {
88        match cause {
89            CancellationCause::Cancelled => Self::Cancelled,
90            CancellationCause::Timeout { elapsed } => Self::Timeout { elapsed },
91            CancellationCause::NodeScanBudgetExceeded { limit, scanned } => {
92                Self::NodeScanBudgetExceeded { limit, scanned }
93            }
94        }
95    }
96}
97
98impl SeleneGraph {
99    /// Exhaustively rank string-valued node properties using BM25.
100    ///
101    /// This is the full-text correctness oracle and small-corpus path. It scans
102    /// the row bitmap for `label`, skips nodes where `property` is absent or not
103    /// a string, tokenizes documents with the built-in Unicode-aware tokenizer,
104    /// and ranks matches with Okapi BM25 (`k1 = 1.2`, `b = 0.75`). Query tokens
105    /// are deduplicated so repeated query terms do not overweight a document.
106    pub fn exact_text_search_nodes(
107        &self,
108        label: &DbString,
109        property: &DbString,
110        query: &str,
111        k: usize,
112    ) -> GraphResult<Vec<TextSearchHit>> {
113        self.exact_text_search_nodes_checked(
114            label,
115            property,
116            query,
117            k,
118            CancellationChecker::disabled(),
119        )
120        .map_err(TextSearchError::into_graph_error)
121    }
122
123    /// Exhaustively rank string-valued node properties with cancellation checks.
124    pub fn exact_text_search_nodes_checked(
125        &self,
126        label: &DbString,
127        property: &DbString,
128        query: &str,
129        k: usize,
130        checker: CancellationChecker<'_>,
131    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
132        self.exact_text_search_nodes_filtered_checked(label, property, query, k, None, checker)
133    }
134
135    /// Exhaustively rank text documents while admitting only `allowed_rows`.
136    ///
137    /// BM25 corpus statistics are still computed over every string document for
138    /// `(label, property)`, so scores and ordering match an unfiltered search
139    /// whose full ranking is filtered by this row set before `k` truncation.
140    pub fn exact_text_search_nodes_in_rows_checked(
141        &self,
142        label: &DbString,
143        property: &DbString,
144        query: &str,
145        k: usize,
146        allowed_rows: &RoaringBitmap,
147        checker: CancellationChecker<'_>,
148    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
149        if allowed_rows.is_empty() {
150            return Ok(Vec::new());
151        }
152        self.exact_text_search_nodes_filtered_checked(
153            label,
154            property,
155            query,
156            k,
157            Some(allowed_rows),
158            checker,
159        )
160    }
161
162    fn exact_text_search_nodes_filtered_checked(
163        &self,
164        label: &DbString,
165        property: &DbString,
166        query: &str,
167        k: usize,
168        allowed_rows: Option<&RoaringBitmap>,
169        checker: CancellationChecker<'_>,
170    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
171        checker.check()?;
172        if k == 0 {
173            return Ok(Vec::new());
174        }
175        let query_terms = unique_query_terms(query);
176        if query_terms.is_empty() {
177            return Ok(Vec::new());
178        }
179        let Some(label_rows) = self.nodes_with_label(label) else {
180            return Ok(Vec::new());
181        };
182
183        let scan = TextScan::new(self, label, property, &query_terms, allowed_rows);
184        let chunk = if should_parallelize_text_scan(label_rows, k) {
185            exact_text_scan_parallel(scan, label_rows, checker)?
186        } else {
187            exact_text_scan_serial(scan, label_rows, checker)?
188        };
189        Ok(rank_text_docs(chunk, k))
190    }
191}
192
193impl SharedGraph {
194    /// Exhaustively rank string-valued node properties in the current snapshot.
195    pub fn exact_text_search_nodes(
196        &self,
197        label: &DbString,
198        property: &DbString,
199        query: &str,
200        k: usize,
201    ) -> GraphResult<Vec<TextSearchHit>> {
202        self.read()
203            .exact_text_search_nodes(label, property, query, k)
204    }
205
206    /// Exhaustively rank string-valued node properties with cancellation checks.
207    pub fn exact_text_search_nodes_checked(
208        &self,
209        label: &DbString,
210        property: &DbString,
211        query: &str,
212        k: usize,
213        checker: CancellationChecker<'_>,
214    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
215        self.read()
216            .exact_text_search_nodes_checked(label, property, query, k, checker)
217    }
218}
219
220#[derive(Clone, Copy)]
221struct TextScan<'a> {
222    graph: &'a SeleneGraph,
223    label: &'a DbString,
224    property: &'a DbString,
225    query_terms: &'a [String],
226    allowed_rows: Option<&'a RoaringBitmap>,
227}
228
229impl<'a> TextScan<'a> {
230    fn new(
231        graph: &'a SeleneGraph,
232        label: &'a DbString,
233        property: &'a DbString,
234        query_terms: &'a [String],
235        allowed_rows: Option<&'a RoaringBitmap>,
236    ) -> Self {
237        Self {
238            graph,
239            label,
240            property,
241            query_terms,
242            allowed_rows,
243        }
244    }
245
246    fn document_for_row(self, raw_row: u32) -> Result<Option<DocumentStats>, TextSearchError> {
247        if !self.graph.node_store.is_alive(raw_row) {
248            return Ok(None);
249        }
250        let row = RowIndex::new(raw_row);
251        let node_id = self
252            .graph
253            .node_id_for_row(row)
254            .ok_or_else(|| GraphError::Inconsistent {
255                reason: format!(
256                    "label index row {raw_row} for {} has no node id",
257                    self.label.as_str()
258                ),
259            })?;
260        let properties = self
261            .graph
262            .node_store
263            .properties
264            .get(raw_row as usize)
265            .ok_or_else(|| GraphError::Inconsistent {
266                reason: format!(
267                    "text search row {raw_row} for {} has no property row",
268                    self.label.as_str()
269                ),
270            })?;
271        let Some(Value::String(text)) = properties.get(self.property) else {
272            return Ok(None);
273        };
274        Ok(document_stats(
275            node_id,
276            text.as_str(),
277            self.query_terms,
278            self.allowed_rows
279                .is_none_or(|allowed_rows| allowed_rows.contains(raw_row)),
280        ))
281    }
282}
283
284#[derive(Debug)]
285struct TextScanChunk {
286    docs: Vec<DocumentStats>,
287    document_frequencies: Vec<u32>,
288    total_document_len: u64,
289}
290
291impl TextScanChunk {
292    fn empty(query_term_count: usize) -> Self {
293        Self {
294            docs: Vec::new(),
295            document_frequencies: vec![0; query_term_count],
296            total_document_len: 0,
297        }
298    }
299
300    fn push(&mut self, doc: DocumentStats) {
301        for (frequency, count) in self.document_frequencies.iter_mut().zip(&doc.term_counts) {
302            if *count > 0 {
303                *frequency = frequency.saturating_add(1);
304            }
305        }
306        self.total_document_len = self.total_document_len.saturating_add(u64::from(doc.len));
307        self.docs.push(doc);
308    }
309}
310
311fn should_parallelize_text_scan(rows: &RoaringBitmap, k: usize) -> bool {
312    should_parallelize_scan(rows.len(), k, TEXT_SEARCH_PARALLEL_MIN_ROWS)
313}
314
315fn exact_text_scan_parallel(
316    scan: TextScan<'_>,
317    rows: &RoaringBitmap,
318    checker: CancellationChecker<'_>,
319) -> Result<TextScanChunk, TextSearchError> {
320    try_reduce_bitmap_chunks(
321        rows,
322        TEXT_SEARCH_PARALLEL_CHUNK_ROWS,
323        checker,
324        || TextScanChunk::empty(scan.query_terms.len()),
325        |chunk| exact_text_scan_chunk(scan, chunk),
326        merge_text_scan_chunks,
327    )
328}
329
330fn exact_text_scan_serial(
331    scan: TextScan<'_>,
332    rows: &RoaringBitmap,
333    checker: CancellationChecker<'_>,
334) -> Result<TextScanChunk, TextSearchError> {
335    let mut chunk = TextScanChunk::empty(scan.query_terms.len());
336    let mut rows_since_check = 0usize;
337    for raw_row in rows.iter() {
338        rows_since_check += 1;
339        if rows_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
340            checker.note_nodes_scanned(rows_since_check)?;
341            rows_since_check = 0;
342        }
343        if let Some(doc) = scan.document_for_row(raw_row)? {
344            chunk.push(doc);
345        }
346    }
347    if rows_since_check > 0 {
348        checker.note_nodes_scanned(rows_since_check)?;
349    }
350    Ok(chunk)
351}
352
353fn exact_text_scan_chunk(
354    scan: TextScan<'_>,
355    rows: &[u32],
356) -> Result<TextScanChunk, TextSearchError> {
357    let mut chunk = TextScanChunk::empty(scan.query_terms.len());
358    for &raw_row in rows {
359        if let Some(doc) = scan.document_for_row(raw_row)? {
360            chunk.push(doc);
361        }
362    }
363    Ok(chunk)
364}
365
366fn merge_text_scan_chunks(
367    mut lhs: TextScanChunk,
368    mut rhs: TextScanChunk,
369) -> Result<TextScanChunk, TextSearchError> {
370    for (lhs_frequency, rhs_frequency) in lhs
371        .document_frequencies
372        .iter_mut()
373        .zip(&rhs.document_frequencies)
374    {
375        *lhs_frequency = lhs_frequency.saturating_add(*rhs_frequency);
376    }
377    lhs.total_document_len = lhs
378        .total_document_len
379        .saturating_add(rhs.total_document_len);
380    lhs.docs.append(&mut rhs.docs);
381    Ok(lhs)
382}
383
384fn rank_text_docs(chunk: TextScanChunk, k: usize) -> Vec<TextSearchHit> {
385    if chunk.docs.is_empty() {
386        return Vec::new();
387    }
388    let corpus_len = chunk.docs.len() as f64;
389    let average_document_len = chunk.total_document_len as f64 / corpus_len;
390    let mut top_k = TextTopK::new(k);
391    for doc in chunk.docs {
392        if !doc.admitted {
393            continue;
394        }
395        let score = bm25_score(
396            &doc,
397            &chunk.document_frequencies,
398            corpus_len,
399            average_document_len,
400        );
401        if score > 0.0 {
402            top_k.push(doc.node_id, score);
403        }
404    }
405    top_k.into_hits()
406}
407
408#[derive(Debug)]
409pub(crate) struct DocumentStats {
410    pub(crate) node_id: NodeId,
411    len: u32,
412    pub(crate) term_counts: TermCounts,
413    admitted: bool,
414}
415
416impl DocumentStats {
417    pub(crate) fn zero(node_id: NodeId, len: u32, query_term_count: usize) -> Self {
418        Self {
419            node_id,
420            len,
421            term_counts: TermCounts::from_elem(0, query_term_count),
422            admitted: true,
423        }
424    }
425}
426
427pub(crate) fn unique_query_terms(query: &str) -> Vec<String> {
428    let terms: BTreeSet<_> = tokenize_borrowed(query).map(Cow::into_owned).collect();
429    terms.into_iter().collect()
430}
431
432fn document_stats(
433    node_id: NodeId,
434    text: &str,
435    query_terms: &[String],
436    admitted: bool,
437) -> Option<DocumentStats> {
438    let mut term_counts = TermCounts::from_elem(0, query_terms.len());
439    let mut len = 0_u32;
440    for token in tokenize_borrowed(text) {
441        len = len.saturating_add(1);
442        if let Ok(index) = query_terms.binary_search_by(|term| term.as_str().cmp(token.as_ref())) {
443            term_counts[index] = term_counts[index].saturating_add(1);
444        }
445    }
446    (len > 0).then_some(DocumentStats {
447        node_id,
448        len,
449        term_counts,
450        admitted,
451    })
452}
453
454/// Iterate lowercase alphanumeric tokens, borrowing when lowercase is unchanged.
455pub(crate) fn tokenize_borrowed(text: &str) -> Tokenizer<'_> {
456    Tokenizer { text, offset: 0 }
457}
458
459/// Borrowing tokenizer for BM25 query/document processing.
460pub(crate) struct Tokenizer<'a> {
461    text: &'a str,
462    offset: usize,
463}
464
465impl<'a> Iterator for Tokenizer<'a> {
466    type Item = Cow<'a, str>;
467
468    fn next(&mut self) -> Option<Self::Item> {
469        let mut start = None;
470        let mut end = self.text.len();
471        let mut owned = None::<String>;
472
473        let base = self.offset;
474        for (relative_index, ch) in self.text[base..].char_indices() {
475            let index = base + relative_index;
476            if !ch.is_alphanumeric() {
477                if start.is_some() {
478                    end = index;
479                    self.offset = index + ch.len_utf8();
480                    break;
481                }
482                self.offset = index + ch.len_utf8();
483                continue;
484            }
485
486            let start_index = *start.get_or_insert(index);
487            let mut lowercase = ch.to_lowercase();
488            let first = lowercase
489                .next()
490                .expect("char lowercase mapping yields at least one char");
491            let second = lowercase.next();
492            let unchanged = first == ch && second.is_none();
493            if let Some(buffer) = owned.as_mut() {
494                if unchanged {
495                    buffer.push(ch);
496                } else {
497                    buffer.push(first);
498                    if let Some(second) = second {
499                        buffer.push(second);
500                    }
501                    buffer.extend(lowercase);
502                }
503            } else if !unchanged {
504                let mut buffer = self.text[start_index..index].to_owned();
505                buffer.push(first);
506                if let Some(second) = second {
507                    buffer.push(second);
508                }
509                buffer.extend(lowercase);
510                owned = Some(buffer);
511            }
512        }
513
514        let start = start?;
515        if self.offset <= start {
516            self.offset = self.text.len();
517        }
518
519        Some(match owned {
520            Some(token) => Cow::Owned(token),
521            None => Cow::Borrowed(&self.text[start..end]),
522        })
523    }
524}
525
526pub(crate) fn bm25_score(
527    doc: &DocumentStats,
528    document_frequencies: &[u32],
529    corpus_len: f64,
530    average_document_len: f64,
531) -> f64 {
532    let document_len = f64::from(doc.len);
533    doc.term_counts
534        .iter()
535        .zip(document_frequencies)
536        .filter(|(term_count, _)| **term_count > 0)
537        .map(|(term_count, document_frequency)| {
538            let term_count = f64::from(*term_count);
539            let document_frequency = f64::from(*document_frequency);
540            let idf =
541                (1.0 + (corpus_len - document_frequency + 0.5) / (document_frequency + 0.5)).ln();
542            let normalization = term_count
543                + BM25_K1 * (1.0 - BM25_B + BM25_B * document_len / average_document_len);
544            idf * (term_count * (BM25_K1 + 1.0)) / normalization
545        })
546        .sum()
547}
548
549#[derive(Debug)]
550pub(crate) struct TextTopK {
551    k: usize,
552    heap: BinaryHeap<TextHeapEntry>,
553}
554
555impl TextTopK {
556    pub(crate) fn new(k: usize) -> Self {
557        Self {
558            k,
559            heap: BinaryHeap::new(),
560        }
561    }
562
563    pub(crate) fn push(&mut self, node_id: NodeId, score: f64) {
564        debug_assert!(score.is_finite(), "BM25 scores must be finite");
565        if self.k == 0 {
566            return;
567        }
568        let entry = TextHeapEntry { score, node_id };
569        if self.heap.len() < self.k {
570            self.heap.push(entry);
571            return;
572        }
573        let Some(worst) = self.heap.peek() else {
574            return;
575        };
576        if entry.cmp(worst).is_lt() {
577            self.heap.pop();
578            self.heap.push(entry);
579        }
580    }
581
582    pub(crate) fn into_hits(self) -> Vec<TextSearchHit> {
583        let mut hits: Vec<_> = self
584            .heap
585            .into_iter()
586            .map(|entry| TextSearchHit {
587                node_id: entry.node_id,
588                score: entry.score,
589            })
590            .collect();
591        hits.sort_by(compare_hit);
592        hits
593    }
594}
595
596#[derive(Debug)]
597struct TextHeapEntry {
598    score: f64,
599    node_id: NodeId,
600}
601
602impl Eq for TextHeapEntry {}
603
604impl PartialEq for TextHeapEntry {
605    fn eq(&self, rhs: &Self) -> bool {
606        self.score.to_bits() == rhs.score.to_bits() && self.node_id == rhs.node_id
607    }
608}
609
610impl Ord for TextHeapEntry {
611    fn cmp(&self, rhs: &Self) -> Ordering {
612        rhs.score
613            .total_cmp(&self.score)
614            .then_with(|| self.node_id.cmp(&rhs.node_id))
615    }
616}
617
618impl PartialOrd for TextHeapEntry {
619    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
620        Some(self.cmp(rhs))
621    }
622}
623
624fn compare_hit(lhs: &TextSearchHit, rhs: &TextSearchHit) -> Ordering {
625    rhs.score
626        .total_cmp(&lhs.score)
627        .then_with(|| lhs.node_id.cmp(&rhs.node_id))
628}
629
630#[cfg(test)]
631#[path = "text_search/tests.rs"]
632mod tests;