1use serde::{Deserialize, Serialize};
4use std::fmt;
5
6pub const DEFAULT_RRF_K: usize = 60;
8
9#[derive(Debug, Clone, PartialEq)]
11pub enum FusionStrategyError {
12 RrfKZero,
14 WeightNaN,
16 WeightInfinite,
18 CustomNameEmpty,
20}
21
22impl fmt::Display for FusionStrategyError {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 Self::RrfKZero => write!(f, "Rrf k must be >= 1"),
26 Self::WeightNaN => write!(f, "Weighted weights must not contain NaN"),
27 Self::WeightInfinite => write!(f, "Weighted weights must not contain infinity"),
28 Self::CustomNameEmpty => write!(f, "Custom strategy name must not be empty"),
29 }
30 }
31}
32
33impl std::error::Error for FusionStrategyError {}
34
35#[derive(Deserialize)]
37#[serde(rename_all = "snake_case")]
38enum RawFusionStrategy {
39 #[serde(alias = "Rrf")]
40 Rrf { k: usize },
41 #[serde(alias = "Weighted")]
42 Weighted { weights: Vec<f64> },
43 #[serde(alias = "Union")]
44 Union,
45 #[serde(alias = "VectorOnly")]
46 VectorOnly,
47 #[serde(alias = "KeywordOnly")]
48 KeywordOnly,
49 #[serde(alias = "Custom")]
50 Custom {
51 name: String,
52 params: serde_json::Value,
53 },
54}
55
56impl TryFrom<RawFusionStrategy> for FusionStrategy {
57 type Error = FusionStrategyError;
58
59 fn try_from(raw: RawFusionStrategy) -> Result<Self, Self::Error> {
60 match raw {
61 RawFusionStrategy::Rrf { k } => FusionStrategy::try_rrf(k),
62 RawFusionStrategy::Weighted { weights } => FusionStrategy::try_weighted(weights),
63 RawFusionStrategy::Union => Ok(FusionStrategy::Union),
64 RawFusionStrategy::VectorOnly => Ok(FusionStrategy::VectorOnly),
65 RawFusionStrategy::KeywordOnly => Ok(FusionStrategy::KeywordOnly),
66 RawFusionStrategy::Custom { name, params } => FusionStrategy::try_custom(name, params),
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Serialize)]
75#[serde(rename_all = "snake_case")]
76#[serde(try_from = "RawFusionStrategy")]
77pub enum FusionStrategy {
78 Rrf {
80 k: usize,
82 },
83
84 Weighted {
86 weights: Vec<f64>,
88 },
89
90 Union,
92
93 VectorOnly,
95
96 KeywordOnly,
98
99 Custom {
101 name: String,
103 params: serde_json::Value,
105 },
106}
107
108impl<'de> Deserialize<'de> for FusionStrategy {
109 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110 where
111 D: serde::Deserializer<'de>,
112 {
113 let raw = RawFusionStrategy::deserialize(deserializer)?;
114 FusionStrategy::try_from(raw).map_err(serde::de::Error::custom)
115 }
116}
117
118impl Default for FusionStrategy {
119 fn default() -> Self {
120 Self::Rrf { k: DEFAULT_RRF_K }
121 }
122}
123
124impl FusionStrategy {
125 #[inline]
127 pub fn rrf() -> Self {
128 Self::Rrf { k: DEFAULT_RRF_K }
129 }
130
131 #[inline]
133 pub fn try_rrf(k: usize) -> Result<Self, FusionStrategyError> {
134 if k == 0 {
135 return Err(FusionStrategyError::RrfKZero);
136 }
137 Ok(Self::Rrf { k })
138 }
139
140 #[inline]
144 pub fn rrf_with_k(k: usize) -> Self {
145 Self::Rrf { k: k.max(1) }
146 }
147
148 pub fn try_weighted(weights: Vec<f64>) -> Result<Self, FusionStrategyError> {
150 for w in &weights {
151 if w.is_nan() {
152 return Err(FusionStrategyError::WeightNaN);
153 }
154 if w.is_infinite() {
155 return Err(FusionStrategyError::WeightInfinite);
156 }
157 }
158 Ok(Self::Weighted { weights })
159 }
160
161 #[inline]
165 pub fn weighted(weights: Vec<f64>) -> Self {
166 Self::try_weighted(weights).expect("weights must be finite")
167 }
168
169 #[inline]
171 pub fn union() -> Self {
172 Self::Union
173 }
174
175 pub fn try_custom(
177 name: String,
178 params: serde_json::Value,
179 ) -> Result<Self, FusionStrategyError> {
180 if name.is_empty() {
181 return Err(FusionStrategyError::CustomNameEmpty);
182 }
183 Ok(Self::Custom { name, params })
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_fusion_strategy_default() {
193 let default = FusionStrategy::default();
194 assert_eq!(default, FusionStrategy::Rrf { k: 60 });
195 }
196
197 #[test]
198 fn test_fusion_strategy_builders() {
199 assert_eq!(FusionStrategy::rrf(), FusionStrategy::Rrf { k: 60 });
200 assert_eq!(
201 FusionStrategy::rrf_with_k(20),
202 FusionStrategy::Rrf { k: 20 }
203 );
204 assert_eq!(FusionStrategy::rrf_with_k(0), FusionStrategy::Rrf { k: 1 });
205 assert_eq!(
206 FusionStrategy::weighted(vec![0.5, 0.5]),
207 FusionStrategy::Weighted {
208 weights: vec![0.5, 0.5]
209 }
210 );
211 assert_eq!(FusionStrategy::union(), FusionStrategy::Union);
212 }
213
214 #[test]
215 fn test_try_rrf_rejects_zero() {
216 assert_eq!(
217 FusionStrategy::try_rrf(0),
218 Err(FusionStrategyError::RrfKZero)
219 );
220 assert!(FusionStrategy::try_rrf(1).is_ok());
221 assert!(FusionStrategy::try_rrf(60).is_ok());
222 }
223
224 #[test]
225 fn test_try_weighted_rejects_nan() {
226 assert_eq!(
227 FusionStrategy::try_weighted(vec![0.5, f64::NAN]),
228 Err(FusionStrategyError::WeightNaN)
229 );
230 }
231
232 #[test]
233 fn test_try_weighted_rejects_infinity() {
234 assert_eq!(
235 FusionStrategy::try_weighted(vec![f64::INFINITY, 0.5]),
236 Err(FusionStrategyError::WeightInfinite)
237 );
238 assert_eq!(
239 FusionStrategy::try_weighted(vec![0.5, f64::NEG_INFINITY]),
240 Err(FusionStrategyError::WeightInfinite)
241 );
242 }
243
244 #[test]
245 fn test_try_weighted_accepts_valid() {
246 assert!(FusionStrategy::try_weighted(vec![0.5, 0.5]).is_ok());
247 assert!(FusionStrategy::try_weighted(vec![0.0, 0.0]).is_ok());
248 assert!(FusionStrategy::try_weighted(vec![-1.0, 1.0]).is_ok());
249 assert!(FusionStrategy::try_weighted(vec![]).is_ok());
250 }
251
252 #[test]
253 fn test_try_custom_rejects_empty_name() {
254 assert_eq!(
255 FusionStrategy::try_custom(String::new(), serde_json::Value::Null),
256 Err(FusionStrategyError::CustomNameEmpty)
257 );
258 }
259
260 #[test]
261 fn test_try_custom_accepts_valid() {
262 let result = FusionStrategy::try_custom(
263 "decay_weighted".to_string(),
264 serde_json::json!({"decay": 0.95}),
265 );
266 assert!(result.is_ok());
267 }
268
269 #[test]
270 fn test_serde_roundtrip_rrf() {
271 let strategy = FusionStrategy::Rrf { k: 60 };
272 let json = serde_json::to_string(&strategy).unwrap();
273 let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
274 assert_eq!(strategy, deserialized);
275 }
276
277 #[test]
278 fn test_serde_roundtrip_weighted() {
279 let strategy = FusionStrategy::Weighted {
280 weights: vec![0.6, 0.4],
281 };
282 let json = serde_json::to_string(&strategy).unwrap();
283 let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
284 assert_eq!(strategy, deserialized);
285 }
286
287 #[test]
288 fn test_serde_roundtrip_custom() {
289 let strategy = FusionStrategy::Custom {
290 name: "decay_weighted".to_string(),
291 params: serde_json::json!({"decay": 0.95}),
292 };
293 let json = serde_json::to_string(&strategy).unwrap();
294 let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
295 assert_eq!(strategy, deserialized);
296 }
297
298 #[test]
299 fn test_serde_rejects_rrf_k_zero() {
300 let json = r#"{"rrf":{"k":0}}"#;
301 let result: Result<FusionStrategy, _> = serde_json::from_str(json);
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn test_serde_rejects_nan_weights() {
307 assert!(FusionStrategy::try_weighted(vec![f64::NAN]).is_err());
310 }
311
312 #[test]
313 fn test_serde_rejects_custom_empty_name() {
314 let json = r#"{"custom":{"name":"","params":null}}"#;
315 let result: Result<FusionStrategy, _> = serde_json::from_str(json);
316 assert!(result.is_err());
317 }
318}