use crate::orchestration::{
SearchTypeRegistry, merge_scoped_contexts, prepare_search_result, scope_context_by_datasets,
};
use crate::types::{SearchError, SearchOutput, SearchParams, SearchRequest, SearchResponse};
use crate::utils::detect_feedback;
use cognee_database::{IngestDb, SearchHistoryDb, SearchHistoryEntry};
use cognee_llm::Llm;
use cognee_session::{SessionContext, SessionManager, UsedGraphElementIds};
use std::sync::Arc;
#[cfg(feature = "telemetry")]
fn emit_search_started(request: &SearchRequest) {
cognee_telemetry::send_telemetry(
"cognee.search EXECUTION STARTED",
request.user_id,
Some(serde_json::json!({
"cognee_version": cognee_telemetry::cognee_version(),
"tenant_id": cognee_telemetry::tenant_id_for_telemetry(None),
})),
);
}
#[cfg(not(feature = "telemetry"))]
#[inline]
fn emit_search_started(_request: &SearchRequest) {}
#[cfg(feature = "telemetry")]
fn emit_search_completed(request: &SearchRequest) {
cognee_telemetry::send_telemetry(
"cognee.search EXECUTION COMPLETED",
request.user_id,
Some(serde_json::json!({
"cognee_version": cognee_telemetry::cognee_version(),
"tenant_id": cognee_telemetry::tenant_id_for_telemetry(None),
})),
);
}
#[cfg(not(feature = "telemetry"))]
#[inline]
fn emit_search_completed(_request: &SearchRequest) {}
fn apply_context_char_limit(gc: &str, max_chars: Option<usize>) -> &str {
match max_chars {
Some(limit) => {
if gc.len() <= limit {
gc
} else {
let boundary = gc
.char_indices()
.take_while(|(byte_pos, _)| *byte_pos < limit)
.last()
.map(|(byte_pos, c)| byte_pos + c.len_utf8())
.unwrap_or(0);
&gc[..boundary]
}
}
None => gc,
}
}
fn build_used_graph_element_ids(
context: Option<&[crate::types::SearchItem]>,
) -> Option<UsedGraphElementIds> {
let items = context?;
let mut node_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut edge_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for item in items {
if let Some(src) = item.payload.get("source_id").and_then(|v| v.as_str()) {
node_ids.insert(src.to_string());
}
if let Some(tgt) = item.payload.get("target_id").and_then(|v| v.as_str()) {
node_ids.insert(tgt.to_string());
}
if let Some(eid) = item.payload.get("edge_id").and_then(|v| v.as_str()) {
edge_ids.insert(eid.to_string());
}
}
if node_ids.is_empty() && edge_ids.is_empty() {
None
} else {
let mut node_ids_vec: Vec<String> = node_ids.into_iter().collect();
let mut edge_ids_vec: Vec<String> = edge_ids.into_iter().collect();
node_ids_vec.sort();
edge_ids_vec.sort();
Some(UsedGraphElementIds {
node_ids: node_ids_vec,
edge_ids: edge_ids_vec,
})
}
}
pub struct SearchOrchestrator {
registry: SearchTypeRegistry,
database: Option<Arc<dyn SearchHistoryDb>>,
dataset_resolver: Option<Arc<dyn IngestDb>>,
session_manager: Option<Arc<SessionManager>>,
llm: Option<Arc<dyn Llm>>,
enable_access_tracking: bool,
}
impl SearchOrchestrator {
pub fn new(registry: SearchTypeRegistry) -> Self {
Self {
registry,
database: None,
dataset_resolver: None,
session_manager: None,
llm: None,
enable_access_tracking: false,
}
}
pub fn with_database(mut self, database: Arc<dyn SearchHistoryDb>) -> Self {
self.database = Some(database);
self
}
pub fn with_dataset_resolver(mut self, resolver: Arc<dyn IngestDb>) -> Self {
self.dataset_resolver = Some(resolver);
self
}
pub fn with_session_manager(mut self, session_manager: Arc<SessionManager>) -> Self {
self.session_manager = Some(session_manager);
self
}
pub fn with_llm(mut self, llm: Arc<dyn Llm>) -> Self {
self.llm = Some(llm);
self
}
pub fn with_access_tracking(mut self) -> Self {
self.enable_access_tracking = true;
self
}
pub async fn get_history(
&self,
user_id: Option<uuid::Uuid>,
limit: Option<usize>,
) -> Result<Vec<SearchHistoryEntry>, SearchError> {
let Some(database) = &self.database else {
return Ok(Vec::new());
};
Ok(database.get_history(user_id, limit).await?)
}
pub fn with_community_retriever(
mut self,
name: impl Into<String>,
retriever: crate::retrievers::SearchRetrieverRef,
) -> Self {
self.registry.register_named(name, retriever);
self
}
pub async fn search_batch(
&self,
requests: &[SearchRequest],
) -> Result<Vec<SearchResponse>, SearchError> {
let mut responses = Vec::with_capacity(requests.len());
for request in requests {
responses.push(self.search(request).await?);
}
Ok(responses)
}
#[tracing::instrument(
name = "cognee.search",
skip(self, request),
fields(
cognee.search.type = %format!("{:?}", request.search_type),
cognee.search.query.len = request.query_text.len(),
)
)]
pub async fn search(
&self,
request: &SearchRequest,
) -> Result<SearchResponse, crate::types::SearchError> {
emit_search_started(request);
let retriever: crate::retrievers::SearchRetrieverRef =
if let Some(ref custom_type) = request.custom_search_type {
self.registry.get_by_name(custom_type).ok_or_else(|| {
SearchError::InvalidInput(format!(
"No community retriever registered for '{custom_type}'"
))
})?
} else {
self.registry.get(request.search_type)?
};
let resolved_request_owned;
let request: &SearchRequest = match (&request.datasets, &request.dataset_ids) {
(Some(names), maybe_ids)
if !names.is_empty()
&& maybe_ids.as_ref().map(|v| v.is_empty()).unwrap_or(true) =>
{
let resolver = self.dataset_resolver.as_ref().ok_or_else(|| {
SearchError::InvalidInput(
"dataset name filter requested but no dataset resolver is wired \
into the SearchOrchestrator (call SearchBuilder::with_dataset_resolver)"
.to_string(),
)
})?;
let owner_id = request.user_id.ok_or_else(|| {
SearchError::InvalidInput(
"dataset name filter requires SearchRequest.user_id to identify the owner"
.to_string(),
)
})?;
let mut resolved = Vec::with_capacity(names.len());
let mut missing = Vec::new();
for name in names {
match resolver.get_dataset_by_name(name, owner_id, None).await? {
Some(ds) => resolved.push(ds.id),
None => missing.push(name.clone()),
}
}
if resolved.is_empty() {
return Err(SearchError::DatasetNotFound(missing.join(", ")));
}
if !missing.is_empty() {
tracing::warn!(
missing = ?missing,
"some requested dataset names did not resolve; proceeding with the resolved subset"
);
}
let mut clone = request.clone();
clone.dataset_ids = Some(resolved);
resolved_request_owned = clone;
&resolved_request_owned
}
_ => request,
};
let params = SearchParams::from(request);
let use_dataset_scope = request
.dataset_ids
.as_ref()
.map(|ids| !ids.is_empty())
.unwrap_or(false);
let should_save_interaction = request.save_interaction.unwrap_or(true);
let query_type = format!("{:?}", request.search_type);
let mut logged_query_id = None;
if should_save_interaction
&& let Some(database) = &self.database
&& let Ok(query_id) = database
.log_query(&request.query_text, &query_type, request.user_id)
.await
{
logged_query_id = Some(query_id);
}
let include_context =
request.only_context() || request.use_combined_context() || use_dataset_scope;
let base_context = if include_context {
let ctx = retriever.get_context(&request.query_text, ¶ms).await?;
if self.enable_access_tracking && !ctx.is_empty() {
if let Some(resolver) = &self.dataset_resolver {
if let Err(e) =
crate::utils::update_node_access_timestamps(resolver.as_ref(), &ctx).await
{
tracing::warn!(
error = %e,
"access tracking: failed to persist last_accessed timestamps"
);
}
} else {
let accessed_ids: Vec<String> = ctx
.iter()
.filter_map(|item| {
item.payload
.get("data_id")
.and_then(|v| v.as_str())
.map(String::from)
})
.collect();
if !accessed_ids.is_empty() {
tracing::debug!(
data_ids = ?accessed_ids,
"access tracking: would update last_accessed for {} data records \
but no IngestDb resolver is wired",
accessed_ids.len()
);
}
}
}
Some(ctx)
} else {
None
};
let scoped_contexts = match (&request.dataset_ids, &base_context) {
(Some(dataset_ids), Some(context)) if !dataset_ids.is_empty() => {
Some(scope_context_by_datasets(context, dataset_ids))
}
_ => None,
};
let context = if let Some(scoped_context_map) = &scoped_contexts {
if request.use_combined_context() {
Some(merge_scoped_contexts(scoped_context_map))
} else if let Some(dataset_ids) = request.dataset_ids.as_ref() {
let first_key = dataset_ids.first().map(|id| id.to_string());
first_key
.and_then(|key| scoped_context_map.get(&key).cloned())
.or_else(|| Some(vec![]))
} else {
base_context.clone()
}
} else {
base_context.clone()
};
if request.only_context() {
let output_context = context.unwrap_or_default();
let mut response = prepare_search_result(
request.search_type,
SearchOutput::Items(output_context.clone()),
Some(output_context),
request.dataset_ids.clone(),
true,
request.use_combined_context(),
request.verbose(),
);
if let Some(scoped_context_map) = scoped_contexts
&& !request.use_combined_context()
{
response.context = Some(scoped_context_map);
}
self.log_result_if_enabled(logged_query_id, &response, request.user_id)
.await;
emit_search_completed(request);
return Ok(response);
}
let user_id_str = request.user_id.map(|id| id.to_string());
let session_context = if let (Some(session_id), Some(sm)) =
(&request.session_id, &self.session_manager)
{
let (history, formatted_history) = sm
.load_history_both(Some(session_id), user_id_str.as_deref())
.await
.unwrap_or_default();
let graph_context = sm
.get_graph_context(Some(session_id), user_id_str.as_deref())
.await
.ok()
.flatten();
let formatted_history = if let Some(gc) =
graph_context.as_deref().filter(|s| !s.is_empty())
{
let gc = apply_context_char_limit(gc, None);
format!(
"Background knowledge from the knowledge graph:\n{gc}\n\n{formatted_history}"
)
} else {
formatted_history
};
SessionContext {
session_id: Some(session_id.clone()),
history,
formatted_history,
graph_context,
}
} else {
SessionContext {
session_id: request.session_id.clone(),
..SessionContext::default()
}
};
let last_qa_id: Option<String> = if let (Some(session_id), Some(sm)) =
(&request.session_id, &self.session_manager)
&& request.auto_feedback_detection.unwrap_or(false)
&& !session_id.is_empty()
{
sm.latest_qa_id(Some(session_id), user_id_str.as_deref())
.await
.unwrap_or(None)
} else {
None
};
if request.auto_feedback_detection.unwrap_or(false)
&& let (Some(session_id), Some(llm)) = (&request.session_id, &self.llm)
&& !session_id.is_empty()
{
let detection = detect_feedback(llm.as_ref(), &request.query_text).await;
if detection.feedback_detected
&& let (Some(prior_id), Some(sm)) = (&last_qa_id, &self.session_manager)
{
let score: Option<i32> = detection.feedback_score.map(|s| {
let s = s.round() as i32;
s.clamp(1, 5)
});
let feedback_text = detection
.feedback_text
.as_deref()
.map(str::trim)
.filter(|t| !t.is_empty())
.map(String::from)
.unwrap_or_else(|| format!("User message: {}", request.query_text.trim()));
if let Err(e) = sm
.add_feedback(
Some(session_id),
user_id_str.as_deref(),
prior_id,
Some(&feedback_text),
score,
)
.await
{
tracing::warn!(
prior_qa_id = %prior_id,
"auto-feedback persistence failed, proceeding without storing: {e}"
);
}
if !detection.contains_followup_question {
let acknowledgment = detection
.response_to_user
.unwrap_or_else(|| "Thank you for your feedback!".to_string());
let response = prepare_search_result(
request.search_type,
SearchOutput::Text(acknowledgment),
None,
request.dataset_ids.clone(),
false,
request.use_combined_context(),
request.verbose(),
);
emit_search_completed(request);
return Ok(response);
}
}
}
let output = retriever
.get_completion(
&request.query_text,
context.clone(),
&session_context,
¶ms,
)
.await?;
if let (Some(session_id), Some(sm)) = (&request.session_id, &self.session_manager)
&& let SearchOutput::Text(ref answer) = output
{
let ctx_json = if request.summarize_context == Some(true) {
context.as_ref().and_then(|c| serde_json::to_string(c).ok())
} else {
Some(String::new())
};
let used_graph_element_ids = build_used_graph_element_ids(context.as_deref());
let _ = sm
.save_qa(
Some(session_id),
user_id_str.as_deref(),
&request.query_text,
answer,
ctx_json.as_deref(),
used_graph_element_ids,
)
.await;
}
let mut response = prepare_search_result(
request.search_type,
output,
context,
request.dataset_ids.clone(),
false,
request.use_combined_context(),
request.verbose(),
);
if let Some(scoped_context_map) = scoped_contexts
&& !request.use_combined_context()
{
response.context = Some(scoped_context_map);
}
self.log_result_if_enabled(logged_query_id, &response, request.user_id)
.await;
emit_search_completed(request);
Ok(response)
}
async fn log_result_if_enabled(
&self,
query_id: Option<uuid::Uuid>,
response: &SearchResponse,
user_id: Option<uuid::Uuid>,
) {
let (Some(query_id), Some(database)) = (query_id, &self.database) else {
return;
};
if let Ok(serialized_response) = serde_json::to_string(response) {
let _ = database
.log_result(query_id, &serialized_response, user_id)
.await;
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use crate::orchestration::SearchTypeRegistry;
use crate::orchestration::{CONTEXT_LABEL_COMBINED, CONTEXT_LABEL_DEFAULT};
use crate::retrievers::SearchRetriever;
use crate::types::{
SearchContext, SearchError, SearchOutput, SearchParams, SearchRequest, SearchType,
};
use async_trait::async_trait;
use cognee_database::IngestDb;
use cognee_database::ops as db_ops;
use cognee_database::{SearchHistoryDb, SearchHistoryEntryType, connect, initialize};
use cognee_models::Dataset;
use cognee_session::SessionContext;
use serde_json::json;
use std::sync::Arc;
use uuid::Uuid;
struct FakeChunksRetriever;
#[async_trait]
impl SearchRetriever for FakeChunksRetriever {
fn search_type(&self) -> SearchType {
SearchType::Chunks
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![crate::types::SearchItem {
id: None,
score: Some(0.9),
payload: json!({ "text": "context value" }),
}])
}
async fn get_completion(
&self,
_query: &str,
_context: Option<SearchContext>,
_session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
Ok(SearchOutput::Text("answer value".to_string()))
}
}
#[tokio::test]
async fn routes_to_registered_retriever_for_completion() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
match response.result {
SearchOutput::Text(answer) => assert_eq!(answer, "answer value"),
_ => panic!("unexpected output kind"),
}
assert!(response.context.is_none());
assert!(response.graphs.is_none());
}
#[tokio::test]
async fn routes_to_registered_retriever_for_context() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(true),
use_combined_context: Some(true),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
assert!(response.only_context);
match response.result {
SearchOutput::Items(items) => {
assert_eq!(items.len(), 1);
assert_eq!(items[0].payload["text"], "context value");
}
_ => panic!("unexpected output kind"),
}
let context = response.context.expect("context should exist");
assert!(context.contains_key(CONTEXT_LABEL_COMBINED));
assert!(response.graphs.is_none());
}
#[tokio::test]
async fn routes_to_registered_retriever_for_default_context_label() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(true),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
let context = response.context.expect("context should exist");
assert!(context.contains_key(CONTEXT_LABEL_DEFAULT));
}
#[tokio::test]
async fn includes_graph_when_context_is_fetched() {
struct FakeGraphRetriever;
#[async_trait]
impl SearchRetriever for FakeGraphRetriever {
fn search_type(&self) -> SearchType {
SearchType::GraphCompletion
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![crate::types::SearchItem {
id: None,
score: Some(0.9),
payload: json!({
"source_id": "a",
"target_id": "b",
"source_name": "Alice",
"target_name": "Bob",
"relationship": "KNOWS"
}),
}])
}
async fn get_completion(
&self,
_query: &str,
_context: Option<SearchContext>,
_session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
Ok(SearchOutput::Text("graph answer".to_string()))
}
}
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeGraphRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::GraphCompletion,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(true),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
let graphs = response
.graphs
.expect("graphs should be present when context is fetched");
let default_graph = graphs
.get(CONTEXT_LABEL_DEFAULT)
.expect("default graph should exist");
assert_eq!(default_graph.nodes.len(), 2);
assert_eq!(default_graph.edges.len(), 1);
}
#[tokio::test]
async fn fans_out_context_by_dataset_when_dataset_scope_enabled() {
let dataset_a = uuid::Uuid::new_v4();
let dataset_b = uuid::Uuid::new_v4();
struct FakeDatasetRetriever {
dataset_a: uuid::Uuid,
dataset_b: uuid::Uuid,
}
#[async_trait]
impl SearchRetriever for FakeDatasetRetriever {
fn search_type(&self) -> SearchType {
SearchType::Chunks
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![
crate::types::SearchItem {
id: None,
score: Some(0.9),
payload: json!({
"dataset_id": self.dataset_a.to_string(),
"text": "A context"
}),
},
crate::types::SearchItem {
id: None,
score: Some(0.8),
payload: json!({
"dataset_id": self.dataset_b.to_string(),
"text": "B context"
}),
},
])
}
async fn get_completion(
&self,
_query: &str,
context: Option<SearchContext>,
_session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
Ok(SearchOutput::Items(context.unwrap_or_default()))
}
}
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeDatasetRetriever {
dataset_a,
dataset_b,
}));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: Some(vec![dataset_a, dataset_b]),
system_prompt: None,
system_prompt_path: None,
only_context: Some(true),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
let context_map = response.context.expect("scoped context map must exist");
assert_eq!(context_map[&dataset_a.to_string()].len(), 1);
assert_eq!(context_map[&dataset_b.to_string()].len(), 1);
}
#[tokio::test]
async fn merges_scoped_context_when_combined_context_enabled() {
let dataset_a = uuid::Uuid::new_v4();
let dataset_b = uuid::Uuid::new_v4();
struct FakeDatasetRetriever {
dataset_a: uuid::Uuid,
dataset_b: uuid::Uuid,
}
#[async_trait]
impl SearchRetriever for FakeDatasetRetriever {
fn search_type(&self) -> SearchType {
SearchType::Chunks
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![
crate::types::SearchItem {
id: None,
score: Some(0.9),
payload: json!({
"dataset_id": self.dataset_a.to_string(),
"text": "A context"
}),
},
crate::types::SearchItem {
id: None,
score: Some(0.8),
payload: json!({
"dataset_id": self.dataset_b.to_string(),
"text": "B context"
}),
},
])
}
async fn get_completion(
&self,
_query: &str,
context: Option<SearchContext>,
_session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
Ok(SearchOutput::Items(context.unwrap_or_default()))
}
}
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeDatasetRetriever {
dataset_a,
dataset_b,
}));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: Some(vec![dataset_a, dataset_b]),
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(true),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: Some(true),
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
match response.result {
SearchOutput::Items(items) => assert_eq!(items.len(), 2),
_ => panic!("expected items output"),
}
let context = response.context.expect("combined context must exist");
assert!(context.contains_key(CONTEXT_LABEL_COMBINED));
}
#[tokio::test]
async fn persists_query_and_result_when_save_interaction_enabled() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let db = connect("sqlite::memory:").await.unwrap();
initialize(&db).await.unwrap();
let db = Arc::new(db);
let orchestrator = super::SearchOrchestrator::new(registry)
.with_database(db.clone() as Arc<dyn SearchHistoryDb>);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: Some(true),
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let _ = orchestrator.search(&request).await.unwrap();
let history = orchestrator.get_history(None, Some(10)).await.unwrap();
assert_eq!(history.len(), 2);
assert!(
history
.iter()
.any(|entry| entry.entry_type == SearchHistoryEntryType::Query)
);
assert!(
history
.iter()
.any(|entry| entry.entry_type == SearchHistoryEntryType::Result)
);
}
#[tokio::test]
async fn search_batch_returns_one_response_per_request() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let requests = vec![
SearchRequest {
query_text: "first".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
},
SearchRequest {
query_text: "second".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
},
];
let responses = orchestrator.search_batch(&requests).await.unwrap();
assert_eq!(responses.len(), 2);
for response in &responses {
match &response.result {
SearchOutput::Text(answer) => assert_eq!(answer, "answer value"),
_ => panic!("unexpected output kind"),
}
}
}
#[tokio::test]
async fn routes_to_community_retriever_by_name() {
let registry = SearchTypeRegistry::new();
let orchestrator = super::SearchOrchestrator::new(registry)
.with_community_retriever("my_custom", Arc::new(FakeChunksRetriever));
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: Some("my_custom".to_string()),
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let response = orchestrator.search(&request).await.unwrap();
match response.result {
SearchOutput::Text(answer) => assert_eq!(answer, "answer value"),
_ => panic!("unexpected output kind"),
}
}
struct ResolutionFixtureRetriever {
dataset_a: uuid::Uuid,
dataset_b: uuid::Uuid,
}
#[async_trait]
impl SearchRetriever for ResolutionFixtureRetriever {
fn search_type(&self) -> SearchType {
SearchType::Chunks
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![
crate::types::SearchItem {
id: None,
score: Some(0.9),
payload: json!({
"dataset_id": self.dataset_a.to_string(),
"text": "A context"
}),
},
crate::types::SearchItem {
id: None,
score: Some(0.8),
payload: json!({
"dataset_id": self.dataset_b.to_string(),
"text": "B context"
}),
},
])
}
async fn get_completion(
&self,
_query: &str,
context: Option<SearchContext>,
_session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
Ok(SearchOutput::Items(context.unwrap_or_default()))
}
}
fn dataset_request_template() -> SearchRequest {
SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(true),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: Some(false),
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
}
}
async fn fresh_db() -> Arc<cognee_database::DatabaseConnection> {
let db = connect("sqlite::memory:").await.unwrap();
initialize(&db).await.unwrap();
Arc::new(db)
}
async fn seed_dataset(
db: &cognee_database::DatabaseConnection,
name: &str,
owner: Uuid,
) -> Dataset {
db_ops::datasets::create_dataset(
db,
Dataset::new(name.to_string(), owner, None, Uuid::new_v4()),
)
.await
.expect("seed dataset")
}
#[tokio::test]
async fn resolves_dataset_names_to_ids_and_scopes_results() {
let owner = Uuid::new_v4();
let db = fresh_db().await;
let dataset = seed_dataset(&db, "real", owner).await;
let other = Uuid::new_v4();
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(ResolutionFixtureRetriever {
dataset_a: dataset.id,
dataset_b: other,
}));
let orchestrator =
super::SearchOrchestrator::new(registry).with_dataset_resolver(db as Arc<dyn IngestDb>);
let request = SearchRequest {
datasets: Some(vec!["real".into()]),
user_id: Some(owner),
..dataset_request_template()
};
let response = orchestrator.search(&request).await.unwrap();
let context_map = response.context.expect("scoped context map");
assert!(context_map.contains_key(&dataset.id.to_string()));
assert!(!context_map.contains_key(&other.to_string()));
}
#[tokio::test]
async fn dataset_name_resolution_is_owner_scoped() {
let owner_a = Uuid::new_v4();
let owner_b = Uuid::new_v4();
let db = fresh_db().await;
let _ = seed_dataset(&db, "shared_name", owner_a).await;
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator =
super::SearchOrchestrator::new(registry).with_dataset_resolver(db as Arc<dyn IngestDb>);
let request = SearchRequest {
datasets: Some(vec!["shared_name".into()]),
user_id: Some(owner_b),
..dataset_request_template()
};
let err = orchestrator.search(&request).await.expect_err("must error");
assert!(
matches!(err, SearchError::DatasetNotFound(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn errors_when_all_dataset_names_are_unknown() {
let owner = Uuid::new_v4();
let db = fresh_db().await;
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator =
super::SearchOrchestrator::new(registry).with_dataset_resolver(db as Arc<dyn IngestDb>);
let request = SearchRequest {
datasets: Some(vec!["does_not_exist".into(), "also_missing".into()]),
user_id: Some(owner),
..dataset_request_template()
};
let err = orchestrator.search(&request).await.expect_err("must error");
let SearchError::DatasetNotFound(joined) = err else {
panic!("expected DatasetNotFound, got {err:?}");
};
assert!(
joined.contains("does_not_exist"),
"missing names list: {joined:?}"
);
assert!(
joined.contains("also_missing"),
"missing names list: {joined:?}"
);
}
#[tokio::test]
async fn partial_resolution_drops_unknown_names_and_succeeds() {
let owner = Uuid::new_v4();
let db = fresh_db().await;
let dataset = seed_dataset(&db, "real", owner).await;
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(ResolutionFixtureRetriever {
dataset_a: dataset.id,
dataset_b: Uuid::new_v4(),
}));
let orchestrator =
super::SearchOrchestrator::new(registry).with_dataset_resolver(db as Arc<dyn IngestDb>);
let request = SearchRequest {
datasets: Some(vec!["real".into(), "missing".into()]),
user_id: Some(owner),
..dataset_request_template()
};
let response = orchestrator
.search(&request)
.await
.expect("partial resolution must succeed");
let context = response.context.expect("scoped context");
assert!(context.contains_key(&dataset.id.to_string()));
}
#[tokio::test]
async fn empty_datasets_vec_behaves_like_none() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
datasets: Some(vec![]),
user_id: None,
only_context: Some(false),
..dataset_request_template()
};
orchestrator
.search(&request)
.await
.expect("empty datasets list must not error");
}
#[tokio::test]
async fn dataset_ids_take_precedence_over_names() {
let id = Uuid::new_v4();
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(ResolutionFixtureRetriever {
dataset_a: id,
dataset_b: Uuid::new_v4(),
}));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
datasets: Some(vec!["bogus".into()]),
dataset_ids: Some(vec![id]),
..dataset_request_template()
};
let response = orchestrator
.search(&request)
.await
.expect("explicit dataset_ids must succeed without resolver");
let context_map = response.context.expect("scoped context");
assert!(context_map.contains_key(&id.to_string()));
}
#[tokio::test]
async fn errors_when_dataset_names_supplied_without_resolver() {
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
datasets: Some(vec!["whatever".into()]),
user_id: Some(Uuid::new_v4()),
..dataset_request_template()
};
let err = orchestrator.search(&request).await.expect_err("must error");
assert!(matches!(err, SearchError::InvalidInput(_)), "got {err:?}");
}
#[tokio::test]
async fn errors_when_dataset_names_supplied_without_user_id() {
let db = fresh_db().await;
let mut registry = SearchTypeRegistry::new();
registry.register(Arc::new(FakeChunksRetriever));
let orchestrator =
super::SearchOrchestrator::new(registry).with_dataset_resolver(db as Arc<dyn IngestDb>);
let request = SearchRequest {
datasets: Some(vec!["whatever".into()]),
user_id: None,
..dataset_request_template()
};
let err = orchestrator.search(&request).await.expect_err("must error");
assert!(matches!(err, SearchError::InvalidInput(_)), "got {err:?}");
}
#[tokio::test]
async fn returns_error_for_unknown_community_retriever_name() {
let registry = SearchTypeRegistry::new();
let orchestrator = super::SearchOrchestrator::new(registry);
let request = SearchRequest {
query_text: "hello".to_string(),
search_type: SearchType::Chunks,
top_k: Some(3),
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: None,
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: None,
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: Some("nonexistent".to_string()),
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
};
let result = orchestrator.search(&request).await;
assert!(
result.is_err(),
"expected error for unknown community retriever"
);
let err = result.unwrap_err();
assert!(
matches!(err, SearchError::InvalidInput(_)),
"expected InvalidInput error, got: {err:?}"
);
}
}