1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use sqlitegraph::backend::NodeSpec;
5use sqlitegraph::config::{open_graph, GraphConfig};
6
7use crate::error::{ForgeError, Result};
8use crate::types::{Language, Location, Reference, ReferenceKind, Symbol, SymbolId, SymbolKind};
9
10use super::store::{StoredReference, UnifiedGraphStore};
11use super::BackendKind;
12
13impl UnifiedGraphStore {
14 pub async fn insert_symbol(&self, symbol: &Symbol) -> Result<SymbolId> {
15 let config = match self.backend_kind {
16 BackendKind::SQLite => GraphConfig::sqlite(),
17 BackendKind::NativeV3 => GraphConfig::native(),
18 };
19 let backend = open_graph(&self.db_path, &config)
20 .map_err(|e| ForgeError::DatabaseError(format!("Failed to open graph: {}", e)))?;
21
22 let kind = match symbol.kind {
23 SymbolKind::Function | SymbolKind::Method => "fn",
24 SymbolKind::Struct => "struct",
25 SymbolKind::Enum => "enum",
26 SymbolKind::Trait => "trait",
27 SymbolKind::Impl => "impl",
28 SymbolKind::Module => "module",
29 SymbolKind::TypeAlias => "type",
30 SymbolKind::Constant | SymbolKind::Static => "const",
31 SymbolKind::Parameter | SymbolKind::LocalVariable | SymbolKind::Field => "variable",
32 SymbolKind::Macro => "macro",
33 SymbolKind::Use => "use",
34 };
35
36 let node = NodeSpec {
37 kind: kind.to_string(),
38 name: symbol.name.to_string(),
39 file_path: Some(symbol.location.file_path.to_string_lossy().into_owned()),
40 data: symbol.metadata.clone(),
41 };
42
43 let id = backend
44 .insert_node(node)
45 .map_err(|e| ForgeError::DatabaseError(format!("Insert node failed: {}", e)))?;
46
47 Ok(SymbolId(id))
48 }
49
50 pub async fn insert_reference(&self, reference: &Reference) -> Result<()> {
51 if self.backend_kind == BackendKind::NativeV3 {
52 let mut refs = self
53 .references
54 .lock()
55 .expect("invariant: references mutex not poisoned");
56
57 let to_symbol = format!("sym_{}", reference.to.0);
58
59 refs.push(StoredReference {
60 to_symbol,
61 kind: reference.kind,
62 file_path: reference.location.file_path.clone(),
63 line_number: reference.location.line_number,
64 });
65 }
66 Ok(())
67 }
68
69 pub async fn query_symbols(&self, name: &str) -> Result<Vec<Symbol>> {
70 let conn = rusqlite::Connection::open(&self.db_path)
71 .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
72
73 let pattern = format!("%{}%", name);
74 let mut stmt = conn
75 .prepare(
76 "SELECT id, kind, name, file_path FROM graph_entities WHERE name LIKE ?1 LIMIT 50",
77 )
78 .map_err(|e| ForgeError::DatabaseError(format!("Prepare failed: {}", e)))?;
79
80 let symbols = stmt
81 .query_map(rusqlite::params![pattern], |row| {
82 let id: i64 = row.get(0)?;
83 let sym_name: String = row.get(2)?;
84 let file_path: Option<String> = row.get(3)?;
85 Ok((id, sym_name, file_path))
86 })
87 .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?
88 .flatten()
89 .map(|(id, sym_name, file_path)| Symbol {
90 id: SymbolId(id),
91 name: Arc::from(sym_name.as_str()),
92 fully_qualified_name: Arc::from(sym_name.as_str()),
93 kind: SymbolKind::Function,
94 language: Language::Rust,
95 location: Location {
96 file_path: file_path
97 .map(PathBuf::from)
98 .unwrap_or_else(|| PathBuf::from("")),
99 byte_start: 0,
100 byte_end: 0,
101 line_number: 0,
102 },
103 parent_id: None,
104 metadata: serde_json::Value::Null,
105 })
106 .collect();
107
108 Ok(symbols)
109 }
110
111 pub async fn get_symbol(&self, _id: SymbolId) -> Result<Symbol> {
112 Err(ForgeError::SymbolNotFound("Not implemented".to_string()))
113 }
114
115 pub async fn symbol_exists(&self, id: SymbolId) -> Result<bool> {
116 let conn = rusqlite::Connection::open(&self.db_path)
117 .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
118 let exists: i64 = conn
119 .query_row(
120 "SELECT EXISTS(SELECT 1 FROM graph_entities WHERE id = ?1)",
121 rusqlite::params![id.0],
122 |row| row.get(0),
123 )
124 .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
125 Ok(exists > 0)
126 }
127
128 pub async fn query_references(&self, symbol_id: SymbolId) -> Result<Vec<Reference>> {
129 if self.backend_kind == BackendKind::NativeV3 {
130 let refs = self
131 .references
132 .lock()
133 .expect("invariant: references mutex not poisoned");
134 let target_symbol = format!("sym_{}", symbol_id.0);
135
136 let mut result = Vec::new();
137 for stored in refs.iter() {
138 if stored.to_symbol == target_symbol {
139 result.push(Reference {
140 from: SymbolId(0),
141 to: symbol_id,
142 from_name: None,
143 to_name: None,
144 kind: stored.kind,
145 location: Location {
146 file_path: stored.file_path.clone(),
147 byte_start: 0,
148 byte_end: 0,
149 line_number: stored.line_number,
150 },
151 });
152 }
153 }
154 return Ok(result);
155 }
156
157 Ok(Vec::new())
158 }
159
160 pub async fn get_all_symbols(&self) -> Result<Vec<Symbol>> {
161 let conn = rusqlite::Connection::open(&self.db_path)
162 .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
163 let mut stmt = conn
164 .prepare("SELECT id, kind, name, file_path FROM graph_entities LIMIT 1000")
165 .map_err(|e| ForgeError::DatabaseError(format!("Prepare failed: {}", e)))?;
166 let symbols = stmt
167 .query_map([], |row| {
168 let id: i64 = row.get(0)?;
169 let sym_name: String = row.get(2)?;
170 let file_path: Option<String> = row.get(3)?;
171 Ok((id, sym_name, file_path))
172 })
173 .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?
174 .flatten()
175 .map(|(id, sym_name, file_path)| Symbol {
176 id: SymbolId(id),
177 name: Arc::from(sym_name.as_str()),
178 fully_qualified_name: Arc::from(sym_name.as_str()),
179 kind: SymbolKind::Function,
180 language: Language::Rust,
181 location: Location {
182 file_path: file_path
183 .map(PathBuf::from)
184 .unwrap_or_else(|| PathBuf::from("")),
185 byte_start: 0,
186 byte_end: 0,
187 line_number: 0,
188 },
189 parent_id: None,
190 metadata: serde_json::Value::Null,
191 })
192 .collect();
193 Ok(symbols)
194 }
195
196 pub async fn symbol_count(&self) -> Result<usize> {
197 let conn = rusqlite::Connection::open(&self.db_path)
198 .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
199 let count: i64 = conn
200 .query_row("SELECT COUNT(*) FROM graph_entities", [], |row| row.get(0))
201 .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
202 Ok(count as usize)
203 }
204
205 pub async fn index_cross_file_references(&self) -> Result<usize> {
206 if self.backend_kind != BackendKind::NativeV3 {
207 return Ok(0);
208 }
209
210 self.legacy_index_cross_file_references().await
211 }
212
213 async fn legacy_index_cross_file_references(&self) -> Result<usize> {
214 use regex::Regex;
215 use tokio::fs;
216
217 let mut symbols: std::collections::HashMap<String, (PathBuf, usize)> =
218 std::collections::HashMap::new();
219 self.collect_symbols_recursive(&self.codebase_path, &mut symbols)
220 .await?;
221
222 let reference_pattern = Regex::new(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(")
223 .expect("invariant: static regex pattern is valid");
224
225 {
226 let mut refs = self
227 .references
228 .lock()
229 .expect("invariant: references mutex not poisoned");
230 refs.clear();
231 }
232
233 let mut found_refs: Vec<StoredReference> = Vec::new();
234
235 for (symbol_name, (_file_path, _)) in &symbols {
236 for (target_file, _) in symbols.values() {
237 if let Ok(content) = fs::read_to_string(target_file).await {
238 for (line_num, line) in content.lines().enumerate() {
239 if line.contains("fn ") || line.contains("struct ") {
240 continue;
241 }
242
243 for cap in reference_pattern.captures_iter(line) {
244 if let Some(matched) = cap.get(1) {
245 if matched.as_str() == symbol_name {
246 found_refs.push(StoredReference {
247 to_symbol: format!("sym_{}", symbol_name),
248 kind: ReferenceKind::Call,
249 file_path: target_file.clone(),
250 line_number: line_num + 1,
251 });
252 }
253 }
254 }
255 }
256 }
257 }
258 }
259
260 let ref_count = found_refs.len();
261 self.references
262 .lock()
263 .expect("invariant: references mutex not poisoned")
264 .extend(found_refs);
265
266 Ok(ref_count)
267 }
268
269 async fn collect_symbols_recursive(
270 &self,
271 dir: &Path,
272 symbols: &mut std::collections::HashMap<String, (PathBuf, usize)>,
273 ) -> Result<()> {
274 use tokio::fs;
275
276 let mut entries = fs::read_dir(dir)
277 .await
278 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
279
280 while let Some(entry) = entries
281 .next_entry()
282 .await
283 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
284 {
285 let path = entry.path();
286 if path.is_dir() {
287 Box::pin(self.collect_symbols_recursive(&path, symbols)).await?;
288 } else if path.extension().map(|e| e == "rs").unwrap_or(false) {
289 if let Ok(content) = fs::read_to_string(&path).await {
290 for (line_num, line) in content.lines().enumerate() {
291 if let Some(fn_pos) = line.find("fn ") {
292 let after_fn = &line[fn_pos + 3..];
293 if let Some(end_pos) =
294 after_fn.find(|c: char| c.is_whitespace() || c == '(')
295 {
296 let name = after_fn[..end_pos].trim().to_string();
297 if !name.is_empty() {
298 symbols.insert(name, (path.clone(), line_num + 1));
299 }
300 }
301 }
302 if let Some(struct_pos) = line.find("struct ") {
303 let after_struct = &line[struct_pos + 7..];
304 if let Some(end_pos) = after_struct
305 .find(|c: char| c.is_whitespace() || c == '{' || c == ';')
306 {
307 let name = after_struct[..end_pos].trim().to_string();
308 if !name.is_empty() {
309 symbols.insert(name, (path.clone(), line_num + 1));
310 }
311 }
312 }
313 }
314 }
315 }
316 }
317
318 Ok(())
319 }
320
321 pub async fn query_references_for_symbol(&self, symbol_name: &str) -> Result<Vec<Reference>> {
322 if self.backend_kind != BackendKind::NativeV3 {
323 return Ok(Vec::new());
324 }
325
326 let refs = self
327 .references
328 .lock()
329 .expect("invariant: references mutex not poisoned");
330 let mut result = Vec::new();
331
332 for stored in refs.iter() {
333 if stored.to_symbol == format!("sym_{}", symbol_name)
334 || stored.to_symbol.contains(symbol_name)
335 {
336 result.push(Reference {
337 from: SymbolId(0),
338 to: SymbolId(0),
339 from_name: None,
340 to_name: None,
341 kind: stored.kind,
342 location: Location {
343 file_path: stored.file_path.clone(),
344 byte_start: 0,
345 byte_end: 0,
346 line_number: stored.line_number,
347 },
348 });
349 }
350 }
351
352 Ok(result)
353 }
354}