use std::collections::{BTreeSet, HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use tracing::{debug, info, info_span, warn};
use super::ai::ner::{
AuthContext as NerAuthContext, HeuristicFallback, LlmNer, NerError, NerProvider, NER_CAPABILITY,
};
use super::statement_frame::{EffectiveScope, ReadFrame};
use super::RedDBRuntime;
use crate::api::{RedDBError, RedDBResult};
use crate::application::SearchContextInput;
use crate::storage::schema::Value;
use crate::storage::unified::entity::{EntityData, EntityKind, UnifiedEntity};
pub const DEFAULT_ROW_CAP: usize = 20;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TokenSet {
pub keywords: Vec<String>,
pub literals: Vec<String>,
}
impl TokenSet {
pub fn is_empty(&self) -> bool {
self.keywords.is_empty() && self.literals.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct CandidateCollections {
pub collections: Vec<String>,
pub columns_by_collection: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct TextHit {
pub collection: String,
pub entity_id: u64,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct VectorHit {
pub collection: String,
pub entity_id: u64,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct GraphHit {
pub collection: String,
pub entity_id: u64,
pub score: f32,
pub depth: usize,
pub kind: GraphHitKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GraphHitKind {
Node,
Edge,
}
#[derive(Debug, Clone)]
pub struct FilteredRow {
pub collection: String,
pub entity: UnifiedEntity,
pub matched_literal: String,
pub matched_column: Option<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StageTimings {
pub extract_us: u64,
pub schema_us: u64,
pub text_us: u64,
pub vector_us: u64,
pub graph_us: u64,
pub filter_us: u64,
}
#[derive(Debug, Clone)]
pub struct AskContext {
pub question: String,
pub tokens: TokenSet,
pub candidates: CandidateCollections,
pub text_hits: Vec<TextHit>,
pub vector_hits: Vec<VectorHit>,
pub graph_hits: Vec<GraphHit>,
pub filtered_rows: Vec<FilteredRow>,
pub source_limit: usize,
pub timings: StageTimings,
}
impl Default for AskContext {
fn default() -> Self {
Self {
question: String::new(),
tokens: TokenSet::default(),
candidates: CandidateCollections::default(),
text_hits: Vec::new(),
vector_hits: Vec::new(),
graph_hits: Vec::new(),
filtered_rows: Vec::new(),
source_limit: DEFAULT_ROW_CAP,
timings: StageTimings::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusedSourceRef {
FilteredRow(usize),
TextHit(usize),
VectorHit(usize),
GraphHit(usize),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FusedSource {
pub source: FusedSourceRef,
pub rrf_score: f64,
}
pub enum AskPipeline {}
impl AskPipeline {
pub fn execute(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
) -> RedDBResult<AskContext> {
Self::execute_with_limit(runtime, scope, question, DEFAULT_ROW_CAP)
}
pub fn execute_with_limit(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
row_cap: usize,
) -> RedDBResult<AskContext> {
Self::execute_with_limit_and_min_score(runtime, scope, question, row_cap, None, None)
}
pub fn execute_with_limit_and_min_score(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
row_cap: usize,
min_score: Option<f32>,
graph_depth: Option<usize>,
) -> RedDBResult<AskContext> {
let span = info_span!(
"ask_pipeline.execute",
tenant = ?scope.effective_scope(),
question_len = question.len(),
row_cap = row_cap,
min_score = ?min_score,
graph_depth = ?graph_depth,
);
let _enter = span.enter();
let stage1 = Instant::now();
let tokens = extract_tokens_routed(runtime, scope, question)?;
let extract_us = stage1.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "extract_tokens",
keywords = ?tokens.keywords,
literals = ?tokens.literals,
elapsed_us = extract_us,
"stage 1 done"
);
if tokens.is_empty() {
warn!(
target: "ask_pipeline",
question_len = question.len(),
"refused: empty token set"
);
return Err(RedDBError::Query(
"ASK question yielded no usable tokens (heuristic NER produced empty keyword + literal set)"
.to_string(),
));
}
let stage2 = Instant::now();
let candidates = match_schema(runtime, scope, &tokens)?;
let schema_us = stage2.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "match_schema",
collections = ?candidates.collections,
elapsed_us = schema_us,
"stage 2 done"
);
let stage3 = Instant::now();
let text_hits = text_search_bm25_scoped(runtime, scope, question, &candidates, row_cap);
let text_us = stage3.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "text_search_bm25_scoped",
hits = text_hits.len(),
elapsed_us = text_us,
"stage 3 done"
);
let stage3b = Instant::now();
let vector_hits =
vector_search_scoped(runtime, scope, question, &candidates, row_cap, min_score);
let vector_us = stage3b.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "vector_search_scoped",
hits = vector_hits.len(),
elapsed_us = vector_us,
"stage 3b done"
);
let stage3c = Instant::now();
let graph_hits = graph_search_scoped(
runtime,
scope,
question,
&candidates,
row_cap,
min_score,
graph_depth,
);
let graph_us = stage3c.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "graph_search_scoped",
hits = graph_hits.len(),
elapsed_us = graph_us,
"stage 3c done"
);
let stage4 = Instant::now();
let filtered_rows = filter_values(runtime, scope, &candidates, &tokens, row_cap);
let filter_us = stage4.elapsed().as_micros() as u64;
debug!(
target: "ask_pipeline",
stage = "filter_values",
rows = filtered_rows.len(),
elapsed_us = filter_us,
"stage 4 done"
);
Ok(AskContext {
question: question.to_string(),
tokens,
candidates,
text_hits,
vector_hits,
graph_hits,
filtered_rows,
source_limit: row_cap,
timings: StageTimings {
extract_us,
schema_us,
text_us,
vector_us,
graph_us,
filter_us,
},
})
}
}
pub fn fused_source_order(ctx: &AskContext) -> Vec<FusedSourceRef> {
fused_sources(ctx)
.into_iter()
.map(|fused| fused.source)
.collect()
}
pub fn fused_sources(ctx: &AskContext) -> Vec<FusedSource> {
use super::ai::rrf_fuser::{fuse, Bucket, Candidate, RRF_K_DEFAULT};
if ctx.source_limit == 0
|| (ctx.filtered_rows.is_empty()
&& ctx.text_hits.is_empty()
&& ctx.vector_hits.is_empty()
&& ctx.graph_hits.is_empty())
{
return Vec::new();
}
let mut refs: HashMap<String, FusedSourceRef> = HashMap::new();
let row_bucket = Bucket {
candidates: ctx
.filtered_rows
.iter()
.enumerate()
.map(|(idx, row)| {
let id = source_identity(&row.collection, row.entity.id.raw());
refs.entry(id.clone())
.or_insert(FusedSourceRef::FilteredRow(idx));
Candidate { id, score: 1.0 }
})
.collect(),
min_score: None,
};
let text_bucket = Bucket {
candidates: ctx
.text_hits
.iter()
.enumerate()
.map(|(idx, hit)| {
let id = source_identity(&hit.collection, hit.entity_id);
refs.entry(id.clone())
.or_insert(FusedSourceRef::TextHit(idx));
Candidate {
id,
score: hit.score as f64,
}
})
.collect(),
min_score: None,
};
let vector_bucket = Bucket {
candidates: ctx
.vector_hits
.iter()
.enumerate()
.map(|(idx, hit)| {
let id = source_identity(&hit.collection, hit.entity_id);
refs.entry(id.clone())
.or_insert(FusedSourceRef::VectorHit(idx));
Candidate {
id,
score: hit.score as f64,
}
})
.collect(),
min_score: None,
};
let graph_bucket = Bucket {
candidates: ctx
.graph_hits
.iter()
.enumerate()
.map(|(idx, hit)| {
let id = source_identity(&hit.collection, hit.entity_id);
refs.entry(id.clone())
.or_insert(FusedSourceRef::GraphHit(idx));
Candidate {
id,
score: hit.score as f64,
}
})
.collect(),
min_score: None,
};
fuse(
&[row_bucket, text_bucket, vector_bucket, graph_bucket],
RRF_K_DEFAULT,
ctx.source_limit,
)
.into_iter()
.filter_map(|item| {
refs.get(&item.id).copied().map(|source| FusedSource {
source,
rrf_score: item.rrf_score,
})
})
.collect()
}
fn source_identity(collection: &str, entity_id: u64) -> String {
format!("{collection}/{entity_id}")
}
pub fn text_search_bm25_scoped(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
candidates: &CandidateCollections,
top_k: usize,
) -> Vec<TextHit> {
if candidates.collections.is_empty() || top_k == 0 {
return Vec::new();
}
let visible = scope.visible_collections();
let allowed: BTreeSet<String> = candidates
.collections
.iter()
.filter(|collection| visible.is_none_or(|set| set.contains(*collection)))
.cloned()
.collect();
if allowed.is_empty() {
return Vec::new();
}
let _scope_guard = AskScopeGuard::install(scope);
let snap_ctx = crate::runtime::impl_core::capture_current_snapshot();
let mut rls_cache: HashMap<String, Option<crate::storage::query::ast::Filter>> = HashMap::new();
let store = runtime.inner.db.store();
runtime
.inner
.db
.store()
.context_index()
.search_bm25(question, top_k, Some(&allowed))
.into_iter()
.filter_map(|hit| {
let entity = store.get(&hit.collection, hit.entity_id)?;
if !ask_entity_allowed(
runtime,
scope,
&hit.collection,
&entity,
snap_ctx.as_ref(),
&mut rls_cache,
) {
return None;
}
Some(TextHit {
collection: hit.collection,
entity_id: hit.entity_id.raw(),
score: hit.score,
})
})
.collect()
}
fn extract_tokens_routed(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
) -> RedDBResult<TokenSet> {
let backend = runtime.config_string("ai.ner.backend", "heuristic");
if backend != "llm" {
return Ok(extract_tokens(question));
}
let endpoint = runtime.config_string("ai.ner.endpoint", "");
let model = runtime.config_string("ai.ner.model", "");
let timeout_ms = runtime
.config_string("ai.ner.timeout_ms", "5000")
.parse::<u32>()
.unwrap_or(5000);
let fallback = match runtime
.config_string("ai.ner.fallback", "use_heuristic")
.as_str()
{
"empty_on_fail" => HeuristicFallback::EmptyOnFail,
"propagate" => HeuristicFallback::Propagate,
_ => HeuristicFallback::UseHeuristic,
};
let provider = if endpoint.is_empty() && model.is_empty() {
NerProvider::Stub(super::ai::ner::StubBehavior::Empty)
} else {
NerProvider::OpenAiCompat { endpoint, model }
};
let mut ner = LlmNer::new(provider, fallback);
ner.timeout_ms = timeout_ms;
let auth = ScopeAuthAdapter(scope);
let llm_result = match tokio::runtime::Handle::try_current() {
Ok(handle) => {
tokio::task::block_in_place(|| handle.block_on(ner.extract(question, scope, &auth)))
}
Err(_) => {
warn!(
target: "ask_pipeline",
"ai.ner.backend=llm configured but no Tokio runtime reachable from extract_tokens; using heuristic fallback"
);
return Ok(extract_tokens(question));
}
};
match llm_result {
Ok(tokens) => Ok(tokens),
Err(NerError::AuthDenied) => {
log_auth_denial_once();
apply_fallback(fallback, question)
}
Err(err) => {
warn!(
target: "ask_pipeline",
error = %err,
"LlmNer extract failed; honouring HeuristicFallback policy"
);
apply_fallback(fallback, question)
}
}
}
fn apply_fallback(fallback: HeuristicFallback, question: &str) -> RedDBResult<TokenSet> {
match fallback {
HeuristicFallback::UseHeuristic => Ok(extract_tokens(question)),
HeuristicFallback::EmptyOnFail => Ok(TokenSet::default()),
HeuristicFallback::Propagate => Err(RedDBError::Query(
"ai.ner.backend=llm: extract failed and ai.ner.fallback=propagate".to_string(),
)),
}
}
fn log_auth_denial_once() {
static EMITTED: AtomicBool = AtomicBool::new(false);
if !EMITTED.swap(true, Ordering::Relaxed) {
info!(
target: "ask_pipeline",
capability = NER_CAPABILITY,
"LlmNer routing configured but capability `{}` not yet wired into auth engine; falling back to heuristic",
NER_CAPABILITY
);
}
}
struct ScopeAuthAdapter<'a>(&'a EffectiveScope);
impl<'a> std::fmt::Debug for ScopeAuthAdapter<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopeAuthAdapter").finish_non_exhaustive()
}
}
impl<'a> NerAuthContext for ScopeAuthAdapter<'a> {
fn has_capability(&self, capability: &str) -> bool {
self.0.has_capability(capability)
}
}
pub fn extract_tokens(question: &str) -> TokenSet {
let mut keywords: Vec<String> = Vec::new();
let mut literals: Vec<String> = Vec::new();
let mut chars = question.chars().peekable();
let mut buf = String::new();
let flush = |buf: &mut String, keywords: &mut Vec<String>, literals: &mut Vec<String>| {
if buf.is_empty() {
return;
}
let word = std::mem::take(buf);
classify_token(&word, keywords, literals);
};
while let Some(ch) = chars.next() {
if ch.is_alphanumeric() || ch == '_' || ch == '-' {
buf.push(ch);
} else {
flush(&mut buf, &mut keywords, &mut literals);
let _ = ch;
}
if chars.peek().is_none() {
flush(&mut buf, &mut keywords, &mut literals);
}
}
if !buf.is_empty() {
classify_token(&buf, &mut keywords, &mut literals);
}
let mut seen = HashSet::new();
keywords.retain(|tok| seen.insert(tok.clone()));
let mut seen_lit = HashSet::new();
literals.retain(|tok| seen_lit.insert(tok.clone()));
TokenSet { keywords, literals }
}
fn classify_token(word: &str, keywords: &mut Vec<String>, literals: &mut Vec<String>) {
let is_upper_id_shape = word.len() >= 3
&& word
.chars()
.all(|c| c.is_ascii_digit() || c == '-' || c.is_ascii_uppercase())
&& word.chars().any(|c| c.is_ascii_digit())
&& word.chars().any(|c| c.is_ascii_uppercase() || c == '-');
let is_long_digit_run = word.len() >= 6 && word.chars().all(|c| c.is_ascii_digit());
if is_upper_id_shape || is_long_digit_run {
literals.push(word.to_string());
return;
}
let trimmed = word.trim_matches(|c: char| !c.is_ascii_alphanumeric() && c != '_');
if trimmed.len() < 2 {
return;
}
if !trimmed
.chars()
.next()
.map(|c| c.is_ascii_alphabetic())
.unwrap_or(false)
{
return;
}
if !trimmed
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_')
{
return;
}
let lower = trimmed.to_ascii_lowercase();
if STOP_WORDS.binary_search(&lower.as_str()).is_ok() {
return;
}
keywords.push(lower);
}
const STOP_WORDS: &[&str] = &[
"a", "about", "an", "and", "are", "as", "at", "be", "by", "do", "for", "from", "how", "in",
"is", "it", "of", "on", "or", "que", "qual", "quais", "sobre", "te", "the", "to", "what",
"where", "which", "with",
];
pub fn match_schema(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
tokens: &TokenSet,
) -> RedDBResult<CandidateCollections> {
let visible = match scope.visible_collections() {
Some(set) => set.clone(),
None => {
runtime
.inner
.db
.store()
.list_collections()
.into_iter()
.collect()
}
};
let mut collections: BTreeSet<String> = BTreeSet::new();
let mut columns_by_collection: HashMap<String, BTreeSet<String>> = HashMap::new();
for keyword in &tokens.keywords {
let hits = runtime.schema_vocabulary_lookup(keyword);
for hit in hits {
if !visible.contains(&hit.collection) {
continue;
}
collections.insert(hit.collection.clone());
if let Some(column) = hit.column {
columns_by_collection
.entry(hit.collection)
.or_default()
.insert(column);
}
}
}
Ok(CandidateCollections {
collections: collections.into_iter().collect(),
columns_by_collection: columns_by_collection
.into_iter()
.map(|(k, v)| (k, v.into_iter().collect()))
.collect(),
})
}
pub fn vector_search_scoped(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
candidates: &CandidateCollections,
top_k: usize,
min_score: Option<f32>,
) -> Vec<VectorHit> {
if candidates.collections.is_empty() {
return Vec::new();
}
let Some(embedding) = embed_question(runtime, question) else {
return Vec::new();
};
let per_collection = top_k.max(1);
let mut hits: Vec<VectorHit> = Vec::new();
let _scope_guard = AskScopeGuard::install(scope);
let snap_ctx = crate::runtime::impl_core::capture_current_snapshot();
let mut rls_cache: HashMap<String, Option<crate::storage::query::ast::Filter>> = HashMap::new();
let store = runtime.inner.db.store();
for collection in &candidates.collections {
match super::authorized_search::AuthorizedSearch::execute_similar(
runtime,
scope,
collection,
&embedding,
per_collection,
min_score.unwrap_or(0.0),
) {
Ok(results) => {
for result in results {
let Some(entity) = store.get(collection, result.entity_id) else {
continue;
};
if !ask_entity_allowed(
runtime,
scope,
collection,
&entity,
snap_ctx.as_ref(),
&mut rls_cache,
) {
continue;
}
hits.push(VectorHit {
collection: collection.clone(),
entity_id: result.entity_id.raw(),
score: result.score,
});
}
}
Err(err) => {
debug!(
target: "ask_pipeline",
collection = collection,
err = %err,
"vector_search_scoped: collection skipped"
);
}
}
}
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.entity_id.cmp(&b.entity_id))
});
hits.truncate(top_k);
hits
}
pub fn graph_search_scoped(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
question: &str,
candidates: &CandidateCollections,
top_k: usize,
min_score: Option<f32>,
graph_depth: Option<usize>,
) -> Vec<GraphHit> {
if candidates.collections.is_empty() || top_k == 0 {
return Vec::new();
}
let depth = graph_depth
.unwrap_or(crate::runtime::ai::mcp_ask_tool::DEPTH_DEFAULT as usize)
.max(1);
let _scope_guard = AskScopeGuard::install(scope);
let input = SearchContextInput {
query: question.to_string(),
field: None,
vector: None,
collections: Some(candidates.collections.clone()),
graph_depth: Some(depth),
graph_max_edges: None,
max_cross_refs: Some(0),
follow_cross_refs: Some(false),
expand_graph: Some(true),
global_scan: Some(true),
reindex: Some(false),
limit: Some(top_k),
min_score,
};
let result = match if scope.visible_collections().is_some() {
super::authorized_search::AuthorizedSearch::execute_context(runtime, scope, input)
} else {
runtime.search_context(input)
} {
Ok(result) => result,
Err(err) => {
debug!(
target: "ask_pipeline",
err = %err,
"graph_search_scoped: context search skipped"
);
return Vec::new();
}
};
let mut hits = Vec::new();
for entity in result
.graph
.nodes
.into_iter()
.chain(result.graph.edges.into_iter())
{
let crate::runtime::DiscoveryMethod::GraphTraversal { depth, .. } = entity.discovery else {
continue;
};
let kind = match entity.entity.kind {
EntityKind::GraphNode(_) => GraphHitKind::Node,
EntityKind::GraphEdge(_) => GraphHitKind::Edge,
_ => continue,
};
hits.push(GraphHit {
collection: entity.collection,
entity_id: entity.entity.id.raw(),
score: entity.score,
depth,
kind,
});
}
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.depth.cmp(&b.depth))
.then_with(|| a.collection.cmp(&b.collection))
.then_with(|| a.entity_id.cmp(&b.entity_id))
});
hits.truncate(top_k);
hits
}
fn embed_question(runtime: &RedDBRuntime, question: &str) -> Option<Vec<f32>> {
let kv_getter = |key: &str| -> RedDBResult<Option<String>> {
match runtime.inner.db.get_kv("red_config", key) {
Some((Value::Text(value), _)) => Ok(Some(value.to_string())),
Some(_) => Ok(None),
None => Ok(None),
}
};
let provider = crate::ai::resolve_default_provider(&kv_getter);
if !provider.is_openai_compatible() {
return None;
}
let model = crate::ai::resolve_default_model(&provider, &kv_getter);
let api_key = crate::ai::resolve_api_key(&provider, None, kv_getter).ok()?;
let transport = crate::runtime::ai::transport::AiTransport::from_runtime(runtime);
let request = crate::ai::OpenAiEmbeddingRequest {
api_key,
model,
inputs: vec![question.to_string()],
dimensions: None,
api_base: provider.resolve_api_base(),
};
let response = crate::runtime::ai::block_on_ai(async move {
crate::ai::openai_embeddings_async(&transport, request).await
})
.and_then(|result| result)
.ok()?;
response.embeddings.into_iter().next()
}
pub fn filter_values(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
candidates: &CandidateCollections,
tokens: &TokenSet,
row_cap: usize,
) -> Vec<FilteredRow> {
if tokens.literals.is_empty() || candidates.collections.is_empty() {
return Vec::new();
}
let visible = scope.visible_collections();
let store = runtime.inner.db.store();
let mut out: Vec<FilteredRow> = Vec::new();
let _scope_guard = AskScopeGuard::install(scope);
let snap_ctx = crate::runtime::impl_core::capture_current_snapshot();
let mut rls_cache: HashMap<String, Option<crate::storage::query::ast::Filter>> = HashMap::new();
'collection: for collection in &candidates.collections {
if let Some(set) = visible {
if !set.contains(collection) {
continue;
}
}
let Some(manager) = store.get_collection(collection) else {
continue;
};
let hint_columns: &[String] = candidates
.columns_by_collection
.get(collection)
.map(|v| v.as_slice())
.unwrap_or(&[]);
for entity in manager.query_all(|_| true) {
if !ask_entity_allowed(
runtime,
scope,
collection,
&entity,
snap_ctx.as_ref(),
&mut rls_cache,
) {
continue;
}
if let Some(hit) = literal_match_in_entity(&entity, &tokens.literals, hint_columns) {
out.push(FilteredRow {
collection: collection.clone(),
entity,
matched_literal: hit.0,
matched_column: hit.1,
});
if out.len() >= row_cap {
break 'collection;
}
}
}
}
out
}
fn ask_entity_allowed(
runtime: &RedDBRuntime,
scope: &EffectiveScope,
collection: &str,
entity: &UnifiedEntity,
snap_ctx: Option<&crate::runtime::impl_core::SnapshotContext>,
rls_cache: &mut HashMap<String, Option<crate::storage::query::ast::Filter>>,
) -> bool {
if scope
.visible_collections()
.is_some_and(|visible| !visible.contains(collection))
{
return false;
}
runtime.search_entity_allowed(collection, entity, snap_ctx, rls_cache)
}
struct AskScopeGuard {
prev_tenant: Option<String>,
prev_auth: Option<(String, crate::auth::Role)>,
}
impl AskScopeGuard {
fn install(scope: &EffectiveScope) -> Self {
let prev_tenant = crate::runtime::impl_core::current_tenant();
let prev_auth = crate::runtime::impl_core::current_auth_identity();
match scope.effective_scope() {
Some(tenant) => crate::runtime::impl_core::set_current_tenant(tenant.to_string()),
None => crate::runtime::impl_core::clear_current_tenant(),
}
match scope.identity() {
Some((user, role)) => {
crate::runtime::impl_core::set_current_auth_identity(user.to_string(), role)
}
None => crate::runtime::impl_core::clear_current_auth_identity(),
}
Self {
prev_tenant,
prev_auth,
}
}
}
impl Drop for AskScopeGuard {
fn drop(&mut self) {
match self.prev_tenant.take() {
Some(tenant) => crate::runtime::impl_core::set_current_tenant(tenant),
None => crate::runtime::impl_core::clear_current_tenant(),
}
match self.prev_auth.take() {
Some((user, role)) => crate::runtime::impl_core::set_current_auth_identity(user, role),
None => crate::runtime::impl_core::clear_current_auth_identity(),
}
}
}
fn literal_match_in_entity(
entity: &UnifiedEntity,
literals: &[String],
hint_columns: &[String],
) -> Option<(String, Option<String>)> {
let row = match &entity.data {
EntityData::Row(row) => row,
_ => return None,
};
for column in hint_columns {
if let Some(value) = row.get_field(column) {
if let Some(lit) = first_literal_in_value(value, literals) {
return Some((lit, Some(column.clone())));
}
}
}
for (name, value) in row.iter_fields() {
if hint_columns.iter().any(|c| c == name) {
continue;
}
if let Some(lit) = first_literal_in_value(value, literals) {
return Some((lit, Some(name.to_string())));
}
}
None
}
fn first_literal_in_value(value: &Value, literals: &[String]) -> Option<String> {
let rendered = match value {
Value::Text(s) => s.to_string(),
Value::Integer(i) => i.to_string(),
Value::Float(f) => f.to_string(),
Value::Boolean(b) => b.to_string(),
Value::Json(j) => String::from_utf8_lossy(j).to_string(),
_ => return None,
};
for lit in literals {
if rendered.contains(lit) {
return Some(lit.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_tokens_splits_keywords_and_literals() {
let tokens = extract_tokens("quais as novidades sobre o passport FDD-12313?");
assert!(tokens.keywords.contains(&"novidades".to_string()));
assert!(tokens.keywords.contains(&"passport".to_string()));
assert!(tokens.literals.contains(&"FDD-12313".to_string()));
assert!(!tokens.is_empty());
}
#[test]
fn extract_tokens_returns_empty_for_punctuation_only() {
let tokens = extract_tokens("??? ...");
assert!(tokens.is_empty());
}
#[test]
fn extract_tokens_long_digit_run_is_a_literal() {
let tokens = extract_tokens("show order 987654321 details");
assert!(tokens.literals.contains(&"987654321".to_string()));
assert!(tokens.keywords.contains(&"order".to_string()));
assert!(tokens.keywords.contains(&"details".to_string()));
assert!(tokens.keywords.contains(&"show".to_string()));
}
#[test]
fn extract_tokens_short_uppercase_word_is_keyword_not_literal() {
let tokens = extract_tokens("USA exports report");
assert!(tokens.keywords.contains(&"usa".to_string()));
assert!(tokens.literals.is_empty());
}
#[test]
fn extract_tokens_dedups() {
let tokens = extract_tokens("passport passport FDD-1 FDD-1");
assert_eq!(
tokens.keywords.iter().filter(|k| *k == "passport").count(),
1
);
assert_eq!(tokens.literals.iter().filter(|l| *l == "FDD-1").count(), 1);
}
#[test]
fn first_literal_in_value_substring_match() {
let lit = first_literal_in_value(
&Value::text("issue FDD-12313 reported by user"),
&["FDD-12313".to_string()],
);
assert_eq!(lit.as_deref(), Some("FDD-12313"));
}
#[test]
fn first_literal_in_value_no_match_returns_none() {
assert!(
first_literal_in_value(&Value::text("nothing here"), &["FDD-12313".to_string()],)
.is_none()
);
}
use crate::api::RedDBOptions;
use crate::auth::Role;
use crate::runtime::statement_frame::EffectiveScope;
use crate::runtime::RedDBRuntime;
use crate::storage::schema::Value;
use crate::storage::transaction::snapshot::Snapshot;
use crate::storage::unified::entity::{
EntityData, EntityId, EntityKind, RowData, UnifiedEntity,
};
use std::sync::Arc;
fn make_scope(visible: HashSet<String>) -> EffectiveScope {
EffectiveScope {
tenant: Some("acme".to_string()),
identity: Some(("alice".to_string(), Role::Read)),
snapshot: Snapshot {
xid: 0,
in_progress: HashSet::new(),
},
visible_collections: Some(visible),
}
}
fn fresh_runtime() -> RedDBRuntime {
RedDBRuntime::with_options(RedDBOptions::in_memory()).expect("runtime boots")
}
fn test_row(collection: &str, id: u64) -> FilteredRow {
FilteredRow {
collection: collection.to_string(),
entity: UnifiedEntity::new(
EntityId::new(id),
EntityKind::TableRow {
table: Arc::from(collection),
row_id: id,
},
EntityData::Row(RowData {
columns: Vec::new(),
named: Some(
[("body".to_string(), Value::text("ticket FDD-1".to_string()))]
.into_iter()
.collect(),
),
schema: None,
}),
),
matched_literal: "FDD-1".to_string(),
matched_column: Some("body".to_string()),
}
}
fn test_graph_hit(collection: &str, id: u64, score: f32, depth: usize) -> GraphHit {
GraphHit {
collection: collection.to_string(),
entity_id: id,
score,
depth,
kind: GraphHitKind::Node,
}
}
fn test_text_hit(collection: &str, id: u64, score: f32) -> TextHit {
TextHit {
collection: collection.to_string(),
entity_id: id,
score,
}
}
struct TenantGuard;
impl TenantGuard {
fn set(tenant: &str) -> Self {
crate::runtime::impl_core::set_current_tenant(tenant.to_string());
Self
}
}
impl Drop for TenantGuard {
fn drop(&mut self) {
crate::runtime::impl_core::clear_current_tenant();
}
}
fn row_text<'a>(entity: &'a UnifiedEntity, field: &str) -> Option<&'a str> {
let row = entity.data.as_row()?;
match row.get_field(field)? {
Value::Text(value) => Some(value.as_ref()),
_ => None,
}
}
#[test]
fn fused_source_order_uses_rrf_and_total_limit() {
let ctx = AskContext {
source_limit: 2,
filtered_rows: vec![test_row("incidents", 2), test_row("incidents", 1)],
vector_hits: vec![
VectorHit {
collection: "incidents".to_string(),
entity_id: 1,
score: 0.91,
},
VectorHit {
collection: "docs".to_string(),
entity_id: 9,
score: 0.88,
},
],
..AskContext::default()
};
let order = fused_source_order(&ctx);
assert_eq!(
order,
vec![
FusedSourceRef::FilteredRow(1),
FusedSourceRef::FilteredRow(0)
]
);
}
#[test]
fn fused_source_order_includes_graph_bucket() {
let ctx = AskContext {
source_limit: 4,
filtered_rows: vec![test_row("incidents", 1)],
text_hits: vec![test_text_hit("articles", 5, 1.2)],
vector_hits: vec![
VectorHit {
collection: "incidents".to_string(),
entity_id: 1,
score: 0.91,
},
VectorHit {
collection: "docs".to_string(),
entity_id: 9,
score: 0.88,
},
],
graph_hits: vec![test_graph_hit("topology", 7, 0.80, 1)],
..AskContext::default()
};
let order = fused_source_order(&ctx);
assert_eq!(
order,
vec![
FusedSourceRef::FilteredRow(0),
FusedSourceRef::TextHit(0),
FusedSourceRef::GraphHit(0),
FusedSourceRef::VectorHit(1),
]
);
}
#[test]
fn text_search_bm25_scoped_ranks_specific_document_first() {
let rt = fresh_runtime();
rt.execute_query("CREATE TABLE docs (body TEXT) WITH CONTEXT INDEX ON (body)")
.expect("create docs");
rt.execute_query("INSERT INTO docs (body) VALUES ('passport renewal')")
.expect("insert specific doc");
rt.execute_query(
"INSERT INTO docs (body) VALUES ('passport renewal travel hotel airline visa luggage itinerary')",
)
.expect("insert broad doc");
let scope = make_scope(["docs".to_string()].into_iter().collect());
let candidates = CandidateCollections {
collections: vec!["docs".to_string()],
columns_by_collection: HashMap::new(),
};
let hits = text_search_bm25_scoped(&rt, &scope, "passport renewal", &candidates, 10);
assert_eq!(hits.len(), 2);
assert!(
hits[0].score > hits[1].score,
"BM25 text bucket should prefer the shorter exact match: {hits:?}"
);
}
#[test]
fn text_search_bm25_scoped_filters_rls_denied_hits() {
let rt = fresh_runtime();
rt.execute_query(
"CREATE TABLE docs (id INT, tenant_id TEXT, body TEXT) WITH CONTEXT INDEX ON (body)",
)
.expect("create docs");
rt.execute_query(
"INSERT INTO docs (id, tenant_id, body) VALUES \
(1, 'acme', 'shared launch plan'), \
(2, 'globex', 'shared launch plan')",
)
.expect("seed docs");
rt.execute_query(
"CREATE POLICY tenant_only ON docs FOR SELECT USING (tenant_id = CURRENT_TENANT())",
)
.expect("create policy");
rt.execute_query("ALTER TABLE docs ENABLE ROW LEVEL SECURITY")
.expect("enable rls");
let _tenant = TenantGuard::set("acme");
let scope = make_scope(["docs".to_string()].into_iter().collect());
let candidates = CandidateCollections {
collections: vec!["docs".to_string()],
columns_by_collection: HashMap::new(),
};
let hits = text_search_bm25_scoped(&rt, &scope, "shared launch", &candidates, 10);
assert_eq!(hits.len(), 1, "RLS should hide the globex hit: {hits:?}");
let entity = rt
.inner
.db
.store()
.get(
"docs",
crate::storage::unified::entity::EntityId::new(hits[0].entity_id),
)
.expect("hit entity exists");
assert_eq!(row_text(&entity, "tenant_id"), Some("acme"));
}
#[test]
fn execute_pipeline_retrieves_known_good_bm25_source_order() {
let rt = fresh_runtime();
rt.execute_query("CREATE TABLE docs (body TEXT) WITH CONTEXT INDEX ON (body)")
.expect("create docs");
rt.execute_query("INSERT INTO docs (body) VALUES ('passport renewal')")
.expect("insert specific doc");
rt.execute_query(
"INSERT INTO docs (body) VALUES ('passport renewal travel hotel airline visa luggage itinerary')",
)
.expect("insert broad doc");
rt.schema_vocabulary_apply(
crate::runtime::schema_vocabulary::DdlEvent::CreateCollection {
collection: "docs".to_string(),
columns: vec!["body".into()],
type_tags: Vec::new(),
description: None,
},
);
let scope = make_scope(["docs".to_string()].into_iter().collect());
let ctx = AskPipeline::execute_with_limit_and_min_score(
&rt,
&scope,
"body passport renewal",
2,
None,
Some(1),
)
.expect("pipeline executes");
assert_eq!(ctx.text_hits.len(), 2);
assert!(
ctx.text_hits[0].score > ctx.text_hits[1].score,
"BM25 source order should prefer the shorter exact match: {:?}",
ctx.text_hits
);
assert!(matches!(
fused_source_order(&ctx).first(),
Some(FusedSourceRef::TextHit(0))
));
}
#[test]
fn graph_search_scoped_honors_depth() {
let rt = fresh_runtime();
rt.execute_query("INSERT INTO tales NODE (label, name) VALUES ('alice', 'Alice')")
.expect("insert alice");
rt.execute_query("INSERT INTO tales NODE (label, name) VALUES ('bob', 'Bob')")
.expect("insert bob");
rt.execute_query("INSERT INTO tales NODE (label, name) VALUES ('carol', 'Carol')")
.expect("insert carol");
rt.execute_query(
"INSERT INTO tales EDGE (label, from, to) VALUES ('knows', 'alice', 'bob')",
)
.expect("insert alice-bob edge");
rt.execute_query(
"INSERT INTO tales EDGE (label, from, to) VALUES ('knows', 'bob', 'carol')",
)
.expect("insert bob-carol edge");
let scope = make_scope(["tales".to_string()].into_iter().collect());
let candidates = CandidateCollections {
collections: vec!["tales".to_string()],
columns_by_collection: HashMap::new(),
};
let depth1 = graph_search_scoped(&rt, &scope, "alice", &candidates, 10, None, Some(1));
let depth2 = graph_search_scoped(&rt, &scope, "alice", &candidates, 10, None, Some(2));
assert!(
depth1.iter().all(|hit| hit.depth <= 1),
"DEPTH 1 returned hits beyond one hop: {depth1:?}"
);
assert!(
depth2.iter().any(|hit| hit.depth == 2),
"DEPTH 2 should include the second-hop graph hit: {depth2:?}"
);
}
#[test]
fn filter_values_filters_rls_denied_rows() {
let rt = fresh_runtime();
rt.execute_query("CREATE TABLE docs (id INT, tenant_id TEXT, body TEXT)")
.expect("create docs");
rt.execute_query(
"INSERT INTO docs (id, tenant_id, body) VALUES \
(1, 'acme', 'incident FDD-12313'), \
(2, 'globex', 'incident FDD-12313')",
)
.expect("seed docs");
rt.execute_query(
"CREATE POLICY tenant_only ON docs FOR SELECT USING (tenant_id = CURRENT_TENANT())",
)
.expect("create policy");
rt.execute_query("ALTER TABLE docs ENABLE ROW LEVEL SECURITY")
.expect("enable rls");
let _tenant = TenantGuard::set("acme");
let scope = make_scope(["docs".to_string()].into_iter().collect());
let candidates = CandidateCollections {
collections: vec!["docs".to_string()],
columns_by_collection: HashMap::from([("docs".to_string(), vec!["body".to_string()])]),
};
let tokens = TokenSet {
keywords: vec!["incident".to_string()],
literals: vec!["FDD-12313".to_string()],
};
let rows = filter_values(&rt, &scope, &candidates, &tokens, 10);
assert_eq!(rows.len(), 1, "RLS should hide the globex row: {rows:?}");
assert_eq!(row_text(&rows[0].entity, "tenant_id"), Some("acme"));
}
#[test]
fn execute_refuses_empty_token_set() {
let rt = fresh_runtime();
let scope = make_scope(HashSet::new());
let err = AskPipeline::execute(&rt, &scope, "??? ...")
.expect_err("empty token set must short-circuit");
let msg = format!("{err}");
assert!(
msg.contains("yielded no usable tokens"),
"expected structured empty-token error, got: {msg}"
);
}
#[test]
fn match_schema_intersects_with_visible_set() {
let rt = fresh_runtime();
rt.schema_vocabulary_apply(
crate::runtime::schema_vocabulary::DdlEvent::CreateCollection {
collection: "travel".to_string(),
columns: vec!["id".into(), "passport".into()],
type_tags: Vec::new(),
description: None,
},
);
rt.schema_vocabulary_apply(
crate::runtime::schema_vocabulary::DdlEvent::CreateCollection {
collection: "secrets".to_string(),
columns: vec!["passport".into()],
type_tags: Vec::new(),
description: None,
},
);
let visible: HashSet<String> = ["travel".to_string()].into_iter().collect();
let scope = make_scope(visible.clone());
let tokens = TokenSet {
keywords: vec!["passport".to_string()],
literals: Vec::new(),
};
let candidates = match_schema(&rt, &scope, &tokens).expect("ok");
assert_eq!(candidates.collections, vec!["travel".to_string()]);
assert!(!candidates.collections.contains(&"secrets".to_string()));
let cols = candidates
.columns_by_collection
.get("travel")
.expect("hint columns");
assert!(cols.contains(&"passport".to_string()));
}
use proptest::prelude::*;
fn arb_collection() -> impl Strategy<Value = String> {
"[a-z]{1,4}"
}
fn arb_visible() -> impl Strategy<Value = HashSet<String>> {
prop::collection::hash_set(arb_collection(), 0..6)
}
fn arb_candidates() -> impl Strategy<Value = Vec<String>> {
prop::collection::vec(arb_collection(), 0..8)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn stage4_rows_subset_of_visible_collections(
visible in arb_visible(),
candidate_names in arb_candidates(),
literal_count in 0usize..3,
) {
let rt = PROPTEST_RUNTIME.get_or_init(fresh_runtime);
let candidates = CandidateCollections {
collections: candidate_names,
columns_by_collection: HashMap::new(),
};
let literals: Vec<String> = (0..literal_count)
.map(|i| format!("ID-{i}"))
.collect();
let tokens = TokenSet {
keywords: vec!["passport".to_string()],
literals,
};
let scope = make_scope(visible.clone());
let rows = filter_values(rt, &scope, &candidates, &tokens, DEFAULT_ROW_CAP);
for row in &rows {
prop_assert!(
visible.contains(&row.collection),
"Stage 4 leaked row collection={} not in visible={:?}",
row.collection, visible
);
}
}
}
static PROPTEST_RUNTIME: std::sync::OnceLock<RedDBRuntime> = std::sync::OnceLock::new();
#[test]
fn integration_passport_fdd_12313_funnels_through_four_stages() {
let rt = fresh_runtime();
rt.execute_query("CREATE TABLE travel (id INT, passport TEXT, notes TEXT)")
.expect("CREATE TABLE travel");
rt.execute_query(
"INSERT INTO travel (id, passport, notes) VALUES \
(1, 'BR-001', 'unrelated note'), \
(2, 'PT-002', 'incident FDD-12313 escalated'), \
(3, 'US-003', 'standard renewal')",
)
.expect("seed rows");
rt.execute_query("CREATE TABLE secrets (id INT, passport TEXT)")
.expect("CREATE TABLE secrets");
rt.execute_query("INSERT INTO secrets (id, passport) VALUES (99, 'FDD-12313')")
.expect("seed secrets");
let visible: HashSet<String> = ["travel".to_string()].into_iter().collect();
let scope = make_scope(visible);
let ctx = AskPipeline::execute(
&rt,
&scope,
"quais as novidades sobre o passport FDD-12313?",
)
.expect("pipeline runs");
assert!(ctx.tokens.keywords.contains(&"passport".to_string()));
assert!(ctx.tokens.literals.contains(&"FDD-12313".to_string()));
assert_eq!(ctx.candidates.collections, vec!["travel".to_string()]);
let _ = &ctx.vector_hits;
assert!(
ctx.filtered_rows
.iter()
.any(|r| r.collection == "travel" && r.matched_literal == "FDD-12313"),
"expected travel row with FDD-12313 match, got: {:?}",
ctx.filtered_rows
);
for row in &ctx.filtered_rows {
assert_ne!(
row.collection, "secrets",
"secrets row leaked into Stage 4 output"
);
}
let _ = ctx.timings.extract_us
+ ctx.timings.schema_us
+ ctx.timings.vector_us
+ ctx.timings.filter_us;
}
fn write_config(rt: &RedDBRuntime, key: &str, value: &str) {
let store = rt.inner.db.store();
store.set_config_tree(key, &crate::serde_json::Value::String(value.to_string()));
}
#[test]
fn routed_default_backend_runs_heuristic() {
let rt = fresh_runtime();
let scope = make_scope(HashSet::new());
let tokens = extract_tokens_routed(&rt, &scope, "passport FDD-12313")
.expect("heuristic path is infallible");
assert!(tokens.keywords.contains(&"passport".to_string()));
assert!(tokens.literals.contains(&"FDD-12313".to_string()));
}
#[tokio::test(flavor = "multi_thread")]
async fn routed_llm_auth_denied_uses_heuristic_fallback() {
let rt = fresh_runtime();
write_config(&rt, "ai.ner.backend", "llm");
write_config(&rt, "ai.ner.fallback", "use_heuristic");
let scope = make_scope(HashSet::new());
let tokens = tokio::task::spawn_blocking(move || {
extract_tokens_routed(&rt, &scope, "passport FDD-12313")
})
.await
.unwrap()
.expect("fallback policy keeps the call OK");
assert!(tokens.keywords.contains(&"passport".to_string()));
assert!(tokens.literals.contains(&"FDD-12313".to_string()));
}
#[tokio::test(flavor = "multi_thread")]
async fn routed_llm_auth_denied_empty_on_fail() {
let rt = fresh_runtime();
write_config(&rt, "ai.ner.backend", "llm");
write_config(&rt, "ai.ner.fallback", "empty_on_fail");
let scope = make_scope(HashSet::new());
let tokens = tokio::task::spawn_blocking(move || {
extract_tokens_routed(&rt, &scope, "passport FDD-12313")
})
.await
.unwrap()
.expect("empty_on_fail returns Ok with empty TokenSet");
assert!(tokens.is_empty(), "expected empty TokenSet, got {tokens:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn routed_llm_auth_denied_propagate_returns_error() {
let rt = fresh_runtime();
write_config(&rt, "ai.ner.backend", "llm");
write_config(&rt, "ai.ner.fallback", "propagate");
let scope = make_scope(HashSet::new());
let err = tokio::task::spawn_blocking(move || {
extract_tokens_routed(&rt, &scope, "passport FDD-12313")
})
.await
.unwrap()
.expect_err("propagate must surface the error");
let msg = format!("{err}");
assert!(
msg.contains("propagate") || msg.contains("ai.ner.backend"),
"expected propagate error message, got: {msg}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn execute_with_llm_backend_falls_back_and_completes_pipeline() {
let rt = fresh_runtime();
write_config(&rt, "ai.ner.backend", "llm");
rt.execute_query("CREATE TABLE travel (id INT, passport TEXT, notes TEXT)")
.expect("CREATE TABLE travel");
rt.execute_query(
"INSERT INTO travel (id, passport, notes) VALUES \
(2, 'PT-002', 'incident FDD-12313 escalated')",
)
.expect("seed rows");
let visible: HashSet<String> = ["travel".to_string()].into_iter().collect();
let scope = make_scope(visible);
let ctx = tokio::task::spawn_blocking(move || {
AskPipeline::execute(&rt, &scope, "passport FDD-12313")
})
.await
.unwrap()
.expect("pipeline runs");
assert!(ctx.tokens.keywords.contains(&"passport".to_string()));
assert!(ctx.tokens.literals.contains(&"FDD-12313".to_string()));
assert_eq!(ctx.candidates.collections, vec!["travel".to_string()]);
assert!(
ctx.filtered_rows
.iter()
.any(|r| r.matched_literal == "FDD-12313"),
"Stage 4 still runs after Stage 1 fallback"
);
}
}