use anyhow::Result;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct SymbolId(pub String);
impl SymbolId {
pub fn new(module_path: &str, name: &str) -> Self {
if module_path.is_empty() {
Self(name.to_string())
} else {
Self(format!("{module_path}::{name}"))
}
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for SymbolId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SymbolKind {
Function,
Method,
Class,
Struct,
Trait,
Impl,
Import,
TypeAlias,
Const,
Test,
TestSuite,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SymbolEntry {
pub id: SymbolId,
pub kind: SymbolKind,
pub source: String,
pub content_hash: String,
pub language: String,
pub dependencies: BTreeSet<SymbolId>,
pub assigned_file: Option<PathBuf>,
pub test_covers: Option<SymbolId>,
}
impl SymbolEntry {
pub fn new(id: SymbolId, kind: SymbolKind, source: String, language: &str) -> Self {
let content_hash = SymbolRegistry::content_hash(&source);
Self {
id,
kind,
source,
content_hash,
language: language.to_string(),
dependencies: BTreeSet::new(),
assigned_file: None,
test_covers: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct RegistryState {
version: u32,
symbols: Vec<SymbolEntry>,
}
pub struct SymbolRegistry {
entries: IndexMap<SymbolId, SymbolEntry>,
pub project_root: PathBuf,
}
impl SymbolRegistry {
const REGISTRY_PATH: &'static str = ".open-mpm/state/symbol-registry.json";
const VERSION: u32 = 1;
pub fn new(project_root: PathBuf) -> Self {
Self {
entries: IndexMap::new(),
project_root,
}
}
pub fn content_hash(source: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(source.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn insert(&mut self, entry: SymbolEntry) {
self.entries.insert(entry.id.clone(), entry);
self.entries.sort_keys();
}
pub fn remove(&mut self, id: &SymbolId) -> Option<SymbolEntry> {
self.entries.shift_remove(id)
}
pub fn get(&self, id: &SymbolId) -> Option<&SymbolEntry> {
self.entries.get(id)
}
pub fn iter(&self) -> impl Iterator<Item = (&SymbolId, &SymbolEntry)> {
self.entries.iter()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn verify_hashes(&self) -> Vec<SymbolId> {
self.entries
.iter()
.filter(|(_, e)| e.content_hash != Self::content_hash(&e.source))
.map(|(id, _)| id.clone())
.collect()
}
pub fn registry_path(&self) -> PathBuf {
self.project_root.join(Self::REGISTRY_PATH)
}
pub fn load(project_root: &Path) -> Result<Self> {
let path = project_root.join(Self::REGISTRY_PATH);
if !path.exists() {
return Ok(Self::new(project_root.to_path_buf()));
}
let json = std::fs::read_to_string(&path)?;
let state: RegistryState = serde_json::from_str(&json)?;
let mut registry = Self::new(project_root.to_path_buf());
for entry in state.symbols {
registry.entries.insert(entry.id.clone(), entry);
}
registry.entries.sort_keys();
Ok(registry)
}
pub fn save(&self) -> Result<()> {
let path = self.registry_path();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let symbols: Vec<SymbolEntry> = self.entries.values().cloned().collect();
let state = RegistryState {
version: Self::VERSION,
symbols,
};
let json = serde_json::to_string_pretty(&state)?;
std::fs::write(&path, json)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_symbol_id_new_with_module() {
let id = SymbolId::new("api::handlers", "process_request");
assert_eq!(id.as_str(), "api::handlers::process_request");
}
#[test]
fn test_symbol_id_root_module() {
let id = SymbolId::new("", "main");
assert_eq!(id.as_str(), "main");
}
#[test]
fn test_content_hash_stable() {
let h1 = SymbolRegistry::content_hash("fn foo() {}");
let h2 = SymbolRegistry::content_hash("fn foo() {}");
assert_eq!(h1, h2);
assert_ne!(h1, SymbolRegistry::content_hash("fn bar() {}"));
}
#[test]
fn test_registry_sorted_on_insert() {
let tmp = TempDir::new().unwrap();
let mut reg = SymbolRegistry::new(tmp.path().to_path_buf());
let e_z = SymbolEntry::new(
SymbolId::new("", "z_func"),
SymbolKind::Function,
"fn z_func() {}".into(),
"rust",
);
let e_a = SymbolEntry::new(
SymbolId::new("", "a_func"),
SymbolKind::Function,
"fn a_func() {}".into(),
"rust",
);
reg.insert(e_z);
reg.insert(e_a);
let ids: Vec<&str> = reg.iter().map(|(id, _)| id.as_str()).collect();
assert_eq!(ids, vec!["a_func", "z_func"]);
}
#[test]
fn test_registry_save_load_roundtrip() {
let tmp = TempDir::new().unwrap();
let mut reg = SymbolRegistry::new(tmp.path().to_path_buf());
reg.insert(SymbolEntry::new(
SymbolId::new("mod", "foo"),
SymbolKind::Function,
"fn foo() {}".into(),
"rust",
));
reg.save().unwrap();
let loaded = SymbolRegistry::load(tmp.path()).unwrap();
assert_eq!(loaded.len(), 1);
assert!(loaded.get(&SymbolId::new("mod", "foo")).is_some());
}
#[test]
fn test_verify_hashes_detects_mismatch() {
let tmp = TempDir::new().unwrap();
let mut reg = SymbolRegistry::new(tmp.path().to_path_buf());
let mut entry = SymbolEntry::new(
SymbolId::new("", "foo"),
SymbolKind::Function,
"fn foo() {}".into(),
"rust",
);
entry.content_hash = "badhash".into();
reg.insert(entry);
let stale = reg.verify_hashes();
assert_eq!(stale.len(), 1);
}
}