cognis_rag/
cross_encoder.rs1use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use cognis_core::{Result, Runnable, RunnableConfig};
16
17use crate::document::Document;
18
19#[async_trait]
21pub trait CrossEncoder: Send + Sync {
22 async fn score(&self, query: &str, docs: &[Document]) -> Result<Vec<f32>>;
25}
26
27pub struct FnCrossEncoder<F>
30where
31 F: Fn(&str, &Document) -> f32 + Send + Sync,
32{
33 pub f: F,
35}
36
37#[async_trait]
38impl<F> CrossEncoder for FnCrossEncoder<F>
39where
40 F: Fn(&str, &Document) -> f32 + Send + Sync,
41{
42 async fn score(&self, query: &str, docs: &[Document]) -> Result<Vec<f32>> {
43 Ok(docs.iter().map(|d| (self.f)(query, d)).collect())
44 }
45}
46
47pub struct CrossEncoderReranker {
49 inner: Arc<dyn Runnable<String, Vec<Document>>>,
50 encoder: Arc<dyn CrossEncoder>,
51 top_k: usize,
52}
53
54impl CrossEncoderReranker {
55 pub fn new(
57 inner: Arc<dyn Runnable<String, Vec<Document>>>,
58 encoder: Arc<dyn CrossEncoder>,
59 top_k: usize,
60 ) -> Self {
61 Self {
62 inner,
63 encoder,
64 top_k,
65 }
66 }
67}
68
69#[async_trait]
70impl Runnable<String, Vec<Document>> for CrossEncoderReranker {
71 async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
72 let docs = self.inner.invoke(query.clone(), config).await?;
73 if docs.is_empty() {
74 return Ok(docs);
75 }
76 let scores = self.encoder.score(&query, &docs).await?;
77 let mut paired: Vec<(f32, Document)> = scores.into_iter().zip(docs).collect();
78 paired.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
79 Ok(paired
80 .into_iter()
81 .take(self.top_k)
82 .map(|(_, d)| d)
83 .collect())
84 }
85 fn name(&self) -> &str {
86 "CrossEncoderReranker"
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 struct StaticInner(Vec<Document>);
95 #[async_trait]
96 impl Runnable<String, Vec<Document>> for StaticInner {
97 async fn invoke(&self, _: String, _: RunnableConfig) -> Result<Vec<Document>> {
98 Ok(self.0.clone())
99 }
100 }
101
102 #[tokio::test]
103 async fn reranks_by_score() {
104 let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticInner(vec![
105 Document::new("apple pie").with_id("a"),
106 Document::new("rust crab").with_id("b"),
107 Document::new("rust ferris").with_id("c"),
108 ]));
109 let enc: Arc<dyn CrossEncoder> = Arc::new(FnCrossEncoder {
111 f: |_q: &str, d: &Document| d.content.matches("rust").count() as f32,
112 });
113 let r = CrossEncoderReranker::new(inner, enc, 2);
114 let out = r
115 .invoke("rust".into(), RunnableConfig::default())
116 .await
117 .unwrap();
118 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
119 assert_eq!(ids.len(), 2);
120 assert!(ids.contains(&"b".to_string()) || ids.contains(&"c".to_string()));
121 assert!(!ids.contains(&"a".to_string()));
122 }
123}