use std::collections::HashSet;
use triblespace_core::query::{Binding, Constraint, Variable, VariableId, VariableSet};
use triblespace_core::inline::encodings::genid::GenId;
use triblespace_core::inline::encodings::hash::Handle;
use triblespace_core::inline::{RawInline, Inline};
use crate::bm25::BM25Index;
use crate::schemas::Embedding;
pub trait BM25Queryable {
fn query_term_boxed<'a>(
&'a self,
term: &RawInline,
) -> Box<dyn Iterator<Item = (RawInline, f32)> + 'a>;
}
impl<D: triblespace_core::inline::InlineEncoding, T: triblespace_core::inline::InlineEncoding>
BM25Queryable for BM25Index<D, T>
{
fn query_term_boxed<'a>(
&'a self,
term: &RawInline,
) -> Box<dyn Iterator<Item = (RawInline, f32)> + 'a> {
let term_val = Inline::<T>::new(*term);
Box::new(self.query_term(&term_val).map(|(v, s)| (v.raw, s)))
}
}
#[cfg(feature = "succinct")]
impl<D: triblespace_core::inline::InlineEncoding, T: triblespace_core::inline::InlineEncoding>
BM25Queryable for crate::succinct::SuccinctBM25Index<D, T>
{
fn query_term_boxed<'a>(
&'a self,
term: &RawInline,
) -> Box<dyn Iterator<Item = (RawInline, f32)> + 'a> {
let term_val = Inline::<T>::new(*term);
Box::new(self.query_term(&term_val).map(|(v, s)| (v.raw, s)))
}
}
pub struct BM25Filter<S = GenId>
where
S: triblespace_core::inline::InlineEncoding,
{
doc: Variable<S>,
entries: Vec<RawInline>,
}
impl<S> BM25Filter<S>
where
S: triblespace_core::inline::InlineEncoding,
{
pub fn from_entries<I>(doc: Variable<S>, entries: I) -> Self
where
I: IntoIterator<Item = RawInline>,
{
Self {
doc,
entries: entries.into_iter().collect(),
}
}
}
fn aggregate_above<I: BM25Queryable + ?Sized>(
index: &I,
terms: &[RawInline],
score_floor: f32,
) -> Vec<RawInline> {
let mut acc: std::collections::HashMap<RawInline, f32> =
std::collections::HashMap::new();
for term in terms {
for (doc, score) in index.query_term_boxed(term) {
*acc.entry(doc).or_insert(0.0) += score;
}
}
acc.into_iter()
.filter_map(|(doc, sum)| (sum >= score_floor).then_some(doc))
.collect()
}
impl<D: triblespace_core::inline::InlineEncoding, T: triblespace_core::inline::InlineEncoding>
BM25Index<D, T>
{
pub fn matches(
&self,
doc: Variable<D>,
terms: &[Inline<T>],
score_floor: f32,
) -> BM25Filter<D> {
let raw_terms: Vec<RawInline> = terms.iter().map(|t| t.raw).collect();
BM25Filter::from_entries(doc, aggregate_above(self, &raw_terms, score_floor))
}
pub fn score(&self, doc: &Inline<D>, terms: &[Inline<T>]) -> f32 {
let mut sum = 0.0;
for term in terms {
for (d, s) in self.query_term(term) {
if d.raw == doc.raw {
sum += s;
break;
}
}
}
sum
}
}
impl<D: triblespace_core::inline::InlineEncoding>
BM25Index<D, crate::tokens::WordHash>
{
pub fn matches_text(
&self,
doc: Variable<D>,
text: &str,
score_floor: f32,
) -> BM25Filter<D> {
self.matches(doc, &crate::tokens::hash_tokens(text), score_floor)
}
pub fn score_text(&self, doc: &Inline<D>, text: &str) -> f32 {
self.score(doc, &crate::tokens::hash_tokens(text))
}
}
#[cfg(feature = "succinct")]
impl<D: triblespace_core::inline::InlineEncoding, T: triblespace_core::inline::InlineEncoding>
crate::succinct::SuccinctBM25Index<D, T>
{
pub fn matches(
&self,
doc: Variable<D>,
terms: &[Inline<T>],
score_floor: f32,
) -> BM25Filter<D> {
let raw_terms: Vec<RawInline> = terms.iter().map(|t| t.raw).collect();
BM25Filter::from_entries(doc, aggregate_above(self, &raw_terms, score_floor))
}
pub fn score(&self, doc: &Inline<D>, terms: &[Inline<T>]) -> f32 {
let mut sum = 0.0;
for term in terms {
for (d, s) in self.query_term(term) {
if d.raw == doc.raw {
sum += s;
break;
}
}
}
sum
}
}
#[cfg(feature = "succinct")]
impl<D: triblespace_core::inline::InlineEncoding>
crate::succinct::SuccinctBM25Index<D, crate::tokens::WordHash>
{
pub fn matches_text(
&self,
doc: Variable<D>,
text: &str,
score_floor: f32,
) -> BM25Filter<D> {
self.matches(doc, &crate::tokens::hash_tokens(text), score_floor)
}
pub fn score_text(&self, doc: &Inline<D>, text: &str) -> f32 {
self.score(doc, &crate::tokens::hash_tokens(text))
}
}
impl<'a, S> Constraint<'a> for BM25Filter<S>
where
S: triblespace_core::inline::InlineEncoding + 'a,
{
fn variables(&self) -> VariableSet {
VariableSet::new_singleton(self.doc.index)
}
fn estimate(&self, variable: VariableId, _binding: &Binding) -> Option<usize> {
if variable == self.doc.index {
Some(self.entries.len())
} else {
None
}
}
fn propose(&self, variable: VariableId, _binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.doc.index {
return;
}
proposals.extend_from_slice(&self.entries);
}
fn confirm(&self, variable: VariableId, _binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.doc.index {
return;
}
let valid: HashSet<RawInline> = self.entries.iter().copied().collect();
proposals.retain(|raw| valid.contains(raw));
}
fn satisfied(&self, binding: &Binding) -> bool {
match binding.get(self.doc.index).copied() {
Some(bound) => self.entries.iter().any(|d| *d == bound),
None => true,
}
}
}
pub trait SimilaritySearch {
fn neighbours_above(
&self,
from: Inline<Handle<Embedding>>,
score_floor: f32,
) -> Vec<Inline<Handle<Embedding>>>;
fn cosine_between(
&self,
a: Inline<Handle<Embedding>>,
b: Inline<Handle<Embedding>>,
) -> Option<f32>;
}
pub struct Similar<'a, I: SimilaritySearch + ?Sized> {
index: &'a I,
a: Variable<Handle<Embedding>>,
b: Variable<Handle<Embedding>>,
score_floor: f32,
}
impl<'a, I: SimilaritySearch + ?Sized> Similar<'a, I> {
pub fn new(
index: &'a I,
a: Variable<Handle<Embedding>>,
b: Variable<Handle<Embedding>>,
score_floor: f32,
) -> Self {
Self {
index,
a,
b,
score_floor,
}
}
}
impl<'a, I: SimilaritySearch + ?Sized + 'a> Constraint<'a> for Similar<'a, I> {
fn variables(&self) -> VariableSet {
VariableSet::new_singleton(self.a.index).union(VariableSet::new_singleton(self.b.index))
}
fn estimate(&self, variable: VariableId, binding: &Binding) -> Option<usize> {
if variable != self.a.index && variable != self.b.index {
return None;
}
let other = if variable == self.a.index {
self.b.index
} else {
self.a.index
};
match binding.get(other).copied() {
Some(from) => Some(
self.index
.neighbours_above(Inline::new(from), self.score_floor)
.len(),
),
None => Some(usize::MAX),
}
}
fn propose(&self, variable: VariableId, binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.a.index && variable != self.b.index {
return;
}
let other = if variable == self.a.index {
self.b.index
} else {
self.a.index
};
let Some(from) = binding.get(other).copied() else {
return;
};
for h in self
.index
.neighbours_above(Inline::new(from), self.score_floor)
{
proposals.push(h.raw);
}
}
fn confirm(&self, variable: VariableId, binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.a.index && variable != self.b.index {
return;
}
let other = if variable == self.a.index {
self.b.index
} else {
self.a.index
};
let Some(from) = binding.get(other).copied() else {
return;
};
let allowed: HashSet<RawInline> = self
.index
.neighbours_above(Inline::new(from), self.score_floor)
.into_iter()
.map(|h| h.raw)
.collect();
proposals.retain(|raw| allowed.contains(raw));
}
fn satisfied(&self, binding: &Binding) -> bool {
match (binding.get(self.a.index), binding.get(self.b.index)) {
(Some(a), Some(b)) => {
match self.index.cosine_between(Inline::new(*a), Inline::new(*b)) {
Some(sim) => sim >= self.score_floor,
None => false,
}
}
_ => true,
}
}
}
pub struct SimilarTo {
var: Variable<Handle<Embedding>>,
candidates: Vec<RawInline>,
}
impl SimilarTo {
pub fn from_candidates(
var: Variable<Handle<Embedding>>,
candidates: Vec<RawInline>,
) -> Self {
Self { var, candidates }
}
}
impl<'a> Constraint<'a> for SimilarTo {
fn variables(&self) -> VariableSet {
VariableSet::new_singleton(self.var.index)
}
fn estimate(&self, variable: VariableId, _binding: &Binding) -> Option<usize> {
if variable == self.var.index {
Some(self.candidates.len())
} else {
None
}
}
fn propose(&self, variable: VariableId, _binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.var.index {
return;
}
for raw in &self.candidates {
proposals.push(*raw);
}
}
fn confirm(&self, variable: VariableId, _binding: &Binding, proposals: &mut Vec<RawInline>) {
if variable != self.var.index {
return;
}
let allowed: HashSet<RawInline> = self.candidates.iter().copied().collect();
proposals.retain(|raw| allowed.contains(raw));
}
fn satisfied(&self, binding: &Binding) -> bool {
match binding.get(self.var.index) {
Some(raw) => self.candidates.iter().any(|c| c == raw),
None => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bm25::BM25Builder;
use crate::tokens::hash_tokens;
use triblespace_core::blob::MemoryBlobStore;
use triblespace_core::id::Id;
use triblespace_core::repo::BlobStore;
use triblespace_core::inline::{IntoInline, InlineEncoding};
fn id(byte: u8) -> Id {
Id::new([byte; 16]).unwrap()
}
fn raw_value_to_id(raw: &RawInline) -> Option<Id> {
Inline::<GenId>::new(*raw).try_from_inline::<Id>().ok()
}
fn id_to_raw_value(id: Id) -> RawInline {
GenId::inline_from(id).raw
}
fn sample_index() -> BM25Index {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("the quick brown fox"));
b.insert(id(2), hash_tokens("the lazy brown dog"));
b.insert(id(3), hash_tokens("quick silver fox jumps"));
b.build_naive()
}
#[test]
fn matches_filter_variables_is_singleton_of_doc() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("fox");
let c = idx.matches(doc, &terms, 0.0);
let vars = c.variables();
assert!(vars.is_set(doc.index));
let mut found = 0;
for i in 0..32 {
if vars.is_set(i) {
found += 1;
}
}
assert_eq!(found, 1);
}
#[test]
fn matches_filter_estimate_is_match_count() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("fox");
let c = idx.matches(doc, &terms, 0.0);
let binding = Binding::default();
assert_eq!(c.estimate(doc.index, &binding), Some(2));
let unk_binding = Binding::default();
assert_eq!(c.estimate(255, &unk_binding), None);
}
#[test]
fn matches_filter_proposes_matching_docs() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("fox");
let c = idx.matches(doc, &terms, 0.0);
let binding = Binding::default();
let mut props: Vec<RawInline> = Vec::new();
c.propose(doc.index, &binding, &mut props);
assert_eq!(props.len(), 2);
let ids: HashSet<Id> = props
.iter()
.map(|r| raw_value_to_id(r).expect("valid GenId value"))
.collect();
assert!(ids.contains(&id(1)));
assert!(ids.contains(&id(3)));
}
#[test]
fn matches_filter_confirm_filters_non_matching_docs() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("fox");
let c = idx.matches(doc, &terms, 0.0);
let binding = Binding::default();
let mut props: Vec<RawInline> = vec![
id_to_raw_value(id(1)),
id_to_raw_value(id(2)),
id_to_raw_value(id(3)),
];
c.confirm(doc.index, &binding, &mut props);
let ids: HashSet<Id> = props.iter().map(|r| raw_value_to_id(r).unwrap()).collect();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&id(1)));
assert!(!ids.contains(&id(2)));
assert!(ids.contains(&id(3)));
}
#[test]
fn matches_filter_satisfied_checks_bound_doc() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("fox");
let c = idx.matches(doc, &terms, 0.0);
let empty = Binding::default();
assert!(c.satisfied(&empty));
let mut bound = Binding::default();
bound.set(doc.index, &id_to_raw_value(id(1)));
assert!(c.satisfied(&bound));
let mut unmatching = Binding::default();
unmatching.set(doc.index, &id_to_raw_value(id(2)));
assert!(!c.satisfied(&unmatching));
}
#[test]
fn matches_multi_term_aggregates_across_terms() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("quick fox");
let c = idx.matches(doc, &terms, 0.0);
let mut props = Vec::new();
c.propose(doc.index, &Binding::default(), &mut props);
let ids: HashSet<Id> = props
.iter()
.map(|r| raw_value_to_id(r).expect("genid"))
.collect();
assert!(ids.contains(&id(1)));
assert!(ids.contains(&id(3)));
assert!(!ids.contains(&id(2)));
}
#[test]
fn matches_text_matches_explicit_tokens() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc_a: Variable<GenId> = ctx.next_variable();
let doc_b: Variable<GenId> = ctx.next_variable();
let explicit = idx.matches(doc_a, &hash_tokens("quick fox"), 0.0);
let sugar = idx.matches_text(doc_b, "quick fox", 0.0);
let mut props_a = Vec::new();
let mut props_b = Vec::new();
explicit.propose(doc_a.index, &Binding::default(), &mut props_a);
sugar.propose(doc_b.index, &Binding::default(), &mut props_b);
let set_a: HashSet<Id> = props_a
.iter()
.map(|r| raw_value_to_id(r).expect("genid"))
.collect();
let set_b: HashSet<Id> = props_b
.iter()
.map(|r| raw_value_to_id(r).expect("genid"))
.collect();
assert_eq!(
set_a, set_b,
"matches_text yields the same doc set as matches(hash_tokens(...))",
);
}
#[test]
fn score_text_matches_explicit_tokens() {
let idx = sample_index();
let s_explicit = idx.score(&id(1).to_inline(), &hash_tokens("quick fox"));
let s_sugar = idx.score_text(&id(1).to_inline(), "quick fox");
assert_eq!(s_explicit, s_sugar);
}
#[test]
fn matches_score_floor_drops_low_scoring_docs() {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("fox quick brown jumps"));
b.insert(id(2), hash_tokens("only fox here, nothing else"));
b.insert(id(3), hash_tokens("unrelated"));
let idx = b.build_naive();
let terms = hash_tokens("fox quick brown jumps");
let s1 = idx.score(&id(1).to_inline(), &terms);
let s2 = idx.score(&id(2).to_inline(), &terms);
assert!(s1 > s2, "fixture: full-match should beat partial");
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let c_low = idx.matches(doc, &terms, 0.0);
let c_mid = idx.matches(doc, &terms, (s1 + s2) / 2.0);
let mut low_props = Vec::new();
c_low.propose(doc.index, &Binding::default(), &mut low_props);
let low_ids: HashSet<Id> =
low_props.iter().map(|r| raw_value_to_id(r).unwrap()).collect();
assert!(low_ids.contains(&id(1)));
assert!(low_ids.contains(&id(2)));
let mut mid_props = Vec::new();
c_mid.propose(doc.index, &Binding::default(), &mut mid_props);
let mid_ids: HashSet<Id> =
mid_props.iter().map(|r| raw_value_to_id(r).unwrap()).collect();
assert!(mid_ids.contains(&id(1)));
assert!(!mid_ids.contains(&id(2)));
}
#[test]
fn score_helper_matches_aggregated_sum() {
let idx = sample_index();
let terms = hash_tokens("quick fox");
for byte in [1u8, 3] {
let doc_value: Inline<GenId> = id(byte).to_inline();
let helper_score = idx.score(&doc_value, &terms);
let target = GenId::inline_from(id(byte)).raw;
let mut expected = 0.0_f32;
for t in &terms {
for (d, s) in idx.query_term(t) {
if d.raw == target {
expected += s;
break;
}
}
}
assert!(
(helper_score - expected).abs() < 1e-6,
"score helper drifted from posting-list sum for doc {byte}"
);
}
let doc2_value: Inline<GenId> = id(2).to_inline();
assert_eq!(idx.score(&doc2_value, &terms), 0.0);
}
#[test]
fn matches_empty_query_yields_no_rows() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms: Vec<triblespace_core::inline::Inline<crate::tokens::WordHash>> = Vec::new();
let c = idx.matches(doc, &terms, 0.0);
assert_eq!(c.estimate(doc.index, &Binding::default()), Some(0));
let mut props = Vec::new();
c.propose(doc.index, &Binding::default(), &mut props);
assert!(props.is_empty());
}
#[test]
fn matches_no_matching_docs_yields_no_rows() {
let idx = sample_index();
let mut ctx = triblespace_core::query::VariableContext::new();
let doc: Variable<GenId> = ctx.next_variable();
let terms = hash_tokens("aardvark zeppelin");
let c = idx.matches(doc, &terms, 0.0);
assert_eq!(c.estimate(doc.index, &Binding::default()), Some(0));
let mut props = Vec::new();
c.propose(doc.index, &Binding::default(), &mut props);
assert!(props.is_empty());
}
fn sample_sim() -> (
crate::hnsw::FlatIndex,
crate::hnsw::HNSWIndex,
MemoryBlobStore,
[Inline<Handle<Embedding>>; 3],
) {
use crate::hnsw::{FlatBuilder, HNSWBuilder};
let mut store = MemoryBlobStore::new();
let vecs = [
vec![1.0f32, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.9, 0.1, 0.0],
];
let mut handles: [Inline<Handle<Embedding>>; 3] =
[Inline::new([0u8; 32]); 3];
for (i, v) in vecs.iter().enumerate() {
handles[i] =
crate::schemas::put_embedding::<_>(&mut store, v.clone()).unwrap();
}
let mut flat = FlatBuilder::new(3);
for h in handles.iter() {
flat.insert(*h);
}
let mut hnsw = HNSWBuilder::new(3).with_seed(42);
for (i, v) in vecs.iter().enumerate() {
hnsw.insert(handles[i], v.clone()).unwrap();
}
(flat.build(), hnsw.build_naive(), store, handles)
}
#[test]
fn flat_similar_proposes_candidates_above_floor() {
let (flat, _hnsw, mut store, handles) = sample_sim();
let reader = store.reader().unwrap();
let view = flat.attach(&reader);
let mut ctx = triblespace_core::query::VariableContext::new();
let a: Variable<Handle<Embedding>> = ctx.next_variable();
let b: Variable<Handle<Embedding>> = ctx.next_variable();
let c = view.similar(a, b, 0.8);
let mut binding = Binding::default();
binding.set(a.index, &handles[0].raw);
let mut props = Vec::new();
c.propose(b.index, &binding, &mut props);
let got: HashSet<RawInline> = props.iter().copied().collect();
assert!(got.contains(&handles[0].raw));
assert!(got.contains(&handles[2].raw));
assert!(!got.contains(&handles[1].raw));
}
#[test]
fn flat_similar_symmetric_bind_on_b() {
let (flat, _hnsw, mut store, handles) = sample_sim();
let reader = store.reader().unwrap();
let view = flat.attach(&reader);
let mut ctx = triblespace_core::query::VariableContext::new();
let a: Variable<Handle<Embedding>> = ctx.next_variable();
let b: Variable<Handle<Embedding>> = ctx.next_variable();
let c = view.similar(a, b, 0.8);
let mut binding = Binding::default();
binding.set(b.index, &handles[2].raw);
let mut props = Vec::new();
c.propose(a.index, &binding, &mut props);
let got: HashSet<RawInline> = props.iter().copied().collect();
assert!(got.contains(&handles[0].raw));
assert!(got.contains(&handles[2].raw));
}
#[test]
fn flat_similar_satisfied_both_bound() {
let (flat, _hnsw, mut store, handles) = sample_sim();
let reader = store.reader().unwrap();
let view = flat.attach(&reader);
let mut ctx = triblespace_core::query::VariableContext::new();
let a: Variable<Handle<Embedding>> = ctx.next_variable();
let b: Variable<Handle<Embedding>> = ctx.next_variable();
let c = view.similar(a, b, 0.8);
let mut good = Binding::default();
good.set(a.index, &handles[0].raw);
good.set(b.index, &handles[2].raw);
assert!(c.satisfied(&good));
let mut bad = Binding::default();
bad.set(a.index, &handles[0].raw);
bad.set(b.index, &handles[1].raw);
assert!(!c.satisfied(&bad));
}
#[test]
fn hnsw_similar_proposes_candidates_above_floor() {
let (_flat, hnsw, mut store, handles) = sample_sim();
let reader = store.reader().unwrap();
let view = hnsw.attach(&reader);
let mut ctx = triblespace_core::query::VariableContext::new();
let a: Variable<Handle<Embedding>> = ctx.next_variable();
let b: Variable<Handle<Embedding>> = ctx.next_variable();
let c = view.similar(a, b, 0.8);
let mut binding = Binding::default();
binding.set(a.index, &handles[0].raw);
let mut props = Vec::new();
c.propose(b.index, &binding, &mut props);
let got: HashSet<RawInline> = props.iter().copied().collect();
assert!(got.contains(&handles[0].raw));
assert!(got.contains(&handles[2].raw));
assert!(!got.contains(&handles[1].raw));
}
#[test]
fn similar_estimate_saturates_when_other_unbound() {
let (flat, _hnsw, mut store, _handles) = sample_sim();
let reader = store.reader().unwrap();
let view = flat.attach(&reader);
let mut ctx = triblespace_core::query::VariableContext::new();
let a: Variable<Handle<Embedding>> = ctx.next_variable();
let b: Variable<Handle<Embedding>> = ctx.next_variable();
let unrelated: Variable<GenId> = ctx.next_variable();
let c = view.similar(a, b, 0.8);
assert_eq!(c.estimate(a.index, &Binding::default()), Some(usize::MAX));
assert_eq!(c.estimate(b.index, &Binding::default()), Some(usize::MAX));
assert_eq!(c.estimate(unrelated.index, &Binding::default()), None);
}
}