Skip to main content

cognis/retrievers/
contextual_compression.rs

1//! `ContextualCompressionRetriever` — score each retrieved doc with the LLM
2//! and drop low-relevance ones.
3//!
4//! For each candidate doc the inner retriever returned, ask the LLM:
5//! "Given the query, is this passage relevant? Reply yes or no." Drop the
6//! `no`s. Order is preserved among the kept documents.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::future::join_all;
12
13use cognis_core::{Message, Result, Runnable, RunnableConfig};
14use cognis_llm::chat::ChatOptions;
15use cognis_llm::Client;
16use cognis_rag::Document;
17
18const DEFAULT_PROMPT: &str =
19    "Decide if the passage is relevant to the query. Reply with exactly one \
20     word: `yes` or `no`. No explanation.\n\nQuery: {query}\n\nPassage: {passage}";
21
22/// Filters retrieved documents by an LLM relevance check.
23pub struct ContextualCompressionRetriever {
24    inner: Arc<dyn Runnable<String, Vec<Document>>>,
25    client: Client,
26    prompt: String,
27}
28
29impl ContextualCompressionRetriever {
30    /// Wrap an inner retriever with an LLM-driven filter.
31    pub fn new(inner: Arc<dyn Runnable<String, Vec<Document>>>, client: Client) -> Self {
32        Self {
33            inner,
34            client,
35            prompt: DEFAULT_PROMPT.to_string(),
36        }
37    }
38
39    /// Override the relevance prompt. Placeholders: `{query}`, `{passage}`.
40    pub fn with_prompt(mut self, p: impl Into<String>) -> Self {
41        self.prompt = p.into();
42        self
43    }
44
45    fn render(&self, query: &str, passage: &str) -> String {
46        self.prompt
47            .replace("{query}", query)
48            .replace("{passage}", passage)
49    }
50
51    async fn keep(&self, query: &str, passage: &str) -> Result<bool> {
52        let prompt = self.render(query, passage);
53        let resp = self
54            .client
55            .chat(vec![Message::human(prompt)], ChatOptions::default())
56            .await?;
57        let answer = resp.message.content().trim().to_lowercase();
58        Ok(matches!(answer.as_str(), "yes" | "y" | "true"))
59    }
60}
61
62#[async_trait]
63impl Runnable<String, Vec<Document>> for ContextualCompressionRetriever {
64    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
65        let docs = self.inner.invoke(query.clone(), config).await?;
66        let checks = docs.iter().map(|d| self.keep(&query, &d.content));
67        let verdicts: Vec<bool> = join_all(checks)
68            .await
69            .into_iter()
70            .collect::<Result<Vec<_>>>()?;
71        Ok(docs
72            .into_iter()
73            .zip(verdicts)
74            .filter_map(|(d, keep)| if keep { Some(d) } else { None })
75            .collect())
76    }
77
78    fn name(&self) -> &str {
79        "ContextualCompressionRetriever"
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    use cognis_core::{Message, Result, RunnableStream};
88    use cognis_llm::chat::{ChatOptions, ChatResponse, HealthStatus, StreamChunk, Usage};
89    use cognis_llm::provider::{LLMProvider, Provider};
90
91    struct StaticInner(Vec<Document>);
92    #[async_trait]
93    impl Runnable<String, Vec<Document>> for StaticInner {
94        async fn invoke(&self, _q: String, _: RunnableConfig) -> Result<Vec<Document>> {
95            Ok(self.0.clone())
96        }
97    }
98
99    /// Per-substring scoring provider — returns "yes" for any passage whose
100    /// content contains a configured trigger; "no" otherwise.
101    struct KeywordJudge {
102        keep_if_contains: Vec<String>,
103    }
104    #[async_trait]
105    impl LLMProvider for KeywordJudge {
106        fn name(&self) -> &str {
107            "judge"
108        }
109        fn provider_type(&self) -> Provider {
110            Provider::Ollama
111        }
112        async fn chat_completion(
113            &self,
114            messages: Vec<Message>,
115            _opts: ChatOptions,
116        ) -> Result<ChatResponse> {
117            // Score against just the passage text; the prompt body always
118            // contains the query so we'd otherwise match every prompt.
119            let prompt = messages
120                .last()
121                .map(|m| m.content().to_string())
122                .unwrap_or_default();
123            let passage = prompt
124                .split_once("Passage:")
125                .map(|(_, rest)| rest.to_lowercase())
126                .unwrap_or_else(|| prompt.to_lowercase());
127            let answer = if self
128                .keep_if_contains
129                .iter()
130                .any(|kw| passage.contains(&kw.to_lowercase()))
131            {
132                "yes"
133            } else {
134                "no"
135            };
136            Ok(ChatResponse {
137                message: Message::ai(answer),
138                usage: Some(Usage::default()),
139                finish_reason: "stop".into(),
140                model: "judge".into(),
141            })
142        }
143        async fn chat_completion_stream(
144            &self,
145            _: Vec<Message>,
146            _: ChatOptions,
147        ) -> Result<RunnableStream<StreamChunk>> {
148            unimplemented!()
149        }
150        async fn health_check(&self) -> Result<HealthStatus> {
151            Ok(HealthStatus::Healthy { latency_ms: 0 })
152        }
153    }
154
155    #[tokio::test]
156    async fn drops_irrelevant_docs() {
157        let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticInner(vec![
158            Document::new("rust ownership rules").with_id("a"),
159            Document::new("a recipe for sourdough").with_id("b"),
160            Document::new("rust borrow checker").with_id("c"),
161        ]));
162        let client = Client::new(Arc::new(KeywordJudge {
163            keep_if_contains: vec!["rust".to_string()],
164        }));
165        let r = ContextualCompressionRetriever::new(inner, client);
166        let out = r
167            .invoke("rust safety".to_string(), RunnableConfig::default())
168            .await
169            .unwrap();
170        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
171        assert_eq!(ids, vec!["a", "c"]);
172    }
173}