1use super::*;
6
7#[derive(Debug, Clone, Copy)]
10pub enum KindScope {
11 All,
13 CodeOnly,
15 Exact(SymbolKind),
17}
18
19impl Database {
20 pub fn upsert_symbol_content(
27 &self,
28 symbol_id: &str,
29 symbol_name: &str,
30 content: &str,
31 header: &str,
32 ) -> Result<()> {
33 let normalized = normalize_symbol_name(symbol_name);
34 self.conn.execute(
39 "DELETE FROM symbol_content WHERE symbol_id = ?1",
40 params![symbol_id],
41 )?;
42 self.conn.execute(
43 "INSERT INTO symbol_content (symbol_id, content, header, normalized_name)
44 VALUES (?1, ?2, ?3, ?4)",
45 params![symbol_id, content, header, normalized],
46 )?;
47 Ok(())
48 }
49
50 pub fn insert_symbol_contents(&self, items: &[(String, String, String, String)]) -> Result<()> {
54 let tx = self.conn.unchecked_transaction()?;
55 self.insert_symbol_contents_in_tx(items)?;
56 tx.commit()?;
57 Ok(())
58 }
59
60 pub fn insert_symbol_contents_in_tx(
63 &self,
64 items: &[(String, String, String, String)],
65 ) -> Result<()> {
66 let mut del = self
69 .conn
70 .prepare_cached("DELETE FROM symbol_content WHERE symbol_id = ?1")?;
71 let mut ins = self.conn.prepare_cached(
72 "INSERT INTO symbol_content (symbol_id, content, header, normalized_name)
73 VALUES (?1, ?2, ?3, ?4)",
74 )?;
75 for (symbol_id, name, content, header) in items {
76 let normalized = normalize_symbol_name(name);
77 del.execute(params![symbol_id])?;
78 ins.execute(params![symbol_id, content, header, normalized])?;
79 }
80 Ok(())
81 }
82
83 pub fn clear_symbol_content_for_file(&self, file_path: &str) -> Result<()> {
85 self.conn.execute(
86 "DELETE FROM symbol_content WHERE symbol_id IN
87 (SELECT id FROM symbols WHERE file_path = ?1)",
88 params![file_path],
89 )?;
90 Ok(())
91 }
92
93 pub fn get_symbol_content(&self, symbol_id: &str) -> Result<Option<(String, String)>> {
95 self.conn
96 .query_row(
97 "SELECT content, header FROM symbol_content WHERE symbol_id = ?1",
98 params![symbol_id],
99 |row| Ok((row.get(0)?, row.get(1)?)),
100 )
101 .optional()
102 .context("Failed to query symbol content")
103 }
104
105 pub fn get_symbol_contents_batch(
109 &self,
110 symbol_ids: &[String],
111 ) -> Result<std::collections::HashMap<String, (String, String)>> {
112 let mut result = std::collections::HashMap::with_capacity(symbol_ids.len());
113 if symbol_ids.is_empty() {
114 return Ok(result);
115 }
116 for chunk in symbol_ids.chunks(Self::FILE_CHUNK_SIZE) {
117 let placeholders: Vec<&str> = chunk.iter().map(|_| "?").collect();
118 let sql = format!(
119 "SELECT symbol_id, content, header FROM symbol_content WHERE symbol_id IN ({})",
120 placeholders.join(",")
121 );
122 let mut stmt = self.conn.prepare(&sql)?;
123 let params: Vec<Box<dyn rusqlite::types::ToSql>> = chunk
124 .iter()
125 .map(|id| Box::new(id.clone()) as Box<dyn rusqlite::types::ToSql>)
126 .collect();
127 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
128 params.iter().map(|p| p.as_ref()).collect();
129 let rows = stmt
130 .query_map(param_refs.as_slice(), |row| {
131 Ok((
132 row.get::<_, String>(0)?,
133 row.get::<_, String>(1)?,
134 row.get::<_, String>(2)?,
135 ))
136 })?
137 .collect::<std::result::Result<Vec<_>, _>>()?;
138 for (id, content, header) in rows {
139 result.insert(id, (content, header));
140 }
141 }
142 Ok(result)
143 }
144
145 pub fn fts5_search(&self, query: &str, limit: u32) -> Result<Vec<String>> {
151 self.fts5_search_kinded(query, limit, KindScope::All)
152 }
153
154 pub fn fts5_search_kinded(
157 &self,
158 query: &str,
159 limit: u32,
160 scope: KindScope,
161 ) -> Result<Vec<String>> {
162 let (where_kind, kind_param): (&str, Option<&str>) = match scope {
164 KindScope::All => ("", None),
165 KindScope::CodeOnly => ("AND s.kind NOT IN ('document', 'import')", None),
166 KindScope::Exact(k) => ("AND s.kind = ?3", Some(k.as_str())),
167 };
168 let sql = if matches!(scope, KindScope::All) {
169 "SELECT sc.symbol_id
170 FROM symbol_fts f
171 JOIN symbol_content sc ON sc.rowid = f.rowid
172 WHERE symbol_fts MATCH ?1
173 ORDER BY rank
174 LIMIT ?2"
175 .to_string()
176 } else {
177 format!(
178 "SELECT sc.symbol_id
179 FROM symbol_fts f
180 JOIN symbol_content sc ON sc.rowid = f.rowid
181 JOIN symbols s ON s.id = sc.symbol_id
182 WHERE symbol_fts MATCH ?1 {where_kind}
183 ORDER BY rank
184 LIMIT ?2"
185 )
186 };
187 let ctx =
188 || format!("fts5_search_kinded (scope={scope:?}, query={query:?}, limit={limit})");
189 let mut stmt = self.conn.prepare(&sql).with_context(ctx)?;
190 let rows: Vec<String> = match kind_param {
191 Some(k) => stmt
192 .query_map(params![query, limit, k], |row| row.get(0))
193 .with_context(ctx)?
194 .collect::<std::result::Result<_, _>>()
195 .with_context(ctx)?,
196 None => stmt
197 .query_map(params![query, limit], |row| row.get(0))
198 .with_context(ctx)?
199 .collect::<std::result::Result<_, _>>()
200 .with_context(ctx)?,
201 };
202 Ok(rows)
203 }
204
205 pub fn get_or_create_embedding_id(&self, symbol_id: &str) -> Result<i64> {
211 let existing: Option<i64> = self
213 .conn
214 .query_row(
215 "SELECT id FROM symbol_embedding_map WHERE symbol_id = ?1",
216 params![symbol_id],
217 |row| row.get(0),
218 )
219 .optional()?;
220
221 if let Some(id) = existing {
222 return Ok(id);
223 }
224
225 self.conn.execute(
227 "INSERT INTO symbol_embedding_map (symbol_id) VALUES (?1)",
228 params![symbol_id],
229 )?;
230 Ok(self.conn.last_insert_rowid())
231 }
232
233 pub fn symbol_id_for_embedding(&self, embedding_id: i64) -> Result<Option<String>> {
235 self.conn
236 .query_row(
237 "SELECT symbol_id FROM symbol_embedding_map WHERE id = ?1",
238 params![embedding_id],
239 |row| row.get(0),
240 )
241 .optional()
242 .context("Failed to query embedding map")
243 }
244
245 pub fn symbol_ids_for_embeddings(&self, embedding_ids: &[i64]) -> Result<Vec<(i64, String)>> {
247 if embedding_ids.is_empty() {
248 return Ok(Vec::new());
249 }
250 let mut all_results = Vec::with_capacity(embedding_ids.len());
251 for chunk in embedding_ids.chunks(Self::FILE_CHUNK_SIZE) {
252 let placeholders: Vec<String> = chunk.iter().map(|_| "?".to_string()).collect();
253 let sql = format!(
254 "SELECT id, symbol_id FROM symbol_embedding_map WHERE id IN ({})",
255 placeholders.join(",")
256 );
257 let mut stmt = self.conn.prepare(&sql)?;
258 let params: Vec<Box<dyn rusqlite::types::ToSql>> = chunk
259 .iter()
260 .map(|id| Box::new(*id) as Box<dyn rusqlite::types::ToSql>)
261 .collect();
262 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
263 params.iter().map(|p| p.as_ref()).collect();
264 let rows = stmt
265 .query_map(param_refs.as_slice(), |row| Ok((row.get(0)?, row.get(1)?)))?
266 .collect::<std::result::Result<Vec<_>, _>>()?;
267 all_results.extend(rows);
268 }
269 Ok(all_results)
270 }
271
272 pub fn upsert_embedding(&self, embedding_id: i64, embedding: &[u8]) -> Result<()> {
279 self.conn.execute(
281 "DELETE FROM symbol_vec WHERE rowid = ?1",
282 params![embedding_id],
283 )?;
284 self.conn.execute(
285 "INSERT INTO symbol_vec (rowid, embedding) VALUES (?1, ?2)",
286 params![embedding_id, embedding],
287 )?;
288 Ok(())
289 }
290
291 pub fn insert_embeddings(&self, items: &[(i64, Vec<u8>)]) -> Result<()> {
293 let tx = self.conn.unchecked_transaction()?;
294 for (id, embedding) in items {
295 self.conn
296 .execute("DELETE FROM symbol_vec WHERE rowid = ?1", params![id])?;
297 self.conn.execute(
298 "INSERT INTO symbol_vec (rowid, embedding) VALUES (?1, ?2)",
299 params![id, embedding],
300 )?;
301 }
302 tx.commit()?;
303 Ok(())
304 }
305
306 pub fn vector_search(&self, query_embedding: &[u8], limit: u32) -> Result<Vec<(i64, f64)>> {
310 let mut stmt = self.conn.prepare(
311 "SELECT rowid, distance
312 FROM symbol_vec
313 WHERE embedding MATCH ?1
314 ORDER BY distance
315 LIMIT ?2",
316 )?;
317 let rows = stmt
318 .query_map(params![query_embedding, limit], |row| {
319 Ok((row.get(0)?, row.get(1)?))
320 })?
321 .collect::<std::result::Result<Vec<_>, _>>()?;
322 Ok(rows)
323 }
324
325 pub fn embedding_count(&self) -> Result<u32> {
329 Ok(self.conn.query_row(
330 "SELECT COUNT(*) FROM symbol_embedding_map em
331 JOIN symbol_vec sv ON sv.rowid = em.id",
332 [],
333 |row| row.get(0),
334 )?)
335 }
336
337 pub fn has_embedding(&self, symbol_id: &str) -> Result<bool> {
339 let map_id: Option<i64> = self
340 .conn
341 .query_row(
342 "SELECT id FROM symbol_embedding_map WHERE symbol_id = ?1",
343 params![symbol_id],
344 |row| row.get(0),
345 )
346 .optional()?;
347
348 if let Some(id) = map_id {
349 let exists: bool = self.conn.query_row(
350 "SELECT EXISTS(SELECT 1 FROM symbol_vec WHERE rowid = ?1)",
351 params![id],
352 |row| row.get(0),
353 )?;
354 Ok(exists)
355 } else {
356 Ok(false)
357 }
358 }
359
360 pub fn clear_rag_data_for_file(&self, file_path: &str) -> Result<()> {
362 self.conn.execute(
364 "DELETE FROM symbol_vec WHERE rowid IN
365 (SELECT em.id FROM symbol_embedding_map em
366 JOIN symbols s ON em.symbol_id = s.id
367 WHERE s.file_path = ?1)",
368 params![file_path],
369 )?;
370 self.conn.execute(
372 "DELETE FROM symbol_embedding_map WHERE symbol_id IN
373 (SELECT id FROM symbols WHERE file_path = ?1)",
374 params![file_path],
375 )?;
376 self.clear_symbol_content_for_file(file_path)?;
378 Ok(())
379 }
380
381 pub(crate) fn delete_embedding_rows_for_id_in_tx(&self, id: &str) -> Result<()> {
387 self.conn
388 .prepare_cached(
389 "DELETE FROM symbol_vec WHERE rowid IN \
390 (SELECT id FROM symbol_embedding_map WHERE symbol_id = ?1)",
391 )?
392 .execute(params![id])?;
393 self.conn
394 .prepare_cached("DELETE FROM symbol_embedding_map WHERE symbol_id = ?1")?
395 .execute(params![id])?;
396 Ok(())
397 }
398
399 pub fn clear_embeddings_for_symbols_in_tx(&self, ids: &[String]) -> Result<()> {
405 for id in ids {
406 self.delete_embedding_rows_for_id_in_tx(id)?;
407 }
408 Ok(())
409 }
410
411 pub fn clear_content_for_symbols_in_tx(&self, ids: &[String]) -> Result<()> {
416 if ids.is_empty() {
417 return Ok(());
418 }
419 let mut del = self
420 .conn
421 .prepare_cached("DELETE FROM symbol_content WHERE symbol_id = ?1")?;
422 for id in ids {
423 del.execute(params![id])?;
424 }
425 Ok(())
426 }
427
428 pub fn get_symbol(&self, id: &str) -> Result<Option<Symbol>> {
430 self.conn
431 .query_row(
432 "SELECT id, name, kind, file_path, start_line, end_line, start_byte, end_byte,
433 parent_id, signature, visibility, is_async, docstring, in_degree,
434 content_hash, subtree_hash
435 FROM symbols WHERE id = ?1",
436 params![id],
437 row_to_symbol,
438 )
439 .optional()
440 .context("Failed to query symbol")
441 }
442
443 pub fn get_symbols_by_ids(&self, ids: &[String]) -> Result<Vec<Symbol>> {
445 if ids.is_empty() {
446 return Ok(Vec::new());
447 }
448 let placeholders: Vec<&str> = ids.iter().map(|_| "?").collect();
449 let sql = format!(
450 "SELECT id, name, kind, file_path, start_line, end_line, start_byte, end_byte,
451 parent_id, signature, visibility, is_async, docstring, in_degree,
452 content_hash, subtree_hash
453 FROM symbols WHERE id IN ({})",
454 placeholders.join(",")
455 );
456 let mut stmt = self.conn.prepare(&sql)?;
457 let params: Vec<Box<dyn rusqlite::types::ToSql>> = ids
458 .iter()
459 .map(|id| Box::new(id.clone()) as Box<dyn rusqlite::types::ToSql>)
460 .collect();
461 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
462 params.iter().map(|p| p.as_ref()).collect();
463 let rows: std::collections::HashMap<String, Symbol> = stmt
464 .query_map(param_refs.as_slice(), row_to_symbol)?
465 .filter_map(|r| r.ok())
466 .map(|s| (s.id.clone(), s))
467 .collect();
468 Ok(ids.iter().filter_map(|id| rows.get(id).cloned()).collect())
470 }
471
472 pub fn symbols_needing_embeddings(&self) -> Result<Vec<String>> {
476 let mut stmt = self.conn.prepare(
477 "SELECT sc.symbol_id FROM symbol_content sc
478 JOIN symbols s ON s.id = sc.symbol_id
479 WHERE s.kind NOT IN (?1, ?2)
480 AND NOT EXISTS (
481 SELECT 1 FROM symbol_embedding_map em
482 JOIN symbol_vec sv ON sv.rowid = em.id
483 WHERE em.symbol_id = sc.symbol_id
484 )",
485 )?;
486 let rows = stmt
487 .query_map(
488 params![SymbolKind::Variable.as_str(), SymbolKind::Import.as_str(),],
489 |row| row.get(0),
490 )?
491 .collect::<std::result::Result<Vec<_>, _>>()?;
492 Ok(rows)
493 }
494
495 pub fn symbol_content_count(&self) -> Result<u32> {
497 Ok(self
498 .conn
499 .query_row("SELECT COUNT(*) FROM symbol_content", [], |row| row.get(0))?)
500 }
501
502 pub fn all_content_symbol_ids(&self) -> Result<Vec<String>> {
504 let mut stmt = self.conn.prepare(
505 "SELECT sc.symbol_id FROM symbol_content sc
506 JOIN symbols s ON s.id = sc.symbol_id
507 WHERE s.kind NOT IN (?1, ?2)
508 ORDER BY sc.symbol_id",
509 )?;
510 let rows = stmt
511 .query_map(
512 params![SymbolKind::Variable.as_str(), SymbolKind::Import.as_str(),],
513 |row| row.get(0),
514 )?
515 .collect::<std::result::Result<Vec<_>, _>>()?;
516 Ok(rows)
517 }
518
519 pub fn clear_all_embeddings(&self) -> Result<()> {
521 self.conn.execute("DELETE FROM symbol_vec", [])?;
522 self.conn.execute("DELETE FROM symbol_embedding_map", [])?;
523 Ok(())
524 }
525}