Skip to main content

cognee_search/retrievers/
base_retriever.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_session::SessionContext;
5
6use crate::types::{SearchContext, SearchError, SearchOutput, SearchParams, SearchType};
7
8pub type SearchRetrieverRef = Arc<dyn SearchRetriever>;
9
10#[async_trait]
11pub trait SearchRetriever: Send + Sync {
12    fn search_type(&self) -> SearchType;
13
14    async fn get_context(
15        &self,
16        query: &str,
17        params: &SearchParams,
18    ) -> Result<SearchContext, SearchError>;
19
20    async fn get_completion(
21        &self,
22        query: &str,
23        context: Option<SearchContext>,
24        session: &SessionContext,
25        params: &SearchParams,
26    ) -> Result<SearchOutput, SearchError>;
27
28    /// Process multiple queries in sequence and return their contexts.
29    ///
30    /// Default: loops over [`get_context`]. Override for efficient batching.
31    async fn get_context_batch(
32        &self,
33        queries: &[String],
34        params: &SearchParams,
35    ) -> Result<Vec<SearchContext>, SearchError> {
36        let mut results = Vec::with_capacity(queries.len());
37        for query in queries {
38            results.push(self.get_context(query, params).await?);
39        }
40        Ok(results)
41    }
42
43    /// Process multiple queries and return their completions.
44    ///
45    /// Default: loops over [`get_completion`]. Override for efficient batching.
46    async fn get_completion_batch(
47        &self,
48        queries: &[String],
49        contexts: Option<Vec<SearchContext>>,
50        session: &SessionContext,
51        params: &SearchParams,
52    ) -> Result<Vec<SearchOutput>, SearchError> {
53        let mut results = Vec::with_capacity(queries.len());
54        for (i, query) in queries.iter().enumerate() {
55            let ctx = contexts.as_ref().and_then(|cs| cs.get(i).cloned());
56            results.push(self.get_completion(query, ctx, session, params).await?);
57        }
58        Ok(results)
59    }
60}