use crate::StorageError;
use crate::blob_store::types::BlobId;
use crate::compat::HashMap;
use crate::ivfpq::{ReadOnlyIvfPqIndex, SearchParams};
use crate::storage_traits::StorageRead;
use crate::transactions::ReadTransaction;
use alloc::collections::BinaryHeap;
use alloc::string::ToString;
use alloc::vec::Vec;
use super::provider::BlobQueryProvider;
use super::scoring::{causal_bfs, normalize_causal, normalize_semantic, normalize_temporal};
use super::types::{
CandidateEntry, HeapEntry, ScoredBlob, SignalScores, SignalWeights, heap_to_sorted,
};
pub struct CompositeQuery<
'a,
P: BlobQueryProvider = ReadTransaction,
R: StorageRead = ReadTransaction,
> {
provider: &'a P,
storage_reader: Option<&'a R>,
vector_index: Option<&'a ReadOnlyIvfPqIndex>,
query_vector: Option<&'a [f32]>,
search_params: Option<SearchParams>,
semantic_candidates: Option<usize>,
time_range: Option<(u64, u64)>,
causal_root: Option<BlobId>,
causal_max_hops: usize,
namespace: Option<&'a str>,
tags: Vec<&'a str>,
weights: SignalWeights,
top_k: usize,
}
impl<'a> CompositeQuery<'a, ReadTransaction, ReadTransaction> {
pub fn new(txn: &'a ReadTransaction) -> Self {
Self {
provider: txn,
storage_reader: Some(txn),
vector_index: None,
query_vector: None,
search_params: None,
semantic_candidates: None,
time_range: None,
causal_root: None,
causal_max_hops: 32,
namespace: None,
tags: Vec::new(),
weights: SignalWeights::default(),
top_k: 10,
}
}
}
impl<'a, P: BlobQueryProvider, R: StorageRead> CompositeQuery<'a, P, R> {
pub fn with_provider(provider: &'a P) -> Self {
Self {
provider,
storage_reader: None,
vector_index: None,
query_vector: None,
search_params: None,
semantic_candidates: None,
time_range: None,
causal_root: None,
causal_max_hops: 32,
namespace: None,
tags: Vec::new(),
weights: SignalWeights::default(),
top_k: 10,
}
}
#[must_use]
pub fn with_storage_reader(mut self, reader: &'a R) -> Self {
self.storage_reader = Some(reader);
self
}
#[must_use]
pub fn semantic(
mut self,
index: &'a ReadOnlyIvfPqIndex,
query: &'a [f32],
weight: f32,
) -> Self {
self.vector_index = Some(index);
self.query_vector = Some(query);
self.weights.semantic = weight.max(0.0);
self
}
#[must_use]
pub fn search_params(mut self, params: SearchParams) -> Self {
self.search_params = Some(params);
self
}
#[must_use]
pub fn semantic_candidates(mut self, n: usize) -> Self {
self.semantic_candidates = Some(n);
self
}
#[must_use]
pub fn temporal(mut self, weight: f32) -> Self {
self.weights.temporal = weight.max(0.0);
self
}
#[must_use]
pub fn time_range(mut self, start_ns: u64, end_ns: u64) -> Self {
self.time_range = Some((start_ns, end_ns));
self
}
#[must_use]
pub fn causal(mut self, root: BlobId, weight: f32) -> Self {
self.causal_root = Some(root);
self.weights.causal = weight.max(0.0);
self
}
#[must_use]
pub fn causal_max_hops(mut self, max_hops: usize) -> Self {
self.causal_max_hops = max_hops;
self
}
#[must_use]
pub fn namespace(mut self, ns: &'a str) -> Self {
self.namespace = Some(ns);
self
}
#[must_use]
pub fn tag(mut self, tag: &'a str) -> Self {
self.tags.push(tag);
self
}
#[must_use]
pub fn top_k(mut self, k: usize) -> Self {
self.top_k = k;
self
}
pub fn execute(self) -> crate::Result<Vec<ScoredBlob>> {
self.validate()?;
let (w_sem, w_tmp, w_cau) = self.weights.normalized_f64();
let sem_active = w_sem > 0.0;
let tmp_active = w_tmp > 0.0;
let cau_active = w_cau > 0.0;
let mut candidates: HashMap<BlobId, CandidateEntry> = HashMap::new();
if sem_active {
let query = self.query_vector.ok_or_else(|| {
StorageError::Internal(
"semantic search enabled but no query vector provided".to_string(),
)
})?;
let k = self.semantic_candidates.unwrap_or(self.top_k.max(1) * 5);
let reader = self.storage_reader.ok_or_else(|| {
StorageError::Internal(
"semantic search requires a StorageRead (use with_storage_reader)".to_string(),
)
})?;
let index = self.vector_index.ok_or_else(|| {
StorageError::Internal(
"semantic search enabled but no vector index provided".to_string(),
)
})?;
let params = self.search_params.unwrap_or(SearchParams {
nprobe: 16,
candidates: k * 2,
k,
rerank: true,
diversity: crate::probe_select::DiversityConfig { lambda: 0.0 },
filter: None,
});
let results = index.search(reader, query, ¶ms)?;
for neighbor in &results {
if let Some((id, meta)) = self
.provider
.blob_by_sequence(neighbor.key)
.map_err(Into::into)?
{
let entry = candidates.entry(id).or_insert_with(|| CandidateEntry {
blob_id: id,
meta: None,
raw_distance: None,
wall_clock_ns: None,
causal_hops: None,
});
entry.raw_distance = Some(neighbor.distance);
entry.meta = Some(meta.clone());
entry.wall_clock_ns = Some(meta.wall_clock_ns);
}
}
}
if tmp_active && let Some((start, end)) = self.time_range {
let temporal_results = self
.provider
.blobs_in_time_range(start, end)
.map_err(Into::into)?;
for (tkey, meta) in &temporal_results {
let id = tkey.blob_id;
let entry = candidates.entry(id).or_insert_with(|| CandidateEntry {
blob_id: id,
meta: None,
raw_distance: None,
wall_clock_ns: None,
causal_hops: None,
});
entry.wall_clock_ns = Some(meta.wall_clock_ns);
if entry.meta.is_none() {
entry.meta = Some(meta.clone());
}
}
}
let causal_distances = if cau_active {
if let Some(ref root) = self.causal_root {
let distances = causal_bfs(self.provider, root, self.causal_max_hops)?;
for (id, hops) in &distances {
let entry = candidates.entry(*id).or_insert_with(|| CandidateEntry {
blob_id: *id,
meta: None,
raw_distance: None,
wall_clock_ns: None,
causal_hops: None,
});
entry.causal_hops = Some(*hops);
}
distances
} else {
HashMap::new()
}
} else {
HashMap::new()
};
for entry in candidates.values_mut() {
if entry.meta.is_none() {
if let Some(meta) = self
.provider
.get_blob_meta(&entry.blob_id)
.map_err(Into::into)?
{
entry.wall_clock_ns = Some(meta.wall_clock_ns);
entry.meta = Some(meta);
}
} else if entry.wall_clock_ns.is_none() {
entry.wall_clock_ns = entry.meta.as_ref().map(|m| m.wall_clock_ns);
}
}
candidates.retain(|_, e| e.meta.is_some());
if let Some(ns) = self.namespace {
let ns_blobs: crate::compat::HashSet<BlobId> = self
.provider
.blobs_in_namespace(ns)
.map_err(Into::into)?
.into_iter()
.map(|(id, _)| id)
.collect();
candidates.retain(|id, _| ns_blobs.contains(id));
}
for tag in &self.tags {
let tag_blobs: crate::compat::HashSet<BlobId> = self
.provider
.blobs_by_tag(tag)
.map_err(Into::into)?
.into_iter()
.collect();
candidates.retain(|id, _| tag_blobs.contains(id));
}
if candidates.is_empty() {
return Ok(Vec::new());
}
let sem_scores = if sem_active {
let pairs: Vec<(BlobId, f32)> = candidates
.values()
.filter_map(|e| e.raw_distance.map(|d| (e.blob_id, d)))
.collect();
normalize_semantic(&pairs)
} else {
HashMap::new()
};
let tmp_scores = if tmp_active {
let pairs: Vec<(BlobId, u64)> = candidates
.values()
.filter_map(|e| e.wall_clock_ns.map(|t| (e.blob_id, t)))
.collect();
normalize_temporal(&pairs)
} else {
HashMap::new()
};
let cau_scores = if cau_active {
normalize_causal(&causal_distances)
} else {
HashMap::new()
};
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(self.top_k + 1);
for entry in candidates.values() {
let Some(meta) = entry.meta.clone() else {
continue;
};
let sem = if sem_active {
Some(*sem_scores.get(&entry.blob_id).unwrap_or(&0.0))
} else {
None
};
let tmp = if tmp_active {
Some(*tmp_scores.get(&entry.blob_id).unwrap_or(&0.0))
} else {
None
};
let cau = if cau_active {
Some(*cau_scores.get(&entry.blob_id).unwrap_or(&0.0))
} else {
None
};
let score = w_sem * sem.unwrap_or(0.0)
+ w_tmp * tmp.unwrap_or(0.0)
+ w_cau * cau.unwrap_or(0.0);
let he = HeapEntry {
score,
blob_id: entry.blob_id,
meta,
signals: SignalScores {
semantic: sem,
temporal: tmp,
causal: cau,
},
};
heap.push(he);
if heap.len() > self.top_k {
heap.pop(); }
}
Ok(heap_to_sorted(heap))
}
fn validate(&self) -> crate::Result<()> {
if !self.weights.any_active() {
return Err(StorageError::invalid_config(
"CompositeQuery: at least one signal must have weight > 0",
));
}
if self.weights.semantic > 0.0 && self.query_vector.is_none() {
return Err(StorageError::invalid_config(
"CompositeQuery: semantic signal requires a query vector",
));
}
if self.weights.semantic > 0.0 && self.vector_index.is_none() {
return Err(StorageError::invalid_config(
"CompositeQuery: semantic signal requires an IVF-PQ index",
));
}
if self.weights.semantic > 0.0 && self.storage_reader.is_none() {
return Err(StorageError::invalid_config(
"CompositeQuery: semantic signal requires a StorageRead (use with_storage_reader)",
));
}
if self.weights.causal > 0.0 && self.causal_root.is_none() {
return Err(StorageError::invalid_config(
"CompositeQuery: causal signal requires a root blob_id",
));
}
if self.weights.temporal > 0.0
&& self.weights.semantic <= 0.0
&& self.weights.causal <= 0.0
&& self.time_range.is_none()
{
return Err(StorageError::invalid_config(
"CompositeQuery: temporal-only signal requires a time_range",
));
}
if self.top_k == 0 {
return Err(StorageError::invalid_config(
"CompositeQuery: top_k must be >= 1",
));
}
Ok(())
}
}