Skip to main content

cognis_rag/
cross_encoder.rs

1//! Cross-encoder scoring trait + cross-encoder-based reranker.
2//!
3//! Distinct from LLM-as-judge reranking: a cross-encoder is a model that
4//! takes `(query, doc)` and emits a relevance score in one forward pass.
5//! Common production choice for first-stage RAG reranking.
6//!
7//! cognis doesn't bundle a model — implement [`CrossEncoder`] against
8//! whatever you actually run (Cohere rerank, hosted bge-reranker, local
9//! `tch-rs`, …).
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use cognis_core::{Result, Runnable, RunnableConfig};
16
17use crate::document::Document;
18
19/// Scores `(query, doc)` pairs.
20#[async_trait]
21pub trait CrossEncoder: Send + Sync {
22    /// Score every `doc` against `query`. Higher = more relevant.
23    /// Implementations should batch when possible.
24    async fn score(&self, query: &str, docs: &[Document]) -> Result<Vec<f32>>;
25}
26
27/// Closure-backed cross-encoder. Useful for tests; in production use a
28/// real impl that calls a hosted scorer.
29pub struct FnCrossEncoder<F>
30where
31    F: Fn(&str, &Document) -> f32 + Send + Sync,
32{
33    /// Per-doc scorer; runs concurrently across docs.
34    pub f: F,
35}
36
37#[async_trait]
38impl<F> CrossEncoder for FnCrossEncoder<F>
39where
40    F: Fn(&str, &Document) -> f32 + Send + Sync,
41{
42    async fn score(&self, query: &str, docs: &[Document]) -> Result<Vec<f32>> {
43        Ok(docs.iter().map(|d| (self.f)(query, d)).collect())
44    }
45}
46
47/// Wraps an inner retriever, then reranks its hits via a [`CrossEncoder`].
48pub struct CrossEncoderReranker {
49    inner: Arc<dyn Runnable<String, Vec<Document>>>,
50    encoder: Arc<dyn CrossEncoder>,
51    top_k: usize,
52}
53
54impl CrossEncoderReranker {
55    /// Build with an inner retriever + cross-encoder + post-rerank top-k.
56    pub fn new(
57        inner: Arc<dyn Runnable<String, Vec<Document>>>,
58        encoder: Arc<dyn CrossEncoder>,
59        top_k: usize,
60    ) -> Self {
61        Self {
62            inner,
63            encoder,
64            top_k,
65        }
66    }
67}
68
69#[async_trait]
70impl Runnable<String, Vec<Document>> for CrossEncoderReranker {
71    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
72        let docs = self.inner.invoke(query.clone(), config).await?;
73        if docs.is_empty() {
74            return Ok(docs);
75        }
76        let scores = self.encoder.score(&query, &docs).await?;
77        let mut paired: Vec<(f32, Document)> = scores.into_iter().zip(docs).collect();
78        paired.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
79        Ok(paired
80            .into_iter()
81            .take(self.top_k)
82            .map(|(_, d)| d)
83            .collect())
84    }
85    fn name(&self) -> &str {
86        "CrossEncoderReranker"
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    struct StaticInner(Vec<Document>);
95    #[async_trait]
96    impl Runnable<String, Vec<Document>> for StaticInner {
97        async fn invoke(&self, _: String, _: RunnableConfig) -> Result<Vec<Document>> {
98            Ok(self.0.clone())
99        }
100    }
101
102    #[tokio::test]
103    async fn reranks_by_score() {
104        let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticInner(vec![
105            Document::new("apple pie").with_id("a"),
106            Document::new("rust crab").with_id("b"),
107            Document::new("rust ferris").with_id("c"),
108        ]));
109        // Score by count of "rust" in content.
110        let enc: Arc<dyn CrossEncoder> = Arc::new(FnCrossEncoder {
111            f: |_q: &str, d: &Document| d.content.matches("rust").count() as f32,
112        });
113        let r = CrossEncoderReranker::new(inner, enc, 2);
114        let out = r
115            .invoke("rust".into(), RunnableConfig::default())
116            .await
117            .unwrap();
118        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
119        assert_eq!(ids.len(), 2);
120        assert!(ids.contains(&"b".to_string()) || ids.contains(&"c".to_string()));
121        assert!(!ids.contains(&"a".to_string()));
122    }
123}