use std::sync::Arc;
use tracing::info;
use super::types::QueryResultItem;
use crate::config::Config;
use crate::document::{DocumentTree, ReasoningIndex};
use crate::error::{Error, Result};
use crate::events::{EventEmitter, QueryEvent};
use crate::retrieval::stream::RetrieveEventReceiver;
use crate::retrieval::{RetrieveOptions, RetrieveResponse};
pub(crate) struct RetrieverClient {
retriever: Arc<crate::retrieval::PipelineRetriever>,
config: Arc<Config>,
events: EventEmitter,
default_options: RetrieveOptions,
}
impl RetrieverClient {
pub fn new(retriever: crate::retrieval::PipelineRetriever, config: Arc<Config>) -> Self {
Self {
retriever: Arc::new(retriever),
config,
events: EventEmitter::new(),
default_options: RetrieveOptions::default(),
}
}
pub fn with_events(mut self, events: EventEmitter) -> Self {
self.events = events;
self
}
pub async fn query_with_reasoning_index(
&self,
tree: &DocumentTree,
question: &str,
options: &RetrieveOptions,
reasoning_index: Option<ReasoningIndex>,
) -> Result<QueryResultItem> {
self.events.emit_query(QueryEvent::Started {
query: question.to_string(),
});
info!("Querying: {:?}", question);
let response = self
.retriever
.retrieve_with_reasoning_index(tree, question, options, reasoning_index)
.await
.map_err(|e| Error::Retrieval(e.to_string()))?;
let result = self.build_query_result(&response);
self.events.emit_query(QueryEvent::Complete {
total_results: result.node_ids.len(),
confidence: result.score,
});
Ok(result)
}
pub async fn query_stream(
&self,
tree: &DocumentTree,
question: &str,
options: &RetrieveOptions,
) -> Result<RetrieveEventReceiver> {
self.events.emit_query(QueryEvent::Started {
query: question.to_string(),
});
info!("Streaming query: {:?}", question);
let (handle, rx) = self.retriever.retrieve_streaming(tree, question, options);
let events = self.events.clone();
let question_owned = question.to_string();
tokio::spawn(async move {
let _ = handle.await;
events.emit_query(QueryEvent::Complete {
total_results: 0,
confidence: 0.0,
});
let _ = question_owned; });
Ok(rx)
}
fn build_query_result(&self, response: &RetrieveResponse) -> QueryResultItem {
let node_ids: Vec<String> = response
.results
.iter()
.filter_map(|r| r.node_id.clone())
.collect();
let content_parts: Vec<String> = response
.results
.iter()
.map(|r| {
let mut parts = vec![format!("## {}", r.title)];
if let Some(ref content) = r.content {
parts.push(content.clone());
}
parts.join("\n\n")
})
.collect();
let content = if content_parts.is_empty() {
response.content.clone()
} else {
content_parts.join("\n\n---\n\n")
};
QueryResultItem {
doc_id: String::new(), node_ids,
content,
score: response.confidence,
}
}
}
impl Clone for RetrieverClient {
fn clone(&self) -> Self {
Self {
retriever: Arc::clone(&self.retriever),
config: Arc::clone(&self.config),
events: self.events.clone(),
default_options: self.default_options.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retriever_client_creation() {
let config = Arc::new(Config::default());
let retriever = crate::retrieval::PipelineRetriever::new();
let client = RetrieverClient::new(retriever, config);
assert!(client.default_options.top_k > 0);
}
}