#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use cognee_database::{DatabaseError, SearchHistoryDb, SearchHistoryEntry};
use cognee_session::{FsSessionStore, SessionContext, SessionManager, SessionStore};
use uuid::Uuid;
use cognee_search::retrievers::SearchRetriever;
use cognee_search::types::{
SearchContext, SearchError, SearchOutput, SearchParams, SearchRequest, SearchType,
};
use cognee_search::{SearchOrchestrator, SearchTypeRegistry};
struct FakeTemporalRetriever {
captured_sessions: Arc<Mutex<Vec<SessionContext>>>,
call_count: Arc<Mutex<u32>>,
}
impl FakeTemporalRetriever {
fn new(
captured_sessions: Arc<Mutex<Vec<SessionContext>>>,
call_count: Arc<Mutex<u32>>,
) -> Self {
Self {
captured_sessions,
call_count,
}
}
}
#[async_trait]
impl SearchRetriever for FakeTemporalRetriever {
fn search_type(&self) -> SearchType {
SearchType::Temporal
}
async fn get_context(
&self,
_query: &str,
_params: &SearchParams,
) -> Result<SearchContext, SearchError> {
Ok(vec![])
}
async fn get_completion(
&self,
_query: &str,
_context: Option<SearchContext>,
session: &SessionContext,
_params: &SearchParams,
) -> Result<SearchOutput, SearchError> {
self.captured_sessions
.lock()
.unwrap() .push(session.clone());
let mut count = self.call_count.lock().unwrap(); *count += 1;
Ok(SearchOutput::Text(format!("temporal answer #{}", *count)))
}
}
struct NoOpHistoryDb;
#[async_trait]
impl SearchHistoryDb for NoOpHistoryDb {
async fn log_query(
&self,
_query_text: &str,
_query_type: &str,
_user_id: Option<Uuid>,
) -> Result<Uuid, DatabaseError> {
Ok(Uuid::new_v4())
}
async fn log_result(
&self,
_query_id: Uuid,
_serialized_result: &str,
_user_id: Option<Uuid>,
) -> Result<Uuid, DatabaseError> {
Ok(Uuid::new_v4())
}
async fn get_history(
&self,
_user_id: Option<Uuid>,
_limit: Option<usize>,
) -> Result<Vec<SearchHistoryEntry>, DatabaseError> {
Ok(vec![])
}
}
fn temporal_request(query: &str, session_id: Option<&str>) -> SearchRequest {
SearchRequest {
query_text: query.to_string(),
search_type: SearchType::Temporal,
top_k: None,
datasets: None,
dataset_ids: None,
system_prompt: None,
system_prompt_path: None,
only_context: Some(false),
use_combined_context: Some(false),
session_id: session_id.map(String::from),
node_type: None,
node_name: None,
node_name_filter_operator: None,
wide_search_top_k: None,
triplet_distance_penalty: None,
save_interaction: Some(true),
user_id: None,
verbose: None,
feedback_influence: None,
retriever_specific_config: None,
response_schema: None,
custom_search_type: None,
auto_feedback_detection: None,
neighborhood_depth: None,
neighborhood_seed_top_k: None,
summarize_context: None,
}
}
#[tokio::test]
async fn temporal_search_passes_session_context() {
let temp_dir = tempfile::tempdir().expect("tempdir creation must succeed");
let session_store = Arc::new(FsSessionStore::new(temp_dir.path().join("sessions")));
let session_manager = Arc::new(SessionManager::new(session_store.clone()));
let captured_sessions: Arc<Mutex<Vec<SessionContext>>> = Arc::new(Mutex::new(Vec::new()));
let call_count: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
let retriever = Arc::new(FakeTemporalRetriever::new(
Arc::clone(&captured_sessions),
Arc::clone(&call_count),
));
let mut registry = SearchTypeRegistry::new();
registry.register(retriever);
let orchestrator = SearchOrchestrator::new(registry)
.with_database(Arc::new(NoOpHistoryDb))
.with_session_manager(session_manager.clone());
let session_id = "test-temporal-session-1";
let response = orchestrator
.search(&temporal_request(
"What events happened in 2024?",
Some(session_id),
))
.await
.expect("temporal search with session should succeed");
match &response.result {
SearchOutput::Text(text) => {
assert!(
text.contains("temporal answer"),
"expected temporal answer, got: {text}"
);
}
other => panic!("expected Text output, got: {other:?}"),
}
{
let sessions = captured_sessions.lock().unwrap(); assert_eq!(sessions.len(), 1, "retriever should be called exactly once");
assert_eq!(
sessions[0].session_id.as_deref(),
Some(session_id),
"session_id should be forwarded to the retriever"
);
assert!(
sessions[0].history.is_empty(),
"first search in a session should have empty history"
);
assert!(
sessions[0].formatted_history.is_empty(),
"first search in a session should have empty formatted history"
);
}
let entries = session_store
.get_all_qa_entries(session_id, None)
.await
.expect("reading session entries should succeed");
assert_eq!(
entries.len(),
1,
"one QA entry should be stored after a single search"
);
assert_eq!(entries[0].question, "What events happened in 2024?");
assert!(entries[0].answer.contains("temporal answer #1"));
}
#[tokio::test]
async fn temporal_search_multiple_queries_in_session() {
let temp_dir = tempfile::tempdir().expect("tempdir creation must succeed");
let session_store = Arc::new(FsSessionStore::new(temp_dir.path().join("sessions")));
let session_manager = Arc::new(SessionManager::new(session_store.clone()));
let captured_sessions: Arc<Mutex<Vec<SessionContext>>> = Arc::new(Mutex::new(Vec::new()));
let call_count: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
let retriever = Arc::new(FakeTemporalRetriever::new(
Arc::clone(&captured_sessions),
Arc::clone(&call_count),
));
let mut registry = SearchTypeRegistry::new();
registry.register(retriever);
let orchestrator = SearchOrchestrator::new(registry)
.with_database(Arc::new(NoOpHistoryDb))
.with_session_manager(session_manager.clone());
let session_id = "test-temporal-session-multi";
let response1 = orchestrator
.search(&temporal_request(
"What happened in World War II?",
Some(session_id),
))
.await
.expect("first temporal search should succeed");
match &response1.result {
SearchOutput::Text(text) => assert!(
text.contains("temporal answer #1"),
"first response: {text}"
),
other => panic!("expected Text output for first query, got: {other:?}"),
}
let response2 = orchestrator
.search(&temporal_request(
"What happened after the war ended?",
Some(session_id),
))
.await
.expect("second temporal search should succeed");
match &response2.result {
SearchOutput::Text(text) => assert!(
text.contains("temporal answer #2"),
"second response: {text}"
),
other => panic!("expected Text output for second query, got: {other:?}"),
}
{
let sessions = captured_sessions.lock().unwrap(); assert_eq!(sessions.len(), 2, "retriever should be called twice");
assert!(
sessions[0].history.is_empty(),
"first call should have empty history"
);
assert!(
!sessions[1].history.is_empty(),
"second call should have non-empty history from the first exchange"
);
assert_eq!(
sessions[1].history.len(),
2,
"second call should have 2 history messages (user + assistant from first exchange)"
);
assert_eq!(
sessions[1].history[0].content, "What happened in World War II?",
"first history message should be the first question"
);
assert!(
sessions[1].history[1]
.content
.contains("temporal answer #1"),
"second history message should be the first answer"
);
assert!(
!sessions[1].formatted_history.is_empty(),
"second call should have non-empty formatted_history"
);
assert!(
sessions[1]
.formatted_history
.contains("What happened in World War II?"),
"formatted_history should contain the first question"
);
}
let entries = session_store
.get_all_qa_entries(session_id, None)
.await
.expect("reading session entries should succeed");
assert_eq!(
entries.len(),
2,
"two QA entries should be stored after two searches"
);
assert_eq!(entries[0].question, "What happened in World War II?");
assert_eq!(entries[1].question, "What happened after the war ended?");
}