Skip to main content

cognis_rag/retrievers/
compressor_pipeline.rs

1//! Multi-stage doc-list transformer chain.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use cognis_core::{Result, Runnable, RunnableConfig};
8
9use crate::document::Document;
10
11type Stage = Arc<dyn Runnable<Vec<Document>, Vec<Document>>>;
12
13/// Chain N doc-list transformers back-to-back. Each stage's output feeds
14/// the next.
15///
16/// Already expressible via repeated `.pipe()` — this type just gives the
17/// pattern a name and a builder.
18pub struct CompressorPipeline {
19    stages: Vec<Stage>,
20}
21
22impl Default for CompressorPipeline {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl CompressorPipeline {
29    /// Empty pipeline.
30    pub fn new() -> Self {
31        Self { stages: Vec::new() }
32    }
33
34    /// Append a stage.
35    pub fn stage(mut self, s: Stage) -> Self {
36        self.stages.push(s);
37        self
38    }
39}
40
41#[async_trait]
42impl Runnable<Vec<Document>, Vec<Document>> for CompressorPipeline {
43    async fn invoke(
44        &self,
45        mut input: Vec<Document>,
46        config: RunnableConfig,
47    ) -> Result<Vec<Document>> {
48        for s in &self.stages {
49            input = s.invoke(input, config.clone()).await?;
50        }
51        Ok(input)
52    }
53    fn name(&self) -> &str {
54        "CompressorPipeline"
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    struct DropOdd;
63    #[async_trait]
64    impl Runnable<Vec<Document>, Vec<Document>> for DropOdd {
65        async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
66            Ok(input
67                .into_iter()
68                .enumerate()
69                .filter(|(i, _)| i % 2 == 0)
70                .map(|(_, d)| d)
71                .collect())
72        }
73    }
74
75    struct Take2;
76    #[async_trait]
77    impl Runnable<Vec<Document>, Vec<Document>> for Take2 {
78        async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
79            Ok(input.into_iter().take(2).collect())
80        }
81    }
82
83    #[tokio::test]
84    async fn stages_run_in_order() {
85        let p = CompressorPipeline::new()
86            .stage(Arc::new(DropOdd))
87            .stage(Arc::new(Take2));
88        let docs: Vec<Document> = (0..6).map(|i| Document::new(i.to_string())).collect();
89        let out = p.invoke(docs, RunnableConfig::default()).await.unwrap();
90        // DropOdd → indices 0,2,4 (3 items) → Take2 → 2 items.
91        assert_eq!(out.len(), 2);
92    }
93}