1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::db::Database;
6use crate::embedding::{self, Embedder};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum SearchMode {
10 Fts,
11 Vector,
12 Hybrid,
13}
14
15impl std::str::FromStr for SearchMode {
16 type Err = anyhow::Error;
17 fn from_str(s: &str) -> Result<Self> {
18 match s.to_lowercase().as_str() {
19 "fts" => Ok(Self::Fts),
20 "vector" => Ok(Self::Vector),
21 "hybrid" => Ok(Self::Hybrid),
22 _ => anyhow::bail!("Invalid search mode: {}. Use fts, vector, or hybrid", s),
23 }
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchResult {
29 pub issue_id: String,
30 pub identifier: String,
31 pub title: String,
32 pub state_name: String,
33 pub priority: i32,
34 pub score: f64,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub fts_rank: Option<usize>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub vector_rank: Option<usize>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub similarity: Option<f32>,
41}
42
43pub async fn search(
45 db: &Database,
46 query: &str,
47 mode: SearchMode,
48 team_key: Option<&str>,
49 state_filter: Option<&str>,
50 limit: usize,
51 embedder: Option<&Embedder>,
52 rrf_k: u32,
53 workspace_id: &str,
54) -> Result<Vec<SearchResult>> {
55 let results = match mode {
56 SearchMode::Fts => fts_search(db, query, limit * 2, workspace_id)?,
57 SearchMode::Vector => {
58 let embedder =
59 embedder.ok_or_else(|| anyhow::anyhow!("Embedder required for vector search"))?;
60 vector_search(db, query, team_key, limit * 2, embedder, workspace_id).await?
61 }
62 SearchMode::Hybrid => {
63 let fts_results = fts_search(db, query, limit * 3, workspace_id)?;
64
65 if let Some(embedder) = embedder {
66 let vec_results =
67 vector_search(db, query, team_key, limit * 3, embedder, workspace_id).await?;
68 reciprocal_rank_fusion(fts_results, vec_results, rrf_k, 0.3, 0.7)
69 } else {
70 fts_results
72 }
73 }
74 };
75
76 let results: Vec<_> = results
78 .into_iter()
79 .filter(|_r| {
80 if let Some(_team) = team_key {
81 true
84 } else {
85 true
86 }
87 })
88 .filter(|r| {
89 if let Some(state) = state_filter {
90 r.state_name.to_lowercase().contains(&state.to_lowercase())
91 } else {
92 true
93 }
94 })
95 .take(limit)
96 .collect();
97
98 Ok(results)
99}
100
101fn fts_search(
102 db: &Database,
103 query: &str,
104 limit: usize,
105 workspace_id: &str,
106) -> Result<Vec<SearchResult>> {
107 let fts_query = build_fts_query(query);
109 let fts_results = db.fts_search(&fts_query, limit, workspace_id)?;
110
111 Ok(fts_results
112 .into_iter()
113 .enumerate()
114 .map(|(rank, r)| SearchResult {
115 issue_id: r.issue_id,
116 identifier: r.identifier,
117 title: r.title,
118 state_name: r.state_name,
119 priority: r.priority,
120 score: -r.bm25_score, fts_rank: Some(rank + 1),
122 vector_rank: None,
123 similarity: None,
124 })
125 .collect())
126}
127
128async fn vector_search(
129 db: &Database,
130 query: &str,
131 team_key: Option<&str>,
132 limit: usize,
133 embedder: &Embedder,
134 workspace_id: &str,
135) -> Result<Vec<SearchResult>> {
136 let query_embedding = embedder.embed_single(query).await?;
137
138 let chunks = if let Some(team) = team_key {
139 db.get_chunks_for_team(team, workspace_id)?
140 } else {
141 db.get_all_chunks(workspace_id)?
142 };
143
144 let mut issue_max_sim: HashMap<String, (f32, String)> = HashMap::new(); for chunk in &chunks {
148 let chunk_embedding = embedding::bytes_to_embedding(&chunk.embedding);
149 let sim = embedding::cosine_similarity(&query_embedding, &chunk_embedding);
150
151 let entry = issue_max_sim
152 .entry(chunk.issue_id.clone())
153 .or_insert((0.0, chunk.identifier.clone()));
154 if sim > entry.0 {
155 entry.0 = sim;
156 }
157 }
158
159 let mut results: Vec<_> = issue_max_sim.into_iter().collect();
161 results.sort_by(|a, b| b.1 .0.partial_cmp(&a.1 .0).unwrap());
162
163 let results: Vec<_> = results
165 .into_iter()
166 .take(limit)
167 .enumerate()
168 .filter_map(|(rank, (issue_id, (sim, _identifier)))| {
169 let issue = db.get_issue(&issue_id).ok()??;
170 Some(SearchResult {
171 issue_id,
172 identifier: issue.identifier,
173 title: issue.title,
174 state_name: issue.state_name,
175 priority: issue.priority,
176 score: sim as f64,
177 fts_rank: None,
178 vector_rank: Some(rank + 1),
179 similarity: Some(sim),
180 })
181 })
182 .collect();
183
184 Ok(results)
185}
186
187fn reciprocal_rank_fusion(
189 fts_results: Vec<SearchResult>,
190 vec_results: Vec<SearchResult>,
191 k: u32,
192 fts_weight: f64,
193 vec_weight: f64,
194) -> Vec<SearchResult> {
195 let mut scores: HashMap<String, (f64, SearchResult)> = HashMap::new();
196 let k = k as f64;
197
198 for (rank, result) in fts_results.into_iter().enumerate() {
199 let rrf_score = fts_weight / (k + (rank + 1) as f64);
200 let entry = scores
201 .entry(result.issue_id.clone())
202 .or_insert((0.0, result.clone()));
203 entry.0 += rrf_score;
204 entry.1.fts_rank = result.fts_rank;
205 }
206
207 for (rank, result) in vec_results.into_iter().enumerate() {
208 let rrf_score = vec_weight / (k + (rank + 1) as f64);
209 let entry = scores
210 .entry(result.issue_id.clone())
211 .or_insert((0.0, result.clone()));
212 entry.0 += rrf_score;
213 entry.1.vector_rank = result.vector_rank;
214 entry.1.similarity = result.similarity;
215 }
216
217 let mut results: Vec<_> = scores
218 .into_values()
219 .map(|(score, mut result)| {
220 result.score = score;
221 result
222 })
223 .collect();
224
225 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
226 results
227}
228
229pub async fn find_duplicates(
231 db: &Database,
232 text: &str,
233 team_key: Option<&str>,
234 threshold: f32,
235 limit: usize,
236 embedder: &Embedder,
237 rrf_k: u32,
238 workspace_id: &str,
239) -> Result<Vec<SearchResult>> {
240 let mut results = search(
241 db,
242 text,
243 SearchMode::Hybrid,
244 team_key,
245 None,
246 limit,
247 Some(embedder),
248 rrf_k,
249 workspace_id,
250 )
251 .await?;
252
253 let _vec_results = vector_search(db, text, team_key, limit, embedder, workspace_id).await?;
255
256 results.retain(|r| r.similarity.unwrap_or(0.0) >= threshold || r.score > 0.01);
258
259 Ok(results)
260}
261
262fn build_fts_query(input: &str) -> String {
264 let words: Vec<_> = input
266 .split_whitespace()
267 .filter_map(|w| {
268 let clean: String = w
270 .chars()
271 .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-')
272 .collect();
273 if clean.is_empty() {
274 None
275 } else {
276 Some(format!("\"{}\"", clean))
277 }
278 })
279 .collect();
280
281 if words.is_empty() {
282 "\"\"".to_string()
283 } else {
284 words.join(" OR ")
285 }
286}