use std::mem::size_of;
use std::sync::Arc;
use roaring::RoaringBitmap;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use selene_core::{CancellationChecker, DbString, NodeId, Value};
use crate::error::{GraphError, GraphResult};
use crate::graph::SeleneGraph;
use crate::shared::SharedGraph;
use crate::store::RowIndex;
use crate::text_search::{
DocumentStats, TextSearchError, TextSearchHit, TextTopK, bm25_score, tokenize_borrowed,
unique_query_terms,
};
#[path = "text_index/builder.rs"]
mod builder;
#[path = "text_index/candidate.rs"]
mod candidate;
#[path = "text_index/maintenance.rs"]
mod maintenance;
use builder::TextIndexBuilder;
type QueryDocumentFrequencies = SmallVec<[u32; 4]>;
type QueryPostings<'a> = SmallVec<[Option<&'a [TextPosting]>; 4]>;
pub(crate) use maintenance::{
apply_node_create, apply_node_delete, apply_node_update, rebuild_text_indexes,
};
#[derive(Clone, Debug)]
pub struct TextIndex {
label: DbString,
property: DbString,
rows: RoaringBitmap,
document_lengths: FxHashMap<NodeId, u32>,
document_terms: FxHashMap<NodeId, Arc<[String]>>,
postings: FxHashMap<String, Arc<Vec<TextPosting>>>,
total_document_len: u64,
posting_count: usize,
}
impl TextIndex {
pub fn build(graph: &SeleneGraph, label: DbString, property: DbString) -> GraphResult<Self> {
let Some(label_rows) = graph.nodes_with_label(&label) else {
return Ok(TextIndexBuilder::empty(label, property).finish());
};
let label_row_capacity = usize::try_from(label_rows.len()).unwrap_or(usize::MAX);
let mut index = TextIndexBuilder::with_document_capacity(
label.clone(),
property.clone(),
label_row_capacity,
);
for raw_row in label_rows.iter() {
if !graph.node_store.is_alive(raw_row) {
continue;
}
let row = RowIndex::new(raw_row);
let node_id = graph
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"label index row {raw_row} for {} has no node id",
label.as_str()
),
})?;
let properties = graph
.node_store
.properties
.get(raw_row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"text index row {raw_row} for {} has no property row",
label.as_str()
),
})?;
let Some(Value::String(text)) = properties.get(&property) else {
continue;
};
index.insert_document(raw_row, node_id, text.as_str());
}
Ok(index.finish())
}
#[must_use]
pub fn empty(label: DbString, property: DbString) -> Self {
Self {
label,
property,
rows: RoaringBitmap::new(),
document_lengths: FxHashMap::default(),
document_terms: FxHashMap::default(),
postings: FxHashMap::default(),
total_document_len: 0,
posting_count: 0,
}
}
#[must_use]
pub const fn label(&self) -> &DbString {
&self.label
}
#[must_use]
pub const fn property(&self) -> &DbString {
&self.property
}
#[must_use]
pub const fn rows(&self) -> &RoaringBitmap {
&self.rows
}
#[must_use]
pub fn document_count(&self) -> usize {
self.document_lengths.len()
}
#[must_use]
pub fn term_count(&self) -> usize {
self.postings.len()
}
#[must_use]
pub const fn posting_count(&self) -> usize {
self.posting_count
}
#[must_use]
pub fn stats(&self) -> TextIndexStats {
TextIndexStats {
indexed_rows: self.rows.len(),
documents: self.document_count(),
distinct_terms: self.term_count(),
postings: self.posting_count,
total_document_len: self.total_document_len,
}
}
#[must_use]
pub fn memory_usage(&self) -> TextIndexMemoryUsage {
let row_bitmap_bytes = roaring_heap_bytes(&self.rows);
let row_bitmap_serialized_bytes = self.rows.serialized_size();
let document_length_bytes = self
.document_lengths
.capacity()
.saturating_mul(size_of::<(NodeId, u32)>());
let mut document_term_bytes = self
.document_terms
.capacity()
.saturating_mul(size_of::<(NodeId, Arc<[String]>)>());
for terms in self.document_terms.values() {
document_term_bytes =
document_term_bytes.saturating_add(terms.len().saturating_mul(size_of::<String>()));
for term in terms.iter() {
document_term_bytes = document_term_bytes.saturating_add(term.capacity());
}
}
let mut posting_bytes = 0usize;
let mut term_bytes = 0usize;
for (term, postings) in &self.postings {
term_bytes = term_bytes.saturating_add(term.capacity());
posting_bytes = posting_bytes
.saturating_add(postings.capacity().saturating_mul(size_of::<TextPosting>()));
}
let terms_table_bytes = self
.postings
.capacity()
.saturating_mul(size_of::<(String, Arc<Vec<TextPosting>>)>());
let estimated_index_bytes = size_of::<Self>()
.saturating_add(row_bitmap_bytes)
.saturating_add(document_length_bytes)
.saturating_add(document_term_bytes)
.saturating_add(terms_table_bytes)
.saturating_add(term_bytes)
.saturating_add(posting_bytes);
TextIndexMemoryUsage {
indexed_rows: self.rows.len(),
documents: self.document_count(),
distinct_terms: self.term_count(),
postings: self.posting_count,
row_bitmap_bytes,
row_bitmap_serialized_bytes,
document_length_bytes,
document_term_bytes,
terms_table_bytes,
term_bytes,
posting_bytes,
estimated_index_bytes,
}
}
#[must_use]
pub fn search(&self, query: &str, k: usize) -> Vec<TextSearchHit> {
self.search_checked(query, k, CancellationChecker::disabled())
.expect("disabled text-index checker cannot fail")
}
pub fn search_checked(
&self,
query: &str,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<TextSearchHit>, TextSearchError> {
checker.check()?;
if k == 0 || self.document_lengths.is_empty() {
return Ok(Vec::new());
}
let query_terms = unique_query_terms(query);
if query_terms.is_empty() {
return Ok(Vec::new());
}
let mut document_frequencies = QueryDocumentFrequencies::with_capacity(query_terms.len());
let mut postings_by_term = QueryPostings::with_capacity(query_terms.len());
let mut candidate_capacity = 0usize;
for term in &query_terms {
match self.postings.get(term) {
Some(postings) => {
candidate_capacity = candidate_capacity.saturating_add(postings.len());
document_frequencies.push(u32::try_from(postings.len()).unwrap_or(u32::MAX));
postings_by_term.push(Some(postings.as_slice()));
}
None => {
document_frequencies.push(0);
postings_by_term.push(None);
}
}
}
let candidate_capacity = candidate_capacity.min(self.document_lengths.len());
if candidate_capacity == 0 {
return Ok(Vec::new());
}
let mut candidates: FxHashMap<NodeId, DocumentStats> = FxHashMap::default();
candidates.reserve(candidate_capacity);
let mut postings_since_check = 0usize;
for (term_index, postings) in postings_by_term.into_iter().enumerate() {
let Some(postings) = postings else {
continue;
};
for posting in postings {
postings_since_check += 1;
if postings_since_check >= crate::text_search::TEXT_SEARCH_CANCEL_STRIDE {
checker.check()?;
postings_since_check = 0;
}
let len = *self
.document_lengths
.get(&posting.node_id)
.expect("posting node must have document length");
let doc = candidates.entry(posting.node_id).or_insert_with(|| {
DocumentStats::zero(posting.node_id, len, query_terms.len())
});
doc.term_counts[term_index] = posting.term_count;
}
}
if candidates.is_empty() {
return Ok(Vec::new());
}
let corpus_len = self.document_lengths.len() as f64;
let average_document_len = self.total_document_len as f64 / corpus_len;
let mut top_k = TextTopK::new(k);
let mut docs_since_check = 0usize;
for doc in candidates.into_values() {
docs_since_check += 1;
if docs_since_check >= crate::text_search::TEXT_SEARCH_CANCEL_STRIDE {
checker.note_nodes_scanned(docs_since_check)?;
docs_since_check = 0;
}
let score = bm25_score(
&doc,
&document_frequencies,
corpus_len,
average_document_len,
);
if score > 0.0 {
top_k.push(doc.node_id, score);
}
}
if docs_since_check > 0 {
checker.note_nodes_scanned(docs_since_check)?;
}
Ok(top_k.into_hits())
}
pub(crate) fn insert_document(&mut self, row: u32, node_id: NodeId, text: &str) {
self.remove_document(row, node_id);
let mut counts: FxHashMap<String, u32> = FxHashMap::default();
let mut len = 0_u32;
for token in tokenize_borrowed(text) {
len = len.saturating_add(1);
let count = counts.entry(token.into_owned()).or_insert(0);
*count = count.saturating_add(1);
}
if len == 0 {
return;
}
self.rows.insert(row);
self.document_lengths.insert(node_id, len);
self.total_document_len = self.total_document_len.saturating_add(u64::from(len));
let mut terms = Vec::with_capacity(counts.len());
for (term, term_count) in counts {
let postings = self
.postings
.entry(term.clone())
.or_insert_with(|| Arc::new(Vec::new()));
let postings = Arc::make_mut(postings);
match postings.binary_search_by_key(&node_id, |posting| posting.node_id) {
Ok(index) => {
postings[index].term_count = term_count;
}
Err(index) => {
postings.insert(
index,
TextPosting {
node_id,
term_count,
},
);
self.posting_count = self.posting_count.saturating_add(1);
}
}
terms.push(term);
}
self.document_terms.insert(node_id, Arc::from(terms));
}
pub(crate) fn remove_document(&mut self, row: u32, node_id: NodeId) {
self.rows.remove(row);
let Some(length) = self.document_lengths.remove(&node_id) else {
return;
};
self.total_document_len = self.total_document_len.saturating_sub(u64::from(length));
let Some(terms) = self.document_terms.remove(&node_id) else {
return;
};
for term in terms.iter() {
let remove_term = if let Some(postings) = self.postings.get_mut(term.as_str()) {
let postings = Arc::make_mut(postings);
if let Ok(index) =
postings.binary_search_by_key(&node_id, |posting| posting.node_id)
{
postings.remove(index);
self.posting_count = self.posting_count.saturating_sub(1);
}
postings.is_empty()
} else {
false
};
if remove_term {
self.postings.remove(term.as_str());
}
}
}
pub(crate) fn rows_eq(&self, reference: &Self) -> bool {
self.rows == reference.rows
&& self.document_lengths == reference.document_lengths
&& self.total_document_len == reference.total_document_len
&& self.posting_count == reference.posting_count
&& self.postings == reference.postings
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct TextIndexStats {
pub indexed_rows: u64,
pub documents: usize,
pub distinct_terms: usize,
pub postings: usize,
pub total_document_len: u64,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct TextIndexMemoryUsage {
pub indexed_rows: u64,
pub documents: usize,
pub distinct_terms: usize,
pub postings: usize,
pub row_bitmap_bytes: usize,
pub row_bitmap_serialized_bytes: usize,
pub document_length_bytes: usize,
pub document_term_bytes: usize,
pub terms_table_bytes: usize,
pub term_bytes: usize,
pub posting_bytes: usize,
pub estimated_index_bytes: usize,
}
impl SeleneGraph {
pub fn build_text_index(
&self,
label: &DbString,
property: &DbString,
) -> GraphResult<TextIndex> {
TextIndex::build(self, label.clone(), property.clone())
}
pub fn indexed_text_search_nodes(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
) -> GraphResult<Vec<TextSearchHit>> {
Ok(self.build_text_index(label, property)?.search(query, k))
}
}
impl SharedGraph {
pub fn build_text_index(
&self,
label: &DbString,
property: &DbString,
) -> GraphResult<TextIndex> {
self.read().build_text_index(label, property)
}
pub fn indexed_text_search_nodes(
&self,
label: &DbString,
property: &DbString,
query: &str,
k: usize,
) -> GraphResult<Vec<TextSearchHit>> {
self.read()
.indexed_text_search_nodes(label, property, query, k)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct TextPosting {
node_id: NodeId,
term_count: u32,
}
fn roaring_heap_bytes(rows: &RoaringBitmap) -> usize {
let statistics = rows.statistics();
usize::try_from(
statistics
.n_bytes_array_containers
.saturating_add(statistics.n_bytes_run_containers)
.saturating_add(statistics.n_bytes_bitset_containers),
)
.unwrap_or(usize::MAX)
}
#[cfg(test)]
#[path = "text_index/tests.rs"]
mod tests;