agent_chain_core/
retrievers.rs

1//! **Retriever** trait returns Documents given a text **query**.
2//!
3//! It is more general than a vector store. A retriever does not need to be able to
4//! store documents, only to return (or retrieve) them. Vector stores can be used as
5//! the backbone of a retriever, but there are other types of retrievers as well.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use agent_chain_core::retrievers::BaseRetriever;
11//! use agent_chain_core::documents::Document;
12//! use agent_chain_core::callbacks::CallbackManagerForRetrieverRun;
13//! use agent_chain_core::error::Result;
14//! use async_trait::async_trait;
15//!
16//! struct SimpleRetriever {
17//!     docs: Vec<Document>,
18//!     k: usize,
19//! }
20//!
21//! #[async_trait]
22//! impl BaseRetriever for SimpleRetriever {
23//!     fn get_relevant_documents(
24//!         &self,
25//!         query: &str,
26//!         _run_manager: Option<&CallbackManagerForRetrieverRun>,
27//!     ) -> Result<Vec<Document>> {
28//!         Ok(self.docs.iter().take(self.k).cloned().collect())
29//!     }
30//! }
31//! ```
32
33use std::collections::HashMap;
34use std::fmt::Debug;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40
41#[cfg(feature = "specta")]
42use specta::Type;
43
44use crate::callbacks::{
45    AsyncCallbackManager, AsyncCallbackManagerForRetrieverRun, CallbackManager,
46    CallbackManagerForRetrieverRun,
47};
48use crate::documents::Document;
49use crate::error::Result;
50use crate::runnables::{RunnableConfig, ensure_config};
51
52/// Type alias for retriever input (a query string).
53pub type RetrieverInput = String;
54
55/// Type alias for retriever output (a list of documents).
56pub type RetrieverOutput = Vec<Document>;
57
58/// LangSmith parameters for tracing.
59#[cfg_attr(feature = "specta", derive(Type))]
60#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
61pub struct LangSmithRetrieverParams {
62    /// Retriever name.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub ls_retriever_name: Option<String>,
65
66    /// Vector store provider.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub ls_vector_store_provider: Option<String>,
69
70    /// Embedding provider.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub ls_embedding_provider: Option<String>,
73
74    /// Embedding model.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub ls_embedding_model: Option<String>,
77}
78
79impl LangSmithRetrieverParams {
80    /// Create new LangSmithRetrieverParams with the given retriever name.
81    pub fn new(retriever_name: impl Into<String>) -> Self {
82        Self {
83            ls_retriever_name: Some(retriever_name.into()),
84            ..Default::default()
85        }
86    }
87
88    /// Set the vector store provider.
89    pub fn with_vector_store_provider(mut self, provider: impl Into<String>) -> Self {
90        self.ls_vector_store_provider = Some(provider.into());
91        self
92    }
93
94    /// Set the embedding provider.
95    pub fn with_embedding_provider(mut self, provider: impl Into<String>) -> Self {
96        self.ls_embedding_provider = Some(provider.into());
97        self
98    }
99
100    /// Set the embedding model.
101    pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
102        self.ls_embedding_model = Some(model.into());
103        self
104    }
105
106    /// Convert to a HashMap for use in metadata.
107    pub fn to_metadata(&self) -> HashMap<String, Value> {
108        let mut metadata = HashMap::new();
109        if let Some(ref name) = self.ls_retriever_name {
110            metadata.insert("ls_retriever_name".to_string(), Value::String(name.clone()));
111        }
112        if let Some(ref provider) = self.ls_vector_store_provider {
113            metadata.insert(
114                "ls_vector_store_provider".to_string(),
115                Value::String(provider.clone()),
116            );
117        }
118        if let Some(ref provider) = self.ls_embedding_provider {
119            metadata.insert(
120                "ls_embedding_provider".to_string(),
121                Value::String(provider.clone()),
122            );
123        }
124        if let Some(ref model) = self.ls_embedding_model {
125            metadata.insert(
126                "ls_embedding_model".to_string(),
127                Value::String(model.clone()),
128            );
129        }
130        metadata
131    }
132}
133
134/// Abstract base trait for a document retrieval system.
135///
136/// A retrieval system is defined as something that can take string queries and return
137/// the most 'relevant' documents from some source.
138///
139/// # Usage
140///
141/// A retriever follows the standard `Runnable` interface, and should be used via the
142/// standard `Runnable` methods of `invoke`, `ainvoke`, `batch`, `abatch`.
143///
144/// # Implementation
145///
146/// When implementing a custom retriever, the struct should implement the
147/// [`get_relevant_documents`][Self::get_relevant_documents] method to define the logic
148/// for retrieving documents.
149///
150/// Optionally, an async native implementation can be provided by overriding the
151/// [`aget_relevant_documents`][Self::aget_relevant_documents] method.
152///
153/// # Example
154///
155/// ```ignore
156/// use agent_chain_core::retrievers::BaseRetriever;
157/// use agent_chain_core::documents::Document;
158/// use agent_chain_core::callbacks::CallbackManagerForRetrieverRun;
159/// use agent_chain_core::error::Result;
160/// use async_trait::async_trait;
161///
162/// struct SimpleRetriever {
163///     docs: Vec<Document>,
164///     k: usize,
165/// }
166///
167/// #[async_trait]
168/// impl BaseRetriever for SimpleRetriever {
169///     fn get_relevant_documents(
170///         &self,
171///         query: &str,
172///         _run_manager: Option<&CallbackManagerForRetrieverRun>,
173///     ) -> Result<Vec<Document>> {
174///         // Return the first k documents from the list of documents
175///         Ok(self.docs.iter().take(self.k).cloned().collect())
176///     }
177///
178///     // Optionally provide async native implementation
179///     async fn aget_relevant_documents(
180///         &self,
181///         query: &str,
182///         _run_manager: Option<&AsyncCallbackManagerForRetrieverRun>,
183///     ) -> Result<Vec<Document>> {
184///         Ok(self.docs.iter().take(self.k).cloned().collect())
185///     }
186/// }
187/// ```
188#[async_trait]
189pub trait BaseRetriever: Send + Sync + Debug {
190    /// Get the name of this retriever.
191    fn get_name(&self) -> String {
192        let type_name = std::any::type_name::<Self>();
193        type_name
194            .rsplit("::")
195            .next()
196            .unwrap_or(type_name)
197            .to_string()
198    }
199
200    /// Optional list of tags associated with the retriever.
201    ///
202    /// These tags will be associated with each call to this retriever,
203    /// and passed as arguments to the handlers defined in `callbacks`.
204    fn tags(&self) -> Option<&[String]> {
205        None
206    }
207
208    /// Optional metadata associated with the retriever.
209    ///
210    /// This metadata will be associated with each call to this retriever,
211    /// and passed as arguments to the handlers defined in `callbacks`.
212    fn metadata(&self) -> Option<&HashMap<String, Value>> {
213        None
214    }
215
216    /// Get standard params for tracing.
217    fn get_ls_params(&self) -> LangSmithRetrieverParams {
218        let name = self.get_name();
219        let default_name = if let Some(stripped) = name.strip_prefix("Retriever") {
220            stripped.to_lowercase()
221        } else if let Some(stripped) = name.strip_suffix("Retriever") {
222            stripped.to_lowercase()
223        } else {
224            name.to_lowercase()
225        };
226
227        LangSmithRetrieverParams::new(default_name)
228    }
229
230    /// Get documents relevant to a query.
231    ///
232    /// This is the main method that retriever implementations should override.
233    ///
234    /// # Arguments
235    ///
236    /// * `query` - String to find relevant documents for.
237    /// * `run_manager` - Optional callback handler to use.
238    ///
239    /// # Returns
240    ///
241    /// List of relevant documents.
242    fn get_relevant_documents(
243        &self,
244        query: &str,
245        run_manager: Option<&CallbackManagerForRetrieverRun>,
246    ) -> Result<Vec<Document>>;
247
248    /// Asynchronously get documents relevant to a query.
249    ///
250    /// The default implementation runs the sync version.
251    ///
252    /// # Arguments
253    ///
254    /// * `query` - String to find relevant documents for.
255    /// * `run_manager` - Optional async callback handler to use.
256    ///
257    /// # Returns
258    ///
259    /// List of relevant documents.
260    async fn aget_relevant_documents(
261        &self,
262        query: &str,
263        run_manager: Option<&AsyncCallbackManagerForRetrieverRun>,
264    ) -> Result<Vec<Document>> {
265        let sync_run_manager = run_manager.map(|rm| rm.get_sync());
266        self.get_relevant_documents(query, sync_run_manager.as_ref())
267    }
268
269    /// Invoke the retriever to get relevant documents.
270    ///
271    /// Main entry point for synchronous retriever invocations.
272    ///
273    /// # Arguments
274    ///
275    /// * `input` - The query string.
276    /// * `config` - Optional configuration for the retriever.
277    ///
278    /// # Returns
279    ///
280    /// List of relevant documents.
281    fn invoke(&self, input: &str, config: Option<RunnableConfig>) -> Result<Vec<Document>> {
282        let config = ensure_config(config);
283
284        // Build inheritable metadata
285        let mut inheritable_metadata = config.metadata.clone();
286        inheritable_metadata.extend(self.get_ls_params().to_metadata());
287
288        // Configure callback manager
289        let callback_manager = CallbackManager::configure(
290            config.callbacks.clone(),
291            None,
292            Some(config.tags.clone()),
293            self.tags().map(|t| t.to_vec()),
294            Some(inheritable_metadata),
295            self.metadata().cloned(),
296            false,
297        );
298
299        // Start retriever run
300        let run_manager =
301            callback_manager.on_retriever_start(&HashMap::new(), input, config.run_id);
302
303        // Get the run name
304        let _run_name = config.run_name.clone().unwrap_or_else(|| self.get_name());
305
306        // Execute retrieval
307        match self.get_relevant_documents(input, Some(&run_manager)) {
308            Ok(result) => {
309                // Convert documents to JSON values for callback
310                let docs_json: Vec<Value> = result
311                    .iter()
312                    .map(|doc| serde_json::to_value(doc).unwrap_or(Value::Null))
313                    .collect();
314                run_manager.on_retriever_end(&docs_json);
315                Ok(result)
316            }
317            Err(e) => {
318                run_manager.on_retriever_error(&e);
319                Err(e)
320            }
321        }
322    }
323
324    /// Asynchronously invoke the retriever to get relevant documents.
325    ///
326    /// Main entry point for asynchronous retriever invocations.
327    ///
328    /// # Arguments
329    ///
330    /// * `input` - The query string.
331    /// * `config` - Optional configuration for the retriever.
332    ///
333    /// # Returns
334    ///
335    /// List of relevant documents.
336    async fn ainvoke(&self, input: &str, config: Option<RunnableConfig>) -> Result<Vec<Document>> {
337        let config = ensure_config(config);
338
339        // Build inheritable metadata
340        let mut inheritable_metadata = config.metadata.clone();
341        inheritable_metadata.extend(self.get_ls_params().to_metadata());
342
343        // Configure callback manager
344        let callback_manager = AsyncCallbackManager::configure(
345            config.callbacks.clone(),
346            None,
347            Some(config.tags.clone()),
348            self.tags().map(|t| t.to_vec()),
349            Some(inheritable_metadata),
350            self.metadata().cloned(),
351            false,
352        );
353
354        // Start retriever run
355        let run_manager = callback_manager
356            .on_retriever_start(&HashMap::new(), input, config.run_id)
357            .await;
358
359        // Execute retrieval
360        let result = self
361            .aget_relevant_documents(input, Some(&run_manager))
362            .await;
363
364        match &result {
365            Ok(docs) => {
366                // Convert documents to JSON values for callback
367                let docs_json: Vec<Value> = docs
368                    .iter()
369                    .map(|doc| serde_json::to_value(doc).unwrap_or(Value::Null))
370                    .collect();
371                run_manager.on_retriever_end(&docs_json).await;
372            }
373            Err(e) => {
374                // Use sync version for error handling to avoid Send issues
375                run_manager.get_sync().on_retriever_error(e);
376            }
377        }
378
379        result
380    }
381
382    /// Transform multiple inputs into outputs in parallel.
383    ///
384    /// # Arguments
385    ///
386    /// * `inputs` - List of query strings.
387    /// * `config` - Optional configuration for the retriever.
388    ///
389    /// # Returns
390    ///
391    /// List of results, one for each input.
392    fn batch(
393        &self,
394        inputs: Vec<&str>,
395        config: Option<RunnableConfig>,
396    ) -> Vec<Result<Vec<Document>>> {
397        let config = ensure_config(config);
398        inputs
399            .into_iter()
400            .map(|input| self.invoke(input, Some(config.clone())))
401            .collect()
402    }
403
404    /// Asynchronously transform multiple inputs into outputs.
405    ///
406    /// # Arguments
407    ///
408    /// * `inputs` - List of query strings.
409    /// * `config` - Optional configuration for the retriever.
410    ///
411    /// # Returns
412    ///
413    /// List of results, one for each input.
414    async fn abatch(
415        &self,
416        inputs: Vec<&str>,
417        config: Option<RunnableConfig>,
418    ) -> Vec<Result<Vec<Document>>> {
419        let config = ensure_config(config);
420        let mut results = Vec::with_capacity(inputs.len());
421        for input in inputs {
422            results.push(self.ainvoke(input, Some(config.clone())).await);
423        }
424        results
425    }
426}
427
428/// A type-erased retriever that can be stored in collections.
429pub type DynRetriever = Arc<dyn BaseRetriever>;
430
431/// Convert any retriever into a DynRetriever.
432pub fn to_dyn<R>(retriever: R) -> DynRetriever
433where
434    R: BaseRetriever + 'static,
435{
436    Arc::new(retriever)
437}
438
439/// A simple retriever that returns documents from a static list.
440///
441/// This is useful for testing or for simple use cases where documents
442/// are known ahead of time.
443#[derive(Debug, Clone)]
444pub struct SimpleRetriever {
445    /// The list of documents to return.
446    pub docs: Vec<Document>,
447    /// The maximum number of documents to return.
448    pub k: usize,
449    /// Optional tags for this retriever.
450    tags: Option<Vec<String>>,
451    /// Optional metadata for this retriever.
452    metadata: Option<HashMap<String, Value>>,
453}
454
455impl SimpleRetriever {
456    /// Create a new SimpleRetriever with the given documents.
457    pub fn new(docs: Vec<Document>) -> Self {
458        Self {
459            docs,
460            k: 5,
461            tags: None,
462            metadata: None,
463        }
464    }
465
466    /// Set the maximum number of documents to return.
467    pub fn with_k(mut self, k: usize) -> Self {
468        self.k = k;
469        self
470    }
471
472    /// Set the tags for this retriever.
473    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
474        self.tags = Some(tags);
475        self
476    }
477
478    /// Set the metadata for this retriever.
479    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
480        self.metadata = Some(metadata);
481        self
482    }
483}
484
485#[async_trait]
486impl BaseRetriever for SimpleRetriever {
487    fn tags(&self) -> Option<&[String]> {
488        self.tags.as_deref()
489    }
490
491    fn metadata(&self) -> Option<&HashMap<String, Value>> {
492        self.metadata.as_ref()
493    }
494
495    fn get_relevant_documents(
496        &self,
497        _query: &str,
498        _run_manager: Option<&CallbackManagerForRetrieverRun>,
499    ) -> Result<Vec<Document>> {
500        Ok(self.docs.iter().take(self.k).cloned().collect())
501    }
502}
503
504/// A retriever that filters documents based on a predicate function.
505#[derive(Clone)]
506pub struct FilterRetriever<R, F>
507where
508    R: BaseRetriever,
509    F: Fn(&Document) -> bool + Send + Sync,
510{
511    /// The underlying retriever.
512    pub retriever: R,
513    /// The filter predicate.
514    pub filter: F,
515}
516
517impl<R, F> Debug for FilterRetriever<R, F>
518where
519    R: BaseRetriever,
520    F: Fn(&Document) -> bool + Send + Sync,
521{
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        f.debug_struct("FilterRetriever")
524            .field("retriever", &self.retriever)
525            .finish()
526    }
527}
528
529impl<R, F> FilterRetriever<R, F>
530where
531    R: BaseRetriever,
532    F: Fn(&Document) -> bool + Send + Sync,
533{
534    /// Create a new FilterRetriever.
535    pub fn new(retriever: R, filter: F) -> Self {
536        Self { retriever, filter }
537    }
538}
539
540#[async_trait]
541impl<R, F> BaseRetriever for FilterRetriever<R, F>
542where
543    R: BaseRetriever,
544    F: Fn(&Document) -> bool + Send + Sync,
545{
546    fn get_name(&self) -> String {
547        format!("FilterRetriever<{}>", self.retriever.get_name())
548    }
549
550    fn tags(&self) -> Option<&[String]> {
551        self.retriever.tags()
552    }
553
554    fn metadata(&self) -> Option<&HashMap<String, Value>> {
555        self.retriever.metadata()
556    }
557
558    fn get_relevant_documents(
559        &self,
560        query: &str,
561        run_manager: Option<&CallbackManagerForRetrieverRun>,
562    ) -> Result<Vec<Document>> {
563        let docs = self.retriever.get_relevant_documents(query, run_manager)?;
564        Ok(docs.into_iter().filter(&self.filter).collect())
565    }
566
567    async fn aget_relevant_documents(
568        &self,
569        query: &str,
570        run_manager: Option<&AsyncCallbackManagerForRetrieverRun>,
571    ) -> Result<Vec<Document>> {
572        let docs = self
573            .retriever
574            .aget_relevant_documents(query, run_manager)
575            .await?;
576        Ok(docs.into_iter().filter(&self.filter).collect())
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[test]
585    fn test_simple_retriever() {
586        let docs = vec![
587            Document::new("Hello world"),
588            Document::new("Goodbye world"),
589            Document::new("Hello again"),
590        ];
591
592        let retriever = SimpleRetriever::new(docs.clone()).with_k(2);
593
594        let result = retriever.get_relevant_documents("test", None).unwrap();
595
596        assert_eq!(result.len(), 2);
597        assert_eq!(result[0].page_content, "Hello world");
598        assert_eq!(result[1].page_content, "Goodbye world");
599    }
600
601    #[test]
602    fn test_simple_retriever_invoke() {
603        let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
604
605        let retriever = SimpleRetriever::new(docs).with_k(5);
606
607        let result = retriever.invoke("test query", None).unwrap();
608
609        assert_eq!(result.len(), 2);
610    }
611
612    #[test]
613    fn test_filter_retriever() {
614        let docs = vec![
615            Document::new("Hello world"),
616            Document::new("Goodbye world"),
617            Document::new("Hello again"),
618        ];
619
620        let base_retriever = SimpleRetriever::new(docs);
621        let filter_retriever =
622            FilterRetriever::new(base_retriever, |doc| doc.page_content.contains("Hello"));
623
624        let result = filter_retriever
625            .get_relevant_documents("test", None)
626            .unwrap();
627
628        assert_eq!(result.len(), 2);
629        assert!(result.iter().all(|doc| doc.page_content.contains("Hello")));
630    }
631
632    #[test]
633    fn test_langsmith_params() {
634        let params = LangSmithRetrieverParams::new("my_retriever")
635            .with_vector_store_provider("pinecone")
636            .with_embedding_provider("openai")
637            .with_embedding_model("text-embedding-3-small");
638
639        assert_eq!(params.ls_retriever_name, Some("my_retriever".to_string()));
640        assert_eq!(
641            params.ls_vector_store_provider,
642            Some("pinecone".to_string())
643        );
644        assert_eq!(params.ls_embedding_provider, Some("openai".to_string()));
645        assert_eq!(
646            params.ls_embedding_model,
647            Some("text-embedding-3-small".to_string())
648        );
649
650        let metadata = params.to_metadata();
651        assert_eq!(metadata.len(), 4);
652    }
653
654    #[tokio::test]
655    async fn test_simple_retriever_ainvoke() {
656        let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
657
658        let retriever = SimpleRetriever::new(docs).with_k(5);
659
660        let result = retriever.ainvoke("test query", None).await.unwrap();
661
662        assert_eq!(result.len(), 2);
663    }
664
665    #[test]
666    fn test_batch() {
667        let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
668
669        let retriever = SimpleRetriever::new(docs);
670
671        let results = retriever.batch(vec!["query1", "query2"], None);
672
673        assert_eq!(results.len(), 2);
674        assert!(results.iter().all(|r| r.is_ok()));
675    }
676
677    #[test]
678    fn test_get_ls_params() {
679        struct TestRetriever;
680
681        impl Debug for TestRetriever {
682            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683                f.debug_struct("TestRetriever").finish()
684            }
685        }
686
687        #[async_trait]
688        impl BaseRetriever for TestRetriever {
689            fn get_relevant_documents(
690                &self,
691                _query: &str,
692                _run_manager: Option<&CallbackManagerForRetrieverRun>,
693            ) -> Result<Vec<Document>> {
694                Ok(vec![])
695            }
696        }
697
698        let retriever = TestRetriever;
699        let params = retriever.get_ls_params();
700
701        assert_eq!(params.ls_retriever_name, Some("test".to_string()));
702    }
703}