cognis/retrievers/
contextual_compression.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::future::join_all;
12
13use cognis_core::{Message, Result, Runnable, RunnableConfig};
14use cognis_llm::chat::ChatOptions;
15use cognis_llm::Client;
16use cognis_rag::Document;
17
18const DEFAULT_PROMPT: &str =
19 "Decide if the passage is relevant to the query. Reply with exactly one \
20 word: `yes` or `no`. No explanation.\n\nQuery: {query}\n\nPassage: {passage}";
21
22pub struct ContextualCompressionRetriever {
24 inner: Arc<dyn Runnable<String, Vec<Document>>>,
25 client: Client,
26 prompt: String,
27}
28
29impl ContextualCompressionRetriever {
30 pub fn new(inner: Arc<dyn Runnable<String, Vec<Document>>>, client: Client) -> Self {
32 Self {
33 inner,
34 client,
35 prompt: DEFAULT_PROMPT.to_string(),
36 }
37 }
38
39 pub fn with_prompt(mut self, p: impl Into<String>) -> Self {
41 self.prompt = p.into();
42 self
43 }
44
45 fn render(&self, query: &str, passage: &str) -> String {
46 self.prompt
47 .replace("{query}", query)
48 .replace("{passage}", passage)
49 }
50
51 async fn keep(&self, query: &str, passage: &str) -> Result<bool> {
52 let prompt = self.render(query, passage);
53 let resp = self
54 .client
55 .chat(vec![Message::human(prompt)], ChatOptions::default())
56 .await?;
57 let answer = resp.message.content().trim().to_lowercase();
58 Ok(matches!(answer.as_str(), "yes" | "y" | "true"))
59 }
60}
61
62#[async_trait]
63impl Runnable<String, Vec<Document>> for ContextualCompressionRetriever {
64 async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
65 let docs = self.inner.invoke(query.clone(), config).await?;
66 let checks = docs.iter().map(|d| self.keep(&query, &d.content));
67 let verdicts: Vec<bool> = join_all(checks)
68 .await
69 .into_iter()
70 .collect::<Result<Vec<_>>>()?;
71 Ok(docs
72 .into_iter()
73 .zip(verdicts)
74 .filter_map(|(d, keep)| if keep { Some(d) } else { None })
75 .collect())
76 }
77
78 fn name(&self) -> &str {
79 "ContextualCompressionRetriever"
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 use cognis_core::{Message, Result, RunnableStream};
88 use cognis_llm::chat::{ChatOptions, ChatResponse, HealthStatus, StreamChunk, Usage};
89 use cognis_llm::provider::{LLMProvider, Provider};
90
91 struct StaticInner(Vec<Document>);
92 #[async_trait]
93 impl Runnable<String, Vec<Document>> for StaticInner {
94 async fn invoke(&self, _q: String, _: RunnableConfig) -> Result<Vec<Document>> {
95 Ok(self.0.clone())
96 }
97 }
98
99 struct KeywordJudge {
102 keep_if_contains: Vec<String>,
103 }
104 #[async_trait]
105 impl LLMProvider for KeywordJudge {
106 fn name(&self) -> &str {
107 "judge"
108 }
109 fn provider_type(&self) -> Provider {
110 Provider::Ollama
111 }
112 async fn chat_completion(
113 &self,
114 messages: Vec<Message>,
115 _opts: ChatOptions,
116 ) -> Result<ChatResponse> {
117 let prompt = messages
120 .last()
121 .map(|m| m.content().to_string())
122 .unwrap_or_default();
123 let passage = prompt
124 .split_once("Passage:")
125 .map(|(_, rest)| rest.to_lowercase())
126 .unwrap_or_else(|| prompt.to_lowercase());
127 let answer = if self
128 .keep_if_contains
129 .iter()
130 .any(|kw| passage.contains(&kw.to_lowercase()))
131 {
132 "yes"
133 } else {
134 "no"
135 };
136 Ok(ChatResponse {
137 message: Message::ai(answer),
138 usage: Some(Usage::default()),
139 finish_reason: "stop".into(),
140 model: "judge".into(),
141 })
142 }
143 async fn chat_completion_stream(
144 &self,
145 _: Vec<Message>,
146 _: ChatOptions,
147 ) -> Result<RunnableStream<StreamChunk>> {
148 unimplemented!()
149 }
150 async fn health_check(&self) -> Result<HealthStatus> {
151 Ok(HealthStatus::Healthy { latency_ms: 0 })
152 }
153 }
154
155 #[tokio::test]
156 async fn drops_irrelevant_docs() {
157 let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticInner(vec![
158 Document::new("rust ownership rules").with_id("a"),
159 Document::new("a recipe for sourdough").with_id("b"),
160 Document::new("rust borrow checker").with_id("c"),
161 ]));
162 let client = Client::new(Arc::new(KeywordJudge {
163 keep_if_contains: vec!["rust".to_string()],
164 }));
165 let r = ContextualCompressionRetriever::new(inner, client);
166 let out = r
167 .invoke("rust safety".to_string(), RunnableConfig::default())
168 .await
169 .unwrap();
170 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
171 assert_eq!(ids, vec!["a", "c"]);
172 }
173}