1use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::collections::{BTreeSet, BinaryHeap};
12use std::time::Duration;
13
14use roaring::RoaringBitmap;
15use selene_core::{CancellationCause, CancellationChecker, DbString, NodeId, Value};
16use smallvec::SmallVec;
17
18use crate::error::{GraphError, GraphResult};
19use crate::graph::SeleneGraph;
20use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
21use crate::shared::SharedGraph;
22use crate::store::RowIndex;
23
24pub(crate) const TEXT_SEARCH_CANCEL_STRIDE: usize = 1024;
25#[cfg(not(test))]
26const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
27#[cfg(test)]
28const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 4;
29
30#[cfg(not(test))]
31const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
32#[cfg(test)]
33const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
34const BM25_K1: f64 = 1.2;
35const BM25_B: f64 = 0.75;
36
37type TermCounts = SmallVec<[u32; 4]>;
38
39#[derive(Clone, Debug, PartialEq)]
41pub struct TextSearchHit {
42 pub node_id: NodeId,
44 pub score: f64,
46}
47
48#[derive(Debug, thiserror::Error)]
50pub enum TextSearchError {
51 #[error(transparent)]
53 Graph(#[from] GraphError),
54 #[error("text search cancelled")]
56 Cancelled,
57 #[error("text search timed out after {elapsed:?}")]
59 Timeout {
60 elapsed: Duration,
62 },
63 #[error("text search node scan budget exceeded ({scanned} > {limit})")]
65 NodeScanBudgetExceeded {
66 limit: usize,
68 scanned: usize,
70 },
71}
72
73impl TextSearchError {
74 fn into_graph_error(self) -> GraphError {
75 match self {
76 Self::Graph(error) => error,
77 Self::Cancelled | Self::Timeout { .. } | Self::NodeScanBudgetExceeded { .. } => {
78 GraphError::Inconsistent {
79 reason: format!("disabled text-search checker returned {self}"),
80 }
81 }
82 }
83 }
84}
85
86impl From<CancellationCause> for TextSearchError {
87 fn from(cause: CancellationCause) -> Self {
88 match cause {
89 CancellationCause::Cancelled => Self::Cancelled,
90 CancellationCause::Timeout { elapsed } => Self::Timeout { elapsed },
91 CancellationCause::NodeScanBudgetExceeded { limit, scanned } => {
92 Self::NodeScanBudgetExceeded { limit, scanned }
93 }
94 }
95 }
96}
97
98impl SeleneGraph {
99 pub fn exact_text_search_nodes(
107 &self,
108 label: &DbString,
109 property: &DbString,
110 query: &str,
111 k: usize,
112 ) -> GraphResult<Vec<TextSearchHit>> {
113 self.exact_text_search_nodes_checked(
114 label,
115 property,
116 query,
117 k,
118 CancellationChecker::disabled(),
119 )
120 .map_err(TextSearchError::into_graph_error)
121 }
122
123 pub fn exact_text_search_nodes_checked(
125 &self,
126 label: &DbString,
127 property: &DbString,
128 query: &str,
129 k: usize,
130 checker: CancellationChecker<'_>,
131 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
132 self.exact_text_search_nodes_filtered_checked(label, property, query, k, None, checker)
133 }
134
135 pub fn exact_text_search_nodes_in_rows_checked(
141 &self,
142 label: &DbString,
143 property: &DbString,
144 query: &str,
145 k: usize,
146 allowed_rows: &RoaringBitmap,
147 checker: CancellationChecker<'_>,
148 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
149 if allowed_rows.is_empty() {
150 return Ok(Vec::new());
151 }
152 self.exact_text_search_nodes_filtered_checked(
153 label,
154 property,
155 query,
156 k,
157 Some(allowed_rows),
158 checker,
159 )
160 }
161
162 fn exact_text_search_nodes_filtered_checked(
163 &self,
164 label: &DbString,
165 property: &DbString,
166 query: &str,
167 k: usize,
168 allowed_rows: Option<&RoaringBitmap>,
169 checker: CancellationChecker<'_>,
170 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
171 checker.check()?;
172 if k == 0 {
173 return Ok(Vec::new());
174 }
175 let query_terms = unique_query_terms(query);
176 if query_terms.is_empty() {
177 return Ok(Vec::new());
178 }
179 let Some(label_rows) = self.nodes_with_label(label) else {
180 return Ok(Vec::new());
181 };
182
183 let scan = TextScan::new(self, label, property, &query_terms, allowed_rows);
184 let chunk = if should_parallelize_text_scan(label_rows, k) {
185 exact_text_scan_parallel(scan, label_rows, checker)?
186 } else {
187 exact_text_scan_serial(scan, label_rows, checker)?
188 };
189 Ok(rank_text_docs(chunk, k))
190 }
191}
192
193impl SharedGraph {
194 pub fn exact_text_search_nodes(
196 &self,
197 label: &DbString,
198 property: &DbString,
199 query: &str,
200 k: usize,
201 ) -> GraphResult<Vec<TextSearchHit>> {
202 self.read()
203 .exact_text_search_nodes(label, property, query, k)
204 }
205
206 pub fn exact_text_search_nodes_checked(
208 &self,
209 label: &DbString,
210 property: &DbString,
211 query: &str,
212 k: usize,
213 checker: CancellationChecker<'_>,
214 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
215 self.read()
216 .exact_text_search_nodes_checked(label, property, query, k, checker)
217 }
218}
219
220#[derive(Clone, Copy)]
221struct TextScan<'a> {
222 graph: &'a SeleneGraph,
223 label: &'a DbString,
224 property: &'a DbString,
225 query_terms: &'a [String],
226 allowed_rows: Option<&'a RoaringBitmap>,
227}
228
229impl<'a> TextScan<'a> {
230 fn new(
231 graph: &'a SeleneGraph,
232 label: &'a DbString,
233 property: &'a DbString,
234 query_terms: &'a [String],
235 allowed_rows: Option<&'a RoaringBitmap>,
236 ) -> Self {
237 Self {
238 graph,
239 label,
240 property,
241 query_terms,
242 allowed_rows,
243 }
244 }
245
246 fn document_for_row(self, raw_row: u32) -> Result<Option<DocumentStats>, TextSearchError> {
247 if !self.graph.node_store.is_alive(raw_row) {
248 return Ok(None);
249 }
250 let row = RowIndex::new(raw_row);
251 let node_id = self
252 .graph
253 .node_id_for_row(row)
254 .ok_or_else(|| GraphError::Inconsistent {
255 reason: format!(
256 "label index row {raw_row} for {} has no node id",
257 self.label.as_str()
258 ),
259 })?;
260 let properties = self
261 .graph
262 .node_store
263 .properties
264 .get(raw_row as usize)
265 .ok_or_else(|| GraphError::Inconsistent {
266 reason: format!(
267 "text search row {raw_row} for {} has no property row",
268 self.label.as_str()
269 ),
270 })?;
271 let Some(Value::String(text)) = properties.get(self.property) else {
272 return Ok(None);
273 };
274 Ok(document_stats(
275 node_id,
276 text.as_str(),
277 self.query_terms,
278 self.allowed_rows
279 .is_none_or(|allowed_rows| allowed_rows.contains(raw_row)),
280 ))
281 }
282}
283
284#[derive(Debug)]
285struct TextScanChunk {
286 docs: Vec<DocumentStats>,
287 document_frequencies: Vec<u32>,
288 total_document_len: u64,
289}
290
291impl TextScanChunk {
292 fn empty(query_term_count: usize) -> Self {
293 Self {
294 docs: Vec::new(),
295 document_frequencies: vec![0; query_term_count],
296 total_document_len: 0,
297 }
298 }
299
300 fn push(&mut self, doc: DocumentStats) {
301 for (frequency, count) in self.document_frequencies.iter_mut().zip(&doc.term_counts) {
302 if *count > 0 {
303 *frequency = frequency.saturating_add(1);
304 }
305 }
306 self.total_document_len = self.total_document_len.saturating_add(u64::from(doc.len));
307 self.docs.push(doc);
308 }
309}
310
311fn should_parallelize_text_scan(rows: &RoaringBitmap, k: usize) -> bool {
312 should_parallelize_scan(rows.len(), k, TEXT_SEARCH_PARALLEL_MIN_ROWS)
313}
314
315fn exact_text_scan_parallel(
316 scan: TextScan<'_>,
317 rows: &RoaringBitmap,
318 checker: CancellationChecker<'_>,
319) -> Result<TextScanChunk, TextSearchError> {
320 try_reduce_bitmap_chunks(
321 rows,
322 TEXT_SEARCH_PARALLEL_CHUNK_ROWS,
323 checker,
324 || TextScanChunk::empty(scan.query_terms.len()),
325 |chunk| exact_text_scan_chunk(scan, chunk),
326 merge_text_scan_chunks,
327 )
328}
329
330fn exact_text_scan_serial(
331 scan: TextScan<'_>,
332 rows: &RoaringBitmap,
333 checker: CancellationChecker<'_>,
334) -> Result<TextScanChunk, TextSearchError> {
335 let mut chunk = TextScanChunk::empty(scan.query_terms.len());
336 let mut rows_since_check = 0usize;
337 for raw_row in rows.iter() {
338 rows_since_check += 1;
339 if rows_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
340 checker.note_nodes_scanned(rows_since_check)?;
341 rows_since_check = 0;
342 }
343 if let Some(doc) = scan.document_for_row(raw_row)? {
344 chunk.push(doc);
345 }
346 }
347 if rows_since_check > 0 {
348 checker.note_nodes_scanned(rows_since_check)?;
349 }
350 Ok(chunk)
351}
352
353fn exact_text_scan_chunk(
354 scan: TextScan<'_>,
355 rows: &[u32],
356) -> Result<TextScanChunk, TextSearchError> {
357 let mut chunk = TextScanChunk::empty(scan.query_terms.len());
358 for &raw_row in rows {
359 if let Some(doc) = scan.document_for_row(raw_row)? {
360 chunk.push(doc);
361 }
362 }
363 Ok(chunk)
364}
365
366fn merge_text_scan_chunks(
367 mut lhs: TextScanChunk,
368 mut rhs: TextScanChunk,
369) -> Result<TextScanChunk, TextSearchError> {
370 for (lhs_frequency, rhs_frequency) in lhs
371 .document_frequencies
372 .iter_mut()
373 .zip(&rhs.document_frequencies)
374 {
375 *lhs_frequency = lhs_frequency.saturating_add(*rhs_frequency);
376 }
377 lhs.total_document_len = lhs
378 .total_document_len
379 .saturating_add(rhs.total_document_len);
380 lhs.docs.append(&mut rhs.docs);
381 Ok(lhs)
382}
383
384fn rank_text_docs(chunk: TextScanChunk, k: usize) -> Vec<TextSearchHit> {
385 if chunk.docs.is_empty() {
386 return Vec::new();
387 }
388 let corpus_len = chunk.docs.len() as f64;
389 let average_document_len = chunk.total_document_len as f64 / corpus_len;
390 let mut top_k = TextTopK::new(k);
391 for doc in chunk.docs {
392 if !doc.admitted {
393 continue;
394 }
395 let score = bm25_score(
396 &doc,
397 &chunk.document_frequencies,
398 corpus_len,
399 average_document_len,
400 );
401 if score > 0.0 {
402 top_k.push(doc.node_id, score);
403 }
404 }
405 top_k.into_hits()
406}
407
408#[derive(Debug)]
409pub(crate) struct DocumentStats {
410 pub(crate) node_id: NodeId,
411 len: u32,
412 pub(crate) term_counts: TermCounts,
413 admitted: bool,
414}
415
416impl DocumentStats {
417 pub(crate) fn zero(node_id: NodeId, len: u32, query_term_count: usize) -> Self {
418 Self {
419 node_id,
420 len,
421 term_counts: TermCounts::from_elem(0, query_term_count),
422 admitted: true,
423 }
424 }
425}
426
427pub(crate) fn unique_query_terms(query: &str) -> Vec<String> {
428 let terms: BTreeSet<_> = tokenize_borrowed(query).map(Cow::into_owned).collect();
429 terms.into_iter().collect()
430}
431
432fn document_stats(
433 node_id: NodeId,
434 text: &str,
435 query_terms: &[String],
436 admitted: bool,
437) -> Option<DocumentStats> {
438 let mut term_counts = TermCounts::from_elem(0, query_terms.len());
439 let mut len = 0_u32;
440 for token in tokenize_borrowed(text) {
441 len = len.saturating_add(1);
442 if let Ok(index) = query_terms.binary_search_by(|term| term.as_str().cmp(token.as_ref())) {
443 term_counts[index] = term_counts[index].saturating_add(1);
444 }
445 }
446 (len > 0).then_some(DocumentStats {
447 node_id,
448 len,
449 term_counts,
450 admitted,
451 })
452}
453
454pub(crate) fn tokenize_borrowed(text: &str) -> Tokenizer<'_> {
456 Tokenizer { text, offset: 0 }
457}
458
459pub(crate) struct Tokenizer<'a> {
461 text: &'a str,
462 offset: usize,
463}
464
465impl<'a> Iterator for Tokenizer<'a> {
466 type Item = Cow<'a, str>;
467
468 fn next(&mut self) -> Option<Self::Item> {
469 let mut start = None;
470 let mut end = self.text.len();
471 let mut owned = None::<String>;
472
473 let base = self.offset;
474 for (relative_index, ch) in self.text[base..].char_indices() {
475 let index = base + relative_index;
476 if !ch.is_alphanumeric() {
477 if start.is_some() {
478 end = index;
479 self.offset = index + ch.len_utf8();
480 break;
481 }
482 self.offset = index + ch.len_utf8();
483 continue;
484 }
485
486 let start_index = *start.get_or_insert(index);
487 let mut lowercase = ch.to_lowercase();
488 let first = lowercase
489 .next()
490 .expect("char lowercase mapping yields at least one char");
491 let second = lowercase.next();
492 let unchanged = first == ch && second.is_none();
493 if let Some(buffer) = owned.as_mut() {
494 if unchanged {
495 buffer.push(ch);
496 } else {
497 buffer.push(first);
498 if let Some(second) = second {
499 buffer.push(second);
500 }
501 buffer.extend(lowercase);
502 }
503 } else if !unchanged {
504 let mut buffer = self.text[start_index..index].to_owned();
505 buffer.push(first);
506 if let Some(second) = second {
507 buffer.push(second);
508 }
509 buffer.extend(lowercase);
510 owned = Some(buffer);
511 }
512 }
513
514 let start = start?;
515 if self.offset <= start {
516 self.offset = self.text.len();
517 }
518
519 Some(match owned {
520 Some(token) => Cow::Owned(token),
521 None => Cow::Borrowed(&self.text[start..end]),
522 })
523 }
524}
525
526pub(crate) fn bm25_score(
527 doc: &DocumentStats,
528 document_frequencies: &[u32],
529 corpus_len: f64,
530 average_document_len: f64,
531) -> f64 {
532 let document_len = f64::from(doc.len);
533 doc.term_counts
534 .iter()
535 .zip(document_frequencies)
536 .filter(|(term_count, _)| **term_count > 0)
537 .map(|(term_count, document_frequency)| {
538 let term_count = f64::from(*term_count);
539 let document_frequency = f64::from(*document_frequency);
540 let idf =
541 (1.0 + (corpus_len - document_frequency + 0.5) / (document_frequency + 0.5)).ln();
542 let normalization = term_count
543 + BM25_K1 * (1.0 - BM25_B + BM25_B * document_len / average_document_len);
544 idf * (term_count * (BM25_K1 + 1.0)) / normalization
545 })
546 .sum()
547}
548
549#[derive(Debug)]
550pub(crate) struct TextTopK {
551 k: usize,
552 heap: BinaryHeap<TextHeapEntry>,
553}
554
555impl TextTopK {
556 pub(crate) fn new(k: usize) -> Self {
557 Self {
558 k,
559 heap: BinaryHeap::new(),
560 }
561 }
562
563 pub(crate) fn push(&mut self, node_id: NodeId, score: f64) {
564 debug_assert!(score.is_finite(), "BM25 scores must be finite");
565 if self.k == 0 {
566 return;
567 }
568 let entry = TextHeapEntry { score, node_id };
569 if self.heap.len() < self.k {
570 self.heap.push(entry);
571 return;
572 }
573 let Some(worst) = self.heap.peek() else {
574 return;
575 };
576 if entry.cmp(worst).is_lt() {
577 self.heap.pop();
578 self.heap.push(entry);
579 }
580 }
581
582 pub(crate) fn into_hits(self) -> Vec<TextSearchHit> {
583 let mut hits: Vec<_> = self
584 .heap
585 .into_iter()
586 .map(|entry| TextSearchHit {
587 node_id: entry.node_id,
588 score: entry.score,
589 })
590 .collect();
591 hits.sort_by(compare_hit);
592 hits
593 }
594}
595
596#[derive(Debug)]
597struct TextHeapEntry {
598 score: f64,
599 node_id: NodeId,
600}
601
602impl Eq for TextHeapEntry {}
603
604impl PartialEq for TextHeapEntry {
605 fn eq(&self, rhs: &Self) -> bool {
606 self.score.to_bits() == rhs.score.to_bits() && self.node_id == rhs.node_id
607 }
608}
609
610impl Ord for TextHeapEntry {
611 fn cmp(&self, rhs: &Self) -> Ordering {
612 rhs.score
613 .total_cmp(&self.score)
614 .then_with(|| self.node_id.cmp(&rhs.node_id))
615 }
616}
617
618impl PartialOrd for TextHeapEntry {
619 fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
620 Some(self.cmp(rhs))
621 }
622}
623
624fn compare_hit(lhs: &TextSearchHit, rhs: &TextSearchHit) -> Ordering {
625 rhs.score
626 .total_cmp(&lhs.score)
627 .then_with(|| lhs.node_id.cmp(&rhs.node_id))
628}
629
630#[cfg(test)]
631#[path = "text_search/tests.rs"]
632mod tests;