1use std::collections::HashMap;
2use std::sync::Arc;
3
4use cognee_database::{IngestDb, SearchHistoryDb};
5use cognee_embedding::EmbeddingEngine;
6use cognee_graph::GraphDBTrait;
7use cognee_llm::Llm;
8use cognee_session::SessionManager;
9use cognee_vector::VectorDB;
10
11use crate::orchestration::{SearchOrchestrator, SearchTypeRegistry};
12use crate::retrievers::{
13 ChunksRetriever, CodingRulesRetriever, CompletionRetriever, CypherSearchRetriever,
14 FeedbackRetriever, FeelingLuckyRetriever, GraphCompletionContextExtensionRetriever,
15 GraphCompletionCotRetriever, GraphCompletionRetriever, GraphSummaryCompletionRetriever,
16 LexicalRetriever, NaturalLanguageRetriever, SearchRetrieverRef, SummariesRetriever,
17 TemporalRetriever, TripletRetriever,
18};
19use crate::types::SearchType;
20
21pub struct SearchBuilder {
22 retrievers: HashMap<SearchType, SearchRetrieverRef>,
23 database: Arc<dyn SearchHistoryDb>,
24 dataset_resolver: Option<Arc<dyn IngestDb>>,
25 session_manager: Option<Arc<SessionManager>>,
26}
27
28impl SearchBuilder {
29 pub fn new(
30 vector_db: Arc<dyn VectorDB>,
31 embedding_engine: Arc<dyn EmbeddingEngine>,
32 graph_db: Arc<dyn GraphDBTrait>,
33 llm: Arc<dyn Llm>,
34 database: Arc<dyn SearchHistoryDb>,
35 ) -> Self {
36 Self {
37 retrievers: HashMap::new(),
38 database,
39 dataset_resolver: None,
40 session_manager: None,
41 }
42 .register_standard_retrievers(vector_db, embedding_engine, graph_db, llm)
43 }
44
45 pub fn with_session_manager(mut self, session_manager: Arc<SessionManager>) -> Self {
46 self.session_manager = Some(session_manager);
47 self
48 }
49
50 pub fn with_dataset_resolver(mut self, resolver: Arc<dyn IngestDb>) -> Self {
54 self.dataset_resolver = Some(resolver);
55 self
56 }
57
58 pub fn register_retriever(mut self, retriever: SearchRetrieverRef) -> Self {
59 self.retrievers.insert(retriever.search_type(), retriever);
60 self
61 }
62
63 fn register_standard_retrievers(
64 mut self,
65 vector_db: Arc<dyn VectorDB>,
66 embedding_engine: Arc<dyn EmbeddingEngine>,
67 graph_db: Arc<dyn GraphDBTrait>,
68 llm: Arc<dyn Llm>,
69 ) -> Self {
70 self.retrievers.insert(
71 SearchType::Chunks,
72 Arc::new(ChunksRetriever::new(
73 Arc::clone(&vector_db),
74 Arc::clone(&embedding_engine),
75 None,
76 )),
77 );
78
79 self.retrievers.insert(
80 SearchType::Summaries,
81 Arc::new(SummariesRetriever::new(
82 Arc::clone(&vector_db),
83 Arc::clone(&embedding_engine),
84 None,
85 )),
86 );
87
88 self.retrievers.insert(
89 SearchType::RagCompletion,
90 Arc::new(CompletionRetriever::new(
91 Arc::clone(&vector_db),
92 Arc::clone(&embedding_engine),
93 Arc::clone(&llm),
94 None,
95 None,
96 None,
97 None,
98 None,
99 )),
100 );
101
102 self.retrievers.insert(
103 SearchType::TripletCompletion,
104 Arc::new(TripletRetriever::new(
105 Arc::clone(&vector_db),
106 Arc::clone(&embedding_engine),
107 Arc::clone(&llm),
108 None,
109 None,
110 None,
111 None,
112 None,
113 )),
114 );
115
116 self.retrievers.insert(
117 SearchType::GraphCompletion,
118 Arc::new(GraphCompletionRetriever::new(
119 Arc::clone(&vector_db),
120 Arc::clone(&embedding_engine),
121 Arc::clone(&graph_db),
122 Arc::clone(&llm),
123 None,
124 None,
125 None,
126 None,
127 None,
128 None,
129 None,
130 )),
131 );
132
133 self.retrievers.insert(
134 SearchType::GraphSummaryCompletion,
135 Arc::new(GraphSummaryCompletionRetriever::new(
136 Arc::clone(&vector_db),
137 Arc::clone(&embedding_engine),
138 Arc::clone(&graph_db),
139 Arc::clone(&llm),
140 None,
141 None,
142 None,
143 None,
144 None,
145 None,
146 None,
147 )),
148 );
149
150 self.retrievers.insert(
151 SearchType::GraphCompletionContextExtension,
152 Arc::new(GraphCompletionContextExtensionRetriever::new(
153 Arc::clone(&vector_db),
154 Arc::clone(&embedding_engine),
155 Arc::clone(&graph_db),
156 Arc::clone(&llm),
157 None,
158 None,
159 None,
160 None,
161 None,
162 None,
163 None,
164 None,
165 )),
166 );
167
168 self.retrievers.insert(
169 SearchType::GraphCompletionCot,
170 Arc::new(GraphCompletionCotRetriever::new(
171 Arc::clone(&vector_db),
172 Arc::clone(&embedding_engine),
173 Arc::clone(&graph_db),
174 Arc::clone(&llm),
175 None,
176 None,
177 None,
178 None,
179 None,
180 None,
181 None,
182 None,
183 )),
184 );
185
186 self.retrievers.insert(
187 SearchType::Cypher,
188 Arc::new(CypherSearchRetriever::new(Arc::clone(&graph_db))),
189 );
190
191 self.retrievers.insert(
192 SearchType::NaturalLanguage,
193 Arc::new(NaturalLanguageRetriever::new(
194 Arc::clone(&graph_db),
195 Arc::clone(&llm),
196 None,
197 None,
198 )),
199 );
200
201 self.retrievers.insert(
202 SearchType::Temporal,
203 Arc::new(TemporalRetriever::new(
204 Arc::clone(&vector_db),
205 Arc::clone(&embedding_engine),
206 Arc::clone(&graph_db),
207 Arc::clone(&llm),
208 None,
209 None,
210 None,
211 None,
212 None,
213 None,
214 None,
215 None,
216 )),
217 );
218
219 self.retrievers.insert(
220 SearchType::ChunksLexical,
221 Arc::new(LexicalRetriever::new(
222 Arc::clone(&graph_db),
223 None,
224 false,
225 None,
226 false,
227 )),
228 );
229
230 self.retrievers.insert(
231 SearchType::Feedback,
232 Arc::new(FeedbackRetriever::new(
233 Arc::clone(&graph_db),
234 Arc::clone(&llm),
235 None,
236 None,
237 )),
238 );
239
240 self.retrievers.insert(
241 SearchType::CodingRules,
242 Arc::new(CodingRulesRetriever::new(Arc::clone(&graph_db), None)),
243 );
244
245 let feeling_lucky_retrievers = self.retrievers.clone();
246 self.retrievers.insert(
247 SearchType::FeelingLucky,
248 Arc::new(FeelingLuckyRetriever::new(
249 llm,
250 feeling_lucky_retrievers,
251 Some(SearchType::RagCompletion),
252 None,
253 )),
254 );
255
256 self
257 }
258
259 pub fn build(self) -> SearchOrchestrator {
260 let mut registry = SearchTypeRegistry::new();
261 for retriever in self.retrievers.values() {
262 registry.register(Arc::clone(retriever));
263 }
264
265 let mut orchestrator = SearchOrchestrator::new(registry).with_database(self.database);
266 if let Some(resolver) = self.dataset_resolver {
267 orchestrator = orchestrator
270 .with_dataset_resolver(resolver)
271 .with_access_tracking();
272 }
273 if let Some(session_manager) = self.session_manager {
274 orchestrator = orchestrator.with_session_manager(session_manager);
275 }
276 orchestrator
277 }
278}
279
280#[cfg(test)]
281#[allow(
282 clippy::unwrap_used,
283 clippy::expect_used,
284 reason = "test code — panics are acceptable failures"
285)]
286mod tests {
287 use std::collections::HashMap;
288 use std::sync::Arc;
289
290 use async_trait::async_trait;
291 use cognee_database::{DatabaseError, SearchHistoryDb, SearchHistoryEntry};
292 use cognee_embedding::EmbeddingResult;
293 use cognee_embedding::engine::EmbeddingEngine;
294 use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
295 use cognee_llm::{
296 GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
297 };
298 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
299
300 use serde_json::json;
301 use std::borrow::Cow;
302 use uuid::Uuid;
303
304 use cognee_session::SessionContext;
305
306 use super::SearchBuilder;
307 use crate::retrievers::SearchRetriever;
308 use crate::types::{
309 SearchContext, SearchError, SearchOutput, SearchParams, SearchRequest, SearchType,
310 };
311
312 struct TestEmbedding;
313
314 #[async_trait]
315 impl EmbeddingEngine for TestEmbedding {
316 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
317 Ok(vec![vec![0.1, 0.2]])
318 }
319
320 fn dimension(&self) -> usize {
321 2
322 }
323
324 fn batch_size(&self) -> usize {
325 8
326 }
327
328 fn max_sequence_length(&self) -> usize {
329 128
330 }
331 }
332
333 struct TestVectorDb;
334
335 #[async_trait]
336 impl VectorDB for TestVectorDb {
337 async fn create_collection(
338 &self,
339 _data_type: &str,
340 _field_name: &str,
341 _dimension: usize,
342 ) -> VectorDBResult<()> {
343 Ok(())
344 }
345
346 async fn has_collection(
347 &self,
348 _data_type: &str,
349 _field_name: &str,
350 ) -> VectorDBResult<bool> {
351 Ok(false)
352 }
353
354 async fn index_points(
355 &self,
356 _data_type: &str,
357 _field_name: &str,
358 _points: &[VectorPoint],
359 ) -> VectorDBResult<()> {
360 Ok(())
361 }
362
363 async fn search_similar(
364 &self,
365 _data_type: &str,
366 _field_name: &str,
367 _query_vector: &[f32],
368 _top_k: usize,
369 ) -> VectorDBResult<Vec<SearchResult>> {
370 Ok(vec![])
371 }
372
373 async fn delete_collection(
374 &self,
375 _data_type: &str,
376 _field_name: &str,
377 ) -> VectorDBResult<()> {
378 Ok(())
379 }
380
381 async fn collection_size(
382 &self,
383 _data_type: &str,
384 _field_name: &str,
385 ) -> VectorDBResult<usize> {
386 Ok(0)
387 }
388 }
389
390 struct TestGraphDb;
391
392 #[async_trait]
393 impl GraphDBTrait for TestGraphDb {
394 async fn initialize(&self) -> GraphDBResult<()> {
395 Ok(())
396 }
397
398 async fn is_empty(&self) -> GraphDBResult<bool> {
399 Ok(true)
400 }
401
402 async fn query(
403 &self,
404 _query: &str,
405 _params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
406 ) -> GraphDBResult<Vec<Vec<serde_json::Value>>> {
407 Ok(vec![])
408 }
409
410 async fn delete_graph(&self) -> GraphDBResult<()> {
411 Ok(())
412 }
413
414 async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
415 Ok(false)
416 }
417
418 async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
419 Ok(())
420 }
421
422 async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
423 Ok(())
424 }
425
426 async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
427 Ok(())
428 }
429
430 async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
431 Ok(())
432 }
433
434 async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
435 Ok(None)
436 }
437
438 async fn get_nodes(&self, _node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
439 Ok(vec![])
440 }
441
442 async fn has_edge(
443 &self,
444 _source_id: &str,
445 _target_id: &str,
446 _relationship_name: &str,
447 ) -> GraphDBResult<bool> {
448 Ok(false)
449 }
450
451 async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
452 Ok(vec![])
453 }
454
455 async fn add_edge(
456 &self,
457 _source_id: &str,
458 _target_id: &str,
459 _relationship_name: &str,
460 _properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
461 ) -> GraphDBResult<()> {
462 Ok(())
463 }
464
465 async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
466 Ok(())
467 }
468
469 async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
470 Ok(vec![])
471 }
472
473 async fn get_neighbors(&self, _node_id: &str) -> GraphDBResult<Vec<NodeData>> {
474 Ok(vec![])
475 }
476
477 async fn get_connections(
478 &self,
479 _node_id: &str,
480 ) -> GraphDBResult<
481 Vec<(
482 NodeData,
483 HashMap<Cow<'static, str>, serde_json::Value>,
484 NodeData,
485 )>,
486 > {
487 Ok(vec![])
488 }
489
490 async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
491 Ok((vec![], vec![]))
492 }
493
494 async fn get_graph_metrics(
495 &self,
496 _include_optional: bool,
497 ) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>> {
498 Ok(HashMap::new())
499 }
500
501 async fn get_filtered_graph_data(
502 &self,
503 _attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
504 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
505 Ok((vec![], vec![]))
506 }
507
508 async fn get_nodeset_subgraph(
509 &self,
510 _node_type: &str,
511 _node_names: &[String],
512 _node_name_filter_operator: &str,
513 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
514 Ok((vec![], vec![]))
515 }
516 }
517
518 struct TestLlm;
519
520 #[async_trait]
521 impl Llm for TestLlm {
522 async fn generate(
523 &self,
524 _messages: Vec<Message>,
525 _options: Option<GenerationOptions>,
526 ) -> LlmResult<GenerationResponse> {
527 Ok(GenerationResponse {
528 content: "ok".to_string(),
529 model: "test".to_string(),
530 usage: Some(TokenUsage {
531 prompt_tokens: 1,
532 completion_tokens: 1,
533 total_tokens: 2,
534 }),
535 finish_reason: Some("stop".to_string()),
536 })
537 }
538
539 async fn create_structured_output_with_messages_raw(
540 &self,
541 _messages: Vec<Message>,
542 _json_schema: &serde_json::Value,
543 _options: Option<GenerationOptions>,
544 ) -> LlmResult<serde_json::Value> {
545 Err(LlmError::ConfigError("not used in this test".to_string()))
546 }
547
548 fn model(&self) -> &str {
549 "test"
550 }
551 }
552
553 struct TestDatabase;
554
555 #[async_trait]
556 impl SearchHistoryDb for TestDatabase {
557 async fn log_query(
558 &self,
559 _query_text: &str,
560 _query_type: &str,
561 _user_id: Option<Uuid>,
562 ) -> Result<Uuid, DatabaseError> {
563 Ok(Uuid::new_v4())
564 }
565
566 async fn log_result(
567 &self,
568 _query_id: Uuid,
569 _serialized_result: &str,
570 _user_id: Option<Uuid>,
571 ) -> Result<Uuid, DatabaseError> {
572 Ok(Uuid::new_v4())
573 }
574
575 async fn get_history(
576 &self,
577 _user_id: Option<Uuid>,
578 _limit: Option<usize>,
579 ) -> Result<Vec<SearchHistoryEntry>, DatabaseError> {
580 Ok(vec![])
581 }
582 }
583
584 struct FakeRetriever;
585
586 #[async_trait]
587 impl SearchRetriever for FakeRetriever {
588 fn search_type(&self) -> SearchType {
589 SearchType::Chunks
590 }
591
592 async fn get_context(
593 &self,
594 _query: &str,
595 _params: &SearchParams,
596 ) -> Result<SearchContext, SearchError> {
597 Ok(vec![])
598 }
599
600 async fn get_completion(
601 &self,
602 _query: &str,
603 _context: Option<SearchContext>,
604 _session: &SessionContext,
605 _params: &SearchParams,
606 ) -> Result<SearchOutput, SearchError> {
607 Ok(SearchOutput::Text("builder-executed".to_string()))
608 }
609 }
610
611 #[tokio::test]
612 async fn executes_search_via_builder_entrypoint() {
613 let orchestrator = SearchBuilder::new(
614 Arc::new(TestVectorDb),
615 Arc::new(TestEmbedding),
616 Arc::new(TestGraphDb),
617 Arc::new(TestLlm),
618 Arc::new(TestDatabase),
619 )
620 .register_retriever(Arc::new(FakeRetriever))
621 .build();
622
623 let request = SearchRequest {
624 query_text: "hello".to_string(),
625 search_type: SearchType::Chunks,
626 top_k: Some(3),
627 datasets: None,
628 dataset_ids: None,
629 system_prompt: None,
630 system_prompt_path: None,
631 only_context: Some(false),
632 use_combined_context: Some(false),
633 session_id: None,
634 node_type: None,
635 node_name: None,
636 node_name_filter_operator: None,
637 wide_search_top_k: None,
638 triplet_distance_penalty: None,
639 save_interaction: None,
640 user_id: None,
641 verbose: None,
642 feedback_influence: None,
643 retriever_specific_config: None,
644 response_schema: None,
645 custom_search_type: None,
646 auto_feedback_detection: None,
647 neighborhood_depth: None,
648 neighborhood_seed_top_k: None,
649 summarize_context: None,
650 };
651
652 let response = orchestrator.search(&request).await.unwrap();
653
654 match response.result {
655 SearchOutput::Text(text) => assert_eq!(text, "builder-executed"),
656 _ => panic!("expected text result"),
657 }
658 }
659
660 #[tokio::test]
661 async fn supports_context_only_execution_through_entrypoint() {
662 struct ContextRetriever;
663
664 #[async_trait]
665 impl SearchRetriever for ContextRetriever {
666 fn search_type(&self) -> SearchType {
667 SearchType::Summaries
668 }
669
670 async fn get_context(
671 &self,
672 _query: &str,
673 _params: &SearchParams,
674 ) -> Result<SearchContext, SearchError> {
675 Ok(vec![crate::types::SearchItem {
676 id: None,
677 score: Some(0.9),
678 payload: json!({ "text": "summary item" }),
679 }])
680 }
681
682 async fn get_completion(
683 &self,
684 _query: &str,
685 _context: Option<SearchContext>,
686 _session: &SessionContext,
687 _params: &SearchParams,
688 ) -> Result<SearchOutput, SearchError> {
689 Ok(SearchOutput::Text("unused".to_string()))
690 }
691 }
692
693 let orchestrator = SearchBuilder::new(
694 Arc::new(TestVectorDb),
695 Arc::new(TestEmbedding),
696 Arc::new(TestGraphDb),
697 Arc::new(TestLlm),
698 Arc::new(TestDatabase),
699 )
700 .register_retriever(Arc::new(ContextRetriever))
701 .build();
702
703 let request = SearchRequest {
704 query_text: "hello".to_string(),
705 search_type: SearchType::Summaries,
706 top_k: Some(3),
707 datasets: None,
708 dataset_ids: None,
709 system_prompt: None,
710 system_prompt_path: None,
711 only_context: Some(true),
712 use_combined_context: Some(false),
713 session_id: None,
714 node_type: None,
715 node_name: None,
716 node_name_filter_operator: None,
717 wide_search_top_k: None,
718 triplet_distance_penalty: None,
719 save_interaction: None,
720 user_id: None,
721 verbose: None,
722 feedback_influence: None,
723 retriever_specific_config: None,
724 response_schema: None,
725 custom_search_type: None,
726 auto_feedback_detection: None,
727 neighborhood_depth: None,
728 neighborhood_seed_top_k: None,
729 summarize_context: None,
730 };
731
732 let response = orchestrator.search(&request).await.unwrap();
733 match response.result {
734 SearchOutput::Items(items) => {
735 assert_eq!(items.len(), 1);
736 assert_eq!(items[0].payload["text"], "summary item");
737 }
738 _ => panic!("expected items result"),
739 }
740 }
741}