use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{BTreeSet, BinaryHeap};
use std::time::Duration;
use roaring::RoaringBitmap;
use selene_core::{CancellationCause, CancellationChecker, DbString, NodeId, Value};
use crate::error::{GraphError, GraphResult};
use crate::graph::SeleneGraph;
use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
use crate::shared::SharedGraph;
use crate::store::RowIndex;
pub(crate) const TEXT_SEARCH_CANCEL_STRIDE: usize = 1024;
#[cfg(not(test))]
const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
#[cfg(test)]
const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 4;
#[cfg(not(test))]
const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
#[cfg(test)]
const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
const BM25_K1: f64 = 1.2;
const BM25_B: f64 = 0.75;
#[derive(Clone, Debug, PartialEq)]
pub struct TextSearchHit {
pub node_id: NodeId,
pub score: f64,
}
#[derive(Debug, thiserror::Error)]
pub enum TextSearchError {
#[error(transparent)]
Graph(#[from] GraphError),
#[error("text search cancelled")]
Cancelled,
#[error("text search timed out after {elapsed:?}")]
Timeout {
elapsed: Duration,
},
}
impl TextSearchError {
fn into_graph_error(self) -> GraphError {
match self {
Self::Graph(error) => error,
Self::Cancelled | Self::Timeout { .. } => GraphError::Inconsistent {
reason: format!("disabled text-search checker returned {self}"),
},
}
}
}
impl From<CancellationCause> for TextSearchError {
fn from(cause: CancellationCause) -> Self {
match cause {
CancellationCause::Cancelled => Self::Cancelled,
CancellationCause::Timeout { elapsed } => Self::Timeout { elapsed },
}
}
}
impl SeleneGraph {
pub fn exact_text_search_nodes(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
) -> GraphResult<Vec<TextSearchHit>> {
self.exact_text_search_nodes_checked(
label,
property,
query,
k,
CancellationChecker::disabled(),
)
.map_err(TextSearchError::into_graph_error)
}
pub fn exact_text_search_nodes_checked(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<TextSearchHit>, TextSearchError> {
checker.check()?;
if k == 0 {
return Ok(Vec::new());
}
let query_terms = unique_query_terms(query);
if query_terms.is_empty() {
return Ok(Vec::new());
}
let Some(label_rows) = self.nodes_with_label(label) else {
return Ok(Vec::new());
};
let scan = TextScan::new(self, label, property, &query_terms);
let chunk = if should_parallelize_text_scan(label_rows, k) {
exact_text_scan_parallel(scan, label_rows, checker)?
} else {
exact_text_scan_serial(scan, label_rows, checker)?
};
Ok(rank_text_docs(chunk, k))
}
}
impl SharedGraph {
pub fn exact_text_search_nodes(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
) -> GraphResult<Vec<TextSearchHit>> {
self.read()
.exact_text_search_nodes(label, property, query, k)
}
pub fn exact_text_search_nodes_checked(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<TextSearchHit>, TextSearchError> {
self.read()
.exact_text_search_nodes_checked(label, property, query, k, checker)
}
}
#[derive(Clone, Copy)]
struct TextScan<'a> {
graph: &'a SeleneGraph,
label: &'a DbString,
property: &'a DbString,
query_terms: &'a [String],
}
impl<'a> TextScan<'a> {
fn new(
graph: &'a SeleneGraph,
label: &'a DbString,
property: &'a DbString,
query_terms: &'a [String],
) -> Self {
Self {
graph,
label,
property,
query_terms,
}
}
fn document_for_row(self, raw_row: u32) -> Result<Option<DocumentStats>, TextSearchError> {
if !self.graph.node_store.is_alive(raw_row) {
return Ok(None);
}
let row = RowIndex::new(raw_row);
let node_id = self
.graph
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"label index row {raw_row} for {} has no node id",
self.label.as_str()
),
})?;
let properties = self
.graph
.node_store
.properties
.get(raw_row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"text search row {raw_row} for {} has no property row",
self.label.as_str()
),
})?;
let Some(Value::String(text)) = properties.get(self.property) else {
return Ok(None);
};
Ok(document_stats(node_id, text.as_str(), self.query_terms))
}
}
#[derive(Debug)]
struct TextScanChunk {
docs: Vec<DocumentStats>,
document_frequencies: Vec<u32>,
total_document_len: u64,
}
impl TextScanChunk {
fn empty(query_term_count: usize) -> Self {
Self {
docs: Vec::new(),
document_frequencies: vec![0; query_term_count],
total_document_len: 0,
}
}
fn push(&mut self, doc: DocumentStats) {
for (frequency, count) in self.document_frequencies.iter_mut().zip(&doc.term_counts) {
if *count > 0 {
*frequency = frequency.saturating_add(1);
}
}
self.total_document_len = self.total_document_len.saturating_add(u64::from(doc.len));
self.docs.push(doc);
}
}
fn should_parallelize_text_scan(rows: &RoaringBitmap, k: usize) -> bool {
should_parallelize_scan(rows.len(), k, TEXT_SEARCH_PARALLEL_MIN_ROWS)
}
fn exact_text_scan_parallel(
scan: TextScan<'_>,
rows: &RoaringBitmap,
checker: CancellationChecker<'_>,
) -> Result<TextScanChunk, TextSearchError> {
try_reduce_bitmap_chunks(
rows,
TEXT_SEARCH_PARALLEL_CHUNK_ROWS,
checker,
|| TextScanChunk::empty(scan.query_terms.len()),
|chunk| exact_text_scan_chunk(scan, chunk),
merge_text_scan_chunks,
)
}
fn exact_text_scan_serial(
scan: TextScan<'_>,
rows: &RoaringBitmap,
checker: CancellationChecker<'_>,
) -> Result<TextScanChunk, TextSearchError> {
let mut chunk = TextScanChunk::empty(scan.query_terms.len());
let mut rows_since_check = 0usize;
for raw_row in rows.iter() {
rows_since_check += 1;
if rows_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
checker.check()?;
rows_since_check = 0;
}
if let Some(doc) = scan.document_for_row(raw_row)? {
chunk.push(doc);
}
}
Ok(chunk)
}
fn exact_text_scan_chunk(
scan: TextScan<'_>,
rows: &[u32],
) -> Result<TextScanChunk, TextSearchError> {
let mut chunk = TextScanChunk::empty(scan.query_terms.len());
for &raw_row in rows {
if let Some(doc) = scan.document_for_row(raw_row)? {
chunk.push(doc);
}
}
Ok(chunk)
}
fn merge_text_scan_chunks(
mut lhs: TextScanChunk,
mut rhs: TextScanChunk,
) -> Result<TextScanChunk, TextSearchError> {
for (lhs_frequency, rhs_frequency) in lhs
.document_frequencies
.iter_mut()
.zip(&rhs.document_frequencies)
{
*lhs_frequency = lhs_frequency.saturating_add(*rhs_frequency);
}
lhs.total_document_len = lhs
.total_document_len
.saturating_add(rhs.total_document_len);
lhs.docs.append(&mut rhs.docs);
Ok(lhs)
}
fn rank_text_docs(chunk: TextScanChunk, k: usize) -> Vec<TextSearchHit> {
if chunk.docs.is_empty() {
return Vec::new();
}
let corpus_len = chunk.docs.len() as f64;
let average_document_len = chunk.total_document_len as f64 / corpus_len;
let mut top_k = TextTopK::new(k);
for doc in chunk.docs {
let score = bm25_score(
&doc,
&chunk.document_frequencies,
corpus_len,
average_document_len,
);
if score > 0.0 {
top_k.push(doc.node_id, score);
}
}
top_k.into_hits()
}
#[derive(Debug)]
pub(crate) struct DocumentStats {
pub(crate) node_id: NodeId,
len: u32,
pub(crate) term_counts: Vec<u32>,
}
impl DocumentStats {
pub(crate) fn zero(node_id: NodeId, len: u32, query_term_count: usize) -> Self {
Self {
node_id,
len,
term_counts: vec![0; query_term_count],
}
}
}
pub(crate) fn unique_query_terms(query: &str) -> Vec<String> {
let terms: BTreeSet<_> = tokenize_borrowed(query).map(Cow::into_owned).collect();
terms.into_iter().collect()
}
fn document_stats(node_id: NodeId, text: &str, query_terms: &[String]) -> Option<DocumentStats> {
let mut term_counts = vec![0_u32; query_terms.len()];
let mut len = 0_u32;
for token in tokenize_borrowed(text) {
len = len.saturating_add(1);
if let Ok(index) = query_terms.binary_search_by(|term| term.as_str().cmp(token.as_ref())) {
term_counts[index] = term_counts[index].saturating_add(1);
}
}
(len > 0).then_some(DocumentStats {
node_id,
len,
term_counts,
})
}
pub(crate) fn tokenize_borrowed(text: &str) -> Tokenizer<'_> {
Tokenizer { text, offset: 0 }
}
pub(crate) struct Tokenizer<'a> {
text: &'a str,
offset: usize,
}
impl<'a> Iterator for Tokenizer<'a> {
type Item = Cow<'a, str>;
fn next(&mut self) -> Option<Self::Item> {
let mut start = None;
let mut end = self.text.len();
let mut owned = None::<String>;
let base = self.offset;
for (relative_index, ch) in self.text[base..].char_indices() {
let index = base + relative_index;
if !ch.is_alphanumeric() {
if start.is_some() {
end = index;
self.offset = index + ch.len_utf8();
break;
}
self.offset = index + ch.len_utf8();
continue;
}
let start_index = *start.get_or_insert(index);
let mut lowercase = ch.to_lowercase();
let first = lowercase
.next()
.expect("char lowercase mapping yields at least one char");
let second = lowercase.next();
let unchanged = first == ch && second.is_none();
if let Some(buffer) = owned.as_mut() {
if unchanged {
buffer.push(ch);
} else {
buffer.push(first);
if let Some(second) = second {
buffer.push(second);
}
buffer.extend(lowercase);
}
} else if !unchanged {
let mut buffer = self.text[start_index..index].to_owned();
buffer.push(first);
if let Some(second) = second {
buffer.push(second);
}
buffer.extend(lowercase);
owned = Some(buffer);
}
}
let start = start?;
if self.offset <= start {
self.offset = self.text.len();
}
Some(match owned {
Some(token) => Cow::Owned(token),
None => Cow::Borrowed(&self.text[start..end]),
})
}
}
pub(crate) fn bm25_score(
doc: &DocumentStats,
document_frequencies: &[u32],
corpus_len: f64,
average_document_len: f64,
) -> f64 {
let document_len = f64::from(doc.len);
doc.term_counts
.iter()
.zip(document_frequencies)
.filter(|(term_count, _)| **term_count > 0)
.map(|(term_count, document_frequency)| {
let term_count = f64::from(*term_count);
let document_frequency = f64::from(*document_frequency);
let idf =
(1.0 + (corpus_len - document_frequency + 0.5) / (document_frequency + 0.5)).ln();
let normalization = term_count
+ BM25_K1 * (1.0 - BM25_B + BM25_B * document_len / average_document_len);
idf * (term_count * (BM25_K1 + 1.0)) / normalization
})
.sum()
}
#[derive(Debug)]
pub(crate) struct TextTopK {
k: usize,
heap: BinaryHeap<TextHeapEntry>,
}
impl TextTopK {
pub(crate) fn new(k: usize) -> Self {
Self {
k,
heap: BinaryHeap::new(),
}
}
pub(crate) fn push(&mut self, node_id: NodeId, score: f64) {
debug_assert!(score.is_finite(), "BM25 scores must be finite");
if self.k == 0 {
return;
}
let entry = TextHeapEntry { score, node_id };
if self.heap.len() < self.k {
self.heap.push(entry);
return;
}
let Some(worst) = self.heap.peek() else {
return;
};
if entry.cmp(worst).is_lt() {
self.heap.pop();
self.heap.push(entry);
}
}
pub(crate) fn into_hits(self) -> Vec<TextSearchHit> {
let mut hits: Vec<_> = self
.heap
.into_iter()
.map(|entry| TextSearchHit {
node_id: entry.node_id,
score: entry.score,
})
.collect();
hits.sort_by(compare_hit);
hits
}
}
#[derive(Debug)]
struct TextHeapEntry {
score: f64,
node_id: NodeId,
}
impl Eq for TextHeapEntry {}
impl PartialEq for TextHeapEntry {
fn eq(&self, rhs: &Self) -> bool {
self.score.to_bits() == rhs.score.to_bits() && self.node_id == rhs.node_id
}
}
impl Ord for TextHeapEntry {
fn cmp(&self, rhs: &Self) -> Ordering {
rhs.score
.total_cmp(&self.score)
.then_with(|| self.node_id.cmp(&rhs.node_id))
}
}
impl PartialOrd for TextHeapEntry {
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
Some(self.cmp(rhs))
}
}
fn compare_hit(lhs: &TextSearchHit, rhs: &TextSearchHit) -> Ordering {
rhs.score
.total_cmp(&lhs.score)
.then_with(|| lhs.node_id.cmp(&rhs.node_id))
}
#[cfg(test)]
#[path = "text_search/tests.rs"]
mod tests;