Skip to main content

khive_fusion/
strategy.rs

1//! Fusion strategy types with invariant validation.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6/// Default RRF constant k=60, standard in literature (Craswell et al., 2009).
7pub const DEFAULT_RRF_K: usize = 60;
8
9/// Error returned when a [`FusionStrategy`] fails invariant validation.
10#[derive(Debug, Clone, PartialEq)]
11pub enum FusionStrategyError {
12    /// RRF k must be >= 1 to avoid division by zero.
13    RrfKZero,
14    /// Weighted strategy weights contain NaN.
15    WeightNaN,
16    /// Weighted strategy weights contain infinity.
17    WeightInfinite,
18    /// Custom strategy name must not be empty.
19    CustomNameEmpty,
20}
21
22impl fmt::Display for FusionStrategyError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            Self::RrfKZero => write!(f, "Rrf k must be >= 1"),
26            Self::WeightNaN => write!(f, "Weighted weights must not contain NaN"),
27            Self::WeightInfinite => write!(f, "Weighted weights must not contain infinity"),
28            Self::CustomNameEmpty => write!(f, "Custom strategy name must not be empty"),
29        }
30    }
31}
32
33impl std::error::Error for FusionStrategyError {}
34
35/// Raw serde form used for deserialization before validation.
36#[derive(Deserialize)]
37#[serde(rename_all = "snake_case")]
38enum RawFusionStrategy {
39    #[serde(alias = "Rrf")]
40    Rrf { k: usize },
41    #[serde(alias = "Weighted")]
42    Weighted { weights: Vec<f64> },
43    #[serde(alias = "Union")]
44    Union,
45    #[serde(alias = "VectorOnly")]
46    VectorOnly,
47    #[serde(alias = "KeywordOnly")]
48    KeywordOnly,
49    #[serde(alias = "Custom")]
50    Custom {
51        name: String,
52        params: serde_json::Value,
53    },
54}
55
56impl TryFrom<RawFusionStrategy> for FusionStrategy {
57    type Error = FusionStrategyError;
58
59    fn try_from(raw: RawFusionStrategy) -> Result<Self, Self::Error> {
60        match raw {
61            RawFusionStrategy::Rrf { k } => FusionStrategy::try_rrf(k),
62            RawFusionStrategy::Weighted { weights } => FusionStrategy::try_weighted(weights),
63            RawFusionStrategy::Union => Ok(FusionStrategy::Union),
64            RawFusionStrategy::VectorOnly => Ok(FusionStrategy::VectorOnly),
65            RawFusionStrategy::KeywordOnly => Ok(FusionStrategy::KeywordOnly),
66            RawFusionStrategy::Custom { name, params } => FusionStrategy::try_custom(name, params),
67        }
68    }
69}
70
71/// Fusion strategy for combining ranked result lists.
72///
73/// Validated at construction and deserialization boundaries.
74#[derive(Debug, Clone, PartialEq, Serialize)]
75#[serde(rename_all = "snake_case")]
76#[serde(try_from = "RawFusionStrategy")]
77pub enum FusionStrategy {
78    /// Reciprocal Rank Fusion (default, recommended). Rank-based, distribution-agnostic.
79    Rrf {
80        /// Smoothing constant (>= 1). Default: 60.
81        k: usize,
82    },
83
84    /// Weighted linear combination of scores. Weights normalized to 1.0; must be finite.
85    Weighted {
86        /// Weights for each source (will be normalized). Must be finite.
87        weights: Vec<f64>,
88    },
89
90    /// Take union with max score per ID.
91    Union,
92
93    /// Skip BM25 entirely -- return only vector (HNSW) results.
94    VectorOnly,
95
96    /// Skip HNSW entirely -- return only BM25 keyword results.
97    KeywordOnly,
98
99    /// Pack-defined or user-defined custom strategy dispatched by name at runtime.
100    Custom {
101        /// Strategy identifier registered with the fusion executor registry.
102        name: String,
103        /// Opaque parameters consumed by the executor.
104        params: serde_json::Value,
105    },
106}
107
108impl<'de> Deserialize<'de> for FusionStrategy {
109    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110    where
111        D: serde::Deserializer<'de>,
112    {
113        let raw = RawFusionStrategy::deserialize(deserializer)?;
114        FusionStrategy::try_from(raw).map_err(serde::de::Error::custom)
115    }
116}
117
118impl Default for FusionStrategy {
119    fn default() -> Self {
120        Self::Rrf { k: DEFAULT_RRF_K }
121    }
122}
123
124impl FusionStrategy {
125    /// Create an RRF strategy with default k=60.
126    #[inline]
127    pub fn rrf() -> Self {
128        Self::Rrf { k: DEFAULT_RRF_K }
129    }
130
131    /// Create an RRF strategy with custom k value. Returns error if k == 0.
132    #[inline]
133    pub fn try_rrf(k: usize) -> Result<Self, FusionStrategyError> {
134        if k == 0 {
135            return Err(FusionStrategyError::RrfKZero);
136        }
137        Ok(Self::Rrf { k })
138    }
139
140    /// Create an RRF strategy, clamping k to at least 1.
141    ///
142    /// Prefer [`try_rrf`](Self::try_rrf) at public API boundaries.
143    #[inline]
144    pub fn rrf_with_k(k: usize) -> Self {
145        Self::Rrf { k: k.max(1) }
146    }
147
148    /// Create a weighted strategy after validating weights are finite.
149    pub fn try_weighted(weights: Vec<f64>) -> Result<Self, FusionStrategyError> {
150        for w in &weights {
151            if w.is_nan() {
152                return Err(FusionStrategyError::WeightNaN);
153            }
154            if w.is_infinite() {
155                return Err(FusionStrategyError::WeightInfinite);
156            }
157        }
158        Ok(Self::Weighted { weights })
159    }
160
161    /// Create a weighted strategy. Panics on NaN/infinity.
162    ///
163    /// Prefer [`try_weighted`](Self::try_weighted) at public API boundaries.
164    #[inline]
165    pub fn weighted(weights: Vec<f64>) -> Self {
166        Self::try_weighted(weights).expect("weights must be finite")
167    }
168
169    /// Create a union strategy.
170    #[inline]
171    pub fn union() -> Self {
172        Self::Union
173    }
174
175    /// Create a custom strategy with name validation.
176    pub fn try_custom(
177        name: String,
178        params: serde_json::Value,
179    ) -> Result<Self, FusionStrategyError> {
180        if name.is_empty() {
181            return Err(FusionStrategyError::CustomNameEmpty);
182        }
183        Ok(Self::Custom { name, params })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_fusion_strategy_default() {
193        let default = FusionStrategy::default();
194        assert_eq!(default, FusionStrategy::Rrf { k: 60 });
195    }
196
197    #[test]
198    fn test_fusion_strategy_builders() {
199        assert_eq!(FusionStrategy::rrf(), FusionStrategy::Rrf { k: 60 });
200        assert_eq!(
201            FusionStrategy::rrf_with_k(20),
202            FusionStrategy::Rrf { k: 20 }
203        );
204        assert_eq!(FusionStrategy::rrf_with_k(0), FusionStrategy::Rrf { k: 1 });
205        assert_eq!(
206            FusionStrategy::weighted(vec![0.5, 0.5]),
207            FusionStrategy::Weighted {
208                weights: vec![0.5, 0.5]
209            }
210        );
211        assert_eq!(FusionStrategy::union(), FusionStrategy::Union);
212    }
213
214    #[test]
215    fn test_try_rrf_rejects_zero() {
216        assert_eq!(
217            FusionStrategy::try_rrf(0),
218            Err(FusionStrategyError::RrfKZero)
219        );
220        assert!(FusionStrategy::try_rrf(1).is_ok());
221        assert!(FusionStrategy::try_rrf(60).is_ok());
222    }
223
224    #[test]
225    fn test_try_weighted_rejects_nan() {
226        assert_eq!(
227            FusionStrategy::try_weighted(vec![0.5, f64::NAN]),
228            Err(FusionStrategyError::WeightNaN)
229        );
230    }
231
232    #[test]
233    fn test_try_weighted_rejects_infinity() {
234        assert_eq!(
235            FusionStrategy::try_weighted(vec![f64::INFINITY, 0.5]),
236            Err(FusionStrategyError::WeightInfinite)
237        );
238        assert_eq!(
239            FusionStrategy::try_weighted(vec![0.5, f64::NEG_INFINITY]),
240            Err(FusionStrategyError::WeightInfinite)
241        );
242    }
243
244    #[test]
245    fn test_try_weighted_accepts_valid() {
246        assert!(FusionStrategy::try_weighted(vec![0.5, 0.5]).is_ok());
247        assert!(FusionStrategy::try_weighted(vec![0.0, 0.0]).is_ok());
248        assert!(FusionStrategy::try_weighted(vec![-1.0, 1.0]).is_ok());
249        assert!(FusionStrategy::try_weighted(vec![]).is_ok());
250    }
251
252    #[test]
253    fn test_try_custom_rejects_empty_name() {
254        assert_eq!(
255            FusionStrategy::try_custom(String::new(), serde_json::Value::Null),
256            Err(FusionStrategyError::CustomNameEmpty)
257        );
258    }
259
260    #[test]
261    fn test_try_custom_accepts_valid() {
262        let result = FusionStrategy::try_custom(
263            "decay_weighted".to_string(),
264            serde_json::json!({"decay": 0.95}),
265        );
266        assert!(result.is_ok());
267    }
268
269    #[test]
270    fn test_serde_roundtrip_rrf() {
271        let strategy = FusionStrategy::Rrf { k: 60 };
272        let json = serde_json::to_string(&strategy).unwrap();
273        let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
274        assert_eq!(strategy, deserialized);
275    }
276
277    #[test]
278    fn test_serde_roundtrip_weighted() {
279        let strategy = FusionStrategy::Weighted {
280            weights: vec![0.6, 0.4],
281        };
282        let json = serde_json::to_string(&strategy).unwrap();
283        let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
284        assert_eq!(strategy, deserialized);
285    }
286
287    #[test]
288    fn test_serde_roundtrip_custom() {
289        let strategy = FusionStrategy::Custom {
290            name: "decay_weighted".to_string(),
291            params: serde_json::json!({"decay": 0.95}),
292        };
293        let json = serde_json::to_string(&strategy).unwrap();
294        let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
295        assert_eq!(strategy, deserialized);
296    }
297
298    #[test]
299    fn test_serde_rejects_rrf_k_zero() {
300        let json = r#"{"rrf":{"k":0}}"#;
301        let result: Result<FusionStrategy, _> = serde_json::from_str(json);
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn test_serde_rejects_nan_weights() {
307        // NaN cannot be represented in JSON, so this tests via the builder.
308        // JSON with null weight would fail at a different level.
309        assert!(FusionStrategy::try_weighted(vec![f64::NAN]).is_err());
310    }
311
312    #[test]
313    fn test_serde_rejects_custom_empty_name() {
314        let json = r#"{"custom":{"name":"","params":null}}"#;
315        let result: Result<FusionStrategy, _> = serde_json::from_str(json);
316        assert!(result.is_err());
317    }
318}