agent_chain_core/documents/
transformers.rs1use std::collections::HashMap;
7
8use async_trait::async_trait;
9use serde_json::Value;
10
11use super::Document;
12
13#[async_trait]
44pub trait BaseDocumentTransformer: Send + Sync {
45 async fn transform_documents(
56 &self,
57 documents: Vec<Document>,
58 kwargs: HashMap<String, Value>,
59 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>;
60
61 fn transform_documents_sync(
76 &self,
77 documents: Vec<Document>,
78 kwargs: HashMap<String, Value>,
79 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>>
80 where
81 Self: Sized,
82 {
83 let _ = (documents, kwargs);
87 Err("Sync version not implemented - use transform_documents instead".into())
88 }
89}
90
91pub struct FunctionTransformer<F>
93where
94 F: Fn(Document) -> Document + Send + Sync,
95{
96 transform_fn: F,
97}
98
99impl<F> FunctionTransformer<F>
100where
101 F: Fn(Document) -> Document + Send + Sync,
102{
103 pub fn new(transform_fn: F) -> Self {
105 Self { transform_fn }
106 }
107}
108
109#[async_trait]
110impl<F> BaseDocumentTransformer for FunctionTransformer<F>
111where
112 F: Fn(Document) -> Document + Send + Sync,
113{
114 async fn transform_documents(
115 &self,
116 documents: Vec<Document>,
117 _kwargs: HashMap<String, Value>,
118 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
119 Ok(documents.into_iter().map(&self.transform_fn).collect())
120 }
121
122 fn transform_documents_sync(
123 &self,
124 documents: Vec<Document>,
125 _kwargs: HashMap<String, Value>,
126 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
127 Ok(documents.into_iter().map(&self.transform_fn).collect())
128 }
129}
130
131pub struct FilterTransformer<F>
133where
134 F: Fn(&Document) -> bool + Send + Sync,
135{
136 filter_fn: F,
137}
138
139impl<F> FilterTransformer<F>
140where
141 F: Fn(&Document) -> bool + Send + Sync,
142{
143 pub fn new(filter_fn: F) -> Self {
145 Self { filter_fn }
146 }
147}
148
149#[async_trait]
150impl<F> BaseDocumentTransformer for FilterTransformer<F>
151where
152 F: Fn(&Document) -> bool + Send + Sync,
153{
154 async fn transform_documents(
155 &self,
156 documents: Vec<Document>,
157 _kwargs: HashMap<String, Value>,
158 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
159 Ok(documents.into_iter().filter(&self.filter_fn).collect())
160 }
161
162 fn transform_documents_sync(
163 &self,
164 documents: Vec<Document>,
165 _kwargs: HashMap<String, Value>,
166 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
167 Ok(documents.into_iter().filter(&self.filter_fn).collect())
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 struct UppercaseTransformer;
176
177 #[async_trait]
178 impl BaseDocumentTransformer for UppercaseTransformer {
179 async fn transform_documents(
180 &self,
181 documents: Vec<Document>,
182 _kwargs: HashMap<String, Value>,
183 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
184 Ok(documents
185 .into_iter()
186 .map(|doc| {
187 Document::new(doc.page_content.to_uppercase()).with_metadata(doc.metadata)
188 })
189 .collect())
190 }
191 }
192
193 #[tokio::test]
194 async fn test_transform_documents() {
195 let transformer = UppercaseTransformer;
196 let documents = vec![Document::new("hello world"), Document::new("goodbye world")];
197
198 let result = transformer
199 .transform_documents(documents, HashMap::new())
200 .await
201 .unwrap();
202
203 assert_eq!(result.len(), 2);
204 assert_eq!(result[0].page_content, "HELLO WORLD");
205 assert_eq!(result[1].page_content, "GOODBYE WORLD");
206 }
207
208 #[tokio::test]
209 async fn test_function_transformer() {
210 let transformer = FunctionTransformer::new(|doc| {
211 Document::new(format!("[PROCESSED] {}", doc.page_content))
212 });
213
214 let documents = vec![Document::new("test")];
215
216 let result = transformer
217 .transform_documents(documents, HashMap::new())
218 .await
219 .unwrap();
220
221 assert_eq!(result[0].page_content, "[PROCESSED] test");
222 }
223
224 #[tokio::test]
225 async fn test_filter_transformer() {
226 let transformer = FilterTransformer::new(|doc| doc.page_content.len() > 5);
227
228 let documents = vec![
229 Document::new("hi"),
230 Document::new("hello world"),
231 Document::new("bye"),
232 ];
233
234 let result = transformer
235 .transform_documents(documents, HashMap::new())
236 .await
237 .unwrap();
238
239 assert_eq!(result.len(), 1);
240 assert_eq!(result[0].page_content, "hello world");
241 }
242}