ipfrs_tensorlogic/
version_control.rs

1//! Model version control system for ML models
2//!
3//! This module provides Git-like version control for ML models:
4//! - Commit/checkout operations
5//! - Branching and merging
6//! - Diff operations for models
7//! - Model history tracking
8
9use ipfrs_core::Cid;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use thiserror::Error;
13
14/// Errors that can occur during version control operations
15#[derive(Debug, Error)]
16pub enum VersionControlError {
17    #[error("Commit not found: {0}")]
18    CommitNotFound(String),
19
20    #[error("Branch not found: {0}")]
21    BranchNotFound(String),
22
23    #[error("Branch already exists: {0}")]
24    BranchAlreadyExists(String),
25
26    #[error("Merge conflict in layer: {0}")]
27    MergeConflict(String),
28
29    #[error("Invalid commit ID: {0}")]
30    InvalidCommitId(String),
31
32    #[error("Cannot merge: {0}")]
33    CannotMerge(String),
34
35    #[error("Detached HEAD state")]
36    DetachedHead,
37
38    #[error("No parent commit")]
39    NoParentCommit,
40}
41
42/// A commit in the model version history
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelCommit {
45    /// Unique commit ID (CID of the commit itself)
46    #[serde(serialize_with = "crate::serialize_cid")]
47    #[serde(deserialize_with = "crate::deserialize_cid")]
48    pub id: Cid,
49
50    /// Parent commit ID(s) (empty for initial commit, multiple for merges)
51    #[serde(serialize_with = "serialize_cid_vec")]
52    #[serde(deserialize_with = "deserialize_cid_vec")]
53    pub parents: Vec<Cid>,
54
55    /// Model CID (points to the actual model data)
56    #[serde(serialize_with = "crate::serialize_cid")]
57    #[serde(deserialize_with = "crate::deserialize_cid")]
58    pub model: Cid,
59
60    /// Commit message
61    pub message: String,
62
63    /// Author
64    pub author: String,
65
66    /// Timestamp
67    pub timestamp: i64,
68
69    /// Metadata (hyperparameters, training info, etc.)
70    pub metadata: HashMap<String, String>,
71}
72
73impl ModelCommit {
74    /// Create a new commit
75    pub fn new(id: Cid, parents: Vec<Cid>, model: Cid, message: String, author: String) -> Self {
76        Self {
77            id,
78            parents,
79            model,
80            message,
81            author,
82            timestamp: chrono::Utc::now().timestamp(),
83            metadata: HashMap::new(),
84        }
85    }
86
87    /// Add metadata to the commit
88    pub fn with_metadata(mut self, key: String, value: String) -> Self {
89        self.metadata.insert(key, value);
90        self
91    }
92
93    /// Check if this is a merge commit
94    pub fn is_merge(&self) -> bool {
95        self.parents.len() > 1
96    }
97
98    /// Check if this is an initial commit
99    pub fn is_initial(&self) -> bool {
100        self.parents.is_empty()
101    }
102}
103
104/// A branch in the version control system
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Branch {
107    /// Branch name
108    pub name: String,
109
110    /// Current commit CID
111    #[serde(serialize_with = "crate::serialize_cid")]
112    #[serde(deserialize_with = "crate::deserialize_cid")]
113    pub head: Cid,
114
115    /// Branch description
116    pub description: Option<String>,
117}
118
119impl Branch {
120    /// Create a new branch
121    pub fn new(name: String, head: Cid) -> Self {
122        Self {
123            name,
124            head,
125            description: None,
126        }
127    }
128
129    /// Add description to the branch
130    pub fn with_description(mut self, description: String) -> Self {
131        self.description = Some(description);
132        self
133    }
134}
135
136/// Model version control repository
137#[derive(Debug, Clone)]
138pub struct ModelRepository {
139    /// All commits (commit ID -> commit)
140    commits: HashMap<String, ModelCommit>,
141
142    /// All branches (branch name -> branch)
143    branches: HashMap<String, Branch>,
144
145    /// Current branch name (None if detached HEAD)
146    current_branch: Option<String>,
147
148    /// Current HEAD commit
149    head: Option<Cid>,
150}
151
152impl ModelRepository {
153    /// Create a new empty repository
154    pub fn new() -> Self {
155        Self {
156            commits: HashMap::new(),
157            branches: HashMap::new(),
158            current_branch: None,
159            head: None,
160        }
161    }
162
163    /// Initialize repository with an initial commit
164    pub fn init(&mut self, initial_commit: ModelCommit) -> Result<(), VersionControlError> {
165        let commit_id = initial_commit.id.to_string();
166        let commit_cid = initial_commit.id;
167
168        self.commits.insert(commit_id, initial_commit);
169
170        // Create main branch
171        let main_branch = Branch::new("main".to_string(), commit_cid);
172        self.branches.insert("main".to_string(), main_branch);
173        self.current_branch = Some("main".to_string());
174        self.head = Some(commit_cid);
175
176        Ok(())
177    }
178
179    /// Create a commit
180    pub fn commit(
181        &mut self,
182        model: Cid,
183        message: String,
184        author: String,
185    ) -> Result<ModelCommit, VersionControlError> {
186        let parents = if let Some(head) = self.head {
187            vec![head]
188        } else {
189            vec![]
190        };
191
192        // In a real implementation, we would compute the CID from the commit content
193        // For now, use a placeholder
194        let commit_id = Cid::default();
195
196        let commit = ModelCommit::new(commit_id, parents, model, message, author);
197
198        self.commits.insert(commit_id.to_string(), commit.clone());
199        self.head = Some(commit_id);
200
201        // Update current branch if we're on one
202        if let Some(branch_name) = &self.current_branch {
203            if let Some(branch) = self.branches.get_mut(branch_name) {
204                branch.head = commit_id;
205            }
206        }
207
208        Ok(commit)
209    }
210
211    /// Checkout to a specific commit or branch
212    pub fn checkout(&mut self, target: &str) -> Result<(), VersionControlError> {
213        // Try to interpret target as a branch name first
214        if let Some(branch) = self.branches.get(target) {
215            self.current_branch = Some(target.to_string());
216            self.head = Some(branch.head);
217            return Ok(());
218        }
219
220        // Try to interpret target as a commit ID
221        if let Some(commit) = self.commits.get(target) {
222            self.current_branch = None; // Detached HEAD
223            self.head = Some(commit.id);
224            return Ok(());
225        }
226
227        Err(VersionControlError::CommitNotFound(target.to_string()))
228    }
229
230    /// Create a new branch
231    pub fn create_branch(
232        &mut self,
233        name: String,
234        start_point: Option<Cid>,
235    ) -> Result<(), VersionControlError> {
236        if self.branches.contains_key(&name) {
237            return Err(VersionControlError::BranchAlreadyExists(name));
238        }
239
240        let head = start_point
241            .or(self.head)
242            .ok_or(VersionControlError::NoParentCommit)?;
243
244        let branch = Branch::new(name.clone(), head);
245        self.branches.insert(name, branch);
246
247        Ok(())
248    }
249
250    /// Delete a branch
251    pub fn delete_branch(&mut self, name: &str) -> Result<(), VersionControlError> {
252        if !self.branches.contains_key(name) {
253            return Err(VersionControlError::BranchNotFound(name.to_string()));
254        }
255
256        if let Some(current) = &self.current_branch {
257            if current == name {
258                return Err(VersionControlError::CannotMerge(
259                    "Cannot delete current branch".to_string(),
260                ));
261            }
262        }
263
264        self.branches.remove(name);
265        Ok(())
266    }
267
268    /// List all branches
269    pub fn list_branches(&self) -> Vec<&Branch> {
270        self.branches.values().collect()
271    }
272
273    /// Get current branch name
274    pub fn current_branch(&self) -> Option<&str> {
275        self.current_branch.as_deref()
276    }
277
278    /// Get current HEAD commit
279    pub fn head_commit(&self) -> Option<&ModelCommit> {
280        self.head
281            .as_ref()
282            .and_then(|cid| self.commits.get(&cid.to_string()))
283    }
284
285    /// Get a commit by ID
286    pub fn get_commit(&self, commit_id: &str) -> Option<&ModelCommit> {
287        self.commits.get(commit_id)
288    }
289
290    /// Get commit history from a starting commit
291    pub fn get_history(&self, start: &Cid, max_count: Option<usize>) -> Vec<&ModelCommit> {
292        let mut history = Vec::new();
293        let mut visited = HashSet::new();
294        let mut queue = vec![start];
295
296        while let Some(cid) = queue.pop() {
297            if visited.contains(cid) {
298                continue;
299            }
300
301            if let Some(max) = max_count {
302                if history.len() >= max {
303                    break;
304                }
305            }
306
307            if let Some(commit) = self.commits.get(&cid.to_string()) {
308                visited.insert(*cid);
309                history.push(commit);
310
311                // Add parents to queue
312                for parent in &commit.parents {
313                    if !visited.contains(parent) {
314                        queue.push(parent);
315                    }
316                }
317            }
318        }
319
320        history
321    }
322
323    /// Perform a fast-forward merge
324    pub fn merge_fast_forward(&mut self, branch: &str) -> Result<(), VersionControlError> {
325        let target_branch = self
326            .branches
327            .get(branch)
328            .ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
329
330        let target_head = target_branch.head;
331
332        // Update current branch
333        if let Some(current_name) = &self.current_branch {
334            if let Some(current_branch) = self.branches.get_mut(current_name) {
335                current_branch.head = target_head;
336            }
337        } else {
338            return Err(VersionControlError::DetachedHead);
339        }
340
341        self.head = Some(target_head);
342
343        Ok(())
344    }
345
346    /// Check if fast-forward merge is possible
347    pub fn can_fast_forward(&self, branch: &str) -> Result<bool, VersionControlError> {
348        let target_branch = self
349            .branches
350            .get(branch)
351            .ok_or_else(|| VersionControlError::BranchNotFound(branch.to_string()))?;
352
353        let current_head = self.head.ok_or(VersionControlError::NoParentCommit)?;
354
355        // Check if current head is an ancestor of target head
356        let history = self.get_history(&target_branch.head, None);
357
358        Ok(history.iter().any(|c| c.id == current_head))
359    }
360}
361
362impl Default for ModelRepository {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368/// Model difference representation
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ModelDiff {
371    /// Layers that were added
372    pub added_layers: Vec<String>,
373
374    /// Layers that were removed
375    pub removed_layers: Vec<String>,
376
377    /// Layers that were modified
378    pub modified_layers: Vec<LayerDiff>,
379}
380
381/// Difference in a single layer
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct LayerDiff {
384    /// Layer name
385    pub name: String,
386
387    /// Shape changed
388    pub shape_changed: bool,
389
390    /// Previous shape
391    pub old_shape: Vec<usize>,
392
393    /// New shape
394    pub new_shape: Vec<usize>,
395
396    /// L2 norm of the difference
397    pub l2_diff: f32,
398
399    /// Maximum absolute difference
400    pub max_diff: f32,
401}
402
403/// Model differ for computing differences between models
404pub struct ModelDiffer;
405
406impl ModelDiffer {
407    /// Compute diff between two models
408    pub fn diff(
409        model_a: &HashMap<String, Vec<f32>>,
410        model_b: &HashMap<String, Vec<f32>>,
411    ) -> ModelDiff {
412        let mut added_layers = Vec::new();
413        let mut removed_layers = Vec::new();
414        let mut modified_layers = Vec::new();
415
416        let keys_a: HashSet<_> = model_a.keys().collect();
417        let keys_b: HashSet<_> = model_b.keys().collect();
418
419        // Find added layers
420        for key in keys_b.difference(&keys_a) {
421            added_layers.push((*key).clone());
422        }
423
424        // Find removed layers
425        for key in keys_a.difference(&keys_b) {
426            removed_layers.push((*key).clone());
427        }
428
429        // Find modified layers
430        for key in keys_a.intersection(&keys_b) {
431            let values_a = &model_a[*key];
432            let values_b = &model_b[*key];
433
434            let shape_changed = values_a.len() != values_b.len();
435
436            if shape_changed || !values_equal(values_a, values_b) {
437                let (l2_diff, max_diff) = compute_diffs(values_a, values_b);
438
439                modified_layers.push(LayerDiff {
440                    name: (*key).clone(),
441                    shape_changed,
442                    old_shape: vec![values_a.len()],
443                    new_shape: vec![values_b.len()],
444                    l2_diff,
445                    max_diff,
446                });
447            }
448        }
449
450        ModelDiff {
451            added_layers,
452            removed_layers,
453            modified_layers,
454        }
455    }
456
457    /// Check if diff has any changes
458    pub fn has_changes(diff: &ModelDiff) -> bool {
459        !diff.added_layers.is_empty()
460            || !diff.removed_layers.is_empty()
461            || !diff.modified_layers.is_empty()
462    }
463}
464
465/// Check if two float vectors are equal within tolerance
466fn values_equal(a: &[f32], b: &[f32]) -> bool {
467    if a.len() != b.len() {
468        return false;
469    }
470
471    a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6)
472}
473
474/// Compute L2 and max difference between two vectors
475fn compute_diffs(a: &[f32], b: &[f32]) -> (f32, f32) {
476    let min_len = a.len().min(b.len());
477
478    let mut l2_sum: f32 = 0.0;
479    let mut max_diff: f32 = 0.0;
480
481    for i in 0..min_len {
482        let diff = (a[i] - b[i]).abs();
483        l2_sum += diff * diff;
484        max_diff = max_diff.max(diff);
485    }
486
487    let l2_diff = l2_sum.sqrt();
488
489    (l2_diff, max_diff)
490}
491
492// Helper functions for serializing/deserializing Vec<Cid>
493fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
494where
495    S: serde::Serializer,
496{
497    use serde::Serialize;
498    let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
499    strings.serialize(serializer)
500}
501
502fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
503where
504    D: serde::Deserializer<'de>,
505{
506    use serde::Deserialize;
507    let strings = Vec::<String>::deserialize(deserializer)?;
508    strings
509        .into_iter()
510        .map(|s| s.parse().map_err(serde::de::Error::custom))
511        .collect()
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_model_commit() {
520        let commit = ModelCommit::new(
521            Cid::default(),
522            vec![],
523            Cid::default(),
524            "Initial commit".to_string(),
525            "test@example.com".to_string(),
526        );
527
528        assert!(commit.is_initial());
529        assert!(!commit.is_merge());
530    }
531
532    #[test]
533    fn test_repository_init() {
534        let mut repo = ModelRepository::new();
535
536        let commit = ModelCommit::new(
537            Cid::default(),
538            vec![],
539            Cid::default(),
540            "Initial commit".to_string(),
541            "test@example.com".to_string(),
542        );
543
544        repo.init(commit).unwrap();
545
546        assert_eq!(repo.current_branch(), Some("main"));
547        assert!(repo.head_commit().is_some());
548    }
549
550    #[test]
551    fn test_branch_creation() {
552        let mut repo = ModelRepository::new();
553
554        let commit = ModelCommit::new(
555            Cid::default(),
556            vec![],
557            Cid::default(),
558            "Initial commit".to_string(),
559            "test@example.com".to_string(),
560        );
561
562        repo.init(commit).unwrap();
563
564        repo.create_branch("develop".to_string(), None).unwrap();
565
566        assert_eq!(repo.list_branches().len(), 2);
567    }
568
569    #[test]
570    fn test_checkout() {
571        let mut repo = ModelRepository::new();
572
573        let commit = ModelCommit::new(
574            Cid::default(),
575            vec![],
576            Cid::default(),
577            "Initial commit".to_string(),
578            "test@example.com".to_string(),
579        );
580
581        repo.init(commit).unwrap();
582        repo.create_branch("develop".to_string(), None).unwrap();
583
584        repo.checkout("develop").unwrap();
585
586        assert_eq!(repo.current_branch(), Some("develop"));
587    }
588
589    #[test]
590    fn test_model_diff() {
591        let mut model_a = HashMap::new();
592        model_a.insert("layer1".to_string(), vec![1.0, 2.0, 3.0]);
593        model_a.insert("layer2".to_string(), vec![4.0, 5.0]);
594
595        let mut model_b = HashMap::new();
596        model_b.insert("layer1".to_string(), vec![1.1, 2.1, 3.1]);
597        model_b.insert("layer3".to_string(), vec![6.0, 7.0]);
598
599        let diff = ModelDiffer::diff(&model_a, &model_b);
600
601        assert_eq!(diff.added_layers.len(), 1);
602        assert_eq!(diff.removed_layers.len(), 1);
603        assert_eq!(diff.modified_layers.len(), 1);
604        assert!(ModelDiffer::has_changes(&diff));
605    }
606
607    #[test]
608    fn test_layer_diff() {
609        let a = vec![1.0, 2.0, 3.0];
610        let b = vec![1.5, 2.5, 3.5];
611
612        let (l2_diff, max_diff) = compute_diffs(&a, &b);
613
614        assert!(l2_diff > 0.0);
615        assert!((max_diff - 0.5).abs() < 1e-6);
616    }
617}