Skip to main content

coding_agent_search/search/
reranker.rs

1//! Reranker trait and types for cross-encoder reranking.
2//!
3//! This module re-exports the canonical [`Reranker`] trait from frankensearch's
4//! [`SyncRerank`](frankensearch::SyncRerank) trait. All reranking implementations
5//! must satisfy `Reranker`, which provides a synchronous reranking interface
6//! suitable for cass's sync call sites.
7//!
8//! The [`SyncRerankerAdapter`](frankensearch::SyncRerankerAdapter) can wrap any
9//! `Reranker` implementor into frankensearch's async `Reranker` trait when needed
10//! for the frankensearch search pipeline.
11//!
12//! # Implementations
13//!
14//! - **FastEmbed Reranker**: Uses ms-marco-MiniLM-L-6-v2 cross-encoder via FastEmbed.
15//!   Requires model download with user consent.
16
17use std::fmt;
18
19pub use frankensearch::SearchError as RerankerError;
20pub use frankensearch::SearchResult as RerankerResult;
21pub use frankensearch::SyncRerank as Reranker;
22pub use frankensearch::{RerankDocument, RerankScore};
23
24/// Convenience function to rerank raw text documents.
25///
26/// Wraps `&[&str]` documents into [`RerankDocument`] structs and extracts
27/// the resulting scores back into a `Vec<f32>` in original document order.
28pub fn rerank_texts(
29    reranker: &dyn Reranker,
30    query: &str,
31    documents: &[&str],
32) -> RerankerResult<Vec<f32>> {
33    let rerank_docs: Vec<RerankDocument> = documents
34        .iter()
35        .enumerate()
36        .map(|(i, text)| RerankDocument {
37            doc_id: i.to_string(),
38            text: text.to_string(),
39        })
40        .collect();
41
42    let scores = reranker.rerank_sync(query, &rerank_docs)?;
43
44    // Convert RerankScore vec back to Vec<f32> in original document order
45    let mut result = vec![0.0f32; documents.len()];
46    for rs in &scores {
47        if let Ok(idx) = rs.doc_id.parse::<usize>()
48            && idx < result.len()
49        {
50            result[idx] = rs.score;
51        }
52    }
53    Ok(result)
54}
55
56/// Metadata about a reranker for display and logging.
57#[derive(Debug, Clone)]
58pub struct RerankerInfo {
59    /// The reranker's unique identifier.
60    pub id: String,
61    /// Whether the reranker is available.
62    pub is_available: bool,
63}
64
65impl RerankerInfo {
66    /// Create info from a reranker instance.
67    pub fn from_reranker(reranker: &dyn Reranker) -> Self {
68        Self {
69            id: reranker.id().to_string(),
70            is_available: reranker.is_available(),
71        }
72    }
73}
74
75impl fmt::Display for RerankerInfo {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        let status = if self.is_available {
78            "available"
79        } else {
80            "unavailable"
81        };
82        write!(f, "{} ({})", self.id, status)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::search::fastembed_reranker::FastEmbedReranker;
90    use std::path::PathBuf;
91
92    fn fastembed_fixture_dir() -> PathBuf {
93        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
94            .join("tests/fixtures/models/xenova-ms-marco-minilm-l6-v2-int8")
95    }
96
97    fn load_fastembed_fixture() -> FastEmbedReranker {
98        FastEmbedReranker::load_from_dir(&fastembed_fixture_dir())
99            .expect("fastembed reranker fixture should load")
100    }
101
102    #[test]
103    fn test_reranker_trait_basic() {
104        let reranker = load_fastembed_fixture();
105        let scores = rerank_texts(
106            &reranker,
107            "test query",
108            &["short", "medium length doc", "longer document text"],
109        )
110        .unwrap();
111
112        assert_eq!(scores.len(), 3);
113        for score in scores {
114            assert!(score.is_finite());
115        }
116    }
117
118    #[test]
119    fn test_reranker_unavailable() {
120        let tmp = tempfile::tempdir().expect("tempdir");
121        let err = match FastEmbedReranker::load_from_dir(tmp.path()) {
122            Ok(_) => panic!("expected unavailable error"),
123            Err(err) => err,
124        };
125        assert!(matches!(
126            err,
127            RerankerError::RerankFailed { .. }
128                | RerankerError::EmbedderUnavailable { .. }
129                | RerankerError::RerankerUnavailable { .. }
130        ));
131    }
132
133    #[test]
134    fn test_reranker_empty_query_error() {
135        let reranker = load_fastembed_fixture();
136        let result = rerank_texts(&reranker, "", &["doc"]);
137        assert!(result.is_err());
138    }
139
140    #[test]
141    fn test_reranker_empty_documents_error() {
142        let reranker = load_fastembed_fixture();
143        let result = rerank_texts(&reranker, "query", &[]);
144        assert!(result.is_err());
145    }
146
147    #[test]
148    fn test_reranker_info() {
149        let reranker = load_fastembed_fixture();
150        let info = RerankerInfo::from_reranker(&reranker);
151
152        assert_eq!(info.id, FastEmbedReranker::reranker_id_static());
153        assert!(info.is_available);
154
155        let display = format!("{info}");
156        assert!(display.contains(FastEmbedReranker::reranker_id_static()));
157        assert!(display.contains("available"));
158    }
159
160    #[test]
161    fn test_reranker_error_display() {
162        let err = RerankerError::RerankFailed {
163            model: "test".to_string(),
164            source: Box::new(std::io::Error::other("inference error")),
165        };
166        assert!(err.to_string().contains("inference error"));
167    }
168}