use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::io;
use std::path::PathBuf;
use super::cross_file_types::{
CallGraphIR, CallSite, ClassDef, FileIR, FuncDef, ImportDef,
ProjectCallGraphV2, VarType,
};
pub const IR_VERSION: &str = "1.0";
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
#[error("IR version mismatch: expected {expected}, got {actual}")]
IRVersionMismatch {
expected: String,
actual: String,
},
#[error("Invalid JSON format: {0}")]
InvalidFormat(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("IO error: {0}")]
Io(#[from] io::Error),
}
#[derive(Serialize, Deserialize)]
struct CallGraphIRJson {
#[serde(rename = "_version")]
version: String,
root: String,
language: String,
files: HashMap<String, FileIRJson>,
}
#[derive(Serialize, Deserialize)]
struct FileIRJson {
path: String,
#[serde(default)]
functions: Vec<FuncDef>,
#[serde(default)]
classes: Vec<ClassDef>,
#[serde(default)]
imports: Vec<ImportDef>,
#[serde(default)]
calls: Vec<CallSite>,
#[serde(default)]
var_types: Vec<VarType>,
}
impl CallGraphIR {
pub fn to_json(&self) -> Result<String, SerializationError> {
let value = self.to_json_value();
serde_json::to_string(&value).map_err(SerializationError::Json)
}
pub fn to_json_value(&self) -> Value {
let json_ir = self.to_json_representation();
serde_json::to_value(&json_ir).expect("CallGraphIR should serialize to JSON")
}
pub fn from_json(json: &str) -> Result<Self, SerializationError> {
let value: Value = serde_json::from_str(json).map_err(|e| {
if json.contains("not valid") || !json.trim().starts_with('{') {
SerializationError::InvalidFormat(e.to_string())
} else {
SerializationError::Json(e)
}
})?;
Self::from_json_value(value)
}
pub fn from_json_value(value: Value) -> Result<Self, SerializationError> {
let version = value
.get("_version")
.and_then(|v| v.as_str())
.ok_or_else(|| SerializationError::MissingField("_version".to_string()))?;
if version != IR_VERSION {
return Err(SerializationError::IRVersionMismatch {
expected: IR_VERSION.to_string(),
actual: version.to_string(),
});
}
let json_ir: CallGraphIRJson =
serde_json::from_value(value).map_err(SerializationError::Json)?;
Ok(Self::from_json_representation(json_ir))
}
fn to_json_representation(&self) -> CallGraphIRJson {
let mut files = HashMap::new();
for (path, file_ir) in &self.files {
let path_str = normalize_path_string(&path.to_string_lossy());
let file_json = FileIRJson {
path: path_str.clone(),
functions: file_ir.funcs.clone(),
classes: file_ir.classes.clone(),
imports: file_ir.imports.clone(),
calls: flatten_calls(&file_ir.calls),
var_types: file_ir.var_types.clone(),
};
files.insert(path_str, file_json);
}
CallGraphIRJson {
version: IR_VERSION.to_string(),
root: normalize_path_string(&self.root.to_string_lossy()),
language: self.language.clone(),
files,
}
}
fn from_json_representation(json_ir: CallGraphIRJson) -> Self {
let mut ir = Self::new(PathBuf::from(&json_ir.root), &json_ir.language);
for (_path_key, file_json) in json_ir.files {
let file_ir = FileIR {
path: PathBuf::from(&file_json.path),
funcs: file_json.functions,
classes: file_json.classes,
imports: file_json.imports,
var_types: file_json.var_types,
calls: unflatten_calls(file_json.calls),
};
ir.add_file(file_ir);
}
ir.build_indices();
ir
}
}
impl ProjectCallGraphV2 {
pub fn edges_to_json(&self) -> Value {
let mut edges: Vec<_> = self.edges().collect();
edges.sort_by(|a, b| {
(&a.src_file, &a.src_func, &a.dst_file, &a.dst_func).cmp(&(
&b.src_file,
&b.src_func,
&b.dst_file,
&b.dst_func,
))
});
Value::Array(
edges
.into_iter()
.map(|e| {
serde_json::json!([
normalize_path_string(&e.src_file.to_string_lossy()),
&e.src_func,
normalize_path_string(&e.dst_file.to_string_lossy()),
&e.dst_func
])
})
.collect(),
)
}
}
fn normalize_path_string(path: &str) -> String {
path.replace('\\', "/")
}
fn flatten_calls(calls: &HashMap<String, Vec<CallSite>>) -> Vec<CallSite> {
let mut result: Vec<CallSite> = calls.values().flatten().cloned().collect();
result.sort_by(|a, b| (&a.caller, &a.target, &a.line).cmp(&(&b.caller, &b.target, &b.line)));
result
}
fn unflatten_calls(calls: Vec<CallSite>) -> HashMap<String, Vec<CallSite>> {
let mut result: HashMap<String, Vec<CallSite>> = HashMap::new();
for call in calls {
result.entry(call.caller.clone()).or_default().push(call);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_path_string() {
assert_eq!(normalize_path_string("src\\main.py"), "src/main.py");
assert_eq!(normalize_path_string("src/main.py"), "src/main.py");
assert_eq!(
normalize_path_string("C:\\Users\\test\\project"),
"C:/Users/test/project"
);
}
#[test]
fn test_ir_version_constant() {
assert_eq!(IR_VERSION, "1.0");
}
#[test]
fn test_error_display() {
let err = SerializationError::IRVersionMismatch {
expected: "1.0".to_string(),
actual: "0.5".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("1.0"));
assert!(msg.contains("0.5"));
}
}