Skip to main content

cognis_rag/
transformers.rs

1//! Document-list transformers — `Runnable<Vec<Document>, Vec<Document>>`.
2//!
3//! Rust-native take: V1 had a separate "DocumentTransformer" trait for
4//! pre-store / post-retrieval doc operations. In V2 these are just
5//! `Runnable`s — they compose with `.pipe()` and slot into chains
6//! anywhere a runnable is expected. No new trait surface.
7
8use std::collections::{HashMap, HashSet};
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use cognis_core::{Result, Runnable, RunnableConfig};
16
17use crate::document::Document;
18
19/// Reorder retrieved documents so the most-relevant ones sit at the
20/// **head and tail** of the list. LLMs attend better to ends than to the
21/// middle of long contexts; this is the classic "lost in the middle" fix.
22///
23/// Assumes the input is already ranked best-first (the standard retriever
24/// output). Reshuffles into: `[1, 3, 5, ..., 6, 4, 2]` (best at index 0,
25/// next-best at last index).
26#[derive(Debug, Default, Clone, Copy)]
27pub struct LongContextReorder;
28
29impl LongContextReorder {
30    /// Construct.
31    pub fn new() -> Self {
32        Self
33    }
34
35    /// Reorder `docs` (assumed best-first ranked) so the best ranks live
36    /// at both ends. Pure function — useful for tests and ad-hoc use.
37    pub fn reorder(docs: Vec<Document>) -> Vec<Document> {
38        let mut head: Vec<Document> = Vec::with_capacity(docs.len());
39        let mut tail: Vec<Document> = Vec::with_capacity(docs.len());
40        for (i, d) in docs.into_iter().enumerate() {
41            if i % 2 == 0 {
42                head.push(d);
43            } else {
44                tail.push(d);
45            }
46        }
47        tail.reverse();
48        head.extend(tail);
49        head
50    }
51}
52
53#[async_trait]
54impl Runnable<Vec<Document>, Vec<Document>> for LongContextReorder {
55    async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
56        Ok(Self::reorder(input))
57    }
58    fn name(&self) -> &str {
59        "LongContextReorder"
60    }
61}
62
63/// Drop duplicate documents from the list. By default, two docs are
64/// considered duplicates if they have the same `content` (whitespace-
65/// trimmed). Use [`Dedup::by`] to dedupe on a custom key (e.g. by id,
66/// or by a normalized hash of the content).
67///
68/// First-seen wins; later duplicates are discarded.
69pub struct Dedup {
70    key_fn: Arc<dyn Fn(&Document) -> String + Send + Sync>,
71}
72
73impl Default for Dedup {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Dedup {
80    /// Dedupe by trimmed content.
81    pub fn new() -> Self {
82        Self {
83            key_fn: Arc::new(|d: &Document| d.content.trim().to_string()),
84        }
85    }
86
87    /// Dedupe by a caller-supplied key function (e.g. `|d| d.id.clone().unwrap_or_default()`).
88    pub fn by<F>(key_fn: F) -> Self
89    where
90        F: Fn(&Document) -> String + Send + Sync + 'static,
91    {
92        Self {
93            key_fn: Arc::new(key_fn),
94        }
95    }
96
97    /// Pure form — useful for tests and ad-hoc use.
98    pub fn dedup(&self, docs: Vec<Document>) -> Vec<Document> {
99        let mut seen: HashSet<u64> = HashSet::new();
100        let mut out = Vec::with_capacity(docs.len());
101        for d in docs {
102            let key = (self.key_fn)(&d);
103            let mut h = std::collections::hash_map::DefaultHasher::new();
104            key.hash(&mut h);
105            if seen.insert(h.finish()) {
106                out.push(d);
107            }
108        }
109        out
110    }
111}
112
113#[async_trait]
114impl Runnable<Vec<Document>, Vec<Document>> for Dedup {
115    async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
116        Ok(self.dedup(input))
117    }
118    fn name(&self) -> &str {
119        "Dedup"
120    }
121}
122
123/// Per-document mutator used by [`Enrichment`].
124pub type EnrichmentFn = Arc<dyn Fn(&mut Document) -> Result<()> + Send + Sync>;
125
126/// Apply a per-document transform — e.g. tag, summarize, redact.
127/// The closure mutates the document in place; documents that fail can
128/// be filtered by returning an `Err` (the transform aborts on first
129/// error).
130pub struct Enrichment {
131    f: EnrichmentFn,
132    name: &'static str,
133}
134
135impl Enrichment {
136    /// Wrap a per-doc enrichment closure.
137    pub fn new<F>(f: F) -> Self
138    where
139        F: Fn(&mut Document) -> Result<()> + Send + Sync + 'static,
140    {
141        Self {
142            f: Arc::new(f),
143            name: "Enrichment",
144        }
145    }
146
147    /// Override the runnable name (shown in tracing/logs).
148    pub fn with_name(mut self, name: &'static str) -> Self {
149        self.name = name;
150        self
151    }
152}
153
154#[async_trait]
155impl Runnable<Vec<Document>, Vec<Document>> for Enrichment {
156    async fn invoke(&self, mut input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
157        for d in &mut input {
158            (self.f)(d)?;
159        }
160        Ok(input)
161    }
162    fn name(&self) -> &str {
163        self.name
164    }
165}
166
167/// Set / merge fixed metadata onto every document. Use to tag a batch
168/// with provenance (`source = "kb-2026-05"`), pipeline stage, etc.
169///
170/// Existing keys are overwritten by default; use [`MetadataTransformer::merge_only_missing`]
171/// to keep existing values.
172#[derive(Debug, Default, Clone)]
173pub struct MetadataTransformer {
174    fields: HashMap<String, Value>,
175    only_missing: bool,
176}
177
178impl MetadataTransformer {
179    /// Empty transformer — add fields with [`MetadataTransformer::set`].
180    pub fn new() -> Self {
181        Self::default()
182    }
183
184    /// Construct from a map. Keys are merged into every document.
185    pub fn from_map(fields: HashMap<String, Value>) -> Self {
186        Self {
187            fields,
188            only_missing: false,
189        }
190    }
191
192    /// Add a single key / value.
193    pub fn set(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
194        self.fields.insert(key.into(), value.into());
195        self
196    }
197
198    /// Don't overwrite keys the doc already has.
199    pub fn merge_only_missing(mut self) -> Self {
200        self.only_missing = true;
201        self
202    }
203
204    /// Pure form.
205    pub fn apply(&self, mut docs: Vec<Document>) -> Vec<Document> {
206        for d in &mut docs {
207            for (k, v) in &self.fields {
208                if self.only_missing && d.metadata.contains_key(k) {
209                    continue;
210                }
211                d.metadata.insert(k.clone(), v.clone());
212            }
213        }
214        docs
215    }
216}
217
218#[async_trait]
219impl Runnable<Vec<Document>, Vec<Document>> for MetadataTransformer {
220    async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
221        Ok(self.apply(input))
222    }
223    fn name(&self) -> &str {
224        "MetadataTransformer"
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn doc(id: &str) -> Document {
233        Document::new(id).with_id(id)
234    }
235
236    #[test]
237    fn reorder_pattern() {
238        // Input ranked best-first: [1, 2, 3, 4, 5]
239        // Expected: [1, 3, 5, 4, 2] — best at ends.
240        let docs = vec![doc("1"), doc("2"), doc("3"), doc("4"), doc("5")];
241        let out = LongContextReorder::reorder(docs);
242        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
243        assert_eq!(ids, vec!["1", "3", "5", "4", "2"]);
244    }
245
246    #[test]
247    fn empty_passes_through() {
248        let out = LongContextReorder::reorder(Vec::new());
249        assert!(out.is_empty());
250    }
251
252    #[test]
253    fn single_doc_passes_through() {
254        let out = LongContextReorder::reorder(vec![doc("only")]);
255        assert_eq!(out.len(), 1);
256        assert_eq!(out[0].id.as_deref(), Some("only"));
257    }
258
259    #[tokio::test]
260    async fn runnable_invoke() {
261        let r = LongContextReorder::new();
262        let out = r
263            .invoke(
264                vec![doc("a"), doc("b"), doc("c")],
265                RunnableConfig::default(),
266            )
267            .await
268            .unwrap();
269        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
270        assert_eq!(ids, vec!["a", "c", "b"]);
271    }
272
273    #[test]
274    fn dedup_by_content_keeps_first_seen() {
275        let docs = vec![
276            Document::new("hello"),
277            Document::new("world"),
278            Document::new(" hello "), // trimmed dup
279            Document::new("rust"),
280        ];
281        let out = Dedup::new().dedup(docs);
282        let contents: Vec<_> = out.iter().map(|d| d.content.clone()).collect();
283        assert_eq!(contents, vec!["hello", "world", "rust"]);
284    }
285
286    #[test]
287    fn dedup_by_id_uses_custom_key() {
288        let docs = vec![
289            Document::new("a body").with_id("a"),
290            Document::new("a body").with_id("b"), // same content, different id
291            Document::new("c body").with_id("a"), // dup id of first
292        ];
293        let out = Dedup::by(|d| d.id.clone().unwrap_or_default()).dedup(docs);
294        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
295        assert_eq!(ids, vec!["a", "b"]);
296    }
297
298    #[tokio::test]
299    async fn enrichment_applies_per_doc() {
300        let r = Enrichment::new(|d: &mut Document| {
301            d.content = d.content.to_uppercase();
302            d.metadata
303                .insert("seen".into(), serde_json::Value::Bool(true));
304            Ok(())
305        });
306        let out = r
307            .invoke(
308                vec![Document::new("hi"), Document::new("ho")],
309                RunnableConfig::default(),
310            )
311            .await
312            .unwrap();
313        assert_eq!(out[0].content, "HI");
314        assert_eq!(out[1].content, "HO");
315        assert!(out[0].metadata.contains_key("seen"));
316    }
317
318    #[test]
319    fn metadata_transformer_overwrites_by_default() {
320        let docs = vec![
321            Document::new("d1").with_metadata("source", serde_json::json!("old")),
322            Document::new("d2"),
323        ];
324        let out = MetadataTransformer::new().set("source", "new").apply(docs);
325        assert_eq!(
326            out[0].metadata.get("source").unwrap(),
327            &serde_json::json!("new")
328        );
329        assert_eq!(
330            out[1].metadata.get("source").unwrap(),
331            &serde_json::json!("new")
332        );
333    }
334
335    #[test]
336    fn metadata_transformer_only_missing_preserves_existing() {
337        let docs = vec![
338            Document::new("d1").with_metadata("source", serde_json::json!("old")),
339            Document::new("d2"),
340        ];
341        let out = MetadataTransformer::new()
342            .set("source", "new")
343            .merge_only_missing()
344            .apply(docs);
345        assert_eq!(
346            out[0].metadata.get("source").unwrap(),
347            &serde_json::json!("old")
348        );
349        assert_eq!(
350            out[1].metadata.get("source").unwrap(),
351            &serde_json::json!("new")
352        );
353    }
354}