Skip to main content

codegraph/
db.rs

1use crate::types::*;
2use anyhow::{Context, Result};
3use rusqlite::{params, Connection, OptionalExtension};
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6
7pub struct Database {
8    conn: Connection,
9    path: PathBuf,
10}
11
12impl Database {
13    pub fn initialize(path: impl AsRef<Path>) -> Result<Self> {
14        let db = Self::open_raw(path)?;
15        db.create_schema()?;
16        Ok(db)
17    }
18
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let db = Self::open_raw(path)?;
21        db.create_schema()?;
22        Ok(db)
23    }
24
25    fn open_raw(path: impl AsRef<Path>) -> Result<Self> {
26        let path = path.as_ref().to_path_buf();
27        if let Some(parent) = path.parent() {
28            std::fs::create_dir_all(parent)?;
29        }
30        let conn =
31            Connection::open(&path).with_context(|| format!("opening {}", path.display()))?;
32        conn.pragma_update(None, "foreign_keys", "ON")?;
33        conn.pragma_update(None, "journal_mode", "WAL")?;
34        conn.pragma_update(None, "busy_timeout", 120_000)?;
35        Ok(Self { conn, path })
36    }
37
38    fn create_schema(&self) -> Result<()> {
39        self.conn.execute_batch(
40            r#"
41            CREATE TABLE IF NOT EXISTS schema_versions (
42                version INTEGER PRIMARY KEY,
43                applied_at INTEGER NOT NULL,
44                description TEXT
45            );
46            INSERT OR IGNORE INTO schema_versions (version, applied_at, description)
47            VALUES (1, strftime('%s', 'now') * 1000, 'Rust schema');
48
49            CREATE TABLE IF NOT EXISTS nodes (
50                id TEXT PRIMARY KEY,
51                kind TEXT NOT NULL,
52                name TEXT NOT NULL,
53                qualified_name TEXT NOT NULL,
54                file_path TEXT NOT NULL,
55                language TEXT NOT NULL,
56                start_line INTEGER NOT NULL,
57                end_line INTEGER NOT NULL,
58                start_column INTEGER NOT NULL,
59                end_column INTEGER NOT NULL,
60                docstring TEXT,
61                signature TEXT,
62                visibility TEXT,
63                is_exported INTEGER DEFAULT 0,
64                is_async INTEGER DEFAULT 0,
65                is_static INTEGER DEFAULT 0,
66                is_abstract INTEGER DEFAULT 0,
67                decorators TEXT,
68                type_parameters TEXT,
69                updated_at INTEGER NOT NULL
70            );
71
72            CREATE TABLE IF NOT EXISTS edges (
73                id INTEGER PRIMARY KEY AUTOINCREMENT,
74                source TEXT NOT NULL,
75                target TEXT NOT NULL,
76                kind TEXT NOT NULL,
77                metadata TEXT,
78                line INTEGER,
79                col INTEGER,
80                provenance TEXT DEFAULT NULL,
81                FOREIGN KEY (source) REFERENCES nodes(id) ON DELETE CASCADE,
82                FOREIGN KEY (target) REFERENCES nodes(id) ON DELETE CASCADE
83            );
84
85            CREATE TABLE IF NOT EXISTS files (
86                path TEXT PRIMARY KEY,
87                content_hash TEXT NOT NULL,
88                language TEXT NOT NULL,
89                size INTEGER NOT NULL,
90                modified_at INTEGER NOT NULL,
91                indexed_at INTEGER NOT NULL,
92                node_count INTEGER DEFAULT 0,
93                errors TEXT
94            );
95
96            CREATE TABLE IF NOT EXISTS unresolved_refs (
97                id INTEGER PRIMARY KEY AUTOINCREMENT,
98                from_node_id TEXT NOT NULL,
99                reference_name TEXT NOT NULL,
100                reference_kind TEXT NOT NULL,
101                line INTEGER NOT NULL,
102                col INTEGER NOT NULL,
103                candidates TEXT,
104                file_path TEXT NOT NULL DEFAULT '',
105                language TEXT NOT NULL DEFAULT 'unknown',
106                FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE
107            );
108
109            CREATE INDEX IF NOT EXISTS idx_nodes_kind ON nodes(kind);
110            CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);
111            CREATE INDEX IF NOT EXISTS idx_nodes_file_path ON nodes(file_path);
112            CREATE INDEX IF NOT EXISTS idx_nodes_language ON nodes(language);
113            CREATE INDEX IF NOT EXISTS idx_edges_kind ON edges(kind);
114            CREATE INDEX IF NOT EXISTS idx_edges_source_kind ON edges(source, kind);
115            CREATE INDEX IF NOT EXISTS idx_edges_target_kind ON edges(target, kind);
116            CREATE INDEX IF NOT EXISTS idx_files_language ON files(language);
117            CREATE INDEX IF NOT EXISTS idx_unresolved_name ON unresolved_refs(reference_name);
118            "#,
119        )?;
120        Ok(())
121    }
122
123    pub fn clear_all(&self) -> Result<()> {
124        self.conn.execute_batch(
125            "DELETE FROM edges; DELETE FROM unresolved_refs; DELETE FROM nodes; DELETE FROM files;",
126        )?;
127        Ok(())
128    }
129
130    pub fn insert_file(&self, file: &FileRecord) -> Result<()> {
131        self.conn.execute(
132            "INSERT OR REPLACE INTO files (path, content_hash, language, size, modified_at, indexed_at, node_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
133            params![file.path, file.content_hash, file.language.as_str(), file.size as i64, file.modified_at, file.indexed_at, file.node_count],
134        )?;
135        Ok(())
136    }
137
138    pub fn insert_nodes(&self, nodes: &[Node]) -> Result<()> {
139        let mut stmt = self.conn.prepare(
140            "INSERT OR REPLACE INTO nodes (id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, decorators, type_parameters, updated_at)
141             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, NULL, NULL, ?18)"
142        )?;
143        for n in nodes {
144            stmt.execute(params![
145                n.id,
146                n.kind.as_str(),
147                n.name,
148                n.qualified_name,
149                n.file_path,
150                n.language.as_str(),
151                n.start_line,
152                n.end_line,
153                n.start_column,
154                n.end_column,
155                n.docstring,
156                n.signature,
157                n.visibility,
158                n.is_exported as i64,
159                n.is_async as i64,
160                n.is_static as i64,
161                n.is_abstract as i64,
162                n.updated_at
163            ])?;
164        }
165        Ok(())
166    }
167
168    pub fn insert_edges(&self, edges: &[Edge]) -> Result<()> {
169        let mut stmt = self.conn.prepare("INSERT INTO edges (source, target, kind, line, col, provenance) VALUES (?1, ?2, ?3, ?4, ?5, ?6)")?;
170        for e in edges {
171            stmt.execute(params![
172                e.source,
173                e.target,
174                e.kind.as_str(),
175                e.line,
176                e.col,
177                e.provenance
178            ])?;
179        }
180        Ok(())
181    }
182
183    pub fn insert_unresolved_refs(&self, refs: &[UnresolvedReference]) -> Result<()> {
184        let mut stmt = self.conn.prepare(
185            "INSERT INTO unresolved_refs (from_node_id, reference_name, reference_kind, line, col, file_path, language) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
186        )?;
187        for r in refs {
188            stmt.execute(params![
189                r.from_node_id,
190                r.reference_name,
191                r.reference_kind.as_str(),
192                r.line,
193                r.column,
194                r.file_path,
195                r.language.as_str()
196            ])?;
197        }
198        Ok(())
199    }
200
201    pub fn resolve_references_by_name(&self) -> Result<()> {
202        let mut refs = self.conn.prepare("SELECT from_node_id, reference_name, reference_kind, line, col, language FROM unresolved_refs")?;
203        let rows = refs.query_map([], |row| {
204            Ok((
205                row.get::<_, String>(0)?,
206                row.get::<_, String>(1)?,
207                row.get::<_, String>(2)?,
208                row.get::<_, Option<i64>>(3)?,
209                row.get::<_, Option<i64>>(4)?,
210                row.get::<_, String>(5)?,
211            ))
212        })?;
213        for row in rows {
214            let (from, name, kind, line, col, lang) = row?;
215            let target: Option<String> = self.conn.query_row(
216                "SELECT id FROM nodes WHERE name = ?1 AND language = ?2 AND id != ?3 ORDER BY CASE kind WHEN 'function' THEN 0 WHEN 'method' THEN 1 WHEN 'struct' THEN 2 WHEN 'trait' THEN 3 ELSE 9 END LIMIT 1",
217                params![name, lang, from],
218                |row| row.get(0),
219            ).optional()?;
220            if let Some(target) = target {
221                self.conn.execute(
222                    "INSERT INTO edges (source, target, kind, line, col, provenance) VALUES (?1, ?2, ?3, ?4, ?5, 'heuristic')",
223                    params![from, target, kind, line, col],
224                )?;
225            }
226        }
227        Ok(())
228    }
229
230    pub fn edge_count(&self) -> Result<i64> {
231        Ok(self
232            .conn
233            .query_row("SELECT COUNT(*) FROM edges", [], |r| r.get(0))?)
234    }
235
236    pub fn stats(&self) -> Result<GraphStats> {
237        let file_count = self
238            .conn
239            .query_row("SELECT COUNT(*) FROM files", [], |r| r.get(0))?;
240        let node_count = self
241            .conn
242            .query_row("SELECT COUNT(*) FROM nodes", [], |r| r.get(0))?;
243        let edge_count = self
244            .conn
245            .query_row("SELECT COUNT(*) FROM edges", [], |r| r.get(0))?;
246        let db_size_bytes = std::fs::metadata(&self.path)
247            .map(|m| m.len() as i64)
248            .unwrap_or_default();
249        let files_by_language = grouped_counts(
250            &self.conn,
251            "SELECT language, COUNT(*) FROM files GROUP BY language",
252        )?;
253        let nodes_by_kind =
254            grouped_counts(&self.conn, "SELECT kind, COUNT(*) FROM nodes GROUP BY kind")?;
255        Ok(GraphStats {
256            file_count,
257            node_count,
258            edge_count,
259            db_size_bytes,
260            files_by_language,
261            nodes_by_kind,
262        })
263    }
264
265    pub fn search_nodes(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
266        let limit = if options.limit <= 0 {
267            10
268        } else {
269            options.limit
270        };
271        let pattern = format!("%{}%", query);
272        let exact = query.to_string();
273        let prefix = format!("{}%", query);
274
275        let base = "SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes";
276        let order = " ORDER BY CASE WHEN name = ? THEN 0 WHEN name LIKE ? THEN 1 ELSE 2 END, length(name) LIMIT ?";
277
278        let rows = match (options.kind, options.language) {
279            (Some(k), Some(l)) => {
280                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ? OR file_path LIKE ?) AND kind = ? AND language = ?{order}");
281                let mut stmt = self.conn.prepare(&sql)?;
282                let nodes = collect_nodes(stmt.query_map(
283                    params![
284                        pattern,
285                        pattern,
286                        pattern,
287                        pattern,
288                        k.as_str(),
289                        l.as_str(),
290                        exact,
291                        prefix,
292                        limit
293                    ],
294                    node_from_row,
295                )?)?;
296                nodes
297            }
298            (Some(k), None) => {
299                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ? OR file_path LIKE ?) AND kind = ?{order}");
300                let mut stmt = self.conn.prepare(&sql)?;
301                let nodes = collect_nodes(stmt.query_map(
302                    params![
303                        pattern,
304                        pattern,
305                        pattern,
306                        pattern,
307                        k.as_str(),
308                        exact,
309                        prefix,
310                        limit
311                    ],
312                    node_from_row,
313                )?)?;
314                nodes
315            }
316            (None, Some(l)) => {
317                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ? OR file_path LIKE ?) AND language = ?{order}");
318                let mut stmt = self.conn.prepare(&sql)?;
319                let nodes = collect_nodes(stmt.query_map(
320                    params![
321                        pattern,
322                        pattern,
323                        pattern,
324                        pattern,
325                        l.as_str(),
326                        exact,
327                        prefix,
328                        limit
329                    ],
330                    node_from_row,
331                )?)?;
332                nodes
333            }
334            (None, None) => {
335                let sql = format!("{base} WHERE (name LIKE ? OR qualified_name LIKE ? OR signature LIKE ? OR file_path LIKE ?){order}");
336                let mut stmt = self.conn.prepare(&sql)?;
337                let nodes = collect_nodes(stmt.query_map(
338                    params![pattern, pattern, pattern, pattern, exact, prefix, limit],
339                    node_from_row,
340                )?)?;
341                nodes
342            }
343        };
344        Ok(rows
345            .into_iter()
346            .map(|node| SearchResult { node, score: 1.0 })
347            .collect())
348    }
349
350    pub fn get_node(&self, id: &str) -> Result<Option<Node>> {
351        self.conn
352            .query_row("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE id = ?1", [id], node_from_row)
353            .optional()
354            .map_err(Into::into)
355    }
356
357    pub fn get_nodes_by_name(&self, name: &str, limit: i64) -> Result<Vec<Node>> {
358        let mut stmt = self.conn.prepare("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE name = ?1 ORDER BY file_path, start_line LIMIT ?2")?;
359        let nodes = collect_nodes(stmt.query_map(params![name, limit], node_from_row)?)?;
360        Ok(nodes)
361    }
362
363    pub fn get_all_files(&self) -> Result<Vec<FileRecord>> {
364        let mut stmt = self.conn.prepare("SELECT path, content_hash, language, size, modified_at, indexed_at, node_count FROM files ORDER BY path")?;
365        let rows = stmt.query_map([], |row| {
366            let language: String = row.get(2)?;
367            Ok(FileRecord {
368                path: row.get(0)?,
369                content_hash: row.get(1)?,
370                language: Language::from_str(&language).unwrap_or(Language::Unknown),
371                size: row.get::<_, i64>(3)? as u64,
372                modified_at: row.get(4)?,
373                indexed_at: row.get(5)?,
374                node_count: row.get(6)?,
375            })
376        })?;
377        let mut out = Vec::new();
378        for row in rows {
379            out.push(row?);
380        }
381        Ok(out)
382    }
383
384    pub fn get_nodes_in_file(&self, file_path: &str) -> Result<Vec<Node>> {
385        let mut stmt = self.conn.prepare("SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, start_column, end_column, docstring, signature, visibility, is_exported, is_async, is_static, is_abstract, updated_at FROM nodes WHERE file_path = ?1 ORDER BY start_line, start_column")?;
386        let nodes = collect_nodes(stmt.query_map([file_path], node_from_row)?)?;
387        Ok(nodes)
388    }
389
390    pub fn get_incoming_edges(
391        &self,
392        node_id: &str,
393        kinds: Option<&[EdgeKind]>,
394    ) -> Result<Vec<Edge>> {
395        self.get_edges(node_id, EdgeDirection::Incoming, kinds)
396    }
397
398    pub fn get_outgoing_edges(
399        &self,
400        node_id: &str,
401        kinds: Option<&[EdgeKind]>,
402    ) -> Result<Vec<Edge>> {
403        self.get_edges(node_id, EdgeDirection::Outgoing, kinds)
404    }
405
406    pub fn get_file_dependents(&self, file_path: &str) -> Result<Vec<String>> {
407        let mut out = std::collections::BTreeSet::new();
408        for node in self.get_nodes_in_file(file_path)? {
409            let edges = self.get_incoming_edges(
410                &node.id,
411                Some(&[
412                    EdgeKind::Calls,
413                    EdgeKind::References,
414                    EdgeKind::Imports,
415                    EdgeKind::Extends,
416                    EdgeKind::Implements,
417                ]),
418            )?;
419            for edge in edges {
420                if let Some(source) = self.get_node(&edge.source)? {
421                    if source.file_path != file_path {
422                        out.insert(source.file_path);
423                    }
424                }
425            }
426        }
427        Ok(out.into_iter().collect())
428    }
429
430    fn get_edges(
431        &self,
432        node_id: &str,
433        direction: EdgeDirection,
434        kinds: Option<&[EdgeKind]>,
435    ) -> Result<Vec<Edge>> {
436        let column = match direction {
437            EdgeDirection::Incoming => "target",
438            EdgeDirection::Outgoing => "source",
439        };
440        let mut sql = format!(
441            "SELECT id, source, target, kind, line, col, provenance FROM edges WHERE {column} = ?"
442        );
443        if let Some(kinds) = kinds {
444            if !kinds.is_empty() {
445                sql.push_str(" AND kind IN (");
446                sql.push_str(
447                    &std::iter::repeat("?")
448                        .take(kinds.len())
449                        .collect::<Vec<_>>()
450                        .join(","),
451                );
452                sql.push(')');
453            }
454        }
455        sql.push_str(" ORDER BY id");
456
457        let mut values = vec![node_id.to_string()];
458        if let Some(kinds) = kinds {
459            values.extend(kinds.iter().map(|k| k.as_str().to_string()));
460        }
461        let mut stmt = self.conn.prepare(&sql)?;
462        let rows = stmt.query_map(rusqlite::params_from_iter(values.iter()), edge_from_row)?;
463        let mut out = Vec::new();
464        for row in rows {
465            out.push(row?);
466        }
467        Ok(out)
468    }
469}
470
471enum EdgeDirection {
472    Incoming,
473    Outgoing,
474}
475
476fn collect_nodes(
477    rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<Node>>,
478) -> Result<Vec<Node>> {
479    let mut out = Vec::new();
480    for row in rows {
481        out.push(row?);
482    }
483    Ok(out)
484}
485
486fn grouped_counts(conn: &Connection, sql: &str) -> Result<Vec<(String, i64)>> {
487    let mut stmt = conn.prepare(sql)?;
488    let rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?;
489    let mut out = Vec::new();
490    for row in rows {
491        out.push(row?);
492    }
493    Ok(out)
494}
495
496fn node_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Node> {
497    let kind: String = row.get(1)?;
498    let language: String = row.get(5)?;
499    Ok(Node {
500        id: row.get(0)?,
501        kind: parse_kind(&kind),
502        name: row.get(2)?,
503        qualified_name: row.get(3)?,
504        file_path: row.get(4)?,
505        language: Language::from_str(&language).unwrap_or(Language::Unknown),
506        start_line: row.get(6)?,
507        end_line: row.get(7)?,
508        start_column: row.get(8)?,
509        end_column: row.get(9)?,
510        docstring: row.get(10)?,
511        signature: row.get(11)?,
512        visibility: row.get(12)?,
513        is_exported: row.get::<_, i64>(13)? != 0,
514        is_async: row.get::<_, i64>(14)? != 0,
515        is_static: row.get::<_, i64>(15)? != 0,
516        is_abstract: row.get::<_, i64>(16)? != 0,
517        updated_at: row.get(17)?,
518    })
519}
520
521fn edge_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Edge> {
522    let kind: String = row.get(3)?;
523    Ok(Edge {
524        id: row.get(0)?,
525        source: row.get(1)?,
526        target: row.get(2)?,
527        kind: parse_edge_kind(&kind),
528        line: row.get(4)?,
529        col: row.get(5)?,
530        provenance: row.get(6)?,
531    })
532}
533
534fn parse_kind(s: &str) -> NodeKind {
535    match s {
536        "file" => NodeKind::File,
537        "module" => NodeKind::Module,
538        "class" => NodeKind::Class,
539        "struct" => NodeKind::Struct,
540        "interface" => NodeKind::Interface,
541        "trait" => NodeKind::Trait,
542        "protocol" => NodeKind::Protocol,
543        "function" => NodeKind::Function,
544        "method" => NodeKind::Method,
545        "property" => NodeKind::Property,
546        "field" => NodeKind::Field,
547        "variable" => NodeKind::Variable,
548        "constant" => NodeKind::Constant,
549        "enum" => NodeKind::Enum,
550        "enum_member" => NodeKind::EnumMember,
551        "type_alias" => NodeKind::TypeAlias,
552        "namespace" => NodeKind::Namespace,
553        "parameter" => NodeKind::Parameter,
554        "import" => NodeKind::Import,
555        "export" => NodeKind::Export,
556        "route" => NodeKind::Route,
557        "component" => NodeKind::Component,
558        _ => NodeKind::Variable,
559    }
560}
561
562fn parse_edge_kind(s: &str) -> EdgeKind {
563    match s {
564        "contains" => EdgeKind::Contains,
565        "calls" => EdgeKind::Calls,
566        "imports" => EdgeKind::Imports,
567        "exports" => EdgeKind::Exports,
568        "extends" => EdgeKind::Extends,
569        "implements" => EdgeKind::Implements,
570        "references" => EdgeKind::References,
571        "type_of" => EdgeKind::TypeOf,
572        "returns" => EdgeKind::Returns,
573        "instantiates" => EdgeKind::Instantiates,
574        "overrides" => EdgeKind::Overrides,
575        "decorates" => EdgeKind::Decorates,
576        _ => EdgeKind::References,
577    }
578}