cognis_rag/retrievers/
compressor_pipeline.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use cognis_core::{Result, Runnable, RunnableConfig};
8
9use crate::document::Document;
10
11type Stage = Arc<dyn Runnable<Vec<Document>, Vec<Document>>>;
12
13pub struct CompressorPipeline {
19 stages: Vec<Stage>,
20}
21
22impl Default for CompressorPipeline {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl CompressorPipeline {
29 pub fn new() -> Self {
31 Self { stages: Vec::new() }
32 }
33
34 pub fn stage(mut self, s: Stage) -> Self {
36 self.stages.push(s);
37 self
38 }
39}
40
41#[async_trait]
42impl Runnable<Vec<Document>, Vec<Document>> for CompressorPipeline {
43 async fn invoke(
44 &self,
45 mut input: Vec<Document>,
46 config: RunnableConfig,
47 ) -> Result<Vec<Document>> {
48 for s in &self.stages {
49 input = s.invoke(input, config.clone()).await?;
50 }
51 Ok(input)
52 }
53 fn name(&self) -> &str {
54 "CompressorPipeline"
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 struct DropOdd;
63 #[async_trait]
64 impl Runnable<Vec<Document>, Vec<Document>> for DropOdd {
65 async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
66 Ok(input
67 .into_iter()
68 .enumerate()
69 .filter(|(i, _)| i % 2 == 0)
70 .map(|(_, d)| d)
71 .collect())
72 }
73 }
74
75 struct Take2;
76 #[async_trait]
77 impl Runnable<Vec<Document>, Vec<Document>> for Take2 {
78 async fn invoke(&self, input: Vec<Document>, _: RunnableConfig) -> Result<Vec<Document>> {
79 Ok(input.into_iter().take(2).collect())
80 }
81 }
82
83 #[tokio::test]
84 async fn stages_run_in_order() {
85 let p = CompressorPipeline::new()
86 .stage(Arc::new(DropOdd))
87 .stage(Arc::new(Take2));
88 let docs: Vec<Document> = (0..6).map(|i| Document::new(i.to_string())).collect();
89 let out = p.invoke(docs, RunnableConfig::default()).await.unwrap();
90 assert_eq!(out.len(), 2);
92 }
93}