ipfrs_semantic/
dynamic.rs

1//! Dynamic embedding updates for evolving embedding spaces
2//!
3//! This module provides mechanisms for:
4//! - Online embedding updates
5//! - Version migration
6//! - Incremental fine-tuning
7//! - Embedding space evolution tracking
8
9use crate::VectorIndex;
10use ipfrs_core::{Cid, Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15/// Version of an embedding model
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub struct ModelVersion {
18    /// Major version
19    pub major: u32,
20    /// Minor version
21    pub minor: u32,
22    /// Patch version
23    pub patch: u32,
24    /// Optional tag (e.g., "alpha", "beta")
25    pub tag: Option<String>,
26}
27
28impl ModelVersion {
29    /// Create a new model version
30    pub fn new(major: u32, minor: u32, patch: u32) -> Self {
31        Self {
32            major,
33            minor,
34            patch,
35            tag: None,
36        }
37    }
38
39    /// Create a version with a tag
40    pub fn with_tag(mut self, tag: String) -> Self {
41        self.tag = Some(tag);
42        self
43    }
44
45    /// Check if this version is compatible with another (same major version)
46    pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
47        self.major == other.major
48    }
49}
50
51impl std::fmt::Display for ModelVersion {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}.{}.{}", self.major, self.minor, self.patch)?;
54        if let Some(tag) = &self.tag {
55            write!(f, "-{}", tag)?;
56        }
57        Ok(())
58    }
59}
60
61/// Embedding transformation for migrating between versions
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct EmbeddingTransform {
64    /// Source version
65    pub from_version: ModelVersion,
66    /// Target version
67    pub to_version: ModelVersion,
68    /// Transformation matrix (if dimensions change)
69    pub transform_matrix: Option<Vec<Vec<f32>>>,
70    /// Bias vector
71    pub bias: Option<Vec<f32>>,
72}
73
74impl EmbeddingTransform {
75    /// Create an identity transform (no change)
76    pub fn identity(version: ModelVersion) -> Self {
77        Self {
78            from_version: version.clone(),
79            to_version: version,
80            transform_matrix: None,
81            bias: None,
82        }
83    }
84
85    /// Create a linear transformation
86    pub fn linear(
87        from_version: ModelVersion,
88        to_version: ModelVersion,
89        matrix: Vec<Vec<f32>>,
90    ) -> Self {
91        Self {
92            from_version,
93            to_version,
94            transform_matrix: Some(matrix),
95            bias: None,
96        }
97    }
98
99    /// Apply transformation to an embedding
100    pub fn apply(&self, embedding: &[f32]) -> Vec<f32> {
101        let mut result = embedding.to_vec();
102
103        // Apply matrix transformation if present
104        if let Some(matrix) = &self.transform_matrix {
105            let out_dim = matrix[0].len();
106            let mut transformed = vec![0.0; out_dim];
107
108            for (i, row) in matrix.iter().enumerate() {
109                if i >= embedding.len() {
110                    break;
111                }
112                for (j, &val) in row.iter().enumerate() {
113                    transformed[j] += embedding[i] * val;
114                }
115            }
116
117            result = transformed;
118        }
119
120        // Apply bias if present
121        if let Some(bias) = &self.bias {
122            for (i, &b) in bias.iter().enumerate() {
123                if i < result.len() {
124                    result[i] += b;
125                }
126            }
127        }
128
129        result
130    }
131}
132
133/// Dynamic index that supports multiple embedding versions
134pub struct DynamicIndex {
135    /// Indices for each version
136    indices: Arc<RwLock<HashMap<ModelVersion, VectorIndex>>>,
137    /// Current active version
138    active_version: Arc<RwLock<ModelVersion>>,
139    /// Transformations between versions
140    transforms: Arc<RwLock<HashMap<(ModelVersion, ModelVersion), EmbeddingTransform>>>,
141    /// Embedding dimension
142    dimension: usize,
143}
144
145impl DynamicIndex {
146    /// Create a new dynamic index
147    pub fn new(initial_version: ModelVersion, dimension: usize) -> Result<Self> {
148        let mut indices = HashMap::new();
149        let index = VectorIndex::with_defaults(dimension)?;
150        indices.insert(initial_version.clone(), index);
151
152        Ok(Self {
153            indices: Arc::new(RwLock::new(indices)),
154            active_version: Arc::new(RwLock::new(initial_version)),
155            transforms: Arc::new(RwLock::new(HashMap::new())),
156            dimension,
157        })
158    }
159
160    /// Get the current active version
161    pub fn active_version(&self) -> ModelVersion {
162        self.active_version.read().unwrap().clone()
163    }
164
165    /// Add a new version with optional transform from previous version
166    pub fn add_version(
167        &self,
168        version: ModelVersion,
169        transform: Option<EmbeddingTransform>,
170    ) -> Result<()> {
171        let mut indices = self.indices.write().unwrap();
172
173        if indices.contains_key(&version) {
174            return Err(Error::InvalidInput(format!(
175                "Version {} already exists",
176                version
177            )));
178        }
179
180        let index = VectorIndex::with_defaults(self.dimension)?;
181        indices.insert(version.clone(), index);
182
183        // Add transform if provided
184        if let Some(t) = transform {
185            let mut transforms = self.transforms.write().unwrap();
186            transforms.insert((t.from_version.clone(), t.to_version.clone()), t);
187        }
188
189        Ok(())
190    }
191
192    /// Set the active version
193    pub fn set_active_version(&self, version: ModelVersion) -> Result<()> {
194        let indices = self.indices.read().unwrap();
195
196        if !indices.contains_key(&version) {
197            return Err(Error::InvalidInput(format!(
198                "Version {} does not exist",
199                version
200            )));
201        }
202
203        let mut active = self.active_version.write().unwrap();
204        *active = version;
205
206        Ok(())
207    }
208
209    /// Insert an embedding for a specific version
210    pub fn insert(
211        &self,
212        cid: &Cid,
213        embedding: &[f32],
214        version: Option<ModelVersion>,
215    ) -> Result<()> {
216        let version = version.unwrap_or_else(|| self.active_version());
217
218        let mut indices = self.indices.write().unwrap();
219        let index = indices
220            .get_mut(&version)
221            .ok_or_else(|| Error::InvalidInput(format!("Version {} does not exist", version)))?;
222
223        index.insert(cid, embedding)?;
224        Ok(())
225    }
226
227    /// Update an existing embedding
228    pub fn update(
229        &self,
230        cid: &Cid,
231        new_embedding: &[f32],
232        version: Option<ModelVersion>,
233    ) -> Result<()> {
234        let version = version.unwrap_or_else(|| self.active_version());
235
236        let mut indices = self.indices.write().unwrap();
237        let index = indices
238            .get_mut(&version)
239            .ok_or_else(|| Error::InvalidInput(format!("Version {} does not exist", version)))?;
240
241        // First delete the old embedding
242        index.delete(cid)?;
243        // Then insert the new one
244        index.insert(cid, new_embedding)?;
245
246        Ok(())
247    }
248
249    /// Migrate embeddings from one version to another
250    pub fn migrate(&self, from: &ModelVersion, to: &ModelVersion) -> Result<usize> {
251        let transforms = self.transforms.read().unwrap();
252        let transform = transforms
253            .get(&(from.clone(), to.clone()))
254            .ok_or_else(|| Error::InvalidInput(format!("No transform from {} to {}", from, to)))?
255            .clone();
256        drop(transforms);
257
258        // Get all embeddings from source version
259        let indices = self.indices.read().unwrap();
260        let source_index = indices.get(from).ok_or_else(|| {
261            Error::InvalidInput(format!("Source version {} does not exist", from))
262        })?;
263
264        // Ensure target version exists
265        if !indices.contains_key(to) {
266            return Err(Error::InvalidInput(format!(
267                "Target version {} does not exist",
268                to
269            )));
270        }
271
272        // Get all embeddings from source index
273        let embeddings = source_index.get_all_embeddings();
274        drop(indices);
275
276        // Apply transformation and insert into target index
277        let mut migrated_count = 0;
278        for (cid, embedding) in embeddings {
279            // Apply transformation
280            let transformed = transform.apply(&embedding);
281
282            // Insert into target index
283            let mut indices = self.indices.write().unwrap();
284            if let Some(target_index) = indices.get_mut(to) {
285                // Only insert if not already present
286                if !target_index.contains(&cid) {
287                    target_index.insert(&cid, &transformed)?;
288                    migrated_count += 1;
289                }
290            }
291            drop(indices);
292        }
293
294        Ok(migrated_count)
295    }
296
297    /// Get statistics for all versions
298    pub fn version_stats(&self) -> HashMap<ModelVersion, VersionStats> {
299        let indices = self.indices.read().unwrap();
300
301        indices
302            .iter()
303            .map(|(version, index)| {
304                let stats = VersionStats {
305                    version: version.clone(),
306                    num_embeddings: index.len(),
307                    is_active: version == &self.active_version(),
308                };
309                (version.clone(), stats)
310            })
311            .collect()
312    }
313}
314
315/// Statistics for a specific version
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct VersionStats {
318    /// Version identifier
319    pub version: ModelVersion,
320    /// Number of embeddings in this version
321    pub num_embeddings: usize,
322    /// Whether this is the active version
323    pub is_active: bool,
324}
325
326/// Online updater for incremental fine-tuning
327pub struct OnlineUpdater {
328    /// Learning rate for updates
329    learning_rate: f32,
330    /// Momentum factor
331    momentum: f32,
332    /// Velocity (for momentum)
333    velocity: Arc<RwLock<HashMap<Cid, Vec<f32>>>>,
334}
335
336impl OnlineUpdater {
337    /// Create a new online updater
338    pub fn new(learning_rate: f32, momentum: f32) -> Self {
339        Self {
340            learning_rate,
341            momentum,
342            velocity: Arc::new(RwLock::new(HashMap::new())),
343        }
344    }
345
346    /// Update an embedding with a gradient
347    pub fn update(&self, cid: &Cid, embedding: &[f32], gradient: &[f32]) -> Vec<f32> {
348        let mut velocity = self.velocity.write().unwrap();
349
350        // Get or initialize velocity for this CID
351        let v = velocity
352            .entry(*cid)
353            .or_insert_with(|| vec![0.0; embedding.len()]);
354
355        // Update velocity with momentum
356        for i in 0..embedding.len().min(gradient.len()) {
357            v[i] = self.momentum * v[i] - self.learning_rate * gradient[i];
358        }
359
360        // Apply velocity to embedding
361        embedding
362            .iter()
363            .zip(v.iter())
364            .map(|(&e, &vel)| e + vel)
365            .collect()
366    }
367
368    /// Clear velocity history
369    pub fn reset(&self) {
370        let mut velocity = self.velocity.write().unwrap();
371        velocity.clear();
372    }
373
374    /// Get statistics
375    pub fn stats(&self) -> OnlineUpdaterStats {
376        let velocity = self.velocity.read().unwrap();
377
378        OnlineUpdaterStats {
379            learning_rate: self.learning_rate,
380            momentum: self.momentum,
381            num_tracked: velocity.len(),
382        }
383    }
384}
385
386/// Statistics for online updater
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct OnlineUpdaterStats {
389    /// Learning rate
390    pub learning_rate: f32,
391    /// Momentum
392    pub momentum: f32,
393    /// Number of tracked embeddings
394    pub num_tracked: usize,
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_model_version() {
403        let v1 = ModelVersion::new(1, 0, 0);
404        let v2 = ModelVersion::new(1, 1, 0);
405        let v3 = ModelVersion::new(2, 0, 0);
406
407        assert!(v1.is_compatible_with(&v2));
408        assert!(!v1.is_compatible_with(&v3));
409
410        assert_eq!(v1.to_string(), "1.0.0");
411        assert_eq!(v1.with_tag("alpha".into()).to_string(), "1.0.0-alpha");
412    }
413
414    #[test]
415    fn test_embedding_transform() {
416        let v1 = ModelVersion::new(1, 0, 0);
417        let v2 = ModelVersion::new(1, 1, 0);
418
419        // Identity transform
420        let identity = EmbeddingTransform::identity(v1.clone());
421        let embedding = vec![1.0, 2.0, 3.0];
422        let result = identity.apply(&embedding);
423        assert_eq!(result, embedding);
424
425        // Linear transform (2x2 -> 2x2)
426        let matrix = vec![vec![1.0, 0.0], vec![0.0, 2.0]];
427        let transform = EmbeddingTransform::linear(v1, v2, matrix);
428        let embedding = vec![1.0, 2.0];
429        let result = transform.apply(&embedding);
430        assert_eq!(result, vec![1.0, 4.0]);
431    }
432
433    #[test]
434    fn test_dynamic_index_creation() {
435        let version = ModelVersion::new(1, 0, 0);
436        let index = DynamicIndex::new(version.clone(), 128).unwrap();
437
438        assert_eq!(index.active_version(), version);
439    }
440
441    #[test]
442    fn test_add_version() {
443        let v1 = ModelVersion::new(1, 0, 0);
444        let v2 = ModelVersion::new(1, 1, 0);
445
446        let index = DynamicIndex::new(v1.clone(), 128).unwrap();
447        index.add_version(v2.clone(), None).unwrap();
448
449        let stats = index.version_stats();
450        assert_eq!(stats.len(), 2);
451        assert!(stats.contains_key(&v1));
452        assert!(stats.contains_key(&v2));
453    }
454
455    #[test]
456    fn test_set_active_version() {
457        let v1 = ModelVersion::new(1, 0, 0);
458        let v2 = ModelVersion::new(1, 1, 0);
459
460        let index = DynamicIndex::new(v1.clone(), 128).unwrap();
461        index.add_version(v2.clone(), None).unwrap();
462
463        assert_eq!(index.active_version(), v1);
464
465        index.set_active_version(v2.clone()).unwrap();
466        assert_eq!(index.active_version(), v2);
467    }
468
469    #[test]
470    fn test_insert_and_update() {
471        use multihash_codetable::{Code, MultihashDigest};
472
473        let version = ModelVersion::new(1, 0, 0);
474        let index = DynamicIndex::new(version, 3).unwrap();
475
476        let data = "test_embedding";
477        let hash = Code::Sha2_256.digest(data.as_bytes());
478        let cid = Cid::new_v1(0x55, hash);
479
480        let embedding = vec![1.0, 2.0, 3.0];
481        index.insert(&cid, &embedding, None).unwrap();
482
483        let stats = index.version_stats();
484        assert_eq!(stats.values().next().unwrap().num_embeddings, 1);
485
486        // Update the embedding
487        let new_embedding = vec![4.0, 5.0, 6.0];
488        index.update(&cid, &new_embedding, None).unwrap();
489
490        let stats = index.version_stats();
491        assert_eq!(stats.values().next().unwrap().num_embeddings, 1);
492    }
493
494    #[test]
495    fn test_online_updater() {
496        use multihash_codetable::{Code, MultihashDigest};
497
498        let updater = OnlineUpdater::new(0.1, 0.9);
499
500        let data = "test";
501        let hash = Code::Sha2_256.digest(data.as_bytes());
502        let cid = Cid::new_v1(0x55, hash);
503
504        let embedding = vec![1.0, 1.0, 1.0];
505        let gradient = vec![0.1, 0.1, 0.1];
506
507        let updated = updater.update(&cid, &embedding, &gradient);
508
509        // With learning_rate=0.1, gradient should decrease embedding
510        assert!(updated[0] < 1.0);
511        assert_eq!(updated.len(), 3);
512
513        let stats = updater.stats();
514        assert_eq!(stats.num_tracked, 1);
515    }
516
517    #[test]
518    fn test_updater_momentum() {
519        use multihash_codetable::{Code, MultihashDigest};
520
521        let updater = OnlineUpdater::new(0.1, 0.9);
522
523        let data = "test";
524        let hash = Code::Sha2_256.digest(data.as_bytes());
525        let cid = Cid::new_v1(0x55, hash);
526
527        let embedding = vec![1.0];
528        let gradient = vec![0.1];
529
530        // First update
531        let updated1 = updater.update(&cid, &embedding, &gradient);
532
533        // Second update with same gradient
534        let updated2 = updater.update(&cid, &updated1, &gradient);
535
536        // With momentum, second update should have larger magnitude
537        let delta1 = (embedding[0] - updated1[0]).abs();
538        let delta2 = (updated1[0] - updated2[0]).abs();
539        assert!(delta2 > delta1);
540    }
541}