reflex/embedding/reranker/
config.rs

1use std::path::PathBuf;
2
3/// Default verification threshold (cross-encoder score).
4pub const DEFAULT_THRESHOLD: f32 = crate::constants::DEFAULT_VERIFICATION_THRESHOLD;
5
6/// Maximum sequence length used for reranker tokenization.
7pub const MAX_SEQ_LEN: usize = 512;
8
9#[derive(Debug, Clone)]
10/// Configuration for [`Reranker`](super::Reranker).
11pub struct RerankerConfig {
12    /// Directory containing `config.json`, `model.safetensors`, and tokenizer files.
13    pub model_path: Option<PathBuf>,
14
15    /// Minimum score to consider a candidate verified.
16    pub threshold: f32,
17}
18
19impl Default for RerankerConfig {
20    fn default() -> Self {
21        Self {
22            model_path: None,
23            threshold: DEFAULT_THRESHOLD,
24        }
25    }
26}
27
28impl RerankerConfig {
29    /// Creates a config for a model directory.
30    pub fn new<P: Into<PathBuf>>(model_path: P) -> Self {
31        Self {
32            model_path: Some(model_path.into()),
33            threshold: DEFAULT_THRESHOLD,
34        }
35    }
36
37    /// Creates a config that runs without a model (stub scoring).
38    pub fn stub() -> Self {
39        Self {
40            model_path: None,
41            threshold: DEFAULT_THRESHOLD,
42        }
43    }
44
45    /// Sets the threshold.
46    pub fn with_threshold(mut self, threshold: f32) -> Self {
47        assert!(
48            (0.0..=1.0).contains(&threshold),
49            "threshold must be between 0.0 and 1.0"
50        );
51        self.threshold = threshold;
52        self
53    }
54
55    /// Validates basic invariants.
56    pub fn validate(&self) -> Result<(), String> {
57        if !(0.0..=1.0).contains(&self.threshold) {
58            return Err(format!(
59                "threshold must be between 0.0 and 1.0, got {}",
60                self.threshold
61            ));
62        }
63
64        if let Some(ref path) = self.model_path
65            && path.as_os_str().is_empty()
66        {
67            return Err("model_path cannot be empty when provided".to_string());
68        }
69
70        Ok(())
71    }
72
73    /// Loads config from `REFLEX_RERANKER_PATH` and `REFLEX_RERANKER_THRESHOLD`.
74    pub fn from_env() -> Self {
75        let model_path = std::env::var("REFLEX_RERANKER_PATH")
76            .ok()
77            .map(|v| v.trim().to_string())
78            .filter(|v| !v.is_empty())
79            .map(PathBuf::from);
80
81        let threshold = std::env::var("REFLEX_RERANKER_THRESHOLD")
82            .ok()
83            .and_then(|v| v.parse().ok())
84            .unwrap_or(DEFAULT_THRESHOLD);
85
86        Self {
87            model_path,
88            threshold,
89        }
90    }
91}