1use crate::VectorIndex;
10use ipfrs_core::{Cid, Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub struct ModelVersion {
18 pub major: u32,
20 pub minor: u32,
22 pub patch: u32,
24 pub tag: Option<String>,
26}
27
28impl ModelVersion {
29 pub fn new(major: u32, minor: u32, patch: u32) -> Self {
31 Self {
32 major,
33 minor,
34 patch,
35 tag: None,
36 }
37 }
38
39 pub fn with_tag(mut self, tag: String) -> Self {
41 self.tag = Some(tag);
42 self
43 }
44
45 pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
47 self.major == other.major
48 }
49}
50
51impl std::fmt::Display for ModelVersion {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)?;
54 if let Some(tag) = &self.tag {
55 write!(f, "-{}", tag)?;
56 }
57 Ok(())
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct EmbeddingTransform {
64 pub from_version: ModelVersion,
66 pub to_version: ModelVersion,
68 pub transform_matrix: Option<Vec<Vec<f32>>>,
70 pub bias: Option<Vec<f32>>,
72}
73
74impl EmbeddingTransform {
75 pub fn identity(version: ModelVersion) -> Self {
77 Self {
78 from_version: version.clone(),
79 to_version: version,
80 transform_matrix: None,
81 bias: None,
82 }
83 }
84
85 pub fn linear(
87 from_version: ModelVersion,
88 to_version: ModelVersion,
89 matrix: Vec<Vec<f32>>,
90 ) -> Self {
91 Self {
92 from_version,
93 to_version,
94 transform_matrix: Some(matrix),
95 bias: None,
96 }
97 }
98
99 pub fn apply(&self, embedding: &[f32]) -> Vec<f32> {
101 let mut result = embedding.to_vec();
102
103 if let Some(matrix) = &self.transform_matrix {
105 let out_dim = matrix[0].len();
106 let mut transformed = vec![0.0; out_dim];
107
108 for (i, row) in matrix.iter().enumerate() {
109 if i >= embedding.len() {
110 break;
111 }
112 for (j, &val) in row.iter().enumerate() {
113 transformed[j] += embedding[i] * val;
114 }
115 }
116
117 result = transformed;
118 }
119
120 if let Some(bias) = &self.bias {
122 for (i, &b) in bias.iter().enumerate() {
123 if i < result.len() {
124 result[i] += b;
125 }
126 }
127 }
128
129 result
130 }
131}
132
133pub struct DynamicIndex {
135 indices: Arc<RwLock<HashMap<ModelVersion, VectorIndex>>>,
137 active_version: Arc<RwLock<ModelVersion>>,
139 transforms: Arc<RwLock<HashMap<(ModelVersion, ModelVersion), EmbeddingTransform>>>,
141 dimension: usize,
143}
144
145impl DynamicIndex {
146 pub fn new(initial_version: ModelVersion, dimension: usize) -> Result<Self> {
148 let mut indices = HashMap::new();
149 let index = VectorIndex::with_defaults(dimension)?;
150 indices.insert(initial_version.clone(), index);
151
152 Ok(Self {
153 indices: Arc::new(RwLock::new(indices)),
154 active_version: Arc::new(RwLock::new(initial_version)),
155 transforms: Arc::new(RwLock::new(HashMap::new())),
156 dimension,
157 })
158 }
159
160 pub fn active_version(&self) -> ModelVersion {
162 self.active_version.read().unwrap().clone()
163 }
164
165 pub fn add_version(
167 &self,
168 version: ModelVersion,
169 transform: Option<EmbeddingTransform>,
170 ) -> Result<()> {
171 let mut indices = self.indices.write().unwrap();
172
173 if indices.contains_key(&version) {
174 return Err(Error::InvalidInput(format!(
175 "Version {} already exists",
176 version
177 )));
178 }
179
180 let index = VectorIndex::with_defaults(self.dimension)?;
181 indices.insert(version.clone(), index);
182
183 if let Some(t) = transform {
185 let mut transforms = self.transforms.write().unwrap();
186 transforms.insert((t.from_version.clone(), t.to_version.clone()), t);
187 }
188
189 Ok(())
190 }
191
192 pub fn set_active_version(&self, version: ModelVersion) -> Result<()> {
194 let indices = self.indices.read().unwrap();
195
196 if !indices.contains_key(&version) {
197 return Err(Error::InvalidInput(format!(
198 "Version {} does not exist",
199 version
200 )));
201 }
202
203 let mut active = self.active_version.write().unwrap();
204 *active = version;
205
206 Ok(())
207 }
208
209 pub fn insert(
211 &self,
212 cid: &Cid,
213 embedding: &[f32],
214 version: Option<ModelVersion>,
215 ) -> Result<()> {
216 let version = version.unwrap_or_else(|| self.active_version());
217
218 let mut indices = self.indices.write().unwrap();
219 let index = indices
220 .get_mut(&version)
221 .ok_or_else(|| Error::InvalidInput(format!("Version {} does not exist", version)))?;
222
223 index.insert(cid, embedding)?;
224 Ok(())
225 }
226
227 pub fn update(
229 &self,
230 cid: &Cid,
231 new_embedding: &[f32],
232 version: Option<ModelVersion>,
233 ) -> Result<()> {
234 let version = version.unwrap_or_else(|| self.active_version());
235
236 let mut indices = self.indices.write().unwrap();
237 let index = indices
238 .get_mut(&version)
239 .ok_or_else(|| Error::InvalidInput(format!("Version {} does not exist", version)))?;
240
241 index.delete(cid)?;
243 index.insert(cid, new_embedding)?;
245
246 Ok(())
247 }
248
249 pub fn migrate(&self, from: &ModelVersion, to: &ModelVersion) -> Result<usize> {
251 let transforms = self.transforms.read().unwrap();
252 let transform = transforms
253 .get(&(from.clone(), to.clone()))
254 .ok_or_else(|| Error::InvalidInput(format!("No transform from {} to {}", from, to)))?
255 .clone();
256 drop(transforms);
257
258 let indices = self.indices.read().unwrap();
260 let source_index = indices.get(from).ok_or_else(|| {
261 Error::InvalidInput(format!("Source version {} does not exist", from))
262 })?;
263
264 if !indices.contains_key(to) {
266 return Err(Error::InvalidInput(format!(
267 "Target version {} does not exist",
268 to
269 )));
270 }
271
272 let embeddings = source_index.get_all_embeddings();
274 drop(indices);
275
276 let mut migrated_count = 0;
278 for (cid, embedding) in embeddings {
279 let transformed = transform.apply(&embedding);
281
282 let mut indices = self.indices.write().unwrap();
284 if let Some(target_index) = indices.get_mut(to) {
285 if !target_index.contains(&cid) {
287 target_index.insert(&cid, &transformed)?;
288 migrated_count += 1;
289 }
290 }
291 drop(indices);
292 }
293
294 Ok(migrated_count)
295 }
296
297 pub fn version_stats(&self) -> HashMap<ModelVersion, VersionStats> {
299 let indices = self.indices.read().unwrap();
300
301 indices
302 .iter()
303 .map(|(version, index)| {
304 let stats = VersionStats {
305 version: version.clone(),
306 num_embeddings: index.len(),
307 is_active: version == &self.active_version(),
308 };
309 (version.clone(), stats)
310 })
311 .collect()
312 }
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct VersionStats {
318 pub version: ModelVersion,
320 pub num_embeddings: usize,
322 pub is_active: bool,
324}
325
326pub struct OnlineUpdater {
328 learning_rate: f32,
330 momentum: f32,
332 velocity: Arc<RwLock<HashMap<Cid, Vec<f32>>>>,
334}
335
336impl OnlineUpdater {
337 pub fn new(learning_rate: f32, momentum: f32) -> Self {
339 Self {
340 learning_rate,
341 momentum,
342 velocity: Arc::new(RwLock::new(HashMap::new())),
343 }
344 }
345
346 pub fn update(&self, cid: &Cid, embedding: &[f32], gradient: &[f32]) -> Vec<f32> {
348 let mut velocity = self.velocity.write().unwrap();
349
350 let v = velocity
352 .entry(*cid)
353 .or_insert_with(|| vec![0.0; embedding.len()]);
354
355 for i in 0..embedding.len().min(gradient.len()) {
357 v[i] = self.momentum * v[i] - self.learning_rate * gradient[i];
358 }
359
360 embedding
362 .iter()
363 .zip(v.iter())
364 .map(|(&e, &vel)| e + vel)
365 .collect()
366 }
367
368 pub fn reset(&self) {
370 let mut velocity = self.velocity.write().unwrap();
371 velocity.clear();
372 }
373
374 pub fn stats(&self) -> OnlineUpdaterStats {
376 let velocity = self.velocity.read().unwrap();
377
378 OnlineUpdaterStats {
379 learning_rate: self.learning_rate,
380 momentum: self.momentum,
381 num_tracked: velocity.len(),
382 }
383 }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct OnlineUpdaterStats {
389 pub learning_rate: f32,
391 pub momentum: f32,
393 pub num_tracked: usize,
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_model_version() {
403 let v1 = ModelVersion::new(1, 0, 0);
404 let v2 = ModelVersion::new(1, 1, 0);
405 let v3 = ModelVersion::new(2, 0, 0);
406
407 assert!(v1.is_compatible_with(&v2));
408 assert!(!v1.is_compatible_with(&v3));
409
410 assert_eq!(v1.to_string(), "1.0.0");
411 assert_eq!(v1.with_tag("alpha".into()).to_string(), "1.0.0-alpha");
412 }
413
414 #[test]
415 fn test_embedding_transform() {
416 let v1 = ModelVersion::new(1, 0, 0);
417 let v2 = ModelVersion::new(1, 1, 0);
418
419 let identity = EmbeddingTransform::identity(v1.clone());
421 let embedding = vec![1.0, 2.0, 3.0];
422 let result = identity.apply(&embedding);
423 assert_eq!(result, embedding);
424
425 let matrix = vec![vec![1.0, 0.0], vec![0.0, 2.0]];
427 let transform = EmbeddingTransform::linear(v1, v2, matrix);
428 let embedding = vec![1.0, 2.0];
429 let result = transform.apply(&embedding);
430 assert_eq!(result, vec![1.0, 4.0]);
431 }
432
433 #[test]
434 fn test_dynamic_index_creation() {
435 let version = ModelVersion::new(1, 0, 0);
436 let index = DynamicIndex::new(version.clone(), 128).unwrap();
437
438 assert_eq!(index.active_version(), version);
439 }
440
441 #[test]
442 fn test_add_version() {
443 let v1 = ModelVersion::new(1, 0, 0);
444 let v2 = ModelVersion::new(1, 1, 0);
445
446 let index = DynamicIndex::new(v1.clone(), 128).unwrap();
447 index.add_version(v2.clone(), None).unwrap();
448
449 let stats = index.version_stats();
450 assert_eq!(stats.len(), 2);
451 assert!(stats.contains_key(&v1));
452 assert!(stats.contains_key(&v2));
453 }
454
455 #[test]
456 fn test_set_active_version() {
457 let v1 = ModelVersion::new(1, 0, 0);
458 let v2 = ModelVersion::new(1, 1, 0);
459
460 let index = DynamicIndex::new(v1.clone(), 128).unwrap();
461 index.add_version(v2.clone(), None).unwrap();
462
463 assert_eq!(index.active_version(), v1);
464
465 index.set_active_version(v2.clone()).unwrap();
466 assert_eq!(index.active_version(), v2);
467 }
468
469 #[test]
470 fn test_insert_and_update() {
471 use multihash_codetable::{Code, MultihashDigest};
472
473 let version = ModelVersion::new(1, 0, 0);
474 let index = DynamicIndex::new(version, 3).unwrap();
475
476 let data = "test_embedding";
477 let hash = Code::Sha2_256.digest(data.as_bytes());
478 let cid = Cid::new_v1(0x55, hash);
479
480 let embedding = vec![1.0, 2.0, 3.0];
481 index.insert(&cid, &embedding, None).unwrap();
482
483 let stats = index.version_stats();
484 assert_eq!(stats.values().next().unwrap().num_embeddings, 1);
485
486 let new_embedding = vec![4.0, 5.0, 6.0];
488 index.update(&cid, &new_embedding, None).unwrap();
489
490 let stats = index.version_stats();
491 assert_eq!(stats.values().next().unwrap().num_embeddings, 1);
492 }
493
494 #[test]
495 fn test_online_updater() {
496 use multihash_codetable::{Code, MultihashDigest};
497
498 let updater = OnlineUpdater::new(0.1, 0.9);
499
500 let data = "test";
501 let hash = Code::Sha2_256.digest(data.as_bytes());
502 let cid = Cid::new_v1(0x55, hash);
503
504 let embedding = vec![1.0, 1.0, 1.0];
505 let gradient = vec![0.1, 0.1, 0.1];
506
507 let updated = updater.update(&cid, &embedding, &gradient);
508
509 assert!(updated[0] < 1.0);
511 assert_eq!(updated.len(), 3);
512
513 let stats = updater.stats();
514 assert_eq!(stats.num_tracked, 1);
515 }
516
517 #[test]
518 fn test_updater_momentum() {
519 use multihash_codetable::{Code, MultihashDigest};
520
521 let updater = OnlineUpdater::new(0.1, 0.9);
522
523 let data = "test";
524 let hash = Code::Sha2_256.digest(data.as_bytes());
525 let cid = Cid::new_v1(0x55, hash);
526
527 let embedding = vec![1.0];
528 let gradient = vec![0.1];
529
530 let updated1 = updater.update(&cid, &embedding, &gradient);
532
533 let updated2 = updater.update(&cid, &updated1, &gradient);
535
536 let delta1 = (embedding[0] - updated1[0]).abs();
538 let delta2 = (updated1[0] - updated2[0]).abs();
539 assert!(delta2 > delta1);
540 }
541}