1use crate::traits::BlockStore;
40use crate::vcs::VersionControl;
41use bytes::Bytes;
42use ipfrs_core::{Block, Cid, Error, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45use std::sync::Arc;
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProvenanceMetadata {
50 pub layer: String,
52 pub timestamp: u64,
54 pub training_config: String,
56 #[serde(
58 serialize_with = "serialize_option_cid",
59 deserialize_with = "deserialize_option_cid"
60 )]
61 pub parent: Option<Cid>,
62 pub step: Option<u64>,
64 pub metadata: HashMap<String, String>,
66}
67
68fn serialize_option_cid<S>(cid: &Option<Cid>, serializer: S) -> std::result::Result<S::Ok, S::Error>
70where
71 S: serde::Serializer,
72{
73 match cid {
74 Some(c) => serializer.serialize_some(&c.to_bytes()),
75 None => serializer.serialize_none(),
76 }
77}
78
79fn deserialize_option_cid<'de, D>(deserializer: D) -> std::result::Result<Option<Cid>, D::Error>
80where
81 D: serde::Deserializer<'de>,
82{
83 let opt: Option<Vec<u8>> = Deserialize::deserialize(deserializer)?;
84 match opt {
85 Some(bytes) => Cid::try_from(bytes)
86 .map(Some)
87 .map_err(serde::de::Error::custom),
88 None => Ok(None),
89 }
90}
91
92impl ProvenanceMetadata {
93 pub fn new(layer: String, training_config: String) -> Self {
95 Self {
96 layer,
97 timestamp: std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .unwrap()
100 .as_secs(),
101 training_config,
102 parent: None,
103 step: None,
104 metadata: HashMap::new(),
105 }
106 }
107
108 pub fn with_parent(mut self, parent: Cid) -> Self {
110 self.parent = Some(parent);
111 self
112 }
113
114 pub fn with_step(mut self, step: u64) -> Self {
116 self.step = Some(step);
117 self
118 }
119
120 pub fn with_metadata(mut self, key: String, value: String) -> Self {
122 self.metadata.insert(key, value);
123 self
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GradientData {
130 pub shape: Vec<usize>,
132 pub dtype: String,
134 pub data: Vec<u8>,
136 pub is_delta: bool,
138 pub provenance: Option<ProvenanceMetadata>,
140}
141
142pub struct DeltaEncoder;
144
145impl DeltaEncoder {
146 pub fn encode_delta(base: &[f32], target: &[f32]) -> Result<Vec<u8>> {
150 if base.len() != target.len() {
151 return Err(Error::Storage(
152 "Base and target must have same length".to_string(),
153 ));
154 }
155
156 let delta: Vec<f32> = target.iter().zip(base.iter()).map(|(t, b)| t - b).collect();
158
159 let mut sparse_delta = Vec::new();
161
162 for (idx, &value) in delta.iter().enumerate() {
163 if value.abs() > 1e-10 {
164 sparse_delta.extend_from_slice(&(idx as u32).to_le_bytes());
167 sparse_delta.extend_from_slice(&value.to_le_bytes());
168 }
169 }
170
171 Ok(sparse_delta)
172 }
173
174 pub fn decode_delta(base: &[f32], delta_bytes: &[u8]) -> Result<Vec<f32>> {
176 let mut result = base.to_vec();
177
178 let mut offset = 0;
180 while offset + 8 <= delta_bytes.len() {
181 let idx_bytes = &delta_bytes[offset..offset + 4];
182 let value_bytes = &delta_bytes[offset + 4..offset + 8];
183
184 let idx = u32::from_le_bytes([idx_bytes[0], idx_bytes[1], idx_bytes[2], idx_bytes[3]])
185 as usize;
186 let value = f32::from_le_bytes([
187 value_bytes[0],
188 value_bytes[1],
189 value_bytes[2],
190 value_bytes[3],
191 ]);
192
193 if idx < result.len() {
194 result[idx] += value;
195 }
196
197 offset += 8;
198 }
199
200 Ok(result)
201 }
202
203 pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f64 {
205 if compressed_size == 0 {
206 return 0.0;
207 }
208 original_size as f64 / compressed_size as f64
209 }
210}
211
212pub struct GradientStore<S: BlockStore> {
214 store: Arc<S>,
216 vcs: Option<Arc<VersionControl<S>>>,
218}
219
220impl<S: BlockStore> GradientStore<S> {
221 pub fn new(store: Arc<S>) -> Self {
223 Self { store, vcs: None }
224 }
225
226 pub fn with_vcs(store: Arc<S>, vcs: Arc<VersionControl<S>>) -> Self {
228 Self {
229 store,
230 vcs: Some(vcs),
231 }
232 }
233
234 pub fn vcs(&self) -> Option<&Arc<VersionControl<S>>> {
236 self.vcs.as_ref()
237 }
238
239 pub async fn store_gradient(
241 &self,
242 data: &[f32],
243 shape: Vec<usize>,
244 provenance: Option<ProvenanceMetadata>,
245 ) -> Result<Cid> {
246 let gradient_data = GradientData {
247 shape,
248 dtype: "f32".to_string(),
249 data: Self::encode_f32_slice(data),
250 is_delta: false,
251 provenance,
252 };
253
254 self.store_gradient_data(&gradient_data).await
255 }
256
257 pub async fn store_gradient_delta(
259 &self,
260 base_cid: &Cid,
261 target: &[f32],
262 shape: Vec<usize>,
263 provenance: Option<ProvenanceMetadata>,
264 ) -> Result<Cid> {
265 let base_data = self.load_gradient(base_cid).await?;
267 let base = Self::decode_f32_slice(&base_data.data)?;
268
269 let delta_bytes = DeltaEncoder::encode_delta(&base, target)?;
271
272 let mut prov = provenance
273 .unwrap_or_else(|| ProvenanceMetadata::new("unknown".to_string(), "delta".to_string()));
274 prov.parent = Some(*base_cid);
275
276 let gradient_data = GradientData {
277 shape,
278 dtype: "f32".to_string(),
279 data: delta_bytes,
280 is_delta: true,
281 provenance: Some(prov),
282 };
283
284 self.store_gradient_data(&gradient_data).await
285 }
286
287 pub async fn load_gradient(&self, cid: &Cid) -> Result<GradientData> {
289 let block = self
290 .store
291 .get(cid)
292 .await?
293 .ok_or_else(|| Error::NotFound(format!("Gradient not found: {cid}")))?;
294
295 let gradient_data: GradientData =
296 oxicode::serde::decode_owned_from_slice(block.data(), oxicode::config::standard())
297 .map(|(v, _)| v)
298 .map_err(|e| {
299 Error::Serialization(format!("Failed to deserialize gradient: {e}"))
300 })?;
301
302 Ok(gradient_data)
303 }
304
305 pub fn reconstruct_gradient<'a>(
307 &'a self,
308 cid: &'a Cid,
309 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>>> + Send + 'a>> {
310 Box::pin(async move {
311 let gradient_data = self.load_gradient(cid).await?;
312
313 if !gradient_data.is_delta {
314 return Self::decode_f32_slice(&gradient_data.data);
316 }
317
318 let parent_cid = gradient_data
320 .provenance
321 .as_ref()
322 .and_then(|p| p.parent)
323 .ok_or_else(|| Error::Storage("Delta gradient missing parent CID".to_string()))?;
324
325 let base = self.reconstruct_gradient(&parent_cid).await?;
326 DeltaEncoder::decode_delta(&base, &gradient_data.data)
327 })
328 }
329
330 async fn store_gradient_data(&self, gradient_data: &GradientData) -> Result<Cid> {
332 let bytes = oxicode::serde::encode_to_vec(gradient_data, oxicode::config::standard())
333 .map_err(|e| Error::Serialization(format!("Failed to serialize gradient: {e}")))?;
334
335 let block = Block::new(Bytes::from(bytes))?;
336 let cid = *block.cid();
337 self.store.put(&block).await?;
338
339 Ok(cid)
340 }
341
342 fn encode_f32_slice(data: &[f32]) -> Vec<u8> {
344 let mut bytes = Vec::with_capacity(data.len() * 4);
345 for &value in data {
346 bytes.extend_from_slice(&value.to_le_bytes());
347 }
348 bytes
349 }
350
351 fn decode_f32_slice(bytes: &[u8]) -> Result<Vec<f32>> {
353 if !bytes.len().is_multiple_of(4) {
354 return Err(Error::Storage("Invalid f32 data length".to_string()));
355 }
356
357 let mut data = Vec::with_capacity(bytes.len() / 4);
358 for chunk in bytes.chunks_exact(4) {
359 let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
360 data.push(value);
361 }
362
363 Ok(data)
364 }
365
366 pub async fn compute_compression_stats(&self, cid: &Cid) -> Result<CompressionStats> {
368 let gradient_data = self.load_gradient(cid).await?;
369
370 let original_size = gradient_data.shape.iter().product::<usize>() * 4; let compressed_size = gradient_data.data.len();
372
373 let ratio = DeltaEncoder::compression_ratio(original_size, compressed_size);
374
375 Ok(CompressionStats {
376 original_size,
377 compressed_size,
378 compression_ratio: ratio,
379 is_delta: gradient_data.is_delta,
380 })
381 }
382}
383
384#[derive(Debug, Clone, PartialEq)]
386pub struct CompressionStats {
387 pub original_size: usize,
389 pub compressed_size: usize,
391 pub compression_ratio: f64,
393 pub is_delta: bool,
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::blockstore::{BlockStoreConfig, SledBlockStore};
401 use std::path::PathBuf;
402
403 #[test]
404 fn test_delta_encoding() {
405 let base = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
406 let target = vec![1.1f32, 2.0, 3.2, 4.0, 5.0];
407
408 let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
409 let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
410
411 for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
412 assert!(
413 (orig - recon).abs() < 1e-5,
414 "Mismatch at index {}: {} vs {}",
415 i,
416 orig,
417 recon
418 );
419 }
420 }
421
422 #[test]
423 fn test_sparse_delta() {
424 let base = vec![0.0f32; 1000];
426 let mut target = vec![0.0f32; 1000];
427 target[10] = 1.5;
428 target[500] = 2.3;
429 target[999] = -0.7;
430
431 let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
432
433 let full_size = 1000 * 4; let delta_size = delta_bytes.len(); assert!(delta_size < full_size / 10, "Delta not sparse enough");
437
438 let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
439 for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
440 assert!(
441 (orig - recon).abs() < 1e-5,
442 "Mismatch at index {}: {} vs {}",
443 i,
444 orig,
445 recon
446 );
447 }
448 }
449
450 #[tokio::test]
451 async fn test_gradient_store() {
452 let config = BlockStoreConfig {
453 path: PathBuf::from("/tmp/ipfrs-gradient-test"),
454 cache_size: 10 * 1024 * 1024,
455 };
456 let _ = std::fs::remove_dir_all(&config.path);
457
458 let store = Arc::new(SledBlockStore::new(config).unwrap());
459 let gradient_store = GradientStore::new(store);
460
461 let gradient = vec![1.0f32, 2.0, 3.0, 4.0];
462 let shape = vec![2, 2];
463
464 let cid = gradient_store
465 .store_gradient(&gradient, shape, None)
466 .await
467 .unwrap();
468
469 let loaded = gradient_store.load_gradient(&cid).await.unwrap();
470 assert_eq!(loaded.shape, vec![2, 2]);
471 assert!(!loaded.is_delta);
472 }
473
474 #[tokio::test]
475 async fn test_gradient_delta_chain() {
476 let config = BlockStoreConfig {
477 path: PathBuf::from("/tmp/ipfrs-gradient-delta-test"),
478 cache_size: 10 * 1024 * 1024,
479 };
480 let _ = std::fs::remove_dir_all(&config.path);
481
482 let store = Arc::new(SledBlockStore::new(config).unwrap());
483 let gradient_store = GradientStore::new(store);
484
485 let base_grad = vec![1.0f32, 2.0, 3.0, 4.0];
487 let base_cid = gradient_store
488 .store_gradient(&base_grad, vec![2, 2], None)
489 .await
490 .unwrap();
491
492 let target_grad = vec![1.1f32, 2.0, 3.2, 4.0];
494 let delta_cid = gradient_store
495 .store_gradient_delta(&base_cid, &target_grad, vec![2, 2], None)
496 .await
497 .unwrap();
498
499 let reconstructed = gradient_store
501 .reconstruct_gradient(&delta_cid)
502 .await
503 .unwrap();
504
505 for (i, (&orig, &recon)) in target_grad.iter().zip(reconstructed.iter()).enumerate() {
506 assert!(
507 (orig - recon).abs() < 1e-5,
508 "Mismatch at index {}: {} vs {}",
509 i,
510 orig,
511 recon
512 );
513 }
514 }
515
516 #[test]
517 fn test_provenance_metadata() {
518 let metadata = ProvenanceMetadata::new("layer1".to_string(), "lr=0.001".to_string())
519 .with_step(100)
520 .with_metadata("optimizer".to_string(), "adam".to_string());
521
522 assert_eq!(metadata.layer, "layer1");
523 assert_eq!(metadata.step, Some(100));
524 assert_eq!(
525 metadata.metadata.get("optimizer").unwrap(),
526 &"adam".to_string()
527 );
528 }
529}