use crate::data::expr::{eval_bytecode, eval_bytecode_pred, Bytecode};
use crate::data::program::{FtsScoreKind, FtsSearch};
use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::LARGEST_UTF_CHAR;
use crate::fts::ast::{FtsExpr, FtsLiteral, FtsNear};
use crate::fts::tokenizer::TextAnalyzer;
use crate::parse::fts::parse_fts_query;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{DataValue, SourceSpan};
use itertools::Itertools;
use miette::{bail, miette, Diagnostic, Result};
use ordered_float::OrderedFloat;
use rustc_hash::{FxHashMap, FxHashSet};
use smartstring::{LazyCompact, SmartString};
use std::cmp::Reverse;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Default)]
pub(crate) struct FtsCache {
total_n_cache: FxHashMap<SmartString<LazyCompact>, usize>,
}
impl FtsCache {
fn get_n_for_relation(&mut self, rel: &RelationHandle, tx: &SessionTx<'_>) -> Result<usize> {
Ok(match self.total_n_cache.entry(rel.name.clone()) {
Entry::Vacant(v) => {
let start = rel.encode_partial_key_for_store(&[]);
let end = rel.encode_partial_key_for_store(&[DataValue::Bot]);
let val = tx.store_tx.range_count(&start, &end)?;
v.insert(val);
val
}
Entry::Occupied(o) => *o.get(),
})
}
fn get_avgdl_for_index(&mut self, idx: &RelationHandle, tx: &SessionTx<'_>) -> Result<f64> {
let avgdl = |total: u64, n: u64| if n > 0 { total as f64 / n as f64 } else { 0.0 };
if let Some((total, n)) = tx.read_fts_doc_stats(idx)? {
return Ok(avgdl(total, n));
}
if let Some((total, n)) = tx
.fts_doc_stats_cache
.lock()
.unwrap()
.get(&idx.name)
.copied()
{
return Ok(avgdl(total, n));
}
let (total, n) = tx.scan_fts_doc_stats(idx)?;
tx.fts_doc_stats_cache
.lock()
.unwrap()
.insert(idx.name.clone(), (total, n));
Ok(avgdl(total, n))
}
}
struct PositionInfo {
position: u32,
}
struct LiteralStats {
key: Tuple,
position_info: Vec<PositionInfo>,
doc_len: u32,
}
impl<'a> SessionTx<'a> {
fn fts_stats_key(idx: &RelationHandle) -> Vec<u8> {
idx.encode_partial_key_for_store(&[DataValue::Bot])
}
pub(crate) fn read_fts_doc_stats(&self, idx: &RelationHandle) -> Result<Option<(u64, u64)>> {
let key = Self::fts_stats_key(idx);
match self.store_tx.get(&key, false)? {
None => Ok(None),
Some(v) => {
let vals: Vec<DataValue> = rmp_serde::from_slice(&v[ENCODED_KEY_MIN_LEN..])
.map_err(|e| miette!("corrupt FTS doc-stats counter: {e}"))?;
let total = vals.first().and_then(|d| d.get_int()).unwrap_or(0).max(0) as u64;
let n = vals.get(1).and_then(|d| d.get_int()).unwrap_or(0).max(0) as u64;
Ok(Some((total, n)))
}
}
}
pub(crate) fn write_fts_doc_stats(
&mut self,
idx: &RelationHandle,
total: u64,
n: u64,
) -> Result<()> {
let key = Self::fts_stats_key(idx);
let val = vec![DataValue::from(total as i64), DataValue::from(n as i64)];
let val_bytes = idx.encode_val_only_for_store(&val, Default::default())?;
self.store_tx.put(&key, &val_bytes)
}
pub(crate) fn scan_fts_doc_stats(&self, idx: &RelationHandle) -> Result<(u64, u64)> {
let start = idx.encode_partial_key_for_store(&[]);
let end = idx.encode_partial_key_for_store(&[DataValue::Bot]);
let mut seen: FxHashSet<Tuple> = FxHashSet::default();
let mut total: u64 = 0;
for item in self.store_tx.range_scan(&start, &end) {
let (kvec, vvec) = item?;
let key_tuple = decode_tuple_from_key(&kvec, idx.metadata.keys.len());
if seen.insert(key_tuple[1..].to_vec()) {
let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..])
.map_err(|e| miette!("corrupt FTS posting value: {e}"))?;
total += vals[3].get_int().unwrap_or(0).max(0) as u64;
}
}
Ok((total, seen.len() as u64))
}
fn ensure_fts_doc_stats(&mut self, idx: &RelationHandle) -> Result<(u64, u64)> {
if let Some(s) = self.read_fts_doc_stats(idx)? {
return Ok(s);
}
let (total, n) = self.scan_fts_doc_stats(idx)?;
self.write_fts_doc_stats(idx, total, n)?;
Ok((total, n))
}
pub(crate) fn rebuild_fts_doc_stats(&mut self, idx: &RelationHandle) -> Result<()> {
let (total, n) = self.scan_fts_doc_stats(idx)?;
self.write_fts_doc_stats(idx, total, n)
}
fn fts_search_literal(
&self,
literal: &FtsLiteral,
idx_handle: &RelationHandle,
) -> Result<Vec<LiteralStats>> {
let start_key_str = &literal.value as &str;
let start_key = vec![DataValue::Str(SmartString::from(start_key_str))];
let mut end_key_str = literal.value.clone();
end_key_str.push(LARGEST_UTF_CHAR);
let end_key = vec![DataValue::Str(end_key_str)];
let start_key_bytes = idx_handle.encode_partial_key_for_store(&start_key);
let end_key_bytes = idx_handle.encode_partial_key_for_store(&end_key);
let mut results = vec![];
for item in self.store_tx.range_scan(&start_key_bytes, &end_key_bytes) {
let (kvec, vvec) = item?;
let key_tuple = decode_tuple_from_key(&kvec, idx_handle.metadata.keys.len());
let found_str_key = key_tuple[0].get_str().unwrap();
if literal.is_prefix {
if !found_str_key.starts_with(start_key_str) {
break;
}
} else if found_str_key != start_key_str {
break;
}
let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap();
let froms = vals[0].get_slice().unwrap();
let tos = vals[1].get_slice().unwrap();
let positions = vals[2].get_slice().unwrap();
let total_length = vals[3].get_int().unwrap();
let position_info = froms
.iter()
.zip(tos.iter())
.zip(positions.iter())
.map(|(_, p)| PositionInfo {
position: p.get_int().unwrap() as u32,
})
.collect_vec();
results.push(LiteralStats {
key: key_tuple[1..].to_vec(),
position_info,
doc_len: total_length as u32,
});
}
Ok(results)
}
fn fts_search_impl(
&self,
ast: &FtsExpr,
config: &FtsSearch,
n: usize,
avgdl: f64,
) -> Result<FxHashMap<Tuple, f64>> {
Ok(match ast {
FtsExpr::Literal(l) => {
let mut res = FxHashMap::default();
let found_docs = self.fts_search_literal(l, &config.idx_handle)?;
let found_docs_len = found_docs.len();
for el in found_docs {
let score = Self::fts_compute_score(
el.position_info.len(),
found_docs_len,
n,
el.doc_len,
avgdl,
l.booster.0,
config,
);
res.insert(el.key, score);
}
res
}
FtsExpr::And(ls) => {
let mut l_iter = ls.iter();
let mut res = self.fts_search_impl(
l_iter.next().unwrap(),
config,
n,
avgdl,
)?;
for nxt in l_iter {
let nxt_res = self.fts_search_impl(nxt, config, n, avgdl)?;
res = res
.into_iter()
.filter_map(|(k, v)| nxt_res.get(&k).map(|nxt_v| (k, v + nxt_v)))
.collect();
}
res
}
FtsExpr::Or(ls) => {
let sum_terms = config.score_kind == FtsScoreKind::Bm25;
let mut res: FxHashMap<Tuple, f64> = FxHashMap::default();
for nxt in ls {
let nxt_res = self.fts_search_impl(nxt, config, n, avgdl)?;
for (k, v) in nxt_res {
if let Some(old_v) = res.get_mut(&k) {
*old_v = if sum_terms { *old_v + v } else { (*old_v).max(v) };
} else {
res.insert(k, v);
}
}
}
res
}
FtsExpr::Near(FtsNear { literals, distance }) => {
let mut l_it = literals.iter();
let mut coll: FxHashMap<_, _> = FxHashMap::default();
let mut doc_lens: FxHashMap<Tuple, u32> = FxHashMap::default();
for first_el in self.fts_search_literal(l_it.next().unwrap(), &config.idx_handle)? {
doc_lens.insert(first_el.key.clone(), first_el.doc_len);
coll.insert(
first_el.key,
first_el
.position_info
.into_iter()
.map(|el| el.position)
.collect_vec(),
);
}
for lit_nxt in literals {
let el_res = self.fts_search_literal(lit_nxt, &config.idx_handle)?;
coll = el_res
.into_iter()
.filter_map(|x| match coll.remove(&x.key) {
None => None,
Some(prev_pos) => {
let mut inner_coll = FxHashSet::default();
for p in prev_pos {
for pi in x.position_info.iter() {
let cur = pi.position;
if cur > p {
if cur - p <= *distance {
inner_coll.insert(p);
}
} else if p - cur <= *distance {
inner_coll.insert(cur);
}
}
}
if inner_coll.is_empty() {
None
} else {
Some((x.key, inner_coll.into_iter().collect_vec()))
}
}
})
.collect();
}
let mut booster = 0.0;
for lit in literals {
booster += lit.booster.0;
}
let coll_len = coll.len();
coll.into_iter()
.map(|(k, cands)| {
let doc_len = doc_lens.get(&k).copied().unwrap_or(0);
let score = Self::fts_compute_score(
cands.len(),
coll_len,
n,
doc_len,
avgdl,
booster,
config,
);
(k, score)
})
.collect()
}
FtsExpr::Not(fst, snd) => {
let mut res = self.fts_search_impl(fst, config, n, avgdl)?;
for el in self
.fts_search_impl(snd, config, n, avgdl)?
.keys()
{
res.remove(el);
}
res
}
})
}
fn fts_compute_score(
tf: usize,
n_found_docs: usize,
n_total: usize,
doc_len: u32,
avgdl: f64,
booster: f64,
config: &FtsSearch,
) -> f64 {
let tf = tf as f64;
match config.score_kind {
FtsScoreKind::Tf => tf * booster,
FtsScoreKind::TfIdf => {
let n_found_docs = n_found_docs as f64;
let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln();
tf * idf * booster
}
FtsScoreKind::Bm25 => {
let df = n_found_docs as f64;
let idf = (1.0 + (n_total as f64 - df + 0.5) / (df + 0.5)).ln();
let avgdl = if avgdl > 0.0 { avgdl } else { 1.0 };
let norm = 1.0 - config.b + config.b * (doc_len as f64) / avgdl;
let denom = tf + config.k1 * norm;
let saturated = if denom > 0.0 {
tf * (config.k1 + 1.0) / denom
} else {
0.0
};
idf * saturated * booster
}
}
}
pub(crate) fn fts_search(
&self,
q: &str,
config: &FtsSearch,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
tokenizer: &TextAnalyzer,
stack: &mut Vec<DataValue>,
cache: &mut FtsCache,
) -> Result<Vec<Tuple>> {
let ast = parse_fts_query(q)?.tokenize(tokenizer);
if ast.is_empty() {
return Ok(vec![]);
}
let n = match config.score_kind {
FtsScoreKind::TfIdf | FtsScoreKind::Bm25 => {
cache.get_n_for_relation(&config.base_handle, self)?
}
FtsScoreKind::Tf => 0,
};
let avgdl = if config.score_kind == FtsScoreKind::Bm25 {
cache.get_avgdl_for_index(&config.idx_handle, self)?
} else {
0.0
};
let mut result: Vec<_> = self
.fts_search_impl(&ast, config, n, avgdl)?
.into_iter()
.collect();
result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));
if config.filter.is_none() {
result.truncate(config.k);
}
let mut ret = Vec::with_capacity(config.k);
for (found_key, score) in result {
let mut cand_tuple = config
.base_handle
.get(self, &found_key)?
.ok_or_else(|| miette!("corrupted index"))?;
if config.bind_score.is_some() {
cand_tuple.push(DataValue::from(score));
}
if let Some((code, span)) = filter_code {
if !eval_bytecode_pred(code, &cand_tuple, stack, *span)? {
continue;
}
}
ret.push(cand_tuple);
if ret.len() >= config.k {
break;
}
}
Ok(ret)
}
pub(crate) fn put_fts_index_item(
&mut self,
tuple: &[DataValue],
extractor: &[Bytecode],
stack: &mut Vec<DataValue>,
tokenizer: &TextAnalyzer,
rel_handle: &RelationHandle,
idx_handle: &RelationHandle,
) -> Result<()> {
let to_index = match eval_bytecode(extractor, tuple, stack)? {
DataValue::Null => return Ok(()),
DataValue::Str(s) => s,
val => {
#[derive(Debug, Diagnostic, Error)]
#[error("FTS index extractor must return a string, got {0}")]
#[diagnostic(code(eval::fts::extractor::invalid_return_type))]
struct FtsExtractError(String);
bail!(FtsExtractError(format!("{}", val)))
}
};
let mut token_stream = tokenizer.token_stream(&to_index);
let mut collector: HashMap<_, (Vec<_>, Vec<_>, Vec<_>), _> = FxHashMap::default();
let mut count = 0i64;
while let Some(token) = token_stream.next() {
let text = SmartString::<LazyCompact>::from(&token.text);
let (fr, to, position) = collector.entry(text).or_default();
fr.push(DataValue::from(token.offset_from as i64));
to.push(DataValue::from(token.offset_to as i64));
position.push(DataValue::from(token.position as i64));
count += 1;
}
let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len());
key.push(DataValue::Bot);
for k in &tuple[..rel_handle.metadata.keys.len()] {
key.push(k.clone());
}
if count > 0 {
let (total, n) = self.ensure_fts_doc_stats(idx_handle)?;
self.write_fts_doc_stats(idx_handle, total + count as u64, n + 1)?;
}
let mut val = vec![
DataValue::Bot,
DataValue::Bot,
DataValue::Bot,
DataValue::from(count),
];
for (text, (from, to, position)) in collector {
key[0] = DataValue::Str(text);
val[0] = DataValue::List(from);
val[1] = DataValue::List(to);
val[2] = DataValue::List(position);
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
let val_bytes = idx_handle.encode_val_only_for_store(&val, Default::default())?;
self.store_tx.put(&key_bytes, &val_bytes)?;
}
Ok(())
}
pub(crate) fn del_fts_index_item(
&mut self,
tuple: &[DataValue],
extractor: &[Bytecode],
stack: &mut Vec<DataValue>,
tokenizer: &TextAnalyzer,
rel_handle: &RelationHandle,
idx_handle: &RelationHandle,
) -> Result<()> {
let to_index = match eval_bytecode(extractor, tuple, stack)? {
DataValue::Null => return Ok(()),
DataValue::Str(s) => s,
val => {
#[derive(Debug, Diagnostic, Error)]
#[error("FTS index extractor must return a string, got {0}")]
#[diagnostic(code(eval::fts::extractor::invalid_return_type))]
struct FtsExtractError(String);
bail!(FtsExtractError(format!("{}", val)))
}
};
let mut token_stream = tokenizer.token_stream(&to_index);
let mut collector = FxHashSet::default();
let mut count = 0i64;
while let Some(token) = token_stream.next() {
let text = SmartString::<LazyCompact>::from(&token.text);
collector.insert(text);
count += 1;
}
let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len());
key.push(DataValue::Bot);
for k in &tuple[..rel_handle.metadata.keys.len()] {
key.push(k.clone());
}
if count > 0 {
if let Some(term) = collector.iter().next() {
let mut probe = key.clone();
probe[0] = DataValue::Str(term.clone());
let probe_bytes = idx_handle.encode_key_for_store(&probe, Default::default())?;
if self.store_tx.exists(&probe_bytes, false)? {
let (total, n) = self.ensure_fts_doc_stats(idx_handle)?;
self.write_fts_doc_stats(
idx_handle,
total.saturating_sub(count as u64),
n.saturating_sub(1),
)?;
}
}
}
for text in collector {
key[0] = DataValue::Str(text);
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.del(&key_bytes)?;
}
Ok(())
}
}