use async_trait::async_trait;
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, QueryComplexity};
use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities};
use crate::config::StrategyConfig;
use crate::document::{DocumentTree, NodeId};
#[async_trait]
pub trait EmbeddingModel: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
fn dimension(&self) -> usize;
}
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("Failed to generate embedding: {0}")]
GenerationFailed(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
}
pub struct SemanticStrategy {
model: Box<dyn EmbeddingModel>,
cache_embeddings: bool,
similarity_threshold: f32,
high_similarity_threshold: f32,
low_similarity_threshold: f32,
}
impl SemanticStrategy {
pub fn new(model: Box<dyn EmbeddingModel>) -> Self {
Self::with_config(model, &StrategyConfig::default())
}
pub fn with_config(model: Box<dyn EmbeddingModel>, config: &StrategyConfig) -> Self {
Self {
model,
cache_embeddings: true,
similarity_threshold: config.similarity_threshold,
high_similarity_threshold: config.high_similarity_threshold,
low_similarity_threshold: config.low_similarity_threshold,
}
}
pub fn with_cache(mut self, cache: bool) -> Self {
self.cache_embeddings = cache;
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold;
self
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
fn get_embedding_text(tree: &DocumentTree, node_id: NodeId) -> String {
if let Some(node) = tree.get(node_id) {
if !node.summary.is_empty() {
format!("{}: {}", node.title, node.summary)
} else if !node.content.is_empty() {
let content = if node.content.len() > 500 {
&node.content[..500]
} else {
&node.content
};
format!("{}: {}", node.title, content)
} else {
node.title.clone()
}
} else {
String::new()
}
}
}
#[async_trait]
impl RetrievalStrategy for SemanticStrategy {
async fn evaluate_node(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> NodeEvaluation {
let node_text = Self::get_embedding_text(tree, node_id);
if node_text.is_empty() {
return NodeEvaluation {
score: 0.0,
decision: NavigationDecision::Skip,
reasoning: Some("Empty node".to_string()),
};
}
let query_embedding = match self.model.embed(&context.query).await {
Ok(e) => e,
Err(e) => {
return NodeEvaluation {
score: 0.0,
decision: NavigationDecision::Skip,
reasoning: Some(format!("Embedding error: {}", e)),
};
}
};
let node_embedding = match self.model.embed(&node_text).await {
Ok(e) => e,
Err(e) => {
return NodeEvaluation {
score: 0.0,
decision: NavigationDecision::Skip,
reasoning: Some(format!("Embedding error: {}", e)),
};
}
};
let similarity = Self::cosine_similarity(&query_embedding, &node_embedding);
let decision = if similarity > self.high_similarity_threshold {
NavigationDecision::ThisIsTheAnswer
} else if similarity > self.similarity_threshold {
if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
}
} else if similarity > self.low_similarity_threshold {
NavigationDecision::ExploreMore
} else {
NavigationDecision::Skip
};
NodeEvaluation {
score: similarity,
decision,
reasoning: Some(format!("Semantic similarity: {:.3}", similarity)),
}
}
async fn evaluate_nodes(
&self,
tree: &DocumentTree,
node_ids: &[NodeId],
context: &RetrievalContext,
) -> Vec<NodeEvaluation> {
let query_embedding = match self.model.embed(&context.query).await {
Ok(e) => e,
Err(e) => {
return node_ids
.iter()
.map(|_| NodeEvaluation {
score: 0.0,
decision: NavigationDecision::Skip,
reasoning: Some(format!("Embedding error: {}", e)),
})
.collect();
}
};
let texts: Vec<String> = node_ids
.iter()
.map(|&id| Self::get_embedding_text(tree, id))
.collect();
let node_embeddings = match self.model.embed_batch(&texts).await {
Ok(e) => e,
Err(e) => {
return node_ids
.iter()
.map(|_| NodeEvaluation {
score: 0.0,
decision: NavigationDecision::Skip,
reasoning: Some(format!("Embedding error: {}", e)),
})
.collect();
}
};
node_ids
.iter()
.zip(node_embeddings.iter())
.map(|(&node_id, node_embedding)| {
let similarity = Self::cosine_similarity(&query_embedding, node_embedding);
let decision = if similarity > 0.8 {
NavigationDecision::ThisIsTheAnswer
} else if similarity > self.similarity_threshold {
if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
}
} else if similarity > 0.3 {
NavigationDecision::ExploreMore
} else {
NavigationDecision::Skip
};
NodeEvaluation {
score: similarity,
decision,
reasoning: Some(format!("Semantic similarity: {:.3}", similarity)),
}
})
.collect()
}
fn name(&self) -> &'static str {
"semantic"
}
fn capabilities(&self) -> StrategyCapabilities {
StrategyCapabilities {
uses_llm: false,
uses_embeddings: true,
supports_sufficiency: true,
typical_latency_ms: 50,
}
}
fn suitable_for_complexity(&self, complexity: QueryComplexity) -> bool {
matches!(
complexity,
QueryComplexity::Simple | QueryComplexity::Medium
)
}
}