reflex/embedding/reranker/
config.rs1use std::path::PathBuf;
2
3pub const DEFAULT_THRESHOLD: f32 = crate::constants::DEFAULT_VERIFICATION_THRESHOLD;
5
6pub const MAX_SEQ_LEN: usize = 512;
8
9#[derive(Debug, Clone)]
10pub struct RerankerConfig {
12 pub model_path: Option<PathBuf>,
14
15 pub threshold: f32,
17}
18
19impl Default for RerankerConfig {
20 fn default() -> Self {
21 Self {
22 model_path: None,
23 threshold: DEFAULT_THRESHOLD,
24 }
25 }
26}
27
28impl RerankerConfig {
29 pub fn new<P: Into<PathBuf>>(model_path: P) -> Self {
31 Self {
32 model_path: Some(model_path.into()),
33 threshold: DEFAULT_THRESHOLD,
34 }
35 }
36
37 pub fn stub() -> Self {
39 Self {
40 model_path: None,
41 threshold: DEFAULT_THRESHOLD,
42 }
43 }
44
45 pub fn with_threshold(mut self, threshold: f32) -> Self {
47 assert!(
48 (0.0..=1.0).contains(&threshold),
49 "threshold must be between 0.0 and 1.0"
50 );
51 self.threshold = threshold;
52 self
53 }
54
55 pub fn validate(&self) -> Result<(), String> {
57 if !(0.0..=1.0).contains(&self.threshold) {
58 return Err(format!(
59 "threshold must be between 0.0 and 1.0, got {}",
60 self.threshold
61 ));
62 }
63
64 if let Some(ref path) = self.model_path
65 && path.as_os_str().is_empty()
66 {
67 return Err("model_path cannot be empty when provided".to_string());
68 }
69
70 Ok(())
71 }
72
73 pub fn from_env() -> Self {
75 let model_path = std::env::var("REFLEX_RERANKER_PATH")
76 .ok()
77 .map(|v| v.trim().to_string())
78 .filter(|v| !v.is_empty())
79 .map(PathBuf::from);
80
81 let threshold = std::env::var("REFLEX_RERANKER_THRESHOLD")
82 .ok()
83 .and_then(|v| v.parse().ok())
84 .unwrap_or(DEFAULT_THRESHOLD);
85
86 Self {
87 model_path,
88 threshold,
89 }
90 }
91}