1use ipfrs_core::Cid;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
16pub enum VersionControlError {
17 #[error("Commit not found: {0}")]
18 CommitNotFound(String),
19
20 #[error("Branch not found: {0}")]
21 BranchNotFound(String),
22
23 #[error("Branch already exists: {0}")]
24 BranchAlreadyExists(String),
25
26 #[error("Merge conflict in layer: {0}")]
27 MergeConflict(String),
28
29 #[error("Invalid commit ID: {0}")]
30 InvalidCommitId(String),
31
32 #[error("Cannot merge: {0}")]
33 CannotMerge(String),
34
35 #[error("Detached HEAD state")]
36 DetachedHead,
37
38 #[error("No parent commit")]
39 NoParentCommit,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelCommit {
45 #[serde(serialize_with = "crate::serialize_cid")]
47 #[serde(deserialize_with = "crate::deserialize_cid")]
48 pub id: Cid,
49
50 #[serde(serialize_with = "serialize_cid_vec")]
52 #[serde(deserialize_with = "deserialize_cid_vec")]
53 pub parents: Vec<Cid>,
54
55 #[serde(serialize_with = "crate::serialize_cid")]
57 #[serde(deserialize_with = "crate::deserialize_cid")]
58 pub model: Cid,
59
60 pub message: String,
62
63 pub author: String,
65
66 pub timestamp: i64,
68
69 pub metadata: HashMap<String, String>,
71}
72
73impl ModelCommit {
74 pub fn new(id: Cid, parents: Vec<Cid>, model: Cid, message: String, author: String) -> Self {
76 Self {
77 id,
78 parents,
79 model,
80 message,
81 author,
82 timestamp: chrono::Utc::now().timestamp(),
83 metadata: HashMap::new(),
84 }
85 }
86
87 pub fn with_metadata(mut self, key: String, value: String) -> Self {
89 self.metadata.insert(key, value);
90 self
91 }
92
93 pub fn is_merge(&self) -> bool {
95 self.parents.len() > 1
96 }
97
98 pub fn is_initial(&self) -> bool {
100 self.parents.is_empty()
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Branch {
107 pub name: String,
109
110 #[serde(serialize_with = "crate::serialize_cid")]
112 #[serde(deserialize_with = "crate::deserialize_cid")]
113 pub head: Cid,
114
115 pub description: Option<String>,
117}
118
119impl Branch {
120 pub fn new(name: String, head: Cid) -> Self {
122 Self {
123 name,
124 head,
125 description: None,
126 }
127 }
128
129 pub fn with_description(mut self, description: String) -> Self {
131 self.description = Some(description);
132 self
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct ModelRepository {
139 commits: HashMap<String, ModelCommit>,
141
142 branches: HashMap<String, Branch>,
144
145 current_branch: Option<String>,
147
148 head: Option<Cid>,
150}
151
152impl ModelRepository {
153 pub fn new() -> Self {
155 Self {
156 commits: HashMap::new(),
157 branches: HashMap::new(),
158 current_branch: None,
159 head: None,
160 }
161 }
162
163 pub fn init(&mut self, initial_commit: ModelCommit) -> Result<(), VersionControlError> {
165 let commit_id = initial_commit.id.to_string();
166 let commit_cid = initial_commit.id;
167
168 self.commits.insert(commit_id, initial_commit);
169
170 let main_branch = Branch::new("main".to_string(), commit_cid);
172 self.branches.insert("main".to_string(), main_branch);
173 self.current_branch = Some("main".to_string());
174 self.head = Some(commit_cid);
175
176 Ok(())
177 }
178
179 pub fn commit(
181 &mut self,
182 model: Cid,
183 message: String,
184 author: String,
185 ) -> Result<ModelCommit, VersionControlError> {
186 let parents = if let Some(head) = self.head {
187 vec![head]
188 } else {
189 vec![]
190 };
191
192 let commit_id = Cid::default();
195
196 let commit = ModelCommit::new(commit_id, parents, model, message, author);
197
198 self.commits.insert(commit_id.to_string(), commit.clone());
199 self.head = Some(commit_id);
200
201 if let Some(branch_name) = &self.current_branch {
203 if let Some(branch) = self.branches.get_mut(branch_name) {
204 branch.head = commit_id;
205 }
206 }
207
208 Ok(commit)
209 }
210
211 pub fn checkout(&mut self, target: &str) -> Result<(), VersionControlError> {
213 if let Some(branch) = self.branches.get(target) {
215 self.current_branch = Some(target.to_string());
216 self.head = Some(branch.head);
217 return Ok(());
218 }
219
220 if let Some(commit) = self.commits.get(target) {
222 self.current_branch = None; self.head = Some(commit.id);
224 return Ok(());
225 }
226
227 Err(VersionControlError::CommitNotFound(target.to_string()))
228 }
229
230 pub fn create_branch(
232 &mut self,
233 name: String,
234 start_point: Option<Cid>,
235 ) -> Result<(), VersionControlError> {
236 if self.branches.contains_key(&name) {
237 return Err(VersionControlError::BranchAlreadyExists(name));
238 }
239
240 let head = start_point
241 .or(self.head)
242 .ok_or(VersionControlError::NoParentCommit)?;
243
244 let branch = Branch::new(name.clone(), head);
245 self.branches.insert(name, branch);
246
247 Ok(())
248 }
249
250 pub fn delete_branch(&mut self, name: &str) -> Result<(), VersionControlError> {
252 if !self.branches.contains_key(name) {
253 return Err(VersionControlError::BranchNotFound(name.to_string()));
254 }
255
256 if let Some(current) = &self.current_branch {
257 if current == name {
258 return Err(VersionControlError::CannotMerge(
259 "Cannot delete current branch".to_string(),
260 ));
261 }
262 }
263
264 self.branches.remove(name);
265 Ok(())
266 }
267
268 pub fn list_branches(&self) -> Vec<&Branch> {
270 self.branches.values().collect()
271 }
272
273 pub fn current_branch(&self) -> Option<&str> {
275 self.current_branch.as_deref()
276 }
277
278 pub fn head_commit(&self) -> Option<&ModelCommit> {
280 self.head
281 .as_ref()
282 .and_then(|cid| self.commits.get(&cid.to_string()))
283 }
284
285 pub fn get_commit(&self, commit_id: &str) -> Option<&ModelCommit> {
287 self.commits.get(commit_id)
288 }
289
290 pub fn get_history(&self, start: &Cid, max_count: Option<usize>) -> Vec<&ModelCommit> {
292 let mut history = Vec::new();
293 let mut visited = HashSet::new();
294 let mut queue = vec![start];
295
296 while let Some(cid) = queue.pop() {
297 if visited.contains(cid) {
298 continue;
299 }
300
301 if let Some(max) = max_count {
302 if history.len() >= max {
303 break;
304 }
305 }
306
307 if let Some(commit) = self.commits.get(&cid.to_string()) {
308 visited.insert(*cid);
309 history.push(commit);
310
311 for parent in &commit.parents {
313 if !visited.contains(parent) {
314 queue.push(parent);
315 }
316 }
317 }
318 }
319
320 history
321 }
322
323 pub fn merge_fast_forward(&mut self, branch: &str) -> Result<(), VersionControlError> {
325 let target_branch = self
326 .branches
327 .get(branch)
328 .ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
329
330 let target_head = target_branch.head;
331
332 if let Some(current_name) = &self.current_branch {
334 if let Some(current_branch) = self.branches.get_mut(current_name) {
335 current_branch.head = target_head;
336 }
337 } else {
338 return Err(VersionControlError::DetachedHead);
339 }
340
341 self.head = Some(target_head);
342
343 Ok(())
344 }
345
346 pub fn can_fast_forward(&self, branch: &str) -> Result<bool, VersionControlError> {
348 let target_branch = self
349 .branches
350 .get(branch)
351 .ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
352
353 let current_head = self.head.ok_or(VersionControlError::NoParentCommit)?;
354
355 let history = self.get_history(&target_branch.head, None);
357
358 Ok(history.iter().any(|c| c.id == current_head))
359 }
360}
361
362impl Default for ModelRepository {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ModelDiff {
371 pub added_layers: Vec<String>,
373
374 pub removed_layers: Vec<String>,
376
377 pub modified_layers: Vec<LayerDiff>,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct LayerDiff {
384 pub name: String,
386
387 pub shape_changed: bool,
389
390 pub old_shape: Vec<usize>,
392
393 pub new_shape: Vec<usize>,
395
396 pub l2_diff: f32,
398
399 pub max_diff: f32,
401}
402
403pub struct ModelDiffer;
405
406impl ModelDiffer {
407 pub fn diff(
409 model_a: &HashMap<String, Vec<f32>>,
410 model_b: &HashMap<String, Vec<f32>>,
411 ) -> ModelDiff {
412 let mut added_layers = Vec::new();
413 let mut removed_layers = Vec::new();
414 let mut modified_layers = Vec::new();
415
416 let keys_a: HashSet<_> = model_a.keys().collect();
417 let keys_b: HashSet<_> = model_b.keys().collect();
418
419 for key in keys_b.difference(&keys_a) {
421 added_layers.push((*key).clone());
422 }
423
424 for key in keys_a.difference(&keys_b) {
426 removed_layers.push((*key).clone());
427 }
428
429 for key in keys_a.intersection(&keys_b) {
431 let values_a = &model_a[*key];
432 let values_b = &model_b[*key];
433
434 let shape_changed = values_a.len() != values_b.len();
435
436 if shape_changed || !values_equal(values_a, values_b) {
437 let (l2_diff, max_diff) = compute_diffs(values_a, values_b);
438
439 modified_layers.push(LayerDiff {
440 name: (*key).clone(),
441 shape_changed,
442 old_shape: vec![values_a.len()],
443 new_shape: vec![values_b.len()],
444 l2_diff,
445 max_diff,
446 });
447 }
448 }
449
450 ModelDiff {
451 added_layers,
452 removed_layers,
453 modified_layers,
454 }
455 }
456
457 pub fn has_changes(diff: &ModelDiff) -> bool {
459 !diff.added_layers.is_empty()
460 || !diff.removed_layers.is_empty()
461 || !diff.modified_layers.is_empty()
462 }
463}
464
465fn values_equal(a: &[f32], b: &[f32]) -> bool {
467 if a.len() != b.len() {
468 return false;
469 }
470
471 a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6)
472}
473
474fn compute_diffs(a: &[f32], b: &[f32]) -> (f32, f32) {
476 let min_len = a.len().min(b.len());
477
478 let mut l2_sum: f32 = 0.0;
479 let mut max_diff: f32 = 0.0;
480
481 for i in 0..min_len {
482 let diff = (a[i] - b[i]).abs();
483 l2_sum += diff * diff;
484 max_diff = max_diff.max(diff);
485 }
486
487 let l2_diff = l2_sum.sqrt();
488
489 (l2_diff, max_diff)
490}
491
492fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
494where
495 S: serde::Serializer,
496{
497 use serde::Serialize;
498 let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
499 strings.serialize(serializer)
500}
501
502fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
503where
504 D: serde::Deserializer<'de>,
505{
506 use serde::Deserialize;
507 let strings = Vec::<String>::deserialize(deserializer)?;
508 strings
509 .into_iter()
510 .map(|s| s.parse().map_err(serde::de::Error::custom))
511 .collect()
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_model_commit() {
520 let commit = ModelCommit::new(
521 Cid::default(),
522 vec![],
523 Cid::default(),
524 "Initial commit".to_string(),
525 "test@example.com".to_string(),
526 );
527
528 assert!(commit.is_initial());
529 assert!(!commit.is_merge());
530 }
531
532 #[test]
533 fn test_repository_init() {
534 let mut repo = ModelRepository::new();
535
536 let commit = ModelCommit::new(
537 Cid::default(),
538 vec![],
539 Cid::default(),
540 "Initial commit".to_string(),
541 "test@example.com".to_string(),
542 );
543
544 repo.init(commit).unwrap();
545
546 assert_eq!(repo.current_branch(), Some("main"));
547 assert!(repo.head_commit().is_some());
548 }
549
550 #[test]
551 fn test_branch_creation() {
552 let mut repo = ModelRepository::new();
553
554 let commit = ModelCommit::new(
555 Cid::default(),
556 vec![],
557 Cid::default(),
558 "Initial commit".to_string(),
559 "test@example.com".to_string(),
560 );
561
562 repo.init(commit).unwrap();
563
564 repo.create_branch("develop".to_string(), None).unwrap();
565
566 assert_eq!(repo.list_branches().len(), 2);
567 }
568
569 #[test]
570 fn test_checkout() {
571 let mut repo = ModelRepository::new();
572
573 let commit = ModelCommit::new(
574 Cid::default(),
575 vec![],
576 Cid::default(),
577 "Initial commit".to_string(),
578 "test@example.com".to_string(),
579 );
580
581 repo.init(commit).unwrap();
582 repo.create_branch("develop".to_string(), None).unwrap();
583
584 repo.checkout("develop").unwrap();
585
586 assert_eq!(repo.current_branch(), Some("develop"));
587 }
588
589 #[test]
590 fn test_model_diff() {
591 let mut model_a = HashMap::new();
592 model_a.insert("layer1".to_string(), vec![1.0, 2.0, 3.0]);
593 model_a.insert("layer2".to_string(), vec![4.0, 5.0]);
594
595 let mut model_b = HashMap::new();
596 model_b.insert("layer1".to_string(), vec![1.1, 2.1, 3.1]);
597 model_b.insert("layer3".to_string(), vec![6.0, 7.0]);
598
599 let diff = ModelDiffer::diff(&model_a, &model_b);
600
601 assert_eq!(diff.added_layers.len(), 1);
602 assert_eq!(diff.removed_layers.len(), 1);
603 assert_eq!(diff.modified_layers.len(), 1);
604 assert!(ModelDiffer::has_changes(&diff));
605 }
606
607 #[test]
608 fn test_layer_diff() {
609 let a = vec![1.0, 2.0, 3.0];
610 let b = vec![1.5, 2.5, 3.5];
611
612 let (l2_diff, max_diff) = compute_diffs(&a, &b);
613
614 assert!(l2_diff > 0.0);
615 assert!((max_diff - 0.5).abs() < 1e-6);
616 }
617}