use crate::graph::node_index::NodeIndex;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use super::frozen::CodeGraph;
use super::indexes::GraphIndexes;
use super::interner::{global_interner, StrKey};
use super::store_models::{CodeEdge, CodeNode, ExtraProps};
const CODEGRAPH_CACHE_VERSION: u32 = 2;
#[derive(Serialize, Deserialize)]
struct CodeGraphCache {
version: u32,
binary_version: String,
nodes: Vec<CodeNode>,
edges: Vec<(u32, u32, CodeEdge)>,
node_index: HashMap<String, u32>,
string_table: HashMap<u32, String>,
extra_props: Vec<(String, SerializableExtraProps)>,
}
#[derive(Serialize, Deserialize)]
struct SerializableExtraProps {
params: Option<String>,
doc_comment: Option<String>,
decorators: Option<String>,
author: Option<String>,
last_modified: Option<String>,
}
impl CodeGraph {
pub fn save_cache(&self, cache_path: &Path) -> Result<()> {
let i = global_interner();
let mut string_table: HashMap<u32, String> = HashMap::new();
for node in self.nodes() {
for &key in &[
node.name,
node.qualified_name,
node.file_path,
node.language,
] {
let raw = key.as_u32();
string_table
.entry(raw)
.or_insert_with(|| i.resolve(key).to_string());
}
}
let node_index: HashMap<String, u32> = self
.node_index_map()
.iter()
.map(|(&key, &idx)| (i.resolve(key).to_string(), idx.as_u32()))
.collect();
let edges: Vec<(u32, u32, CodeEdge)> = self
.edge_list()
.iter()
.map(|&(src, tgt, e)| (src.as_u32(), tgt.as_u32(), e))
.collect();
let extra_props_ser: Vec<(String, SerializableExtraProps)> = self
.node_index_map()
.keys()
.filter_map(|&qn_key| {
let ep = self.extra_props(qn_key)?;
let qn_str = i.resolve(qn_key).to_string();
let ser = SerializableExtraProps {
params: ep.params.map(|k| i.resolve(k).to_string()),
doc_comment: ep.doc_comment.map(|k| i.resolve(k).to_string()),
decorators: ep.decorators.map(|k| i.resolve(k).to_string()),
author: ep.author.map(|k| i.resolve(k).to_string()),
last_modified: ep.last_modified.map(|k| i.resolve(k).to_string()),
};
Some((qn_str, ser))
})
.collect();
let cache = CodeGraphCache {
version: CODEGRAPH_CACHE_VERSION,
binary_version: env!("CARGO_PKG_VERSION").to_string(),
nodes: self.nodes().to_vec(),
edges,
node_index,
string_table,
extra_props: extra_props_ser,
};
let bytes = bitcode::serialize(&cache).context("Failed to serialize CodeGraph cache")?;
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp_path = cache_path.with_extension("bin.tmp");
std::fs::write(&tmp_path, &bytes).context("Failed to write CodeGraph cache")?;
std::fs::rename(&tmp_path, cache_path).context("Failed to finalize CodeGraph cache")?;
Ok(())
}
pub fn load_cache(cache_path: &Path) -> Option<Self> {
let bytes = std::fs::read(cache_path).ok()?;
let cache: CodeGraphCache = bitcode::deserialize(&bytes).ok()?;
if cache.version != CODEGRAPH_CACHE_VERSION
|| cache.binary_version != env!("CARGO_PKG_VERSION")
{
tracing::info!("CodeGraph cache version mismatch, rebuilding");
return None;
}
let i = global_interner();
let remap: HashMap<u32, StrKey> = cache
.string_table
.iter()
.map(|(&raw, s)| (raw, i.intern(s)))
.collect();
let mut nodes = cache.nodes;
for node in &mut nodes {
if let Some(&new) = remap.get(&node.name.as_u32()) {
node.name = new;
}
if let Some(&new) = remap.get(&node.qualified_name.as_u32()) {
node.qualified_name = new;
}
if let Some(&new) = remap.get(&node.file_path.as_u32()) {
node.file_path = new;
}
if let Some(&new) = remap.get(&node.language.as_u32()) {
node.language = new;
}
}
let node_index: HashMap<StrKey, NodeIndex> = cache
.node_index
.iter()
.map(|(key_str, &idx)| (i.intern(key_str), NodeIndex::new(idx)))
.collect();
let edges: Vec<(NodeIndex, NodeIndex, CodeEdge)> = cache
.edges
.into_iter()
.map(|(src, tgt, e)| (NodeIndex::new(src), NodeIndex::new(tgt), e))
.collect();
let mut extra_props: HashMap<StrKey, ExtraProps> = HashMap::new();
for (qn_str, ser) in cache.extra_props {
let qn_key = i.intern(&qn_str);
let ep = ExtraProps {
params: ser.params.as_deref().map(|s| i.intern(s)),
doc_comment: ser.doc_comment.as_deref().map(|s| i.intern(s)),
decorators: ser.decorators.as_deref().map(|s| i.intern(s)),
author: ser.author.as_deref().map(|s| i.intern(s)),
last_modified: ser.last_modified.as_deref().map(|s| i.intern(s)),
};
extra_props.insert(qn_key, ep);
}
let indexes = GraphIndexes::build_from_vecs(&nodes, &edges, &node_index, None);
tracing::info!("Loaded CodeGraph cache ({} nodes)", node_index.len());
Some(CodeGraph::from_parts(
nodes,
node_index,
edges,
extra_props,
indexes,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::builder::GraphBuilder;
use crate::graph::store_models::NodeKind;
#[test]
fn test_save_load_roundtrip() {
let mut builder = GraphBuilder::new();
builder.add_node(CodeNode::file("a.py"));
let f1 = builder.add_node(CodeNode::function("foo", "a.py").with_lines(1, 10));
let f2 = builder.add_node(CodeNode::function("bar", "a.py").with_lines(12, 20));
builder.add_node(CodeNode::class("MyClass", "a.py"));
builder.add_edge(f1, f2, CodeEdge::calls());
let si = builder.interner();
let qn_key = si.intern("a.py::foo");
let ep = ExtraProps {
author: Some(si.intern("alice")),
..Default::default()
};
builder.set_extra_props(qn_key, ep);
let graph = builder.freeze();
let dir = tempfile::tempdir().unwrap();
let cache_path = dir.path().join("graph.bin");
graph.save_cache(&cache_path).unwrap();
assert!(cache_path.exists());
let loaded = CodeGraph::load_cache(&cache_path).unwrap();
assert_eq!(loaded.node_count(), graph.node_count());
assert_eq!(loaded.edge_count(), graph.edge_count());
assert_eq!(loaded.functions().len(), 2);
assert_eq!(loaded.classes().len(), 1);
assert_eq!(loaded.files().len(), 1);
let si = loaded.interner();
let (_, foo_node) = loaded.node_by_name("a.py::foo").unwrap();
assert_eq!(foo_node.kind, NodeKind::Function);
assert_eq!(foo_node.line_start, 1);
assert_eq!(foo_node.line_end, 10);
let foo_key = si.intern("a.py::foo");
let ep = loaded.extra_props(foo_key).unwrap();
assert_eq!(si.resolve(ep.author.unwrap()), "alice");
let (foo_idx, _) = loaded.node_by_name("a.py::foo").unwrap();
let (bar_idx, _) = loaded.node_by_name("a.py::bar").unwrap();
assert_eq!(loaded.callees(foo_idx).len(), 1);
assert_eq!(loaded.callers(bar_idx).len(), 1);
}
#[test]
fn test_load_missing_file() {
let result = CodeGraph::load_cache(Path::new("/nonexistent/path/graph.bin"));
assert!(result.is_none());
}
#[test]
fn test_load_corrupt_file() {
let dir = tempfile::tempdir().unwrap();
let cache_path = dir.path().join("graph.bin");
std::fs::write(&cache_path, b"not valid bincode").unwrap();
let result = CodeGraph::load_cache(&cache_path);
assert!(result.is_none());
}
}