agent_chain_core/documents/compressor.rs
1//! Document compressor.
2//!
3//! This module provides the [`BaseDocumentCompressor`] trait for post-processing
4//! of retrieved documents.
5
6use async_trait::async_trait;
7
8use super::Document;
9use crate::callbacks::Callbacks;
10
11/// Base trait for document compressors.
12///
13/// This abstraction is primarily used for post-processing of retrieved documents.
14///
15/// [`Document`] objects matching a given query are first retrieved.
16/// Then the list of documents can be further processed.
17///
18/// For example, one could re-rank the retrieved documents using an LLM.
19///
20/// Users should favor using a `RunnableLambda` instead of implementing this
21/// trait directly when possible.
22///
23/// # Example
24///
25/// ```ignore
26/// use agent_chain_core::documents::{BaseDocumentCompressor, Document};
27/// use agent_chain_core::callbacks::Callbacks;
28/// use async_trait::async_trait;
29///
30/// struct MyCompressor {
31/// threshold: f64,
32/// }
33///
34/// #[async_trait]
35/// impl BaseDocumentCompressor for MyCompressor {
36/// async fn compress_documents(
37/// &self,
38/// documents: Vec<Document>,
39/// query: &str,
40/// _callbacks: Option<Callbacks>,
41/// ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
42/// // Filter documents based on some criteria
43/// Ok(documents
44/// .into_iter()
45/// .filter(|doc| doc.page_content.contains(query))
46/// .collect())
47/// }
48/// }
49/// ```
50#[async_trait]
51pub trait BaseDocumentCompressor: Send + Sync {
52 /// Compress retrieved documents given the query context.
53 ///
54 /// # Arguments
55 ///
56 /// * `documents` - The retrieved [`Document`] objects.
57 /// * `query` - The query context.
58 /// * `callbacks` - Optional [`Callbacks`] to run during compression.
59 ///
60 /// # Returns
61 ///
62 /// The compressed documents, or an error if compression fails.
63 async fn compress_documents(
64 &self,
65 documents: Vec<Document>,
66 query: &str,
67 callbacks: Option<Callbacks>,
68 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>;
69
70 /// Synchronously compress retrieved documents given the query context.
71 ///
72 /// This is a blocking version of [`compress_documents`][Self::compress_documents].
73 /// The default implementation runs the async version using a blocking runtime.
74 ///
75 /// # Arguments
76 ///
77 /// * `documents` - The retrieved [`Document`] objects.
78 /// * `query` - The query context.
79 /// * `callbacks` - Optional [`Callbacks`] to run during compression.
80 ///
81 /// # Returns
82 ///
83 /// The compressed documents, or an error if compression fails.
84 fn compress_documents_sync(
85 &self,
86 documents: Vec<Document>,
87 query: &str,
88 callbacks: Option<Callbacks>,
89 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>
90 where
91 Self: Sized,
92 {
93 // Note: In a real implementation, this would need to be handled differently
94 // as we can't easily call async from sync without a runtime.
95 // This is a placeholder that indicates the sync version needs to be implemented.
96 let _ = (documents, query, callbacks);
97 Err("Sync version not implemented - use compress_documents instead".into())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 struct TestCompressor;
106
107 #[async_trait]
108 impl BaseDocumentCompressor for TestCompressor {
109 async fn compress_documents(
110 &self,
111 documents: Vec<Document>,
112 query: &str,
113 _callbacks: Option<Callbacks>,
114 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
115 Ok(documents
116 .into_iter()
117 .filter(|doc| doc.page_content.contains(query))
118 .collect())
119 }
120 }
121
122 #[tokio::test]
123 async fn test_compress_documents() {
124 let compressor = TestCompressor;
125 let documents = vec![
126 Document::new("Hello world"),
127 Document::new("Goodbye world"),
128 Document::new("Hello again"),
129 ];
130
131 let result = compressor
132 .compress_documents(documents, "Hello", None)
133 .await
134 .unwrap();
135
136 assert_eq!(result.len(), 2);
137 assert!(result.iter().all(|doc| doc.page_content.contains("Hello")));
138 }
139}