Skip to main content

codegraph/
db.rs

1use crate::types::*;
2use anyhow::{Context, Result};
3use rusqlite::{params, Connection, OptionalExtension};
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6
7pub struct Database {
8    conn: Connection,
9    path: PathBuf,
10}
11
12impl Database {
13    pub fn initialize(path: impl AsRef<Path>) -> Result<Self> {
14        let db = Self::open_raw(path)?;
15        db.create_schema()?;
16        Ok(db)
17    }
18
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let db = Self::open_raw(path)?;
21        db.create_schema()?;
22        Ok(db)
23    }
24
25    fn open_raw(path: impl AsRef<Path>) -> Result<Self> {
26        let path = path.as_ref().to_path_buf();
27        if let Some(parent) = path.parent() {
28            std::fs::create_dir_all(parent)?;
29        }
30        let conn =
31            Connection::open(&path).with_context(|| format!("opening {}", path.display()))?;
32        conn.pragma_update(None, "foreign_keys", "ON")?;
33        conn.pragma_update(None, "journal_mode", "WAL")?;
34        conn.pragma_update(None, "busy_timeout", 120_000)?;
35        Ok(Self { conn, path })
36    }
37
38    fn create_schema(&self) -> Result<()> {
39        self.conn.execute_batch(
40            r#"
41            CREATE TABLE IF NOT EXISTS schema_versions (
42                version INTEGER PRIMARY KEY,
43                applied_at INTEGER NOT NULL,
44                description TEXT
45            );
46            INSERT OR IGNORE INTO schema_versions (version, applied_at, description)
47            VALUES (1, strftime('%s', 'now') * 1000, 'Rust schema');
48
49            CREATE TABLE IF NOT EXISTS nodes (
50                id TEXT PRIMARY KEY,
51                kind TEXT NOT NULL,
52                name TEXT NOT NULL,
53                qualified_name TEXT NOT NULL,
54                file_path TEXT NOT NULL,
55                language TEXT NOT NULL,
56                start_line INTEGER NOT NULL,
57                end_line INTEGER NOT NULL,
58                start_column INTEGER NOT NULL,
59                end_column INTEGER NOT NULL,
60                docstring TEXT,
61                signature TEXT,
62                visibility TEXT,
63                is_exported INTEGER DEFAULT 0,
64                is_async INTEGER DEFAULT 0,
65                is_static INTEGER DEFAULT 0,
66                is_abstract INTEGER DEFAULT 0,
67                decorators TEXT,
68                type_parameters TEXT,
69                updated_at INTEGER NOT NULL
70            );
71
72            CREATE TABLE IF NOT EXISTS edges (
73                id INTEGER PRIMARY KEY AUTOINCREMENT,
74                source TEXT NOT NULL,
75                target TEXT NOT NULL,
76                kind TEXT NOT NULL,
77                metadata TEXT,
78                line INTEGER,
79                col INTEGER,
80                provenance TEXT DEFAULT NULL,
81                FOREIGN KEY (source) REFERENCES nodes(id) ON DELETE CASCADE,
82                FOREIGN KEY (target) REFERENCES nodes(id) ON DELETE CASCADE
83            );
84
85            CREATE TABLE IF NOT EXISTS files (
86                path TEXT PRIMARY KEY,
87                content_hash TEXT NOT NULL,
88                language TEXT NOT NULL,
89                size INTEGER NOT NULL,
90                modified_at INTEGER NOT NULL,
91                indexed_at INTEGER NOT NULL,
92                node_count INTEGER DEFAULT 0,
93                errors TEXT
94            );
95
96            CREATE TABLE IF NOT EXISTS unresolved_refs (
97                id INTEGER PRIMARY KEY AUTOINCREMENT,
98                from_node_id TEXT NOT NULL,
99                reference_name TEXT NOT NULL,
100                reference_kind TEXT NOT NULL,
101                line INTEGER NOT NULL,
102                col INTEGER NOT NULL,
103                candidates TEXT,
104                file_path TEXT NOT NULL DEFAULT '',
105                language TEXT NOT NULL DEFAULT 'unknown',
106                FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE
107            );
108
109            CREATE INDEX IF NOT EXISTS idx_nodes_kind ON nodes(kind);
110            CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);
111            CREATE INDEX IF NOT EXISTS idx_nodes_file_path ON nodes(file_path);
112            CREATE INDEX IF NOT EXISTS idx_nodes_language ON nodes(language);
113            CREATE INDEX IF NOT EXISTS idx_edges_kind ON edges(kind);
114            CREATE INDEX IF NOT EXISTS idx_edges_source_kind ON edges(source, kind);
115            CREATE INDEX IF NOT EXISTS idx_edges_target_kind ON edges(target, kind);
116            CREATE INDEX IF NOT EXISTS idx_files_language ON files(language);
117            CREATE INDEX IF NOT EXISTS idx_unresolved_name ON unresolved_refs(reference_name);
118            "#,
119        )?;
120        Ok(())
121    }
122
123    pub fn clear_all(&self) -> Result<()> {
124        self.conn.execute_batch(
125            "DELETE FROM edges; DELETE FROM unresolved_refs; DELETE FROM nodes; DELETE FROM files;",
126        )?;
127        Ok(())
128    }
129
130    pub fn insert_file(&self, file: &FileRecord) -> Result<()> {
131        self.conn.execute(
132            "INSERT OR REPLACE INTO files (path, content_hash, language, size, modified_at, indexed_at, node_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
133            params![file.path, file.content_hash, file.language.as_str(), file.size as i64, file.modified_at, file.indexed_at, file.node_count],
134        )?;
135        Ok(())
136    }
137
138    pub fn insert_nodes(&self, nodes: &[Node]) -> Result<()> {
139        let mut stmt = self.conn.prepare(
140            "INSERT OR REPLACE INTO nodes (id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, decorators, type_parameters, updated_at)
141             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, NULL, NULL, ?18)"
142        )?;
143        for n in nodes {
144            stmt.execute(params![
145                n.id,
146                n.kind.as_str(),
147                n.name,
148                n.qualified_name,
149                n.file_path,
150                n.language.as_str(),
151                n.start_line,
152                n.end_line,
153                n.start_column,
154                n.end_column,
155                n.docstring,
156                n.signature,
157                n.visibility,
158                n.is_exported as i64,
159                n.is_async as i64,
160                n.is_static as i64,
161                n.is_abstract as i64,
162                n.updated_at
163            ])?;
164        }
165        Ok(())
166    }
167
168    pub fn insert_edges(&self, edges: &[Edge]) -> Result<()> {
169        let mut stmt = self.conn.prepare("INSERT INTO edges (source, target, kind, line, col, provenance) VALUES (?1, ?2, ?3, ?4, ?5, ?6)")?;
170        for e in edges {
171            stmt.execute(params![
172                e.source,
173                e.target,
174                e.kind.as_str(),
175                e.line,
176                e.col,
177                e.provenance
178            ])?;
179        }
180        Ok(())
181    }
182
183    pub fn insert_unresolved_refs(&self, refs: &[UnresolvedReference]) -> Result<()> {
184        let mut stmt = self.conn.prepare(
185            "INSERT INTO unresolved_refs (from_node_id, reference_name, reference_kind, line, col, file_path, language) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
186        )?;
187        for r in refs {
188            stmt.execute(params![
189                r.from_node_id,
190                r.reference_name,
191                r.reference_kind.as_str(),
192                r.line,
193                r.column,
194                r.file_path,
195                r.language.as_str()
196            ])?;
197        }
198        Ok(())
199    }
200
201    pub fn resolve_references_by_name(&self) -> Result<()> {
202        let mut refs = self.conn.prepare("SELECT from_node_id, reference_name, reference_kind, line, col, language FROM unresolved_refs")?;
203        let rows = refs.query_map([], |row| {
204            Ok((
205                row.get::<_, String>(0)?,
206                row.get::<_, String>(1)?,
207                row.get::<_, String>(2)?,
208                row.get::<_, Option<i64>>(3)?,
209                row.get::<_, Option<i64>>(4)?,
210                row.get::<_, String>(5)?,
211            ))
212        })?;
213        for row in rows {
214            let (from, name, kind, line, col, lang) = row?;
215            let target: Option<String> = self.conn.query_row(
216                "SELECT id FROM nodes WHERE name = ?1 AND language = ?2 AND id != ?3 ORDER BY CASE kind WHEN 'function' THEN 0 WHEN 'method' THEN 1 WHEN 'struct' THEN 2 WHEN 'trait' THEN 3 ELSE 9 END LIMIT 1",
217                params![name, lang, from],
218                |row| row.get(0),
219            ).optional()?;
220            if let Some(target) = target {
221                self.conn.execute(
222                    "INSERT INTO edges (source, target, kind, line, col, provenance) VALUES (?1, ?2, ?3, ?4, ?5, 'heuristic')",
223                    params![from, target, kind, line, col],
224                )?;
225            }
226        }
227        Ok(())
228    }
229
230    pub fn edge_count(&self) -> Result<i64> {
231        Ok(self
232            .conn
233            .query_row("SELECT COUNT(*) FROM edges", [], |r| r.get(0))?)
234    }
235
236    pub fn stats(&self) -> Result<GraphStats> {
237        let file_count = self
238            .conn
239            .query_row("SELECT COUNT(*) FROM files", [], |r| r.get(0))?;
240        let node_count = self
241            .conn
242            .query_row("SELECT COUNT(*) FROM nodes", [], |r| r.get(0))?;
243        let edge_count = self
244            .conn
245            .query_row("SELECT COUNT(*) FROM edges", [], |r| r.get(0))?;
246        let db_size_bytes = std::fs::metadata(&self.path)
247            .map(|m| m.len() as i64)
248            .unwrap_or_default();
249        let files_by_language = grouped_counts(
250            &self.conn,
251            "SELECT language, COUNT(*) FROM files GROUP BY language",
252        )?;
253        let nodes_by_kind =
254            grouped_counts(&self.conn, "SELECT kind, COUNT(*) FROM nodes GROUP BY kind")?;
255        Ok(GraphStats {
256            file_count,
257            node_count,
258            edge_count,
259            db_size_bytes,
260            files_by_language,
261            nodes_by_kind,
262        })
263    }
264
265    pub fn search_nodes(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
266        let limit = if options.limit <= 0 {
267            10
268        } else {
269            options.limit
270        };
271        let pattern = format!("%{}%", query);
272        let exact = query.to_string();
273        let prefix = format!("{}%", query);
274
275        let base = "SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes";
276        let order = " ORDER BY CASE WHEN name = ? THEN 0 WHEN name LIKE ? THEN 1 ELSE 2 END, length(name) LIMIT ?";
277
278        let rows = match (options.kind, options.language) {
279            (Some(k), Some(l)) => {
280                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ?) AND kind = ? AND language = ?{order}");
281                let mut stmt = self.conn.prepare(&sql)?;
282                let nodes = collect_nodes(stmt.query_map(
283                    params![
284                        pattern,
285                        pattern,
286                        pattern,
287                        k.as_str(),
288                        l.as_str(),
289                        exact,
290                        prefix,
291                        limit
292                    ],
293                    node_from_row,
294                )?)?;
295                nodes
296            }
297            (Some(k), None) => {
298                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ?) AND kind = ?{order}");
299                let mut stmt = self.conn.prepare(&sql)?;
300                let nodes = collect_nodes(stmt.query_map(
301                    params![pattern, pattern, pattern, k.as_str(), exact, prefix, limit],
302                    node_from_row,
303                )?)?;
304                nodes
305            }
306            (None, Some(l)) => {
307                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ?) AND language = ?{order}");
308                let mut stmt = self.conn.prepare(&sql)?;
309                let nodes = collect_nodes(stmt.query_map(
310                    params![pattern, pattern, pattern, l.as_str(), exact, prefix, limit],
311                    node_from_row,
312                )?)?;
313                nodes
314            }
315            (None, None) => {
316                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ?){order}");
317                let mut stmt = self.conn.prepare(&sql)?;
318                let nodes = collect_nodes(stmt.query_map(
319                    params![pattern, pattern, pattern, exact, prefix, limit],
320                    node_from_row,
321                )?)?;
322                nodes
323            }
324        };
325        Ok(rows
326            .into_iter()
327            .map(|node| SearchResult { node, score: 1.0 })
328            .collect())
329    }
330
331    pub fn get_node(&self, id: &str) -> Result<Option<Node>> {
332        self.conn
333            .query_row("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE id = ?1", [id], node_from_row)
334            .optional()
335            .map_err(Into::into)
336    }
337
338    pub fn get_nodes_by_name(&self, name: &str, limit: i64) -> Result<Vec<Node>> {
339        let mut stmt = self.conn.prepare("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE name = ?1 ORDER BY file_path, start_line LIMIT ?2")?;
340        let nodes = collect_nodes(stmt.query_map(params![name, limit], node_from_row)?)?;
341        Ok(nodes)
342    }
343
344    pub fn get_all_files(&self) -> Result<Vec<FileRecord>> {
345        let mut stmt = self.conn.prepare("SELECT path, content_hash, language, size, modified_at, indexed_at, node_count FROM files ORDER BY path")?;
346        let rows = stmt.query_map([], |row| {
347            let language: String = row.get(2)?;
348            Ok(FileRecord {
349                path: row.get(0)?,
350                content_hash: row.get(1)?,
351                language: Language::from_str(&language).unwrap_or(Language::Unknown),
352                size: row.get::<_, i64>(3)? as u64,
353                modified_at: row.get(4)?,
354                indexed_at: row.get(5)?,
355                node_count: row.get(6)?,
356            })
357        })?;
358        let mut out = Vec::new();
359        for row in rows {
360            out.push(row?);
361        }
362        Ok(out)
363    }
364
365    pub fn get_nodes_in_file(&self, file_path: &str) -> Result<Vec<Node>> {
366        let mut stmt = self.conn.prepare("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE file_path = ?1 ORDER BY start_line, start_column")?;
367        let nodes = collect_nodes(stmt.query_map([file_path], node_from_row)?)?;
368        Ok(nodes)
369    }
370
371    pub fn get_incoming_edges(
372        &self,
373        node_id: &str,
374        kinds: Option<&[EdgeKind]>,
375    ) -> Result<Vec<Edge>> {
376        self.get_edges(node_id, EdgeDirection::Incoming, kinds)
377    }
378
379    pub fn get_outgoing_edges(
380        &self,
381        node_id: &str,
382        kinds: Option<&[EdgeKind]>,
383    ) -> Result<Vec<Edge>> {
384        self.get_edges(node_id, EdgeDirection::Outgoing, kinds)
385    }
386
387    pub fn get_file_dependents(&self, file_path: &str) -> Result<Vec<String>> {
388        let mut out = std::collections::BTreeSet::new();
389        for node in self.get_nodes_in_file(file_path)? {
390            let edges = self.get_incoming_edges(
391                &node.id,
392                Some(&[
393                    EdgeKind::Calls,
394                    EdgeKind::References,
395                    EdgeKind::Imports,
396                    EdgeKind::Extends,
397                    EdgeKind::Implements,
398                ]),
399            )?;
400            for edge in edges {
401                if let Some(source) = self.get_node(&edge.source)? {
402                    if source.file_path != file_path {
403                        out.insert(source.file_path);
404                    }
405                }
406            }
407        }
408        Ok(out.into_iter().collect())
409    }
410
411    fn get_edges(
412        &self,
413        node_id: &str,
414        direction: EdgeDirection,
415        kinds: Option<&[EdgeKind]>,
416    ) -> Result<Vec<Edge>> {
417        let column = match direction {
418            EdgeDirection::Incoming => "target",
419            EdgeDirection::Outgoing => "source",
420        };
421        let mut sql = format!(
422            "SELECT id, source, target, kind, line, col, provenance FROM edges WHERE {column} = ?"
423        );
424        if let Some(kinds) = kinds {
425            if !kinds.is_empty() {
426                sql.push_str(" AND kind IN (");
427                sql.push_str(
428                    &std::iter::repeat("?")
429                        .take(kinds.len())
430                        .collect::<Vec<_>>()
431                        .join(","),
432                );
433                sql.push(')');
434            }
435        }
436        sql.push_str(" ORDER BY id");
437
438        let mut values = vec![node_id.to_string()];
439        if let Some(kinds) = kinds {
440            values.extend(kinds.iter().map(|k| k.as_str().to_string()));
441        }
442        let mut stmt = self.conn.prepare(&sql)?;
443        let rows = stmt.query_map(rusqlite::params_from_iter(values.iter()), edge_from_row)?;
444        let mut out = Vec::new();
445        for row in rows {
446            out.push(row?);
447        }
448        Ok(out)
449    }
450}
451
452enum EdgeDirection {
453    Incoming,
454    Outgoing,
455}
456
457fn collect_nodes(
458    rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<Node>>,
459) -> Result<Vec<Node>> {
460    let mut out = Vec::new();
461    for row in rows {
462        out.push(row?);
463    }
464    Ok(out)
465}
466
467fn grouped_counts(conn: &Connection, sql: &str) -> Result<Vec<(String, i64)>> {
468    let mut stmt = conn.prepare(sql)?;
469    let rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?;
470    let mut out = Vec::new();
471    for row in rows {
472        out.push(row?);
473    }
474    Ok(out)
475}
476
477fn node_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Node> {
478    let kind: String = row.get(1)?;
479    let language: String = row.get(5)?;
480    Ok(Node {
481        id: row.get(0)?,
482        kind: parse_kind(&kind),
483        name: row.get(2)?,
484        qualified_name: row.get(3)?,
485        file_path: row.get(4)?,
486        language: Language::from_str(&language).unwrap_or(Language::Unknown),
487        start_line: row.get(6)?,
488        end_line: row.get(7)?,
489        start_column: row.get(8)?,
490        end_column: row.get(9)?,
491        docstring: row.get(10)?,
492        signature: row.get(11)?,
493        visibility: row.get(12)?,
494        is_exported: row.get::<_, i64>(13)? != 0,
495        is_async: row.get::<_, i64>(14)? != 0,
496        is_static: row.get::<_, i64>(15)? != 0,
497        is_abstract: row.get::<_, i64>(16)? != 0,
498        updated_at: row.get(17)?,
499    })
500}
501
502fn edge_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Edge> {
503    let kind: String = row.get(3)?;
504    Ok(Edge {
505        id: row.get(0)?,
506        source: row.get(1)?,
507        target: row.get(2)?,
508        kind: parse_edge_kind(&kind),
509        line: row.get(4)?,
510        col: row.get(5)?,
511        provenance: row.get(6)?,
512    })
513}
514
515fn parse_kind(s: &str) -> NodeKind {
516    match s {
517        "file" => NodeKind::File,
518        "module" => NodeKind::Module,
519        "class" => NodeKind::Class,
520        "struct" => NodeKind::Struct,
521        "interface" => NodeKind::Interface,
522        "trait" => NodeKind::Trait,
523        "protocol" => NodeKind::Protocol,
524        "function" => NodeKind::Function,
525        "method" => NodeKind::Method,
526        "property" => NodeKind::Property,
527        "field" => NodeKind::Field,
528        "variable" => NodeKind::Variable,
529        "constant" => NodeKind::Constant,
530        "enum" => NodeKind::Enum,
531        "enum_member" => NodeKind::EnumMember,
532        "type_alias" => NodeKind::TypeAlias,
533        "namespace" => NodeKind::Namespace,
534        "parameter" => NodeKind::Parameter,
535        "import" => NodeKind::Import,
536        "export" => NodeKind::Export,
537        "route" => NodeKind::Route,
538        "component" => NodeKind::Component,
539        _ => NodeKind::Variable,
540    }
541}
542
543fn parse_edge_kind(s: &str) -> EdgeKind {
544    match s {
545        "contains" => EdgeKind::Contains,
546        "calls" => EdgeKind::Calls,
547        "imports" => EdgeKind::Imports,
548        "exports" => EdgeKind::Exports,
549        "extends" => EdgeKind::Extends,
550        "implements" => EdgeKind::Implements,
551        "references" => EdgeKind::References,
552        "type_of" => EdgeKind::TypeOf,
553        "returns" => EdgeKind::Returns,
554        "instantiates" => EdgeKind::Instantiates,
555        "overrides" => EdgeKind::Overrides,
556        "decorates" => EdgeKind::Decorates,
557        _ => EdgeKind::References,
558    }
559}