use async_trait::async_trait;
use super::types::{RetrieveOptions, RetrieveResponse};
use crate::document::DocumentTree;
pub type RetrieverResult<T> = Result<T, RetrieverError>;
#[derive(Debug, thiserror::Error)]
pub enum RetrieverError {
#[error("Invalid document tree: {0}")]
InvalidTree(String),
#[error("No relevant nodes found for query")]
NoResults,
#[error("LLM error: {0}")]
LlmError(String),
#[error("Embedding error: {0}")]
EmbeddingError(String),
#[error("Cache error: {0}")]
CacheError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Internal error: {0}")]
Internal(String),
}
#[async_trait]
pub trait Retriever: Send + Sync {
async fn retrieve(
&self,
tree: &DocumentTree,
query: &str,
options: &RetrieveOptions,
) -> RetrieverResult<RetrieveResponse>;
fn name(&self) -> &str;
fn supports_options(&self, _options: &RetrieveOptions) -> bool {
true
}
fn estimate_cost(&self, tree: &DocumentTree, _options: &RetrieveOptions) -> CostEstimate {
let node_count = tree.node_count();
CostEstimate {
llm_calls: node_count / 2, tokens: node_count * 100,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CostEstimate {
pub llm_calls: usize,
pub tokens: usize,
}
#[derive(Debug, Clone)]
pub struct RetrievalContext {
pub query: String,
pub query_normalized: String,
pub query_tokens: Vec<String>,
pub current_depth: usize,
pub results_count: usize,
pub tokens_collected: usize,
pub max_tokens: usize,
pub sufficiency_enabled: bool,
}
impl RetrievalContext {
pub fn new(query: &str, max_tokens: usize, sufficiency_enabled: bool) -> Self {
let query_normalized = query.to_lowercase();
let query_tokens: Vec<String> = query_normalized
.split_whitespace()
.map(|s| s.to_string())
.collect();
Self {
query: query.to_string(),
query_normalized,
query_tokens,
current_depth: 0,
results_count: 0,
tokens_collected: 0,
max_tokens,
sufficiency_enabled,
}
}
pub fn is_token_limit_reached(&self) -> bool {
self.tokens_collected >= self.max_tokens
}
pub fn token_utilization(&self) -> f32 {
if self.max_tokens == 0 {
0.0
} else {
(self.tokens_collected as f32 / self.max_tokens as f32).min(1.0)
}
}
}