use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum VersionControlError {
#[error("Commit not found: {0}")]
CommitNotFound(String),
#[error("Branch not found: {0}")]
BranchNotFound(String),
#[error("Branch already exists: {0}")]
BranchAlreadyExists(String),
#[error("Merge conflict in layer: {0}")]
MergeConflict(String),
#[error("Invalid commit ID: {0}")]
InvalidCommitId(String),
#[error("Cannot merge: {0}")]
CannotMerge(String),
#[error("Detached HEAD state")]
DetachedHead,
#[error("No parent commit")]
NoParentCommit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCommit {
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub id: Cid,
#[serde(serialize_with = "serialize_cid_vec")]
#[serde(deserialize_with = "deserialize_cid_vec")]
pub parents: Vec<Cid>,
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub model: Cid,
pub message: String,
pub author: String,
pub timestamp: i64,
pub metadata: HashMap<String, String>,
}
impl ModelCommit {
pub fn new(id: Cid, parents: Vec<Cid>, model: Cid, message: String, author: String) -> Self {
Self {
id,
parents,
model,
message,
author,
timestamp: chrono::Utc::now().timestamp(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn is_merge(&self) -> bool {
self.parents.len() > 1
}
pub fn is_initial(&self) -> bool {
self.parents.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Branch {
pub name: String,
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub head: Cid,
pub description: Option<String>,
}
impl Branch {
pub fn new(name: String, head: Cid) -> Self {
Self {
name,
head,
description: None,
}
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
}
#[derive(Debug, Clone)]
pub struct ModelRepository {
commits: HashMap<String, ModelCommit>,
branches: HashMap<String, Branch>,
current_branch: Option<String>,
head: Option<Cid>,
}
impl ModelRepository {
pub fn new() -> Self {
Self {
commits: HashMap::new(),
branches: HashMap::new(),
current_branch: None,
head: None,
}
}
pub fn init(&mut self, initial_commit: ModelCommit) -> Result<(), VersionControlError> {
let commit_id = initial_commit.id.to_string();
let commit_cid = initial_commit.id;
self.commits.insert(commit_id, initial_commit);
let main_branch = Branch::new("main".to_string(), commit_cid);
self.branches.insert("main".to_string(), main_branch);
self.current_branch = Some("main".to_string());
self.head = Some(commit_cid);
Ok(())
}
pub fn commit(
&mut self,
model: Cid,
message: String,
author: String,
) -> Result<ModelCommit, VersionControlError> {
let parents = if let Some(head) = self.head {
vec![head]
} else {
vec![]
};
let commit_id = Cid::default();
let commit = ModelCommit::new(commit_id, parents, model, message, author);
self.commits.insert(commit_id.to_string(), commit.clone());
self.head = Some(commit_id);
if let Some(branch_name) = &self.current_branch {
if let Some(branch) = self.branches.get_mut(branch_name) {
branch.head = commit_id;
}
}
Ok(commit)
}
pub fn checkout(&mut self, target: &str) -> Result<(), VersionControlError> {
if let Some(branch) = self.branches.get(target) {
self.current_branch = Some(target.to_string());
self.head = Some(branch.head);
return Ok(());
}
if let Some(commit) = self.commits.get(target) {
self.current_branch = None; self.head = Some(commit.id);
return Ok(());
}
Err(VersionControlError::CommitNotFound(target.to_string()))
}
pub fn create_branch(
&mut self,
name: String,
start_point: Option<Cid>,
) -> Result<(), VersionControlError> {
if self.branches.contains_key(&name) {
return Err(VersionControlError::BranchAlreadyExists(name));
}
let head = start_point
.or(self.head)
.ok_or(VersionControlError::NoParentCommit)?;
let branch = Branch::new(name.clone(), head);
self.branches.insert(name, branch);
Ok(())
}
pub fn delete_branch(&mut self, name: &str) -> Result<(), VersionControlError> {
if !self.branches.contains_key(name) {
return Err(VersionControlError::BranchNotFound(name.to_string()));
}
if let Some(current) = &self.current_branch {
if current == name {
return Err(VersionControlError::CannotMerge(
"Cannot delete current branch".to_string(),
));
}
}
self.branches.remove(name);
Ok(())
}
pub fn list_branches(&self) -> Vec<&Branch> {
self.branches.values().collect()
}
pub fn current_branch(&self) -> Option<&str> {
self.current_branch.as_deref()
}
pub fn head_commit(&self) -> Option<&ModelCommit> {
self.head
.as_ref()
.and_then(|cid| self.commits.get(&cid.to_string()))
}
pub fn get_commit(&self, commit_id: &str) -> Option<&ModelCommit> {
self.commits.get(commit_id)
}
pub fn get_history(&self, start: &Cid, max_count: Option<usize>) -> Vec<&ModelCommit> {
let mut history = Vec::new();
let mut visited = HashSet::new();
let mut queue = vec![start];
while let Some(cid) = queue.pop() {
if visited.contains(cid) {
continue;
}
if let Some(max) = max_count {
if history.len() >= max {
break;
}
}
if let Some(commit) = self.commits.get(&cid.to_string()) {
visited.insert(*cid);
history.push(commit);
for parent in &commit.parents {
if !visited.contains(parent) {
queue.push(parent);
}
}
}
}
history
}
pub fn merge_fast_forward(&mut self, branch: &str) -> Result<(), VersionControlError> {
let target_branch = self
.branches
.get(branch)
.ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
let target_head = target_branch.head;
if let Some(current_name) = &self.current_branch {
if let Some(current_branch) = self.branches.get_mut(current_name) {
current_branch.head = target_head;
}
} else {
return Err(VersionControlError::DetachedHead);
}
self.head = Some(target_head);
Ok(())
}
pub fn can_fast_forward(&self, branch: &str) -> Result<bool, VersionControlError> {
let target_branch = self
.branches
.get(branch)
.ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
let current_head = self.head.ok_or(VersionControlError::NoParentCommit)?;
let history = self.get_history(&target_branch.head, None);
Ok(history.iter().any(|c| c.id == current_head))
}
}
impl Default for ModelRepository {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDiff {
pub added_layers: Vec<String>,
pub removed_layers: Vec<String>,
pub modified_layers: Vec<LayerDiff>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerDiff {
pub name: String,
pub shape_changed: bool,
pub old_shape: Vec<usize>,
pub new_shape: Vec<usize>,
pub l2_diff: f32,
pub max_diff: f32,
}
pub struct ModelDiffer;
impl ModelDiffer {
pub fn diff(
model_a: &HashMap<String, Vec<f32>>,
model_b: &HashMap<String, Vec<f32>>,
) -> ModelDiff {
let mut added_layers = Vec::new();
let mut removed_layers = Vec::new();
let mut modified_layers = Vec::new();
let keys_a: HashSet<_> = model_a.keys().collect();
let keys_b: HashSet<_> = model_b.keys().collect();
for key in keys_b.difference(&keys_a) {
added_layers.push((*key).clone());
}
for key in keys_a.difference(&keys_b) {
removed_layers.push((*key).clone());
}
for key in keys_a.intersection(&keys_b) {
let values_a = &model_a[*key];
let values_b = &model_b[*key];
let shape_changed = values_a.len() != values_b.len();
if shape_changed || !values_equal(values_a, values_b) {
let (l2_diff, max_diff) = compute_diffs(values_a, values_b);
modified_layers.push(LayerDiff {
name: (*key).clone(),
shape_changed,
old_shape: vec![values_a.len()],
new_shape: vec![values_b.len()],
l2_diff,
max_diff,
});
}
}
ModelDiff {
added_layers,
removed_layers,
modified_layers,
}
}
pub fn has_changes(diff: &ModelDiff) -> bool {
!diff.added_layers.is_empty()
|| !diff.removed_layers.is_empty()
|| !diff.modified_layers.is_empty()
}
}
fn values_equal(a: &[f32], b: &[f32]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6)
}
fn compute_diffs(a: &[f32], b: &[f32]) -> (f32, f32) {
let min_len = a.len().min(b.len());
let mut l2_sum: f32 = 0.0;
let mut max_diff: f32 = 0.0;
for i in 0..min_len {
let diff = (a[i] - b[i]).abs();
l2_sum += diff * diff;
max_diff = max_diff.max(diff);
}
let l2_diff = l2_sum.sqrt();
(l2_diff, max_diff)
}
fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
strings.serialize(serializer)
}
fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let strings = Vec::<String>::deserialize(deserializer)?;
strings
.into_iter()
.map(|s| s.parse().map_err(serde::de::Error::custom))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_commit() {
let commit = ModelCommit::new(
Cid::default(),
vec![],
Cid::default(),
"Initial commit".to_string(),
"test@example.com".to_string(),
);
assert!(commit.is_initial());
assert!(!commit.is_merge());
}
#[test]
fn test_repository_init() {
let mut repo = ModelRepository::new();
let commit = ModelCommit::new(
Cid::default(),
vec![],
Cid::default(),
"Initial commit".to_string(),
"test@example.com".to_string(),
);
repo.init(commit).unwrap();
assert_eq!(repo.current_branch(), Some("main"));
assert!(repo.head_commit().is_some());
}
#[test]
fn test_branch_creation() {
let mut repo = ModelRepository::new();
let commit = ModelCommit::new(
Cid::default(),
vec![],
Cid::default(),
"Initial commit".to_string(),
"test@example.com".to_string(),
);
repo.init(commit).unwrap();
repo.create_branch("develop".to_string(), None).unwrap();
assert_eq!(repo.list_branches().len(), 2);
}
#[test]
fn test_checkout() {
let mut repo = ModelRepository::new();
let commit = ModelCommit::new(
Cid::default(),
vec![],
Cid::default(),
"Initial commit".to_string(),
"test@example.com".to_string(),
);
repo.init(commit).unwrap();
repo.create_branch("develop".to_string(), None).unwrap();
repo.checkout("develop").unwrap();
assert_eq!(repo.current_branch(), Some("develop"));
}
#[test]
fn test_model_diff() {
let mut model_a = HashMap::new();
model_a.insert("layer1".to_string(), vec![1.0, 2.0, 3.0]);
model_a.insert("layer2".to_string(), vec![4.0, 5.0]);
let mut model_b = HashMap::new();
model_b.insert("layer1".to_string(), vec![1.1, 2.1, 3.1]);
model_b.insert("layer3".to_string(), vec![6.0, 7.0]);
let diff = ModelDiffer::diff(&model_a, &model_b);
assert_eq!(diff.added_layers.len(), 1);
assert_eq!(diff.removed_layers.len(), 1);
assert_eq!(diff.modified_layers.len(), 1);
assert!(ModelDiffer::has_changes(&diff));
}
#[test]
fn test_layer_diff() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.5, 2.5, 3.5];
let (l2_diff, max_diff) = compute_diffs(&a, &b);
assert!(l2_diff > 0.0);
assert!((max_diff - 0.5).abs() < 1e-6);
}
}