Skip to main content

omni_index/
query.rs

1use crate::cache::{bm25_path, state_path};
2use crate::search::{Bm25Index, Bm25Params, FieldWeights};
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6use std::fs;
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SearchDoc {
11    pub symbol: String,
12    pub file: String,
13    pub start_byte: usize,
14    pub end_byte: usize,
15    pub start_line: usize,
16    pub end_line: usize,
17    pub start_col: usize,
18    pub end_col: usize,
19    pub preview: String,
20    pub indexed_text: String,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct SearchState {
25    pub docs: Vec<SearchDoc>,
26}
27
28#[derive(Debug, Clone)]
29pub struct SearchIndex {
30    pub root: PathBuf,
31    pub docs: Vec<SearchDoc>,
32    pub bm25: Bm25Index,
33}
34
35#[derive(Debug, Clone, Default)]
36pub struct QueryFilters {
37    pub include_paths: Vec<String>,
38    pub exclude_paths: Vec<String>,
39    pub include_exts: Vec<String>,
40    pub exclude_exts: Vec<String>,
41}
42
43#[derive(Debug, Clone, Serialize)]
44pub struct QueryResult {
45    pub doc_id: u32,
46    pub symbol: String,
47    pub file: String,
48    pub start_byte: usize,
49    pub end_byte: usize,
50    pub start_line: usize,
51    pub end_line: usize,
52    pub start_col: usize,
53    pub end_col: usize,
54    pub score: f32,
55    pub preview: String,
56}
57
58#[derive(Debug, Clone, Serialize)]
59pub struct QueryResponse {
60    pub root: String,
61    pub query: String,
62    pub top_k: usize,
63    pub results: Vec<QueryResult>,
64}
65
66pub fn load_search_state(root: &Path) -> Result<Option<SearchState>> {
67    let path = state_path(root);
68    if !path.exists() {
69        return Ok(None);
70    }
71    let data =
72        fs::read(&path).with_context(|| format!("Failed to read state: {}", path.display()))?;
73    let state: SearchState = bincode::deserialize(&data)
74        .with_context(|| format!("Failed to decode state: {}", path.display()))?;
75    Ok(Some(state))
76}
77
78pub fn save_search_state(root: &Path, state: &SearchState) -> Result<()> {
79    crate::cache::ensure_cache_dir(root)?;
80    let path = state_path(root);
81    let data = bincode::serialize(state)?;
82    fs::write(&path, data).with_context(|| format!("Failed to write state: {}", path.display()))?;
83    Ok(())
84}
85
86pub fn load_bm25(root: &Path) -> Result<Option<Bm25Index>> {
87    let path = bm25_path(root);
88    if !path.exists() {
89        return Ok(None);
90    }
91    let data =
92        fs::read(&path).with_context(|| format!("Failed to read BM25: {}", path.display()))?;
93    let index: Bm25Index = bincode::deserialize(&data)
94        .with_context(|| format!("Failed to decode BM25: {}", path.display()))?;
95    Ok(Some(index))
96}
97
98pub fn save_bm25(root: &Path, index: &Bm25Index) -> Result<()> {
99    crate::cache::ensure_cache_dir(root)?;
100    let path = bm25_path(root);
101    let data = bincode::serialize(index)?;
102    fs::write(&path, data).with_context(|| format!("Failed to write BM25: {}", path.display()))?;
103    Ok(())
104}
105
106pub fn load_search_index(root: &Path) -> Result<Option<SearchIndex>> {
107    let Some(state) = load_search_state(root)? else {
108        return Ok(None);
109    };
110    let Some(bm25) = load_bm25(root)? else {
111        return Ok(None);
112    };
113    Ok(Some(SearchIndex {
114        root: root.to_path_buf(),
115        docs: state.docs,
116        bm25,
117    }))
118}
119
120pub fn parse_query_filters(query: &str, extra_filters: &[String]) -> (String, QueryFilters) {
121    let mut filters = QueryFilters::default();
122    let mut terms = Vec::new();
123
124    let mut handle_token = |token: &str| {
125        let token = token.trim();
126        if token.is_empty() {
127            return;
128        }
129        let (negated, rest) = token
130            .strip_prefix('-')
131            .map(|t| (true, t))
132            .unwrap_or((false, token));
133        if let Some(path) = rest.strip_prefix("path:") {
134            if negated {
135                filters.exclude_paths.push(path.to_string());
136            } else {
137                filters.include_paths.push(path.to_string());
138            }
139            return;
140        }
141        if let Some(ext) = rest.strip_prefix("ext:") {
142            let ext = ext.trim_start_matches('.');
143            if negated {
144                filters.exclude_exts.push(ext.to_lowercase());
145            } else {
146                filters.include_exts.push(ext.to_lowercase());
147            }
148            return;
149        }
150        terms.push(token.to_string());
151    };
152
153    for token in query.split_whitespace() {
154        handle_token(token);
155    }
156
157    for filter in extra_filters {
158        handle_token(filter);
159    }
160
161    (terms.join(" "), filters)
162}
163
164pub fn execute_query(
165    index: &SearchIndex,
166    query: &str,
167    top_k: usize,
168    filters: &QueryFilters,
169) -> QueryResponse {
170    let search_k = top_k.saturating_mul(5).max(top_k).min(1000);
171    let results = index.bm25.search(
172        query,
173        &FieldWeights::default(),
174        Bm25Params::default(),
175        search_k,
176    );
177
178    let mut filtered = Vec::new();
179
180    for result in results {
181        let doc_id = result.doc_id as usize;
182        if doc_id >= index.docs.len() {
183            continue;
184        }
185        let doc = &index.docs[doc_id];
186        if !matches_filters(doc, filters) {
187            continue;
188        }
189        filtered.push(QueryResult {
190            doc_id: result.doc_id,
191            symbol: doc.symbol.clone(),
192            file: doc.file.clone(),
193            start_byte: doc.start_byte,
194            end_byte: doc.end_byte,
195            start_line: doc.start_line + 1,
196            end_line: doc.end_line + 1,
197            start_col: doc.start_col + 1,
198            end_col: doc.end_col + 1,
199            score: result.score,
200            preview: doc.preview.clone(),
201        });
202    }
203
204    filtered.sort_by(|a, b| {
205        b.score
206            .partial_cmp(&a.score)
207            .unwrap_or(std::cmp::Ordering::Equal)
208            .then_with(|| a.file.cmp(&b.file))
209            .then_with(|| a.start_byte.cmp(&b.start_byte))
210    });
211
212    if filtered.len() > top_k {
213        filtered.truncate(top_k);
214    }
215
216    QueryResponse {
217        root: index.root.display().to_string(),
218        query: query.to_string(),
219        top_k,
220        results: filtered,
221    }
222}
223
224fn matches_filters(doc: &SearchDoc, filters: &QueryFilters) -> bool {
225    if !filters.include_paths.is_empty() {
226        let mut matched = false;
227        for pat in &filters.include_paths {
228            if doc.file.contains(pat) {
229                matched = true;
230                break;
231            }
232        }
233        if !matched {
234            return false;
235        }
236    }
237
238    for pat in &filters.exclude_paths {
239        if doc.file.contains(pat) {
240            return false;
241        }
242    }
243
244    let ext = Path::new(&doc.file)
245        .extension()
246        .and_then(|e| e.to_str())
247        .unwrap_or("")
248        .to_lowercase();
249
250    if !filters.include_exts.is_empty() && !filters.include_exts.contains(&ext) {
251        return false;
252    }
253
254    if !filters.exclude_exts.is_empty() && filters.exclude_exts.contains(&ext) {
255        return false;
256    }
257
258    true
259}
260
261pub fn rebuild_bm25(docs: &[SearchDoc]) -> Bm25Index {
262    let mut index = Bm25Index::new();
263
264    for (doc_id, doc) in docs.iter().enumerate() {
265        let doc_id = doc_id as u32;
266        let path_tokens = crate::search::path_tokens(Path::new(&doc.file));
267        let ident_tokens = crate::search::tokenize(&doc.symbol);
268        let code_text = doc.indexed_text.as_str();
269
270        index.add_document(
271            doc_id,
272            path_tokens,
273            ident_tokens,
274            std::iter::empty::<&str>(),
275            std::iter::empty::<&str>(),
276            code_text,
277        );
278    }
279
280    index.finalize();
281    index
282}
283
284pub fn prune_docs_for_files(docs: &[SearchDoc], files: &HashSet<String>) -> Vec<SearchDoc> {
285    docs.iter()
286        .filter(|doc| !files.contains(&doc.file))
287        .cloned()
288        .collect()
289}