use cognee_search::observability::{
COGNEE_RECALL_SCOPE, COGNEE_RECALL_SOURCE, COGNEE_RESULT_COUNT, COGNEE_SEARCH_QUERY,
COGNEE_SESSION_ENTRY_COUNT, COGNEE_SESSION_ID,
};
use cognee_search::recall_scope::{fetch_graph_context, run_graph, search_session, search_trace};
use cognee_search::{SearchOrchestrator, SearchResponse, SearchType};
use cognee_session::{SessionManager, SessionStore};
use tracing::{field, info};
use super::error::ApiError;
pub use cognee_search::recall_scope::{
RecallItem, RecallOptions, RecallScope, RecallSource, ScopeInput, normalize_scope,
};
#[derive(Debug, Clone)]
pub struct RecallResult {
pub items: Vec<RecallItem>,
pub search_type_used: Option<SearchType>,
pub auto_routed: bool,
pub search_response: Option<SearchResponse>,
}
#[allow(clippy::too_many_arguments)]
pub async fn recall(
query_text: &str,
query_type: Option<SearchType>,
datasets: Option<Vec<String>>,
top_k: usize,
auto_route: bool,
session_id: Option<&str>,
user_id: Option<&str>,
search_orchestrator: &SearchOrchestrator,
session_store: Option<&dyn SessionStore>,
session_manager: Option<&SessionManager>,
scope: Option<Vec<RecallScope>>,
options: Option<RecallOptions>,
) -> Result<RecallResult, ApiError> {
let normalized: Vec<RecallScope> = match scope {
None => vec![RecallScope::Auto],
Some(v) if v.is_empty() => vec![RecallScope::Auto],
Some(v) => v,
};
let auto_mode = normalized.as_slice() == [RecallScope::Auto];
let (sources, auto_fallthrough): (Vec<RecallScope>, bool) = if auto_mode {
match (session_id, datasets.as_ref(), query_type) {
(Some(_), None, None) => (vec![RecallScope::Session, RecallScope::Graph], true),
(Some(_), _, _) => (vec![RecallScope::Session, RecallScope::Graph], false),
(None, _, _) => (vec![RecallScope::Graph], false),
}
} else {
(normalized, false)
};
let span_scope: String = sources
.iter()
.filter_map(|s| s.as_source().map(|src| src.as_str()))
.collect::<Vec<_>>()
.join(",");
let query_preview: &str = {
let mut end = query_text.len();
if query_text.chars().count() > 500 {
let mut idx = 0usize;
for (count, (byte_idx, _)) in query_text.char_indices().enumerate() {
if count == 500 {
idx = byte_idx;
break;
}
}
if idx > 0 {
end = idx;
}
}
&query_text[..end]
};
let span = tracing::info_span!(
"cognee.api.recall",
{ COGNEE_SEARCH_QUERY } = query_preview,
{ COGNEE_RECALL_SCOPE } = span_scope.as_str(),
{ COGNEE_SESSION_ID } = session_id.unwrap_or(""),
"cognee.recall.top_k" = top_k,
{ cognee_search::observability::COGNEE_SEARCH_TYPE } = field::Empty,
{ COGNEE_RECALL_SOURCE } = field::Empty,
{ COGNEE_RESULT_COUNT } = field::Empty,
{ COGNEE_SESSION_ENTRY_COUNT } = field::Empty,
);
let _enter = span.enter();
let mut merged: Vec<RecallItem> = Vec::new();
let mut graph_search_type: Option<SearchType> = None;
let mut graph_auto_routed = false;
let mut graph_response: Option<SearchResponse> = None;
let mut session_result_count: usize = 0;
for src in &sources {
if auto_fallthrough && *src == RecallScope::Graph && !merged.is_empty() {
break;
}
let part: Vec<RecallItem> = match src {
RecallScope::Auto => continue, RecallScope::Session => {
search_session(query_text, session_id, user_id, top_k, session_store)
.await
.map_err(|e| ApiError::Search(e.to_string()))?
}
RecallScope::Trace => {
search_trace(query_text, session_id, user_id, top_k, session_manager)
.await
.map_err(|e| ApiError::Search(e.to_string()))?
}
RecallScope::GraphContext => fetch_graph_context(session_id, user_id, session_manager)
.await
.map_err(|e| ApiError::Search(e.to_string()))?,
RecallScope::Graph => {
let (items, used_type, was_auto, response) = run_graph(
query_text,
query_type,
datasets.clone(),
top_k,
auto_route,
session_id,
search_orchestrator,
&span,
options.as_ref(),
)
.await
.map_err(|e| ApiError::Search(e.to_string()))?;
graph_search_type = Some(used_type);
graph_auto_routed = was_auto;
graph_response = Some(response);
items
}
};
if *src == RecallScope::Session {
session_result_count = part.len();
}
merged.extend(part);
}
let source_label: &str = if sources.iter().filter(|s| s.as_source().is_some()).count() == 1 {
sources
.iter()
.find_map(|s| s.as_source())
.map(|s| s.as_str())
.unwrap_or("graph")
} else {
"multi"
};
span.record(COGNEE_RECALL_SOURCE, source_label);
span.record(COGNEE_RESULT_COUNT, merged.len());
if session_result_count > 0 {
span.record(COGNEE_SESSION_ENTRY_COUNT, session_result_count);
}
info!(
results = merged.len(),
sources = ?sources,
session_id = session_id.unwrap_or("-"),
"recall: completed"
);
#[cfg(feature = "telemetry")]
{
let search_type_label = graph_search_type
.or(query_type)
.map(|t| format!("{t:?}"))
.unwrap_or_default();
cognee_telemetry::send_telemetry(
"cognee.recall",
user_id.unwrap_or("sdk"),
Some(serde_json::json!({
"query_length": query_text.len(),
"scope": span_scope,
"auto_route": auto_route,
"top_k": top_k,
"search_type": search_type_label,
"session_id": session_id,
"datasets": datasets,
"dataset_ids": serde_json::Value::Null,
})),
);
}
Ok(RecallResult {
items: merged,
search_type_used: graph_search_type,
auto_routed: graph_auto_routed,
search_response: graph_response,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn recall_options_fields_are_set_correctly() {
let opts = RecallOptions {
triplet_distance_penalty: Some(6.5),
node_name: Some(vec!["Alice".to_string()]),
..Default::default()
};
assert_eq!(opts.triplet_distance_penalty, Some(6.5));
assert_eq!(
opts.node_name.as_deref(),
Some(["Alice".to_string()].as_slice())
);
assert!(opts.system_prompt.is_none());
assert!(opts.wide_search_top_k.is_none());
assert!(opts.feedback_influence.is_none());
assert!(opts.neighborhood_depth.is_none());
assert!(opts.neighborhood_seed_top_k.is_none());
}
}