1use std::collections::HashMap;
34use std::fmt::Debug;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40
41#[cfg(feature = "specta")]
42use specta::Type;
43
44use crate::callbacks::{
45 AsyncCallbackManager, AsyncCallbackManagerForRetrieverRun, CallbackManager,
46 CallbackManagerForRetrieverRun,
47};
48use crate::documents::Document;
49use crate::error::Result;
50use crate::runnables::{RunnableConfig, ensure_config};
51
52pub type RetrieverInput = String;
54
55pub type RetrieverOutput = Vec<Document>;
57
58#[cfg_attr(feature = "specta", derive(Type))]
60#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
61pub struct LangSmithRetrieverParams {
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub ls_retriever_name: Option<String>,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub ls_vector_store_provider: Option<String>,
69
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub ls_embedding_provider: Option<String>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub ls_embedding_model: Option<String>,
77}
78
79impl LangSmithRetrieverParams {
80 pub fn new(retriever_name: impl Into<String>) -> Self {
82 Self {
83 ls_retriever_name: Some(retriever_name.into()),
84 ..Default::default()
85 }
86 }
87
88 pub fn with_vector_store_provider(mut self, provider: impl Into<String>) -> Self {
90 self.ls_vector_store_provider = Some(provider.into());
91 self
92 }
93
94 pub fn with_embedding_provider(mut self, provider: impl Into<String>) -> Self {
96 self.ls_embedding_provider = Some(provider.into());
97 self
98 }
99
100 pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
102 self.ls_embedding_model = Some(model.into());
103 self
104 }
105
106 pub fn to_metadata(&self) -> HashMap<String, Value> {
108 let mut metadata = HashMap::new();
109 if let Some(ref name) = self.ls_retriever_name {
110 metadata.insert("ls_retriever_name".to_string(), Value::String(name.clone()));
111 }
112 if let Some(ref provider) = self.ls_vector_store_provider {
113 metadata.insert(
114 "ls_vector_store_provider".to_string(),
115 Value::String(provider.clone()),
116 );
117 }
118 if let Some(ref provider) = self.ls_embedding_provider {
119 metadata.insert(
120 "ls_embedding_provider".to_string(),
121 Value::String(provider.clone()),
122 );
123 }
124 if let Some(ref model) = self.ls_embedding_model {
125 metadata.insert(
126 "ls_embedding_model".to_string(),
127 Value::String(model.clone()),
128 );
129 }
130 metadata
131 }
132}
133
134#[async_trait]
189pub trait BaseRetriever: Send + Sync + Debug {
190 fn get_name(&self) -> String {
192 let type_name = std::any::type_name::<Self>();
193 type_name
194 .rsplit("::")
195 .next()
196 .unwrap_or(type_name)
197 .to_string()
198 }
199
200 fn tags(&self) -> Option<&[String]> {
205 None
206 }
207
208 fn metadata(&self) -> Option<&HashMap<String, Value>> {
213 None
214 }
215
216 fn get_ls_params(&self) -> LangSmithRetrieverParams {
218 let name = self.get_name();
219 let default_name = if let Some(stripped) = name.strip_prefix("Retriever") {
220 stripped.to_lowercase()
221 } else if let Some(stripped) = name.strip_suffix("Retriever") {
222 stripped.to_lowercase()
223 } else {
224 name.to_lowercase()
225 };
226
227 LangSmithRetrieverParams::new(default_name)
228 }
229
230 fn get_relevant_documents(
243 &self,
244 query: &str,
245 run_manager: Option<&CallbackManagerForRetrieverRun>,
246 ) -> Result<Vec<Document>>;
247
248 async fn aget_relevant_documents(
261 &self,
262 query: &str,
263 run_manager: Option<&AsyncCallbackManagerForRetrieverRun>,
264 ) -> Result<Vec<Document>> {
265 let sync_run_manager = run_manager.map(|rm| rm.get_sync());
266 self.get_relevant_documents(query, sync_run_manager.as_ref())
267 }
268
269 fn invoke(&self, input: &str, config: Option<RunnableConfig>) -> Result<Vec<Document>> {
282 let config = ensure_config(config);
283
284 let mut inheritable_metadata = config.metadata.clone();
286 inheritable_metadata.extend(self.get_ls_params().to_metadata());
287
288 let callback_manager = CallbackManager::configure(
290 config.callbacks.clone(),
291 None,
292 Some(config.tags.clone()),
293 self.tags().map(|t| t.to_vec()),
294 Some(inheritable_metadata),
295 self.metadata().cloned(),
296 false,
297 );
298
299 let run_manager =
301 callback_manager.on_retriever_start(&HashMap::new(), input, config.run_id);
302
303 let _run_name = config.run_name.clone().unwrap_or_else(|| self.get_name());
305
306 match self.get_relevant_documents(input, Some(&run_manager)) {
308 Ok(result) => {
309 let docs_json: Vec<Value> = result
311 .iter()
312 .map(|doc| serde_json::to_value(doc).unwrap_or(Value::Null))
313 .collect();
314 run_manager.on_retriever_end(&docs_json);
315 Ok(result)
316 }
317 Err(e) => {
318 run_manager.on_retriever_error(&e);
319 Err(e)
320 }
321 }
322 }
323
324 async fn ainvoke(&self, input: &str, config: Option<RunnableConfig>) -> Result<Vec<Document>> {
337 let config = ensure_config(config);
338
339 let mut inheritable_metadata = config.metadata.clone();
341 inheritable_metadata.extend(self.get_ls_params().to_metadata());
342
343 let callback_manager = AsyncCallbackManager::configure(
345 config.callbacks.clone(),
346 None,
347 Some(config.tags.clone()),
348 self.tags().map(|t| t.to_vec()),
349 Some(inheritable_metadata),
350 self.metadata().cloned(),
351 false,
352 );
353
354 let run_manager = callback_manager
356 .on_retriever_start(&HashMap::new(), input, config.run_id)
357 .await;
358
359 let result = self
361 .aget_relevant_documents(input, Some(&run_manager))
362 .await;
363
364 match &result {
365 Ok(docs) => {
366 let docs_json: Vec<Value> = docs
368 .iter()
369 .map(|doc| serde_json::to_value(doc).unwrap_or(Value::Null))
370 .collect();
371 run_manager.on_retriever_end(&docs_json).await;
372 }
373 Err(e) => {
374 run_manager.get_sync().on_retriever_error(e);
376 }
377 }
378
379 result
380 }
381
382 fn batch(
393 &self,
394 inputs: Vec<&str>,
395 config: Option<RunnableConfig>,
396 ) -> Vec<Result<Vec<Document>>> {
397 let config = ensure_config(config);
398 inputs
399 .into_iter()
400 .map(|input| self.invoke(input, Some(config.clone())))
401 .collect()
402 }
403
404 async fn abatch(
415 &self,
416 inputs: Vec<&str>,
417 config: Option<RunnableConfig>,
418 ) -> Vec<Result<Vec<Document>>> {
419 let config = ensure_config(config);
420 let mut results = Vec::with_capacity(inputs.len());
421 for input in inputs {
422 results.push(self.ainvoke(input, Some(config.clone())).await);
423 }
424 results
425 }
426}
427
428pub type DynRetriever = Arc<dyn BaseRetriever>;
430
431pub fn to_dyn<R>(retriever: R) -> DynRetriever
433where
434 R: BaseRetriever + 'static,
435{
436 Arc::new(retriever)
437}
438
439#[derive(Debug, Clone)]
444pub struct SimpleRetriever {
445 pub docs: Vec<Document>,
447 pub k: usize,
449 tags: Option<Vec<String>>,
451 metadata: Option<HashMap<String, Value>>,
453}
454
455impl SimpleRetriever {
456 pub fn new(docs: Vec<Document>) -> Self {
458 Self {
459 docs,
460 k: 5,
461 tags: None,
462 metadata: None,
463 }
464 }
465
466 pub fn with_k(mut self, k: usize) -> Self {
468 self.k = k;
469 self
470 }
471
472 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
474 self.tags = Some(tags);
475 self
476 }
477
478 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
480 self.metadata = Some(metadata);
481 self
482 }
483}
484
485#[async_trait]
486impl BaseRetriever for SimpleRetriever {
487 fn tags(&self) -> Option<&[String]> {
488 self.tags.as_deref()
489 }
490
491 fn metadata(&self) -> Option<&HashMap<String, Value>> {
492 self.metadata.as_ref()
493 }
494
495 fn get_relevant_documents(
496 &self,
497 _query: &str,
498 _run_manager: Option<&CallbackManagerForRetrieverRun>,
499 ) -> Result<Vec<Document>> {
500 Ok(self.docs.iter().take(self.k).cloned().collect())
501 }
502}
503
504#[derive(Clone)]
506pub struct FilterRetriever<R, F>
507where
508 R: BaseRetriever,
509 F: Fn(&Document) -> bool + Send + Sync,
510{
511 pub retriever: R,
513 pub filter: F,
515}
516
517impl<R, F> Debug for FilterRetriever<R, F>
518where
519 R: BaseRetriever,
520 F: Fn(&Document) -> bool + Send + Sync,
521{
522 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 f.debug_struct("FilterRetriever")
524 .field("retriever", &self.retriever)
525 .finish()
526 }
527}
528
529impl<R, F> FilterRetriever<R, F>
530where
531 R: BaseRetriever,
532 F: Fn(&Document) -> bool + Send + Sync,
533{
534 pub fn new(retriever: R, filter: F) -> Self {
536 Self { retriever, filter }
537 }
538}
539
540#[async_trait]
541impl<R, F> BaseRetriever for FilterRetriever<R, F>
542where
543 R: BaseRetriever,
544 F: Fn(&Document) -> bool + Send + Sync,
545{
546 fn get_name(&self) -> String {
547 format!("FilterRetriever<{}>", self.retriever.get_name())
548 }
549
550 fn tags(&self) -> Option<&[String]> {
551 self.retriever.tags()
552 }
553
554 fn metadata(&self) -> Option<&HashMap<String, Value>> {
555 self.retriever.metadata()
556 }
557
558 fn get_relevant_documents(
559 &self,
560 query: &str,
561 run_manager: Option<&CallbackManagerForRetrieverRun>,
562 ) -> Result<Vec<Document>> {
563 let docs = self.retriever.get_relevant_documents(query, run_manager)?;
564 Ok(docs.into_iter().filter(&self.filter).collect())
565 }
566
567 async fn aget_relevant_documents(
568 &self,
569 query: &str,
570 run_manager: Option<&AsyncCallbackManagerForRetrieverRun>,
571 ) -> Result<Vec<Document>> {
572 let docs = self
573 .retriever
574 .aget_relevant_documents(query, run_manager)
575 .await?;
576 Ok(docs.into_iter().filter(&self.filter).collect())
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn test_simple_retriever() {
586 let docs = vec![
587 Document::new("Hello world"),
588 Document::new("Goodbye world"),
589 Document::new("Hello again"),
590 ];
591
592 let retriever = SimpleRetriever::new(docs.clone()).with_k(2);
593
594 let result = retriever.get_relevant_documents("test", None).unwrap();
595
596 assert_eq!(result.len(), 2);
597 assert_eq!(result[0].page_content, "Hello world");
598 assert_eq!(result[1].page_content, "Goodbye world");
599 }
600
601 #[test]
602 fn test_simple_retriever_invoke() {
603 let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
604
605 let retriever = SimpleRetriever::new(docs).with_k(5);
606
607 let result = retriever.invoke("test query", None).unwrap();
608
609 assert_eq!(result.len(), 2);
610 }
611
612 #[test]
613 fn test_filter_retriever() {
614 let docs = vec![
615 Document::new("Hello world"),
616 Document::new("Goodbye world"),
617 Document::new("Hello again"),
618 ];
619
620 let base_retriever = SimpleRetriever::new(docs);
621 let filter_retriever =
622 FilterRetriever::new(base_retriever, |doc| doc.page_content.contains("Hello"));
623
624 let result = filter_retriever
625 .get_relevant_documents("test", None)
626 .unwrap();
627
628 assert_eq!(result.len(), 2);
629 assert!(result.iter().all(|doc| doc.page_content.contains("Hello")));
630 }
631
632 #[test]
633 fn test_langsmith_params() {
634 let params = LangSmithRetrieverParams::new("my_retriever")
635 .with_vector_store_provider("pinecone")
636 .with_embedding_provider("openai")
637 .with_embedding_model("text-embedding-3-small");
638
639 assert_eq!(params.ls_retriever_name, Some("my_retriever".to_string()));
640 assert_eq!(
641 params.ls_vector_store_provider,
642 Some("pinecone".to_string())
643 );
644 assert_eq!(params.ls_embedding_provider, Some("openai".to_string()));
645 assert_eq!(
646 params.ls_embedding_model,
647 Some("text-embedding-3-small".to_string())
648 );
649
650 let metadata = params.to_metadata();
651 assert_eq!(metadata.len(), 4);
652 }
653
654 #[tokio::test]
655 async fn test_simple_retriever_ainvoke() {
656 let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
657
658 let retriever = SimpleRetriever::new(docs).with_k(5);
659
660 let result = retriever.ainvoke("test query", None).await.unwrap();
661
662 assert_eq!(result.len(), 2);
663 }
664
665 #[test]
666 fn test_batch() {
667 let docs = vec![Document::new("Hello world"), Document::new("Goodbye world")];
668
669 let retriever = SimpleRetriever::new(docs);
670
671 let results = retriever.batch(vec!["query1", "query2"], None);
672
673 assert_eq!(results.len(), 2);
674 assert!(results.iter().all(|r| r.is_ok()));
675 }
676
677 #[test]
678 fn test_get_ls_params() {
679 struct TestRetriever;
680
681 impl Debug for TestRetriever {
682 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683 f.debug_struct("TestRetriever").finish()
684 }
685 }
686
687 #[async_trait]
688 impl BaseRetriever for TestRetriever {
689 fn get_relevant_documents(
690 &self,
691 _query: &str,
692 _run_manager: Option<&CallbackManagerForRetrieverRun>,
693 ) -> Result<Vec<Document>> {
694 Ok(vec![])
695 }
696 }
697
698 let retriever = TestRetriever;
699 let params = retriever.get_ls_params();
700
701 assert_eq!(params.ls_retriever_name, Some("test".to_string()));
702 }
703}