1use std::collections::{HashMap, HashSet};
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use cognis_core::{Result, Runnable, RunnableConfig};
16
17use crate::document::Document;
18
19#[derive(Debug, Default, Clone, Copy)]
27pub struct LongContextReorder;
28
29impl LongContextReorder {
30 pub fn new() -> Self {
32 Self
33 }
34
35 pub fn reorder(docs: Vec<Document>) -> Vec<Document> {
38 let mut head: Vec<Document> = Vec::with_capacity(docs.len());
39 let mut tail: Vec<Document> = Vec::with_capacity(docs.len());
40 for (i, d) in docs.into_iter().enumerate() {
41 if i % 2 == 0 {
42 head.push(d);
43 } else {
44 tail.push(d);
45 }
46 }
47 tail.reverse();
48 head.extend(tail);
49 head
50 }
51}
52
53#[async_trait]
54impl Runnable<Vec<Document>, Vec<Document>> for LongContextReorder {
55 async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
56 Ok(Self::reorder(input))
57 }
58 fn name(&self) -> &str {
59 "LongContextReorder"
60 }
61}
62
63pub struct Dedup {
70 key_fn: Arc<dyn Fn(&Document) -> String + Send + Sync>,
71}
72
73impl Default for Dedup {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Dedup {
80 pub fn new() -> Self {
82 Self {
83 key_fn: Arc::new(|d: &Document| d.content.trim().to_string()),
84 }
85 }
86
87 pub fn by<F>(key_fn: F) -> Self
89 where
90 F: Fn(&Document) -> String + Send + Sync + 'static,
91 {
92 Self {
93 key_fn: Arc::new(key_fn),
94 }
95 }
96
97 pub fn dedup(&self, docs: Vec<Document>) -> Vec<Document> {
99 let mut seen: HashSet<u64> = HashSet::new();
100 let mut out = Vec::with_capacity(docs.len());
101 for d in docs {
102 let key = (self.key_fn)(&d);
103 let mut h = std::collections::hash_map::DefaultHasher::new();
104 key.hash(&mut h);
105 if seen.insert(h.finish()) {
106 out.push(d);
107 }
108 }
109 out
110 }
111}
112
113#[async_trait]
114impl Runnable<Vec<Document>, Vec<Document>> for Dedup {
115 async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
116 Ok(self.dedup(input))
117 }
118 fn name(&self) -> &str {
119 "Dedup"
120 }
121}
122
123pub type EnrichmentFn = Arc<dyn Fn(&mut Document) -> Result<()> + Send + Sync>;
125
126pub struct Enrichment {
131 f: EnrichmentFn,
132 name: &'static str,
133}
134
135impl Enrichment {
136 pub fn new<F>(f: F) -> Self
138 where
139 F: Fn(&mut Document) -> Result<()> + Send + Sync + 'static,
140 {
141 Self {
142 f: Arc::new(f),
143 name: "Enrichment",
144 }
145 }
146
147 pub fn with_name(mut self, name: &'static str) -> Self {
149 self.name = name;
150 self
151 }
152}
153
154#[async_trait]
155impl Runnable<Vec<Document>, Vec<Document>> for Enrichment {
156 async fn invoke(&self, mut input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
157 for d in &mut input {
158 (self.f)(d)?;
159 }
160 Ok(input)
161 }
162 fn name(&self) -> &str {
163 self.name
164 }
165}
166
167#[derive(Debug, Default, Clone)]
173pub struct MetadataTransformer {
174 fields: HashMap<String, Value>,
175 only_missing: bool,
176}
177
178impl MetadataTransformer {
179 pub fn new() -> Self {
181 Self::default()
182 }
183
184 pub fn from_map(fields: HashMap<String, Value>) -> Self {
186 Self {
187 fields,
188 only_missing: false,
189 }
190 }
191
192 pub fn set(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
194 self.fields.insert(key.into(), value.into());
195 self
196 }
197
198 pub fn merge_only_missing(mut self) -> Self {
200 self.only_missing = true;
201 self
202 }
203
204 pub fn apply(&self, mut docs: Vec<Document>) -> Vec<Document> {
206 for d in &mut docs {
207 for (k, v) in &self.fields {
208 if self.only_missing && d.metadata.contains_key(k) {
209 continue;
210 }
211 d.metadata.insert(k.clone(), v.clone());
212 }
213 }
214 docs
215 }
216}
217
218#[async_trait]
219impl Runnable<Vec<Document>, Vec<Document>> for MetadataTransformer {
220 async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
221 Ok(self.apply(input))
222 }
223 fn name(&self) -> &str {
224 "MetadataTransformer"
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn doc(id: &str) -> Document {
233 Document::new(id).with_id(id)
234 }
235
236 #[test]
237 fn reorder_pattern() {
238 let docs = vec![doc("1"), doc("2"), doc("3"), doc("4"), doc("5")];
241 let out = LongContextReorder::reorder(docs);
242 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
243 assert_eq!(ids, vec!["1", "3", "5", "4", "2"]);
244 }
245
246 #[test]
247 fn empty_passes_through() {
248 let out = LongContextReorder::reorder(Vec::new());
249 assert!(out.is_empty());
250 }
251
252 #[test]
253 fn single_doc_passes_through() {
254 let out = LongContextReorder::reorder(vec![doc("only")]);
255 assert_eq!(out.len(), 1);
256 assert_eq!(out[0].id.as_deref(), Some("only"));
257 }
258
259 #[tokio::test]
260 async fn runnable_invoke() {
261 let r = LongContextReorder::new();
262 let out = r
263 .invoke(
264 vec![doc("a"), doc("b"), doc("c")],
265 RunnableConfig::default(),
266 )
267 .await
268 .unwrap();
269 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
270 assert_eq!(ids, vec!["a", "c", "b"]);
271 }
272
273 #[test]
274 fn dedup_by_content_keeps_first_seen() {
275 let docs = vec![
276 Document::new("hello"),
277 Document::new("world"),
278 Document::new(" hello "), Document::new("rust"),
280 ];
281 let out = Dedup::new().dedup(docs);
282 let contents: Vec<_> = out.iter().map(|d| d.content.clone()).collect();
283 assert_eq!(contents, vec!["hello", "world", "rust"]);
284 }
285
286 #[test]
287 fn dedup_by_id_uses_custom_key() {
288 let docs = vec![
289 Document::new("a body").with_id("a"),
290 Document::new("a body").with_id("b"), Document::new("c body").with_id("a"), ];
293 let out = Dedup::by(|d| d.id.clone().unwrap_or_default()).dedup(docs);
294 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
295 assert_eq!(ids, vec!["a", "b"]);
296 }
297
298 #[tokio::test]
299 async fn enrichment_applies_per_doc() {
300 let r = Enrichment::new(|d: &mut Document| {
301 d.content = d.content.to_uppercase();
302 d.metadata
303 .insert("seen".into(), serde_json::Value::Bool(true));
304 Ok(())
305 });
306 let out = r
307 .invoke(
308 vec![Document::new("hi"), Document::new("ho")],
309 RunnableConfig::default(),
310 )
311 .await
312 .unwrap();
313 assert_eq!(out[0].content, "HI");
314 assert_eq!(out[1].content, "HO");
315 assert!(out[0].metadata.contains_key("seen"));
316 }
317
318 #[test]
319 fn metadata_transformer_overwrites_by_default() {
320 let docs = vec![
321 Document::new("d1").with_metadata("source", serde_json::json!("old")),
322 Document::new("d2"),
323 ];
324 let out = MetadataTransformer::new().set("source", "new").apply(docs);
325 assert_eq!(
326 out[0].metadata.get("source").unwrap(),
327 &serde_json::json!("new")
328 );
329 assert_eq!(
330 out[1].metadata.get("source").unwrap(),
331 &serde_json::json!("new")
332 );
333 }
334
335 #[test]
336 fn metadata_transformer_only_missing_preserves_existing() {
337 let docs = vec![
338 Document::new("d1").with_metadata("source", serde_json::json!("old")),
339 Document::new("d2"),
340 ];
341 let out = MetadataTransformer::new()
342 .set("source", "new")
343 .merge_only_missing()
344 .apply(docs);
345 assert_eq!(
346 out[0].metadata.get("source").unwrap(),
347 &serde_json::json!("old")
348 );
349 assert_eq!(
350 out[1].metadata.get("source").unwrap(),
351 &serde_json::json!("new")
352 );
353 }
354}