use serde::Serialize;
use crate::task_state::TaskOperatingState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum RetrievalStrategy {
CacheOnly,
IndexedRetrieval,
LiveDiscovery,
DirectVerification,
}
impl RetrievalStrategy {
pub fn as_str(&self) -> &'static str {
match self {
Self::CacheOnly => "cache_only",
Self::IndexedRetrieval => "indexed_retrieval",
Self::LiveDiscovery => "live_discovery",
Self::DirectVerification => "direct_verification",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct RetrievalDecision {
pub strategy: RetrievalStrategy,
pub signals: Vec<String>,
pub confidence: f64,
}
const FRESHNESS_MARKERS: &[&str] = &[
"latest",
"current",
"today",
"right now",
"breaking",
"recent news",
"what's happening",
"what is happening",
];
const VERIFICATION_MARKERS: &[&str] = &[
"verify",
"confirm",
"is it true",
"fact check",
"fact-check",
"actually",
"really true",
"double check",
"double-check",
"check if",
];
pub fn decide_retrieval_strategy(
intent_names: &[String],
_task_state: &TaskOperatingState,
user_prompt: &str,
is_first_turn: bool,
) -> RetrievalDecision {
let lower = user_prompt.to_ascii_lowercase();
let mut signals: Vec<String> = Vec::new();
let has_verification_marker = VERIFICATION_MARKERS.iter().any(|m| {
if lower.contains(m) {
signals.push(format!("marker:{m}"));
true
} else {
false
}
});
if has_verification_marker {
return RetrievalDecision {
strategy: RetrievalStrategy::DirectVerification,
signals,
confidence: 0.85,
};
}
let has_current_events_intent = intent_names
.iter()
.any(|i| i.eq_ignore_ascii_case("currentevents"));
if has_current_events_intent {
signals.push("intent:CurrentEvents".into());
return RetrievalDecision {
strategy: RetrievalStrategy::LiveDiscovery,
signals,
confidence: 0.90,
};
}
let has_freshness_marker = FRESHNESS_MARKERS.iter().any(|m| {
if lower.contains(m) {
signals.push(format!("marker:{m}"));
true
} else {
false
}
});
if has_freshness_marker {
return RetrievalDecision {
strategy: RetrievalStrategy::LiveDiscovery,
signals,
confidence: 0.75,
};
}
let has_acknowledgement_intent = intent_names
.iter()
.any(|i| i.eq_ignore_ascii_case("acknowledgement"));
if has_acknowledgement_intent {
signals.push("intent:Acknowledgement".into());
return RetrievalDecision {
strategy: RetrievalStrategy::CacheOnly,
signals,
confidence: 0.90,
};
}
if is_first_turn && is_short_greeting(&lower) {
signals.push("first_turn_greeting".into());
return RetrievalDecision {
strategy: RetrievalStrategy::CacheOnly,
signals,
confidence: 0.80,
};
}
signals.push("default".into());
RetrievalDecision {
strategy: RetrievalStrategy::IndexedRetrieval,
signals,
confidence: 0.70,
}
}
fn is_short_greeting(lower: &str) -> bool {
let trimmed = lower.trim();
if trimmed.len() > 30 {
return false;
}
const GREETINGS: &[&str] = &[
"hi",
"hey",
"hello",
"yo",
"sup",
"howdy",
"good morning",
"good afternoon",
"good evening",
"what's up",
"whats up",
"hola",
"greetings",
];
GREETINGS.iter().any(|g| {
trimmed == *g
|| trimmed.starts_with(&format!("{g} "))
|| trimmed.starts_with(&format!("{g}!"))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task_state::{TaskStateInput, synthesize};
fn base_input() -> TaskStateInput {
TaskStateInput {
user_content: "test message".into(),
intents: vec![],
authority: "SelfGenerated".into(),
retrieval_metrics: None,
tool_search_stats: None,
mcp_tools_available: false,
taskable_agent_count: 0,
fit_agent_count: 0,
fit_agent_names: vec![],
enabled_skill_count: 0,
matching_skill_count: 0,
missing_skills: vec![],
remaining_budget_tokens: 8000,
provider_breaker_open: false,
inference_mode: "standard".into(),
decomposition_proposal: None,
explicit_specialist_workflow: false,
named_tool_match: false,
recent_response_skeletons: vec![],
recent_user_message_lengths: vec![],
self_echo_fragments: vec![],
declared_action: None,
previous_turn_had_protocol_issues: false,
normalization_retry_streak: 0,
}
}
fn make_state(input: &TaskStateInput) -> TaskOperatingState {
synthesize(input)
}
#[test]
fn current_events_intent_selects_live_discovery() {
let input = base_input();
let state = make_state(&input);
let intents = vec!["CurrentEvents".to_string()];
let decision =
decide_retrieval_strategy(&intents, &state, "what's going on in the world", false);
assert_eq!(decision.strategy, RetrievalStrategy::LiveDiscovery);
assert!(decision.signals.iter().any(|s| s.contains("CurrentEvents")));
assert!(decision.confidence > 0.5);
}
#[test]
fn verify_keyword_selects_direct_verification() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision =
decide_retrieval_strategy(&intents, &state, "can you verify that claim", false);
assert_eq!(decision.strategy, RetrievalStrategy::DirectVerification);
assert!(decision.signals.iter().any(|s| s.contains("verify")));
assert!(decision.confidence > 0.5);
}
#[test]
fn acknowledgement_selects_cache_only() {
let input = base_input();
let state = make_state(&input);
let intents = vec!["Acknowledgement".to_string()];
let decision = decide_retrieval_strategy(&intents, &state, "ok got it", false);
assert_eq!(decision.strategy, RetrievalStrategy::CacheOnly);
assert!(
decision
.signals
.iter()
.any(|s| s.contains("Acknowledgement"))
);
}
#[test]
fn normal_prompt_selects_indexed_retrieval() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision = decide_retrieval_strategy(
&intents,
&state,
"tell me about the architecture of the system",
false,
);
assert_eq!(decision.strategy, RetrievalStrategy::IndexedRetrieval);
assert!(decision.signals.iter().any(|s| s == "default"));
}
#[test]
fn freshness_markers_select_live_discovery() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
for marker in &["latest", "right now", "breaking", "what's happening"] {
let prompt = format!("tell me the {marker} developments");
let decision = decide_retrieval_strategy(&intents, &state, &prompt, false);
assert_eq!(
decision.strategy,
RetrievalStrategy::LiveDiscovery,
"marker '{marker}' should select LiveDiscovery"
);
}
}
#[test]
fn first_turn_greeting_selects_cache_only() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision = decide_retrieval_strategy(&intents, &state, "hello", true);
assert_eq!(decision.strategy, RetrievalStrategy::CacheOnly);
assert!(decision.signals.iter().any(|s| s == "first_turn_greeting"));
}
#[test]
fn first_turn_non_greeting_uses_indexed_retrieval() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision =
decide_retrieval_strategy(&intents, &state, "explain the theory of relativity", true);
assert_eq!(decision.strategy, RetrievalStrategy::IndexedRetrieval);
}
#[test]
fn verification_takes_priority_over_freshness() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision =
decide_retrieval_strategy(&intents, &state, "can you verify the latest claims", false);
assert_eq!(decision.strategy, RetrievalStrategy::DirectVerification);
}
#[test]
fn fact_check_selects_direct_verification() {
let input = base_input();
let state = make_state(&input);
let intents: Vec<String> = vec![];
let decision =
decide_retrieval_strategy(&intents, &state, "fact check this statement for me", false);
assert_eq!(decision.strategy, RetrievalStrategy::DirectVerification);
assert!(decision.signals.iter().any(|s| s.contains("fact check")));
}
}