Skip to main content

cognee_search/orchestration/
search_execution_builder.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use cognee_database::{IngestDb, SearchHistoryDb};
5use cognee_embedding::EmbeddingEngine;
6use cognee_graph::GraphDBTrait;
7use cognee_llm::Llm;
8use cognee_session::SessionManager;
9use cognee_vector::VectorDB;
10
11use crate::orchestration::{SearchOrchestrator, SearchTypeRegistry};
12use crate::retrievers::{
13    ChunksRetriever, CodingRulesRetriever, CompletionRetriever, CypherSearchRetriever,
14    FeedbackRetriever, FeelingLuckyRetriever, GraphCompletionContextExtensionRetriever,
15    GraphCompletionCotRetriever, GraphCompletionRetriever, GraphSummaryCompletionRetriever,
16    LexicalRetriever, NaturalLanguageRetriever, SearchRetrieverRef, SummariesRetriever,
17    TemporalRetriever, TripletRetriever,
18};
19use crate::types::SearchType;
20
21pub struct SearchBuilder {
22    retrievers: HashMap<SearchType, SearchRetrieverRef>,
23    database: Arc<dyn SearchHistoryDb>,
24    dataset_resolver: Option<Arc<dyn IngestDb>>,
25    session_manager: Option<Arc<SessionManager>>,
26}
27
28impl SearchBuilder {
29    pub fn new(
30        vector_db: Arc<dyn VectorDB>,
31        embedding_engine: Arc<dyn EmbeddingEngine>,
32        graph_db: Arc<dyn GraphDBTrait>,
33        llm: Arc<dyn Llm>,
34        database: Arc<dyn SearchHistoryDb>,
35    ) -> Self {
36        Self {
37            retrievers: HashMap::new(),
38            database,
39            dataset_resolver: None,
40            session_manager: None,
41        }
42        .register_standard_retrievers(vector_db, embedding_engine, graph_db, llm)
43    }
44
45    pub fn with_session_manager(mut self, session_manager: Arc<SessionManager>) -> Self {
46        self.session_manager = Some(session_manager);
47        self
48    }
49
50    /// Wire in an `IngestDb`-backed resolver so dataset name strings can be
51    /// translated to UUIDs when `SearchRequest.datasets` is set. Without one,
52    /// requests carrying `datasets` fail with `SearchError::InvalidInput`.
53    pub fn with_dataset_resolver(mut self, resolver: Arc<dyn IngestDb>) -> Self {
54        self.dataset_resolver = Some(resolver);
55        self
56    }
57
58    pub fn register_retriever(mut self, retriever: SearchRetrieverRef) -> Self {
59        self.retrievers.insert(retriever.search_type(), retriever);
60        self
61    }
62
63    fn register_standard_retrievers(
64        mut self,
65        vector_db: Arc<dyn VectorDB>,
66        embedding_engine: Arc<dyn EmbeddingEngine>,
67        graph_db: Arc<dyn GraphDBTrait>,
68        llm: Arc<dyn Llm>,
69    ) -> Self {
70        self.retrievers.insert(
71            SearchType::Chunks,
72            Arc::new(ChunksRetriever::new(
73                Arc::clone(&vector_db),
74                Arc::clone(&embedding_engine),
75                None,
76            )),
77        );
78
79        self.retrievers.insert(
80            SearchType::Summaries,
81            Arc::new(SummariesRetriever::new(
82                Arc::clone(&vector_db),
83                Arc::clone(&embedding_engine),
84                None,
85            )),
86        );
87
88        self.retrievers.insert(
89            SearchType::RagCompletion,
90            Arc::new(CompletionRetriever::new(
91                Arc::clone(&vector_db),
92                Arc::clone(&embedding_engine),
93                Arc::clone(&llm),
94                None,
95                None,
96                None,
97                None,
98                None,
99            )),
100        );
101
102        self.retrievers.insert(
103            SearchType::TripletCompletion,
104            Arc::new(TripletRetriever::new(
105                Arc::clone(&vector_db),
106                Arc::clone(&embedding_engine),
107                Arc::clone(&llm),
108                None,
109                None,
110                None,
111                None,
112                None,
113            )),
114        );
115
116        self.retrievers.insert(
117            SearchType::GraphCompletion,
118            Arc::new(GraphCompletionRetriever::new(
119                Arc::clone(&vector_db),
120                Arc::clone(&embedding_engine),
121                Arc::clone(&graph_db),
122                Arc::clone(&llm),
123                None,
124                None,
125                None,
126                None,
127                None,
128                None,
129                None,
130            )),
131        );
132
133        self.retrievers.insert(
134            SearchType::GraphSummaryCompletion,
135            Arc::new(GraphSummaryCompletionRetriever::new(
136                Arc::clone(&vector_db),
137                Arc::clone(&embedding_engine),
138                Arc::clone(&graph_db),
139                Arc::clone(&llm),
140                None,
141                None,
142                None,
143                None,
144                None,
145                None,
146                None,
147            )),
148        );
149
150        self.retrievers.insert(
151            SearchType::GraphCompletionContextExtension,
152            Arc::new(GraphCompletionContextExtensionRetriever::new(
153                Arc::clone(&vector_db),
154                Arc::clone(&embedding_engine),
155                Arc::clone(&graph_db),
156                Arc::clone(&llm),
157                None,
158                None,
159                None,
160                None,
161                None,
162                None,
163                None,
164                None,
165            )),
166        );
167
168        self.retrievers.insert(
169            SearchType::GraphCompletionCot,
170            Arc::new(GraphCompletionCotRetriever::new(
171                Arc::clone(&vector_db),
172                Arc::clone(&embedding_engine),
173                Arc::clone(&graph_db),
174                Arc::clone(&llm),
175                None,
176                None,
177                None,
178                None,
179                None,
180                None,
181                None,
182                None,
183            )),
184        );
185
186        self.retrievers.insert(
187            SearchType::Cypher,
188            Arc::new(CypherSearchRetriever::new(Arc::clone(&graph_db))),
189        );
190
191        self.retrievers.insert(
192            SearchType::NaturalLanguage,
193            Arc::new(NaturalLanguageRetriever::new(
194                Arc::clone(&graph_db),
195                Arc::clone(&llm),
196                None,
197                None,
198            )),
199        );
200
201        self.retrievers.insert(
202            SearchType::Temporal,
203            Arc::new(TemporalRetriever::new(
204                Arc::clone(&vector_db),
205                Arc::clone(&embedding_engine),
206                Arc::clone(&graph_db),
207                Arc::clone(&llm),
208                None,
209                None,
210                None,
211                None,
212                None,
213                None,
214                None,
215                None,
216            )),
217        );
218
219        self.retrievers.insert(
220            SearchType::ChunksLexical,
221            Arc::new(LexicalRetriever::new(
222                Arc::clone(&graph_db),
223                None,
224                false,
225                None,
226                false,
227            )),
228        );
229
230        self.retrievers.insert(
231            SearchType::Feedback,
232            Arc::new(FeedbackRetriever::new(
233                Arc::clone(&graph_db),
234                Arc::clone(&llm),
235                None,
236                None,
237            )),
238        );
239
240        self.retrievers.insert(
241            SearchType::CodingRules,
242            Arc::new(CodingRulesRetriever::new(Arc::clone(&graph_db), None)),
243        );
244
245        let feeling_lucky_retrievers = self.retrievers.clone();
246        self.retrievers.insert(
247            SearchType::FeelingLucky,
248            Arc::new(FeelingLuckyRetriever::new(
249                llm,
250                feeling_lucky_retrievers,
251                Some(SearchType::RagCompletion),
252                None,
253            )),
254        );
255
256        self
257    }
258
259    pub fn build(self) -> SearchOrchestrator {
260        let mut registry = SearchTypeRegistry::new();
261        for retriever in self.retrievers.values() {
262            registry.register(Arc::clone(retriever));
263        }
264
265        let mut orchestrator = SearchOrchestrator::new(registry).with_database(self.database);
266        if let Some(resolver) = self.dataset_resolver {
267            // Auto-enable access tracking when an IngestDb resolver is available,
268            // so last_accessed timestamps on Data records are updated on every search.
269            orchestrator = orchestrator
270                .with_dataset_resolver(resolver)
271                .with_access_tracking();
272        }
273        if let Some(session_manager) = self.session_manager {
274            orchestrator = orchestrator.with_session_manager(session_manager);
275        }
276        orchestrator
277    }
278}
279
280#[cfg(test)]
281#[allow(
282    clippy::unwrap_used,
283    clippy::expect_used,
284    reason = "test code — panics are acceptable failures"
285)]
286mod tests {
287    use std::collections::HashMap;
288    use std::sync::Arc;
289
290    use async_trait::async_trait;
291    use cognee_database::{DatabaseError, SearchHistoryDb, SearchHistoryEntry};
292    use cognee_embedding::EmbeddingResult;
293    use cognee_embedding::engine::EmbeddingEngine;
294    use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
295    use cognee_llm::{
296        GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
297    };
298    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
299
300    use serde_json::json;
301    use std::borrow::Cow;
302    use uuid::Uuid;
303
304    use cognee_session::SessionContext;
305
306    use super::SearchBuilder;
307    use crate::retrievers::SearchRetriever;
308    use crate::types::{
309        SearchContext, SearchError, SearchOutput, SearchParams, SearchRequest, SearchType,
310    };
311
312    struct TestEmbedding;
313
314    #[async_trait]
315    impl EmbeddingEngine for TestEmbedding {
316        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
317            Ok(vec![vec![0.1, 0.2]])
318        }
319
320        fn dimension(&self) -> usize {
321            2
322        }
323
324        fn batch_size(&self) -> usize {
325            8
326        }
327
328        fn max_sequence_length(&self) -> usize {
329            128
330        }
331    }
332
333    struct TestVectorDb;
334
335    #[async_trait]
336    impl VectorDB for TestVectorDb {
337        async fn create_collection(
338            &self,
339            _data_type: &str,
340            _field_name: &str,
341            _dimension: usize,
342        ) -> VectorDBResult<()> {
343            Ok(())
344        }
345
346        async fn has_collection(
347            &self,
348            _data_type: &str,
349            _field_name: &str,
350        ) -> VectorDBResult<bool> {
351            Ok(false)
352        }
353
354        async fn index_points(
355            &self,
356            _data_type: &str,
357            _field_name: &str,
358            _points: &[VectorPoint],
359        ) -> VectorDBResult<()> {
360            Ok(())
361        }
362
363        async fn search_similar(
364            &self,
365            _data_type: &str,
366            _field_name: &str,
367            _query_vector: &[f32],
368            _top_k: usize,
369        ) -> VectorDBResult<Vec<SearchResult>> {
370            Ok(vec![])
371        }
372
373        async fn delete_collection(
374            &self,
375            _data_type: &str,
376            _field_name: &str,
377        ) -> VectorDBResult<()> {
378            Ok(())
379        }
380
381        async fn collection_size(
382            &self,
383            _data_type: &str,
384            _field_name: &str,
385        ) -> VectorDBResult<usize> {
386            Ok(0)
387        }
388    }
389
390    struct TestGraphDb;
391
392    #[async_trait]
393    impl GraphDBTrait for TestGraphDb {
394        async fn initialize(&self) -> GraphDBResult<()> {
395            Ok(())
396        }
397
398        async fn is_empty(&self) -> GraphDBResult<bool> {
399            Ok(true)
400        }
401
402        async fn query(
403            &self,
404            _query: &str,
405            _params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
406        ) -> GraphDBResult<Vec<Vec<serde_json::Value>>> {
407            Ok(vec![])
408        }
409
410        async fn delete_graph(&self) -> GraphDBResult<()> {
411            Ok(())
412        }
413
414        async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
415            Ok(false)
416        }
417
418        async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
419            Ok(())
420        }
421
422        async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
423            Ok(())
424        }
425
426        async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
427            Ok(())
428        }
429
430        async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
431            Ok(())
432        }
433
434        async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
435            Ok(None)
436        }
437
438        async fn get_nodes(&self, _node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
439            Ok(vec![])
440        }
441
442        async fn has_edge(
443            &self,
444            _source_id: &str,
445            _target_id: &str,
446            _relationship_name: &str,
447        ) -> GraphDBResult<bool> {
448            Ok(false)
449        }
450
451        async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
452            Ok(vec![])
453        }
454
455        async fn add_edge(
456            &self,
457            _source_id: &str,
458            _target_id: &str,
459            _relationship_name: &str,
460            _properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
461        ) -> GraphDBResult<()> {
462            Ok(())
463        }
464
465        async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
466            Ok(())
467        }
468
469        async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
470            Ok(vec![])
471        }
472
473        async fn get_neighbors(&self, _node_id: &str) -> GraphDBResult<Vec<NodeData>> {
474            Ok(vec![])
475        }
476
477        async fn get_connections(
478            &self,
479            _node_id: &str,
480        ) -> GraphDBResult<
481            Vec<(
482                NodeData,
483                HashMap<Cow<'static, str>, serde_json::Value>,
484                NodeData,
485            )>,
486        > {
487            Ok(vec![])
488        }
489
490        async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
491            Ok((vec![], vec![]))
492        }
493
494        async fn get_graph_metrics(
495            &self,
496            _include_optional: bool,
497        ) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>> {
498            Ok(HashMap::new())
499        }
500
501        async fn get_filtered_graph_data(
502            &self,
503            _attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
504        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
505            Ok((vec![], vec![]))
506        }
507
508        async fn get_nodeset_subgraph(
509            &self,
510            _node_type: &str,
511            _node_names: &[String],
512            _node_name_filter_operator: &str,
513        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
514            Ok((vec![], vec![]))
515        }
516    }
517
518    struct TestLlm;
519
520    #[async_trait]
521    impl Llm for TestLlm {
522        async fn generate(
523            &self,
524            _messages: Vec<Message>,
525            _options: Option<GenerationOptions>,
526        ) -> LlmResult<GenerationResponse> {
527            Ok(GenerationResponse {
528                content: "ok".to_string(),
529                model: "test".to_string(),
530                usage: Some(TokenUsage {
531                    prompt_tokens: 1,
532                    completion_tokens: 1,
533                    total_tokens: 2,
534                }),
535                finish_reason: Some("stop".to_string()),
536            })
537        }
538
539        async fn create_structured_output_with_messages_raw(
540            &self,
541            _messages: Vec<Message>,
542            _json_schema: &serde_json::Value,
543            _options: Option<GenerationOptions>,
544        ) -> LlmResult<serde_json::Value> {
545            Err(LlmError::ConfigError("not used in this test".to_string()))
546        }
547
548        fn model(&self) -> &str {
549            "test"
550        }
551    }
552
553    struct TestDatabase;
554
555    #[async_trait]
556    impl SearchHistoryDb for TestDatabase {
557        async fn log_query(
558            &self,
559            _query_text: &str,
560            _query_type: &str,
561            _user_id: Option<Uuid>,
562        ) -> Result<Uuid, DatabaseError> {
563            Ok(Uuid::new_v4())
564        }
565
566        async fn log_result(
567            &self,
568            _query_id: Uuid,
569            _serialized_result: &str,
570            _user_id: Option<Uuid>,
571        ) -> Result<Uuid, DatabaseError> {
572            Ok(Uuid::new_v4())
573        }
574
575        async fn get_history(
576            &self,
577            _user_id: Option<Uuid>,
578            _limit: Option<usize>,
579        ) -> Result<Vec<SearchHistoryEntry>, DatabaseError> {
580            Ok(vec![])
581        }
582    }
583
584    struct FakeRetriever;
585
586    #[async_trait]
587    impl SearchRetriever for FakeRetriever {
588        fn search_type(&self) -> SearchType {
589            SearchType::Chunks
590        }
591
592        async fn get_context(
593            &self,
594            _query: &str,
595            _params: &SearchParams,
596        ) -> Result<SearchContext, SearchError> {
597            Ok(vec![])
598        }
599
600        async fn get_completion(
601            &self,
602            _query: &str,
603            _context: Option<SearchContext>,
604            _session: &SessionContext,
605            _params: &SearchParams,
606        ) -> Result<SearchOutput, SearchError> {
607            Ok(SearchOutput::Text("builder-executed".to_string()))
608        }
609    }
610
611    #[tokio::test]
612    async fn executes_search_via_builder_entrypoint() {
613        let orchestrator = SearchBuilder::new(
614            Arc::new(TestVectorDb),
615            Arc::new(TestEmbedding),
616            Arc::new(TestGraphDb),
617            Arc::new(TestLlm),
618            Arc::new(TestDatabase),
619        )
620        .register_retriever(Arc::new(FakeRetriever))
621        .build();
622
623        let request = SearchRequest {
624            query_text: "hello".to_string(),
625            search_type: SearchType::Chunks,
626            top_k: Some(3),
627            datasets: None,
628            dataset_ids: None,
629            system_prompt: None,
630            system_prompt_path: None,
631            only_context: Some(false),
632            use_combined_context: Some(false),
633            session_id: None,
634            node_type: None,
635            node_name: None,
636            node_name_filter_operator: None,
637            wide_search_top_k: None,
638            triplet_distance_penalty: None,
639            save_interaction: None,
640            user_id: None,
641            verbose: None,
642            feedback_influence: None,
643            retriever_specific_config: None,
644            response_schema: None,
645            custom_search_type: None,
646            auto_feedback_detection: None,
647            neighborhood_depth: None,
648            neighborhood_seed_top_k: None,
649            summarize_context: None,
650        };
651
652        let response = orchestrator.search(&request).await.unwrap();
653
654        match response.result {
655            SearchOutput::Text(text) => assert_eq!(text, "builder-executed"),
656            _ => panic!("expected text result"),
657        }
658    }
659
660    #[tokio::test]
661    async fn supports_context_only_execution_through_entrypoint() {
662        struct ContextRetriever;
663
664        #[async_trait]
665        impl SearchRetriever for ContextRetriever {
666            fn search_type(&self) -> SearchType {
667                SearchType::Summaries
668            }
669
670            async fn get_context(
671                &self,
672                _query: &str,
673                _params: &SearchParams,
674            ) -> Result<SearchContext, SearchError> {
675                Ok(vec![crate::types::SearchItem {
676                    id: None,
677                    score: Some(0.9),
678                    payload: json!({ "text": "summary item" }),
679                }])
680            }
681
682            async fn get_completion(
683                &self,
684                _query: &str,
685                _context: Option<SearchContext>,
686                _session: &SessionContext,
687                _params: &SearchParams,
688            ) -> Result<SearchOutput, SearchError> {
689                Ok(SearchOutput::Text("unused".to_string()))
690            }
691        }
692
693        let orchestrator = SearchBuilder::new(
694            Arc::new(TestVectorDb),
695            Arc::new(TestEmbedding),
696            Arc::new(TestGraphDb),
697            Arc::new(TestLlm),
698            Arc::new(TestDatabase),
699        )
700        .register_retriever(Arc::new(ContextRetriever))
701        .build();
702
703        let request = SearchRequest {
704            query_text: "hello".to_string(),
705            search_type: SearchType::Summaries,
706            top_k: Some(3),
707            datasets: None,
708            dataset_ids: None,
709            system_prompt: None,
710            system_prompt_path: None,
711            only_context: Some(true),
712            use_combined_context: Some(false),
713            session_id: None,
714            node_type: None,
715            node_name: None,
716            node_name_filter_operator: None,
717            wide_search_top_k: None,
718            triplet_distance_penalty: None,
719            save_interaction: None,
720            user_id: None,
721            verbose: None,
722            feedback_influence: None,
723            retriever_specific_config: None,
724            response_schema: None,
725            custom_search_type: None,
726            auto_feedback_detection: None,
727            neighborhood_depth: None,
728            neighborhood_seed_top_k: None,
729            summarize_context: None,
730        };
731
732        let response = orchestrator.search(&request).await.unwrap();
733        match response.result {
734            SearchOutput::Items(items) => {
735                assert_eq!(items.len(), 1);
736                assert_eq!(items[0].payload["text"], "summary item");
737            }
738            _ => panic!("expected items result"),
739        }
740    }
741}