1use crate::cache::{bm25_path, state_path};
2use crate::search::{Bm25Index, Bm25Params, FieldWeights};
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6use std::fs;
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SearchDoc {
11 pub symbol: String,
12 pub file: String,
13 pub start_byte: usize,
14 pub end_byte: usize,
15 pub start_line: usize,
16 pub end_line: usize,
17 pub start_col: usize,
18 pub end_col: usize,
19 pub preview: String,
20 pub indexed_text: String,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct SearchState {
25 pub docs: Vec<SearchDoc>,
26}
27
28#[derive(Debug, Clone)]
29pub struct SearchIndex {
30 pub root: PathBuf,
31 pub docs: Vec<SearchDoc>,
32 pub bm25: Bm25Index,
33}
34
35#[derive(Debug, Clone, Default)]
36pub struct QueryFilters {
37 pub include_paths: Vec<String>,
38 pub exclude_paths: Vec<String>,
39 pub include_exts: Vec<String>,
40 pub exclude_exts: Vec<String>,
41}
42
43#[derive(Debug, Clone, Serialize)]
44pub struct QueryResult {
45 pub doc_id: u32,
46 pub symbol: String,
47 pub file: String,
48 pub start_byte: usize,
49 pub end_byte: usize,
50 pub start_line: usize,
51 pub end_line: usize,
52 pub start_col: usize,
53 pub end_col: usize,
54 pub score: f32,
55 pub preview: String,
56}
57
58#[derive(Debug, Clone, Serialize)]
59pub struct QueryResponse {
60 pub root: String,
61 pub query: String,
62 pub top_k: usize,
63 pub results: Vec<QueryResult>,
64}
65
66pub fn load_search_state(root: &Path) -> Result<Option<SearchState>> {
67 let path = state_path(root);
68 if !path.exists() {
69 return Ok(None);
70 }
71 let data =
72 fs::read(&path).with_context(|| format!("Failed to read state: {}", path.display()))?;
73 let state: SearchState = bincode::deserialize(&data)
74 .with_context(|| format!("Failed to decode state: {}", path.display()))?;
75 Ok(Some(state))
76}
77
78pub fn save_search_state(root: &Path, state: &SearchState) -> Result<()> {
79 crate::cache::ensure_cache_dir(root)?;
80 let path = state_path(root);
81 let data = bincode::serialize(state)?;
82 fs::write(&path, data).with_context(|| format!("Failed to write state: {}", path.display()))?;
83 Ok(())
84}
85
86pub fn load_bm25(root: &Path) -> Result<Option<Bm25Index>> {
87 let path = bm25_path(root);
88 if !path.exists() {
89 return Ok(None);
90 }
91 let data =
92 fs::read(&path).with_context(|| format!("Failed to read BM25: {}", path.display()))?;
93 let index: Bm25Index = bincode::deserialize(&data)
94 .with_context(|| format!("Failed to decode BM25: {}", path.display()))?;
95 Ok(Some(index))
96}
97
98pub fn save_bm25(root: &Path, index: &Bm25Index) -> Result<()> {
99 crate::cache::ensure_cache_dir(root)?;
100 let path = bm25_path(root);
101 let data = bincode::serialize(index)?;
102 fs::write(&path, data).with_context(|| format!("Failed to write BM25: {}", path.display()))?;
103 Ok(())
104}
105
106pub fn load_search_index(root: &Path) -> Result<Option<SearchIndex>> {
107 let Some(state) = load_search_state(root)? else {
108 return Ok(None);
109 };
110 let Some(bm25) = load_bm25(root)? else {
111 return Ok(None);
112 };
113 Ok(Some(SearchIndex {
114 root: root.to_path_buf(),
115 docs: state.docs,
116 bm25,
117 }))
118}
119
120pub fn parse_query_filters(query: &str, extra_filters: &[String]) -> (String, QueryFilters) {
121 let mut filters = QueryFilters::default();
122 let mut terms = Vec::new();
123
124 let mut handle_token = |token: &str| {
125 let token = token.trim();
126 if token.is_empty() {
127 return;
128 }
129 let (negated, rest) = token
130 .strip_prefix('-')
131 .map(|t| (true, t))
132 .unwrap_or((false, token));
133 if let Some(path) = rest.strip_prefix("path:") {
134 if negated {
135 filters.exclude_paths.push(path.to_string());
136 } else {
137 filters.include_paths.push(path.to_string());
138 }
139 return;
140 }
141 if let Some(ext) = rest.strip_prefix("ext:") {
142 let ext = ext.trim_start_matches('.');
143 if negated {
144 filters.exclude_exts.push(ext.to_lowercase());
145 } else {
146 filters.include_exts.push(ext.to_lowercase());
147 }
148 return;
149 }
150 terms.push(token.to_string());
151 };
152
153 for token in query.split_whitespace() {
154 handle_token(token);
155 }
156
157 for filter in extra_filters {
158 handle_token(filter);
159 }
160
161 (terms.join(" "), filters)
162}
163
164pub fn execute_query(
165 index: &SearchIndex,
166 query: &str,
167 top_k: usize,
168 filters: &QueryFilters,
169) -> QueryResponse {
170 let search_k = top_k.saturating_mul(5).max(top_k).min(1000);
171 let results = index.bm25.search(
172 query,
173 &FieldWeights::default(),
174 Bm25Params::default(),
175 search_k,
176 );
177
178 let mut filtered = Vec::new();
179
180 for result in results {
181 let doc_id = result.doc_id as usize;
182 if doc_id >= index.docs.len() {
183 continue;
184 }
185 let doc = &index.docs[doc_id];
186 if !matches_filters(doc, filters) {
187 continue;
188 }
189 filtered.push(QueryResult {
190 doc_id: result.doc_id,
191 symbol: doc.symbol.clone(),
192 file: doc.file.clone(),
193 start_byte: doc.start_byte,
194 end_byte: doc.end_byte,
195 start_line: doc.start_line + 1,
196 end_line: doc.end_line + 1,
197 start_col: doc.start_col + 1,
198 end_col: doc.end_col + 1,
199 score: result.score,
200 preview: doc.preview.clone(),
201 });
202 }
203
204 filtered.sort_by(|a, b| {
205 b.score
206 .partial_cmp(&a.score)
207 .unwrap_or(std::cmp::Ordering::Equal)
208 .then_with(|| a.file.cmp(&b.file))
209 .then_with(|| a.start_byte.cmp(&b.start_byte))
210 });
211
212 if filtered.len() > top_k {
213 filtered.truncate(top_k);
214 }
215
216 QueryResponse {
217 root: index.root.display().to_string(),
218 query: query.to_string(),
219 top_k,
220 results: filtered,
221 }
222}
223
224fn matches_filters(doc: &SearchDoc, filters: &QueryFilters) -> bool {
225 if !filters.include_paths.is_empty() {
226 let mut matched = false;
227 for pat in &filters.include_paths {
228 if doc.file.contains(pat) {
229 matched = true;
230 break;
231 }
232 }
233 if !matched {
234 return false;
235 }
236 }
237
238 for pat in &filters.exclude_paths {
239 if doc.file.contains(pat) {
240 return false;
241 }
242 }
243
244 let ext = Path::new(&doc.file)
245 .extension()
246 .and_then(|e| e.to_str())
247 .unwrap_or("")
248 .to_lowercase();
249
250 if !filters.include_exts.is_empty() && !filters.include_exts.contains(&ext) {
251 return false;
252 }
253
254 if !filters.exclude_exts.is_empty() && filters.exclude_exts.contains(&ext) {
255 return false;
256 }
257
258 true
259}
260
261pub fn rebuild_bm25(docs: &[SearchDoc]) -> Bm25Index {
262 let mut index = Bm25Index::new();
263
264 for (doc_id, doc) in docs.iter().enumerate() {
265 let doc_id = doc_id as u32;
266 let path_tokens = crate::search::path_tokens(Path::new(&doc.file));
267 let ident_tokens = crate::search::tokenize(&doc.symbol);
268 let code_text = doc.indexed_text.as_str();
269
270 index.add_document(
271 doc_id,
272 path_tokens,
273 ident_tokens,
274 std::iter::empty::<&str>(),
275 std::iter::empty::<&str>(),
276 code_text,
277 );
278 }
279
280 index.finalize();
281 index
282}
283
284pub fn prune_docs_for_files(docs: &[SearchDoc], files: &HashSet<String>) -> Vec<SearchDoc> {
285 docs.iter()
286 .filter(|doc| !files.contains(&doc.file))
287 .cloned()
288 .collect()
289}