agent_chain_core/documents/
transformers.rs

1//! Document transformers.
2//!
3//! This module provides the [`BaseDocumentTransformer`] trait for document
4//! transformation operations.
5
6use std::collections::HashMap;
7
8use async_trait::async_trait;
9use serde_json::Value;
10
11use super::Document;
12
13/// Abstract base trait for document transformation.
14///
15/// A document transformation takes a sequence of [`Document`] objects and returns a
16/// sequence of transformed [`Document`] objects.
17///
18/// # Example
19///
20/// ```ignore
21/// use agent_chain_core::documents::{BaseDocumentTransformer, Document};
22/// use async_trait::async_trait;
23/// use std::collections::HashMap;
24/// use serde_json::Value;
25///
26/// struct EmbeddingsRedundantFilter {
27///     similarity_threshold: f64,
28/// }
29///
30/// #[async_trait]
31/// impl BaseDocumentTransformer for EmbeddingsRedundantFilter {
32///     async fn transform_documents(
33///         &self,
34///         documents: Vec<Document>,
35///         _kwargs: HashMap<String, Value>,
36///     ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
37///         // Filter out redundant documents based on embeddings similarity
38///         // This is a simplified example
39///         Ok(documents)
40///     }
41/// }
42/// ```
43#[async_trait]
44pub trait BaseDocumentTransformer: Send + Sync {
45    /// Transform a list of documents.
46    ///
47    /// # Arguments
48    ///
49    /// * `documents` - A sequence of [`Document`] objects to be transformed.
50    /// * `kwargs` - Additional keyword arguments for transformation.
51    ///
52    /// # Returns
53    ///
54    /// A sequence of transformed [`Document`] objects, or an error if transformation fails.
55    async fn transform_documents(
56        &self,
57        documents: Vec<Document>,
58        kwargs: HashMap<String, Value>,
59    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>;
60
61    /// Synchronously transform a list of documents.
62    ///
63    /// This is a blocking version of [`transform_documents`][Self::transform_documents].
64    /// The default implementation indicates that the sync version needs to be implemented
65    /// or the async version should be used.
66    ///
67    /// # Arguments
68    ///
69    /// * `documents` - A sequence of [`Document`] objects to be transformed.
70    /// * `kwargs` - Additional keyword arguments for transformation.
71    ///
72    /// # Returns
73    ///
74    /// A sequence of transformed [`Document`] objects, or an error if transformation fails.
75    fn transform_documents_sync(
76        &self,
77        documents: Vec<Document>,
78        kwargs: HashMap<String, Value>,
79    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>
80    where
81        Self: Sized,
82    {
83        // Note: In a real implementation, this would need to be handled differently
84        // as we can't easily call async from sync without a runtime.
85        // This is a placeholder that indicates the sync version needs to be implemented.
86        let _ = (documents, kwargs);
87        Err("Sync version not implemented - use transform_documents instead".into())
88    }
89}
90
91/// A simple document transformer that applies a function to each document.
92pub struct FunctionTransformer<F>
93where
94    F: Fn(Document) -> Document + Send + Sync,
95{
96    transform_fn: F,
97}
98
99impl<F> FunctionTransformer<F>
100where
101    F: Fn(Document) -> Document + Send + Sync,
102{
103    /// Create a new FunctionTransformer with the given function.
104    pub fn new(transform_fn: F) -> Self {
105        Self { transform_fn }
106    }
107}
108
109#[async_trait]
110impl<F> BaseDocumentTransformer for FunctionTransformer<F>
111where
112    F: Fn(Document) -> Document + Send + Sync,
113{
114    async fn transform_documents(
115        &self,
116        documents: Vec<Document>,
117        _kwargs: HashMap<String, Value>,
118    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
119        Ok(documents.into_iter().map(&self.transform_fn).collect())
120    }
121
122    fn transform_documents_sync(
123        &self,
124        documents: Vec<Document>,
125        _kwargs: HashMap<String, Value>,
126    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
127        Ok(documents.into_iter().map(&self.transform_fn).collect())
128    }
129}
130
131/// A document transformer that filters documents based on a predicate.
132pub struct FilterTransformer<F>
133where
134    F: Fn(&Document) -> bool + Send + Sync,
135{
136    filter_fn: F,
137}
138
139impl<F> FilterTransformer<F>
140where
141    F: Fn(&Document) -> bool + Send + Sync,
142{
143    /// Create a new FilterTransformer with the given predicate.
144    pub fn new(filter_fn: F) -> Self {
145        Self { filter_fn }
146    }
147}
148
149#[async_trait]
150impl<F> BaseDocumentTransformer for FilterTransformer<F>
151where
152    F: Fn(&Document) -> bool + Send + Sync,
153{
154    async fn transform_documents(
155        &self,
156        documents: Vec<Document>,
157        _kwargs: HashMap<String, Value>,
158    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
159        Ok(documents.into_iter().filter(&self.filter_fn).collect())
160    }
161
162    fn transform_documents_sync(
163        &self,
164        documents: Vec<Document>,
165        _kwargs: HashMap<String, Value>,
166    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
167        Ok(documents.into_iter().filter(&self.filter_fn).collect())
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    struct UppercaseTransformer;
176
177    #[async_trait]
178    impl BaseDocumentTransformer for UppercaseTransformer {
179        async fn transform_documents(
180            &self,
181            documents: Vec<Document>,
182            _kwargs: HashMap<String, Value>,
183        ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
184            Ok(documents
185                .into_iter()
186                .map(|doc| {
187                    Document::new(doc.page_content.to_uppercase()).with_metadata(doc.metadata)
188                })
189                .collect())
190        }
191    }
192
193    #[tokio::test]
194    async fn test_transform_documents() {
195        let transformer = UppercaseTransformer;
196        let documents = vec![Document::new("hello world"), Document::new("goodbye world")];
197
198        let result = transformer
199            .transform_documents(documents, HashMap::new())
200            .await
201            .unwrap();
202
203        assert_eq!(result.len(), 2);
204        assert_eq!(result[0].page_content, "HELLO WORLD");
205        assert_eq!(result[1].page_content, "GOODBYE WORLD");
206    }
207
208    #[tokio::test]
209    async fn test_function_transformer() {
210        let transformer = FunctionTransformer::new(|doc| {
211            Document::new(format!("[PROCESSED] {}", doc.page_content))
212        });
213
214        let documents = vec![Document::new("test")];
215
216        let result = transformer
217            .transform_documents(documents, HashMap::new())
218            .await
219            .unwrap();
220
221        assert_eq!(result[0].page_content, "[PROCESSED] test");
222    }
223
224    #[tokio::test]
225    async fn test_filter_transformer() {
226        let transformer = FilterTransformer::new(|doc| doc.page_content.len() > 5);
227
228        let documents = vec![
229            Document::new("hi"),
230            Document::new("hello world"),
231            Document::new("bye"),
232        ];
233
234        let result = transformer
235            .transform_documents(documents, HashMap::new())
236            .await
237            .unwrap();
238
239        assert_eq!(result.len(), 1);
240        assert_eq!(result[0].page_content, "hello world");
241    }
242}