Skip to main content

forgekit_core/storage/
ops.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use sqlitegraph::backend::NodeSpec;
5use sqlitegraph::config::{open_graph, GraphConfig};
6
7use crate::error::{ForgeError, Result};
8use crate::types::{Language, Location, Reference, ReferenceKind, Symbol, SymbolId, SymbolKind};
9
10use super::store::{StoredReference, UnifiedGraphStore};
11use super::BackendKind;
12
13impl UnifiedGraphStore {
14    pub async fn insert_symbol(&self, symbol: &Symbol) -> Result<SymbolId> {
15        let config = match self.backend_kind {
16            BackendKind::SQLite => GraphConfig::sqlite(),
17            BackendKind::NativeV3 => GraphConfig::native(),
18        };
19        let backend = open_graph(&self.db_path, &config)
20            .map_err(|e| ForgeError::DatabaseError(format!("Failed to open graph: {}", e)))?;
21
22        let kind = match symbol.kind {
23            SymbolKind::Function | SymbolKind::Method => "fn",
24            SymbolKind::Struct => "struct",
25            SymbolKind::Enum => "enum",
26            SymbolKind::Trait => "trait",
27            SymbolKind::Impl => "impl",
28            SymbolKind::Module => "module",
29            SymbolKind::TypeAlias => "type",
30            SymbolKind::Constant | SymbolKind::Static => "const",
31            SymbolKind::Parameter | SymbolKind::LocalVariable | SymbolKind::Field => "variable",
32            SymbolKind::Macro => "macro",
33            SymbolKind::Use => "use",
34        };
35
36        let node = NodeSpec {
37            kind: kind.to_string(),
38            name: symbol.name.to_string(),
39            file_path: Some(symbol.location.file_path.to_string_lossy().into_owned()),
40            data: symbol.metadata.clone(),
41        };
42
43        let id = backend
44            .insert_node(node)
45            .map_err(|e| ForgeError::DatabaseError(format!("Insert node failed: {}", e)))?;
46
47        Ok(SymbolId(id))
48    }
49
50    pub async fn insert_reference(&self, reference: &Reference) -> Result<()> {
51        if self.backend_kind == BackendKind::NativeV3 {
52            let mut refs = self
53                .references
54                .lock()
55                .expect("invariant: references mutex not poisoned");
56
57            let to_symbol = format!("sym_{}", reference.to.0);
58
59            refs.push(StoredReference {
60                to_symbol,
61                kind: reference.kind,
62                file_path: reference.location.file_path.clone(),
63                line_number: reference.location.line_number,
64            });
65        }
66        Ok(())
67    }
68
69    pub async fn query_symbols(&self, name: &str) -> Result<Vec<Symbol>> {
70        let conn = rusqlite::Connection::open(&self.db_path)
71            .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
72
73        let pattern = format!("%{}%", name);
74        let mut stmt = conn
75            .prepare(
76                "SELECT id, kind, name, file_path FROM graph_entities WHERE name LIKE ?1 LIMIT 50",
77            )
78            .map_err(|e| ForgeError::DatabaseError(format!("Prepare failed: {}", e)))?;
79
80        let symbols = stmt
81            .query_map(rusqlite::params![pattern], |row| {
82                let id: i64 = row.get(0)?;
83                let sym_name: String = row.get(2)?;
84                let file_path: Option<String> = row.get(3)?;
85                Ok((id, sym_name, file_path))
86            })
87            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?
88            .flatten()
89            .map(|(id, sym_name, file_path)| Symbol {
90                id: SymbolId(id),
91                name: Arc::from(sym_name.as_str()),
92                fully_qualified_name: Arc::from(sym_name.as_str()),
93                kind: SymbolKind::Function,
94                language: Language::Rust,
95                location: Location {
96                    file_path: file_path
97                        .map(PathBuf::from)
98                        .unwrap_or_else(|| PathBuf::from("")),
99                    byte_start: 0,
100                    byte_end: 0,
101                    line_number: 0,
102                },
103                parent_id: None,
104                metadata: serde_json::Value::Null,
105            })
106            .collect();
107
108        Ok(symbols)
109    }
110
111    pub async fn get_symbol(&self, _id: SymbolId) -> Result<Symbol> {
112        Err(ForgeError::SymbolNotFound("Not implemented".to_string()))
113    }
114
115    pub async fn symbol_exists(&self, id: SymbolId) -> Result<bool> {
116        let conn = rusqlite::Connection::open(&self.db_path)
117            .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
118        let exists: i64 = conn
119            .query_row(
120                "SELECT EXISTS(SELECT 1 FROM graph_entities WHERE id = ?1)",
121                rusqlite::params![id.0],
122                |row| row.get(0),
123            )
124            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
125        Ok(exists > 0)
126    }
127
128    pub async fn query_references(&self, symbol_id: SymbolId) -> Result<Vec<Reference>> {
129        if self.backend_kind == BackendKind::NativeV3 {
130            let refs = self
131                .references
132                .lock()
133                .expect("invariant: references mutex not poisoned");
134            let target_symbol = format!("sym_{}", symbol_id.0);
135
136            let mut result = Vec::new();
137            for stored in refs.iter() {
138                if stored.to_symbol == target_symbol {
139                    result.push(Reference {
140                        from: SymbolId(0),
141                        to: symbol_id,
142                        from_name: None,
143                        to_name: None,
144                        kind: stored.kind,
145                        location: Location {
146                            file_path: stored.file_path.clone(),
147                            byte_start: 0,
148                            byte_end: 0,
149                            line_number: stored.line_number,
150                        },
151                    });
152                }
153            }
154            return Ok(result);
155        }
156
157        Ok(Vec::new())
158    }
159
160    pub async fn get_all_symbols(&self) -> Result<Vec<Symbol>> {
161        let conn = rusqlite::Connection::open(&self.db_path)
162            .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
163        let mut stmt = conn
164            .prepare("SELECT id, kind, name, file_path FROM graph_entities LIMIT 1000")
165            .map_err(|e| ForgeError::DatabaseError(format!("Prepare failed: {}", e)))?;
166        let symbols = stmt
167            .query_map([], |row| {
168                let id: i64 = row.get(0)?;
169                let sym_name: String = row.get(2)?;
170                let file_path: Option<String> = row.get(3)?;
171                Ok((id, sym_name, file_path))
172            })
173            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?
174            .flatten()
175            .map(|(id, sym_name, file_path)| Symbol {
176                id: SymbolId(id),
177                name: Arc::from(sym_name.as_str()),
178                fully_qualified_name: Arc::from(sym_name.as_str()),
179                kind: SymbolKind::Function,
180                language: Language::Rust,
181                location: Location {
182                    file_path: file_path
183                        .map(PathBuf::from)
184                        .unwrap_or_else(|| PathBuf::from("")),
185                    byte_start: 0,
186                    byte_end: 0,
187                    line_number: 0,
188                },
189                parent_id: None,
190                metadata: serde_json::Value::Null,
191            })
192            .collect();
193        Ok(symbols)
194    }
195
196    pub async fn symbol_count(&self) -> Result<usize> {
197        let conn = rusqlite::Connection::open(&self.db_path)
198            .map_err(|e| ForgeError::DatabaseError(format!("Open db failed: {}", e)))?;
199        let count: i64 = conn
200            .query_row("SELECT COUNT(*) FROM graph_entities", [], |row| row.get(0))
201            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
202        Ok(count as usize)
203    }
204
205    pub async fn index_cross_file_references(&self) -> Result<usize> {
206        if self.backend_kind != BackendKind::NativeV3 {
207            return Ok(0);
208        }
209
210        self.legacy_index_cross_file_references().await
211    }
212
213    async fn legacy_index_cross_file_references(&self) -> Result<usize> {
214        use regex::Regex;
215        use tokio::fs;
216
217        let mut symbols: std::collections::HashMap<String, (PathBuf, usize)> =
218            std::collections::HashMap::new();
219        self.collect_symbols_recursive(&self.codebase_path, &mut symbols)
220            .await?;
221
222        let reference_pattern = Regex::new(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(")
223            .expect("invariant: static regex pattern is valid");
224
225        {
226            let mut refs = self
227                .references
228                .lock()
229                .expect("invariant: references mutex not poisoned");
230            refs.clear();
231        }
232
233        let mut found_refs: Vec<StoredReference> = Vec::new();
234
235        for (symbol_name, (_file_path, _)) in &symbols {
236            for (target_file, _) in symbols.values() {
237                if let Ok(content) = fs::read_to_string(target_file).await {
238                    for (line_num, line) in content.lines().enumerate() {
239                        if line.contains("fn ") || line.contains("struct ") {
240                            continue;
241                        }
242
243                        for cap in reference_pattern.captures_iter(line) {
244                            if let Some(matched) = cap.get(1) {
245                                if matched.as_str() == symbol_name {
246                                    found_refs.push(StoredReference {
247                                        to_symbol: format!("sym_{}", symbol_name),
248                                        kind: ReferenceKind::Call,
249                                        file_path: target_file.clone(),
250                                        line_number: line_num + 1,
251                                    });
252                                }
253                            }
254                        }
255                    }
256                }
257            }
258        }
259
260        let ref_count = found_refs.len();
261        self.references
262            .lock()
263            .expect("invariant: references mutex not poisoned")
264            .extend(found_refs);
265
266        Ok(ref_count)
267    }
268
269    async fn collect_symbols_recursive(
270        &self,
271        dir: &Path,
272        symbols: &mut std::collections::HashMap<String, (PathBuf, usize)>,
273    ) -> Result<()> {
274        use tokio::fs;
275
276        let mut entries = fs::read_dir(dir)
277            .await
278            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
279
280        while let Some(entry) = entries
281            .next_entry()
282            .await
283            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
284        {
285            let path = entry.path();
286            if path.is_dir() {
287                Box::pin(self.collect_symbols_recursive(&path, symbols)).await?;
288            } else if path.extension().map(|e| e == "rs").unwrap_or(false) {
289                if let Ok(content) = fs::read_to_string(&path).await {
290                    for (line_num, line) in content.lines().enumerate() {
291                        if let Some(fn_pos) = line.find("fn ") {
292                            let after_fn = &line[fn_pos + 3..];
293                            if let Some(end_pos) =
294                                after_fn.find(|c: char| c.is_whitespace() || c == '(')
295                            {
296                                let name = after_fn[..end_pos].trim().to_string();
297                                if !name.is_empty() {
298                                    symbols.insert(name, (path.clone(), line_num + 1));
299                                }
300                            }
301                        }
302                        if let Some(struct_pos) = line.find("struct ") {
303                            let after_struct = &line[struct_pos + 7..];
304                            if let Some(end_pos) = after_struct
305                                .find(|c: char| c.is_whitespace() || c == '{' || c == ';')
306                            {
307                                let name = after_struct[..end_pos].trim().to_string();
308                                if !name.is_empty() {
309                                    symbols.insert(name, (path.clone(), line_num + 1));
310                                }
311                            }
312                        }
313                    }
314                }
315            }
316        }
317
318        Ok(())
319    }
320
321    pub async fn query_references_for_symbol(&self, symbol_name: &str) -> Result<Vec<Reference>> {
322        if self.backend_kind != BackendKind::NativeV3 {
323            return Ok(Vec::new());
324        }
325
326        let refs = self
327            .references
328            .lock()
329            .expect("invariant: references mutex not poisoned");
330        let mut result = Vec::new();
331
332        for stored in refs.iter() {
333            if stored.to_symbol == format!("sym_{}", symbol_name)
334                || stored.to_symbol.contains(symbol_name)
335            {
336                result.push(Reference {
337                    from: SymbolId(0),
338                    to: SymbolId(0),
339                    from_name: None,
340                    to_name: None,
341                    kind: stored.kind,
342                    location: Location {
343                        file_path: stored.file_path.clone(),
344                        byte_start: 0,
345                        byte_end: 0,
346                        line_number: stored.line_number,
347                    },
348                });
349            }
350        }
351
352        Ok(result)
353    }
354}