1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::db::Database;
6use crate::embedding::{self, Embedder};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum SearchMode {
10 Fts,
11 Vector,
12 Hybrid,
13}
14
15impl std::str::FromStr for SearchMode {
16 type Err = anyhow::Error;
17 fn from_str(s: &str) -> Result<Self> {
18 match s.to_lowercase().as_str() {
19 "fts" => Ok(Self::Fts),
20 "vector" => Ok(Self::Vector),
21 "hybrid" => Ok(Self::Hybrid),
22 _ => anyhow::bail!("Invalid search mode: {}. Use fts, vector, or hybrid", s),
23 }
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchResult {
29 pub issue_id: String,
30 pub identifier: String,
31 pub title: String,
32 pub state_name: String,
33 pub priority: i32,
34 pub score: f64,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub fts_rank: Option<usize>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub vector_rank: Option<usize>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub similarity: Option<f32>,
41}
42
43pub struct SearchParams<'a> {
44 pub query: &'a str,
45 pub mode: SearchMode,
46 pub team_key: Option<&'a str>,
47 pub state_filter: Option<&'a str>,
48 pub limit: usize,
49 pub embedder: Option<&'a Embedder>,
50 pub rrf_k: u32,
51 pub workspace_id: &'a str,
52}
53
54pub async fn search(db: &Database, params: SearchParams<'_>) -> Result<Vec<SearchResult>> {
56 let SearchParams {
57 query,
58 mode,
59 team_key,
60 state_filter,
61 limit,
62 embedder,
63 rrf_k,
64 workspace_id,
65 } = params;
66 let results = match mode {
67 SearchMode::Fts => fts_search(db, query, limit * 2, workspace_id)?,
68 SearchMode::Vector => {
69 let embedder =
70 embedder.ok_or_else(|| anyhow::anyhow!("Embedder required for vector search"))?;
71 vector_search(db, query, team_key, limit * 2, embedder, workspace_id).await?
72 }
73 SearchMode::Hybrid => {
74 let fts_results = fts_search(db, query, limit * 3, workspace_id)?;
75
76 if let Some(embedder) = embedder {
77 let vec_results =
78 vector_search(db, query, team_key, limit * 3, embedder, workspace_id).await?;
79 reciprocal_rank_fusion(fts_results, vec_results, rrf_k, 0.3, 0.7)
80 } else {
81 fts_results
83 }
84 }
85 };
86
87 let results: Vec<_> = results
89 .into_iter()
90 .filter(|_r| {
91 if let Some(_team) = team_key {
92 true
95 } else {
96 true
97 }
98 })
99 .filter(|r| {
100 if let Some(state) = state_filter {
101 r.state_name.to_lowercase().contains(&state.to_lowercase())
102 } else {
103 true
104 }
105 })
106 .take(limit)
107 .collect();
108
109 Ok(results)
110}
111
112fn fts_search(
113 db: &Database,
114 query: &str,
115 limit: usize,
116 workspace_id: &str,
117) -> Result<Vec<SearchResult>> {
118 let fts_query = build_fts_query(query);
120 let fts_results = db.fts_search(&fts_query, limit, workspace_id)?;
121
122 Ok(fts_results
123 .into_iter()
124 .enumerate()
125 .map(|(rank, r)| SearchResult {
126 issue_id: r.issue_id,
127 identifier: r.identifier,
128 title: r.title,
129 state_name: r.state_name,
130 priority: r.priority,
131 score: -r.bm25_score, fts_rank: Some(rank + 1),
133 vector_rank: None,
134 similarity: None,
135 })
136 .collect())
137}
138
139async fn vector_search(
140 db: &Database,
141 query: &str,
142 team_key: Option<&str>,
143 limit: usize,
144 embedder: &Embedder,
145 workspace_id: &str,
146) -> Result<Vec<SearchResult>> {
147 let query_embedding = embedder.embed_single(query).await?;
148
149 let chunks = if let Some(team) = team_key {
150 db.get_chunks_for_team(team, workspace_id)?
151 } else {
152 db.get_all_chunks(workspace_id)?
153 };
154
155 let mut issue_max_sim: HashMap<String, (f32, String)> = HashMap::new(); for chunk in &chunks {
159 let chunk_embedding = embedding::bytes_to_embedding(&chunk.embedding);
160 let sim = embedding::cosine_similarity(&query_embedding, &chunk_embedding);
161
162 let entry = issue_max_sim
163 .entry(chunk.issue_id.clone())
164 .or_insert((0.0, chunk.identifier.clone()));
165 if sim > entry.0 {
166 entry.0 = sim;
167 }
168 }
169
170 let mut results: Vec<_> = issue_max_sim.into_iter().collect();
172 results.sort_by(|a, b| b.1 .0.partial_cmp(&a.1 .0).unwrap());
173
174 let results: Vec<_> = results
176 .into_iter()
177 .take(limit)
178 .enumerate()
179 .filter_map(|(rank, (issue_id, (sim, _identifier)))| {
180 let issue = db.get_issue(&issue_id).ok()??;
181 Some(SearchResult {
182 issue_id,
183 identifier: issue.identifier,
184 title: issue.title,
185 state_name: issue.state_name,
186 priority: issue.priority,
187 score: sim as f64,
188 fts_rank: None,
189 vector_rank: Some(rank + 1),
190 similarity: Some(sim),
191 })
192 })
193 .collect();
194
195 Ok(results)
196}
197
198fn reciprocal_rank_fusion(
200 fts_results: Vec<SearchResult>,
201 vec_results: Vec<SearchResult>,
202 k: u32,
203 fts_weight: f64,
204 vec_weight: f64,
205) -> Vec<SearchResult> {
206 let mut scores: HashMap<String, (f64, SearchResult)> = HashMap::new();
207 let k = k as f64;
208
209 for (rank, result) in fts_results.into_iter().enumerate() {
210 let rrf_score = fts_weight / (k + (rank + 1) as f64);
211 let entry = scores
212 .entry(result.issue_id.clone())
213 .or_insert((0.0, result.clone()));
214 entry.0 += rrf_score;
215 entry.1.fts_rank = result.fts_rank;
216 }
217
218 for (rank, result) in vec_results.into_iter().enumerate() {
219 let rrf_score = vec_weight / (k + (rank + 1) as f64);
220 let entry = scores
221 .entry(result.issue_id.clone())
222 .or_insert((0.0, result.clone()));
223 entry.0 += rrf_score;
224 entry.1.vector_rank = result.vector_rank;
225 entry.1.similarity = result.similarity;
226 }
227
228 let mut results: Vec<_> = scores
229 .into_values()
230 .map(|(score, mut result)| {
231 result.score = score;
232 result
233 })
234 .collect();
235
236 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
237 results
238}
239
240pub async fn find_duplicates(
242 db: &Database,
243 text: &str,
244 team_key: Option<&str>,
245 threshold: f32,
246 limit: usize,
247 embedder: &Embedder,
248 rrf_k: u32,
249 workspace_id: &str,
250) -> Result<Vec<SearchResult>> {
251 let mut results = search(
252 db,
253 SearchParams {
254 query: text,
255 mode: SearchMode::Hybrid,
256 team_key,
257 state_filter: None,
258 limit,
259 embedder: Some(embedder),
260 rrf_k,
261 workspace_id,
262 },
263 )
264 .await?;
265
266 let _vec_results = vector_search(db, text, team_key, limit, embedder, workspace_id).await?;
268
269 results.retain(|r| r.similarity.unwrap_or(0.0) >= threshold || r.score > 0.01);
271
272 Ok(results)
273}
274
275fn build_fts_query(input: &str) -> String {
277 let words: Vec<_> = input
279 .split_whitespace()
280 .filter_map(|w| {
281 let clean: String = w
283 .chars()
284 .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-')
285 .collect();
286 if clean.is_empty() {
287 None
288 } else {
289 Some(format!("\"{}\"", clean))
290 }
291 })
292 .collect();
293
294 if words.is_empty() {
295 "\"\"".to_string()
296 } else {
297 words.join(" OR ")
298 }
299}