use serde_json::Value;
use super::chain::{ChainResult, HookChain};
use super::events::{HookEvent, RecallExpandQuery};
use super::executor::ExecutorRegistry;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PreRecallOutcome {
Allow,
Modified {
query: String,
namespace: String,
k: u32,
},
Denied { reason: String, code: i32 },
}
impl PreRecallOutcome {
#[must_use]
pub fn query(&self, original: &str) -> String {
match self {
PreRecallOutcome::Allow | PreRecallOutcome::Denied { .. } => original.to_string(),
PreRecallOutcome::Modified { query, .. } => query.clone(),
}
}
#[must_use]
pub fn namespace(&self, original: &str) -> String {
match self {
PreRecallOutcome::Allow | PreRecallOutcome::Denied { .. } => original.to_string(),
PreRecallOutcome::Modified { namespace, .. } => namespace.clone(),
}
}
#[must_use]
pub fn k(&self, original: u32) -> u32 {
match self {
PreRecallOutcome::Allow | PreRecallOutcome::Denied { .. } => original,
PreRecallOutcome::Modified { k, .. } => *k,
}
}
#[must_use]
pub fn is_denied(&self) -> bool {
matches!(self, PreRecallOutcome::Denied { .. })
}
}
pub async fn apply_pre_recall_expand(
query: &str,
namespace: &str,
k: u32,
chain: &HookChain,
registry: &mut ExecutorRegistry,
) -> PreRecallOutcome {
if chain.hooks().is_empty() {
return PreRecallOutcome::Allow;
}
let payload_struct = RecallExpandQuery {
query: query.to_string(),
namespace: namespace.to_string(),
k,
};
let payload = serde_json::to_value(&payload_struct).unwrap_or_else(|_| Value::Null);
let result = chain
.fire(HookEvent::PreRecallExpand, payload, registry)
.await;
match result {
ChainResult::Allow => PreRecallOutcome::Allow,
ChainResult::ModifiedAllow(delta) => {
let new_query = delta.content.unwrap_or_else(|| query.to_string());
let new_namespace = delta.namespace.unwrap_or_else(|| namespace.to_string());
let new_k = match delta.priority {
Some(p) if p > 0 => u32::try_from(p).unwrap_or(k),
_ => k,
};
PreRecallOutcome::Modified {
query: new_query,
namespace: new_namespace,
k: new_k,
}
}
ChainResult::Deny { reason, code } => PreRecallOutcome::Denied { reason, code },
ChainResult::AskUser { .. } => {
tracing::warn!(
"hooks: pre_recall_expand returned AskUser; degrading to Allow \
(operator prompts are incompatible with the recall hot path)"
);
PreRecallOutcome::Allow
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn outcome_allow_uses_original_triple() {
let o = PreRecallOutcome::Allow;
assert_eq!(o.query("orig"), "orig");
assert_eq!(o.namespace("ns"), "ns");
assert_eq!(o.k(7), 7);
assert!(!o.is_denied());
}
#[test]
fn outcome_modified_returns_rewritten_triple() {
let o = PreRecallOutcome::Modified {
query: "rewrite".into(),
namespace: "team/x".into(),
k: 25,
};
assert_eq!(o.query("orig"), "rewrite");
assert_eq!(o.namespace("ns"), "team/x");
assert_eq!(o.k(7), 25);
assert!(!o.is_denied());
}
#[test]
fn outcome_denied_falls_back_to_original_for_logging() {
let o = PreRecallOutcome::Denied {
reason: "blocked".into(),
code: 451,
};
assert_eq!(o.query("orig"), "orig");
assert_eq!(o.namespace("ns"), "ns");
assert_eq!(o.k(7), 7);
assert!(o.is_denied());
}
#[tokio::test]
async fn empty_chain_is_allow_fast_path() {
let chain = HookChain::new(vec![]);
let mut reg = ExecutorRegistry::new();
let out = apply_pre_recall_expand("hello", "default", 10, &chain, &mut reg).await;
assert_eq!(out, PreRecallOutcome::Allow);
}
}