use std::sync::Arc;
use tracing::info;
use super::events::{EventEmitter, QueryEvent};
use super::types::QueryResultItem;
use crate::config::Config;
use crate::document::{DocumentTree, NodeId, ReasoningIndex};
use crate::error::{Error, Result};
use crate::retrieval::content::ContentAggregatorConfig;
use crate::retrieval::stream::RetrieveEventReceiver;
use crate::retrieval::{RetrievalResult, RetrieveOptions, RetrieveResponse};
pub(crate) struct RetrieverClient {
retriever: Arc<crate::retrieval::PipelineRetriever>,
config: Arc<Config>,
events: EventEmitter,
default_options: RetrieveOptions,
}
#[derive(Debug, Clone)]
pub(crate) struct RetrieverClientConfig {
pub default_top_k: usize,
pub default_token_budget: usize,
pub content_config: Option<ContentAggregatorConfig>,
pub enable_cache: bool,
}
impl Default for RetrieverClientConfig {
fn default() -> Self {
Self {
default_top_k: 5,
default_token_budget: 4000,
content_config: None,
enable_cache: true,
}
}
}
impl RetrieverClient {
pub fn new(retriever: crate::retrieval::PipelineRetriever, config: Arc<Config>) -> Self {
Self {
retriever: Arc::new(retriever),
config,
events: EventEmitter::new(),
default_options: RetrieveOptions::default(),
}
}
pub fn with_events(mut self, events: EventEmitter) -> Self {
self.events = events;
self
}
pub fn with_config(mut self, config: RetrieverClientConfig) -> Self {
self.default_options = RetrieveOptions::new()
.with_top_k(config.default_top_k)
.with_max_tokens(config.default_token_budget)
.with_enable_cache(config.enable_cache);
self
}
pub(crate) fn from_arc(
retriever: Arc<crate::retrieval::PipelineRetriever>,
config: Arc<Config>,
events: EventEmitter,
) -> Self {
Self {
retriever,
config,
events,
default_options: RetrieveOptions::default(),
}
}
pub async fn query(
&self,
tree: &DocumentTree,
question: &str,
options: &RetrieveOptions,
) -> Result<QueryResultItem> {
self.query_with_reasoning_index(tree, question, options, None).await
}
pub async fn query_with_reasoning_index(
&self,
tree: &DocumentTree,
question: &str,
options: &RetrieveOptions,
reasoning_index: Option<ReasoningIndex>,
) -> Result<QueryResultItem> {
self.events.emit_query(QueryEvent::Started {
query: question.to_string(),
});
info!("Querying: {:?}", question);
let response = self
.retriever
.retrieve_with_reasoning_index(tree, question, options, reasoning_index)
.await
.map_err(|e| Error::Retrieval(e.to_string()))?;
let result = self.build_query_result(&response);
self.events.emit_query(QueryEvent::Complete {
total_results: result.node_ids.len(),
confidence: result.score,
});
Ok(result)
}
pub async fn query_stream(
&self,
tree: &DocumentTree,
question: &str,
options: &RetrieveOptions,
) -> Result<RetrieveEventReceiver> {
self.events.emit_query(QueryEvent::Started {
query: question.to_string(),
});
info!("Streaming query: {:?}", question);
let (handle, rx) = self.retriever.retrieve_streaming(tree, question, options);
let events = self.events.clone();
let question_owned = question.to_string();
tokio::spawn(async move {
let _ = handle.await;
events.emit_query(QueryEvent::Complete {
total_results: 0,
confidence: 0.0,
});
let _ = question_owned; });
Ok(rx)
}
fn build_query_result(&self, response: &RetrieveResponse) -> QueryResultItem {
let node_ids: Vec<String> = response
.results
.iter()
.filter_map(|r| r.node_id.clone())
.collect();
let content_parts: Vec<String> = response
.results
.iter()
.map(|r| {
let mut parts = vec![format!("## {}", r.title)];
if let Some(ref content) = r.content {
parts.push(content.clone());
}
parts.join("\n\n")
})
.collect();
let content = if content_parts.is_empty() {
response.content.clone()
} else {
content_parts.join("\n\n---\n\n")
};
QueryResultItem {
doc_id: String::new(), node_ids,
content,
score: response.confidence,
}
}
pub fn find_similar(
&self,
tree: &DocumentTree,
node_id: NodeId,
top_k: usize,
) -> Result<Vec<RetrievalResult>> {
let mut results = Vec::new();
let target_content = tree
.get(node_id)
.map(|n| n.content.clone())
.unwrap_or_default();
if target_content.is_empty() {
return Ok(results);
}
let target_keywords = self.extract_keywords(&target_content);
let root = tree.root();
let mut stack = vec![root];
while let Some(current_id) = stack.pop() {
if current_id == node_id {
stack.extend(tree.children(current_id));
continue;
}
if let Some(node) = tree.get(current_id) {
let node_keywords = self.extract_keywords(&node.content);
let similarity = self.calculate_similarity(&target_keywords, &node_keywords);
if similarity > 0.3 {
results.push(
RetrievalResult::new(&node.title)
.with_node_id(format!("{:?}", current_id))
.with_content(node.content.clone())
.with_score(similarity)
.with_depth(tree.depth(current_id)),
);
}
}
stack.extend(tree.children(current_id));
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
fn extract_keywords(&self, content: &str) -> Vec<String> {
content
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3)
.take(20)
.map(|s| s.to_string())
.collect()
}
fn calculate_similarity(&self, set1: &[String], set2: &[String]) -> f32 {
if set1.is_empty() || set2.is_empty() {
return 0.0;
}
let set1_set: std::collections::HashSet<_> = set1.iter().collect();
let set2_set: std::collections::HashSet<_> = set2.iter().collect();
let intersection = set1_set.intersection(&set2_set).count();
let union = set1_set.union(&set2_set).count();
intersection as f32 / union as f32
}
pub fn get_node_context(
&self,
tree: &DocumentTree,
node_id: NodeId,
ancestor_depth: usize,
) -> Result<NodeContext> {
let mut ancestors = Vec::new();
let mut siblings = Vec::new();
let mut current_id = Some(node_id);
let mut depth = 0;
while let Some(id) = current_id {
if depth >= ancestor_depth {
break;
}
if let Some(node) = tree.get(id) {
ancestors.push(
RetrievalResult::new(&node.title)
.with_node_id(format!("{:?}", id))
.with_depth(tree.depth(id)),
);
if let Some(parent_id) = tree.parent(id) {
for child_id in tree.children(parent_id) {
if child_id != id {
if let Some(sibling) = tree.get(child_id) {
siblings.push(
RetrievalResult::new(&sibling.title)
.with_node_id(format!("{:?}", child_id))
.with_depth(tree.depth(child_id)),
);
}
}
}
}
}
current_id = tree.parent(id);
depth += 1;
}
let target = tree.get(node_id).map(|n| {
RetrievalResult::new(&n.title)
.with_node_id(format!("{:?}", node_id))
.with_content(n.content.clone())
.with_depth(tree.depth(node_id))
});
Ok(NodeContext {
target,
ancestors,
siblings,
})
}
pub(crate) fn inner(&self) -> Arc<crate::retrieval::PipelineRetriever> {
Arc::clone(&self.retriever)
}
}
impl Clone for RetrieverClient {
fn clone(&self) -> Self {
Self {
retriever: Arc::clone(&self.retriever),
config: Arc::clone(&self.config),
events: self.events.clone(),
default_options: self.default_options.clone(),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct NodeContext {
pub target: Option<RetrievalResult>,
pub ancestors: Vec<RetrievalResult>,
pub siblings: Vec<RetrievalResult>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retriever_client_creation() {
let config = Arc::new(Config::default());
let retriever = crate::retrieval::PipelineRetriever::new();
let client = RetrieverClient::new(retriever, config);
assert!(client.default_options.top_k > 0);
}
}