agent_chain_core/documents/
compressor.rs

1//! Document compressor.
2//!
3//! This module provides the [`BaseDocumentCompressor`] trait for post-processing
4//! of retrieved documents.
5
6use async_trait::async_trait;
7
8use super::Document;
9use crate::callbacks::Callbacks;
10
11/// Base trait for document compressors.
12///
13/// This abstraction is primarily used for post-processing of retrieved documents.
14///
15/// [`Document`] objects matching a given query are first retrieved.
16/// Then the list of documents can be further processed.
17///
18/// For example, one could re-rank the retrieved documents using an LLM.
19///
20/// Users should favor using a `RunnableLambda` instead of implementing this
21/// trait directly when possible.
22///
23/// # Example
24///
25/// ```ignore
26/// use agent_chain_core::documents::{BaseDocumentCompressor, Document};
27/// use agent_chain_core::callbacks::Callbacks;
28/// use async_trait::async_trait;
29///
30/// struct MyCompressor {
31///     threshold: f64,
32/// }
33///
34/// #[async_trait]
35/// impl BaseDocumentCompressor for MyCompressor {
36///     async fn compress_documents(
37///         &self,
38///         documents: Vec<Document>,
39///         query: &str,
40///         _callbacks: Option<Callbacks>,
41///     ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
42///         // Filter documents based on some criteria
43///         Ok(documents
44///             .into_iter()
45///             .filter(|doc| doc.page_content.contains(query))
46///             .collect())
47///     }
48/// }
49/// ```
50#[async_trait]
51pub trait BaseDocumentCompressor: Send + Sync {
52    /// Compress retrieved documents given the query context.
53    ///
54    /// # Arguments
55    ///
56    /// * `documents` - The retrieved [`Document`] objects.
57    /// * `query` - The query context.
58    /// * `callbacks` - Optional [`Callbacks`] to run during compression.
59    ///
60    /// # Returns
61    ///
62    /// The compressed documents, or an error if compression fails.
63    async fn compress_documents(
64        &self,
65        documents: Vec<Document>,
66        query: &str,
67        callbacks: Option<Callbacks>,
68    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>;
69
70    /// Synchronously compress retrieved documents given the query context.
71    ///
72    /// This is a blocking version of [`compress_documents`][Self::compress_documents].
73    /// The default implementation runs the async version using a blocking runtime.
74    ///
75    /// # Arguments
76    ///
77    /// * `documents` - The retrieved [`Document`] objects.
78    /// * `query` - The query context.
79    /// * `callbacks` - Optional [`Callbacks`] to run during compression.
80    ///
81    /// # Returns
82    ///
83    /// The compressed documents, or an error if compression fails.
84    fn compress_documents_sync(
85        &self,
86        documents: Vec<Document>,
87        query: &str,
88        callbacks: Option<Callbacks>,
89    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>
90    where
91        Self: Sized,
92    {
93        // Note: In a real implementation, this would need to be handled differently
94        // as we can't easily call async from sync without a runtime.
95        // This is a placeholder that indicates the sync version needs to be implemented.
96        let _ = (documents, query, callbacks);
97        Err("Sync version not implemented - use compress_documents instead".into())
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    struct TestCompressor;
106
107    #[async_trait]
108    impl BaseDocumentCompressor for TestCompressor {
109        async fn compress_documents(
110            &self,
111            documents: Vec<Document>,
112            query: &str,
113            _callbacks: Option<Callbacks>,
114        ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
115            Ok(documents
116                .into_iter()
117                .filter(|doc| doc.page_content.contains(query))
118                .collect())
119        }
120    }
121
122    #[tokio::test]
123    async fn test_compress_documents() {
124        let compressor = TestCompressor;
125        let documents = vec![
126            Document::new("Hello world"),
127            Document::new("Goodbye world"),
128            Document::new("Hello again"),
129        ];
130
131        let result = compressor
132            .compress_documents(documents, "Hello", None)
133            .await
134            .unwrap();
135
136        assert_eq!(result.len(), 2);
137        assert!(result.iter().all(|doc| doc.page_content.contains("Hello")));
138    }
139}