oxirs_vec/hybrid_search/
types.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HybridQuery {
9 pub query_text: String,
11 pub query_vector: Option<Vec<f32>>,
13 pub top_k: usize,
15 pub weights: SearchWeights,
17 pub filters: HashMap<String, String>,
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
23pub struct SearchWeights {
24 pub keyword_weight: f32,
26 pub semantic_weight: f32,
28 pub recency_weight: f32,
30}
31
32impl Default for SearchWeights {
33 fn default() -> Self {
34 Self {
35 keyword_weight: 0.3,
36 semantic_weight: 0.7,
37 recency_weight: 0.0,
38 }
39 }
40}
41
42impl SearchWeights {
43 pub fn validate(&self) -> anyhow::Result<()> {
45 let sum = self.keyword_weight + self.semantic_weight + self.recency_weight;
46 if (sum - 1.0).abs() > 0.1 {
47 anyhow::bail!(
48 "Search weights should sum to approximately 1.0, got {}",
49 sum
50 );
51 }
52 if self.keyword_weight < 0.0 || self.semantic_weight < 0.0 || self.recency_weight < 0.0 {
53 anyhow::bail!("Search weights must be non-negative");
54 }
55 Ok(())
56 }
57
58 pub fn normalize(&mut self) {
60 let sum = self.keyword_weight + self.semantic_weight + self.recency_weight;
61 if sum > 0.0 {
62 self.keyword_weight /= sum;
63 self.semantic_weight /= sum;
64 self.recency_weight /= sum;
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct HybridResult {
72 pub doc_id: String,
74 pub score: f32,
76 pub score_breakdown: ScoreBreakdown,
78 pub metadata: HashMap<String, String>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ScoreBreakdown {
85 pub keyword_score: f32,
87 pub semantic_score: f32,
89 pub recency_score: f32,
91 pub keyword_rank: Option<usize>,
93 pub semantic_rank: Option<usize>,
95}
96
97#[derive(Debug, Clone)]
99pub struct DocumentScore {
100 pub doc_id: String,
102 pub score: f32,
104 pub rank: usize,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct KeywordMatch {
111 pub doc_id: String,
113 pub score: f32,
115 pub matched_terms: Vec<String>,
117 pub term_frequencies: HashMap<String, usize>,
119}
120
121impl HybridResult {
122 pub fn new(
124 doc_id: String,
125 keyword_score: f32,
126 semantic_score: f32,
127 recency_score: f32,
128 weights: &SearchWeights,
129 ) -> Self {
130 let score = keyword_score * weights.keyword_weight
131 + semantic_score * weights.semantic_weight
132 + recency_score * weights.recency_weight;
133
134 Self {
135 doc_id,
136 score,
137 score_breakdown: ScoreBreakdown {
138 keyword_score,
139 semantic_score,
140 recency_score,
141 keyword_rank: None,
142 semantic_rank: None,
143 },
144 metadata: HashMap::new(),
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_search_weights_validation() {
155 let weights = SearchWeights {
156 keyword_weight: 0.3,
157 semantic_weight: 0.7,
158 recency_weight: 0.0,
159 };
160 assert!(weights.validate().is_ok());
161
162 let bad_weights = SearchWeights {
163 keyword_weight: 0.5,
164 semantic_weight: 0.8,
165 recency_weight: 0.0,
166 };
167 assert!(bad_weights.validate().is_err());
168 }
169
170 #[test]
171 fn test_weights_normalization() {
172 let mut weights = SearchWeights {
173 keyword_weight: 1.0,
174 semantic_weight: 2.0,
175 recency_weight: 1.0,
176 };
177 weights.normalize();
178 assert!((weights.keyword_weight - 0.25).abs() < 0.001);
179 assert!((weights.semantic_weight - 0.5).abs() < 0.001);
180 assert!((weights.recency_weight - 0.25).abs() < 0.001);
181 }
182
183 #[test]
184 fn test_hybrid_result_scoring() {
185 let weights = SearchWeights {
186 keyword_weight: 0.4,
187 semantic_weight: 0.6,
188 recency_weight: 0.0,
189 };
190 let result = HybridResult::new("doc1".to_string(), 0.8, 0.9, 0.0, &weights);
191 let expected_score = 0.8 * 0.4 + 0.9 * 0.6;
192 assert!((result.score - expected_score).abs() < 0.001);
193 }
194}