ipfrs_storage/
gradient.rs

1//! Gradient and tensor storage with delta encoding
2//!
3//! Provides efficient storage for neural network gradients and tensors:
4//! - Delta encoding (store changes only)
5//! - Sparse gradient compression
6//! - Provenance metadata tracking
7//! - Integration with version control
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use ipfrs_storage::{GradientStore, DeltaEncoder, SledBlockStore, BlockStoreConfig};
13//! use std::sync::Arc;
14//! use std::path::PathBuf;
15//!
16//! # async fn example() -> ipfrs_core::Result<()> {
17//! // Create block store
18//! let store = Arc::new(SledBlockStore::new(BlockStoreConfig {
19//!     path: PathBuf::from(".ipfrs/gradients"),
20//!     cache_size: 100 * 1024 * 1024,
21//! })?);
22//!
23//! // Create gradient store
24//! let gradient_store = GradientStore::new(store);
25//!
26//! // Store a gradient with delta encoding
27//! let gradient = vec![1.0f32, 2.0, 3.0, 4.0];
28//! let metadata = ProvenanceMetadata {
29//!     layer: "layer1".to_string(),
30//!     timestamp: 1234567890,
31//!     training_config: "lr=0.001".to_string(),
32//! };
33//!
34//! let cid = gradient_store.store_gradient(&gradient, Some(metadata)).await?;
35//! # Ok(())
36//! # }
37//! ```
38
39use crate::traits::BlockStore;
40use crate::vcs::VersionControl;
41use bytes::Bytes;
42use ipfrs_core::{Block, Cid, Error, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45use std::sync::Arc;
46
47/// Provenance metadata for tracking gradient origins
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProvenanceMetadata {
50    /// Layer name or identifier
51    pub layer: String,
52    /// Unix timestamp when gradient was computed
53    pub timestamp: u64,
54    /// Training configuration (hyperparameters, etc.)
55    pub training_config: String,
56    /// Optional parent gradient CID (for delta encoding)
57    #[serde(
58        serialize_with = "serialize_option_cid",
59        deserialize_with = "deserialize_option_cid"
60    )]
61    pub parent: Option<Cid>,
62    /// Training step/epoch number
63    pub step: Option<u64>,
64    /// Additional metadata
65    pub metadata: HashMap<String, String>,
66}
67
68// Custom serialization for Option<Cid>
69fn serialize_option_cid<S>(cid: &Option<Cid>, serializer: S) -> std::result::Result<S::Ok, S::Error>
70where
71    S: serde::Serializer,
72{
73    match cid {
74        Some(c) => serializer.serialize_some(&c.to_bytes()),
75        None => serializer.serialize_none(),
76    }
77}
78
79fn deserialize_option_cid<'de, D>(deserializer: D) -> std::result::Result<Option<Cid>, D::Error>
80where
81    D: serde::Deserializer<'de>,
82{
83    let opt: Option<Vec<u8>> = Deserialize::deserialize(deserializer)?;
84    match opt {
85        Some(bytes) => Cid::try_from(bytes)
86            .map(Some)
87            .map_err(serde::de::Error::custom),
88        None => Ok(None),
89    }
90}
91
92impl ProvenanceMetadata {
93    /// Create new provenance metadata
94    pub fn new(layer: String, training_config: String) -> Self {
95        Self {
96            layer,
97            timestamp: std::time::SystemTime::now()
98                .duration_since(std::time::UNIX_EPOCH)
99                .unwrap()
100                .as_secs(),
101            training_config,
102            parent: None,
103            step: None,
104            metadata: HashMap::new(),
105        }
106    }
107
108    /// Set parent gradient CID
109    pub fn with_parent(mut self, parent: Cid) -> Self {
110        self.parent = Some(parent);
111        self
112    }
113
114    /// Set training step
115    pub fn with_step(mut self, step: u64) -> Self {
116        self.step = Some(step);
117        self
118    }
119
120    /// Add custom metadata
121    pub fn with_metadata(mut self, key: String, value: String) -> Self {
122        self.metadata.insert(key, value);
123        self
124    }
125}
126
127/// Gradient data with metadata
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GradientData {
130    /// Shape of the gradient tensor
131    pub shape: Vec<usize>,
132    /// Data type (f32, f64, etc.)
133    pub dtype: String,
134    /// Encoded gradient data (delta or full)
135    pub data: Vec<u8>,
136    /// Whether this is a delta or full gradient
137    pub is_delta: bool,
138    /// Provenance metadata
139    pub provenance: Option<ProvenanceMetadata>,
140}
141
142/// Delta encoder for efficient gradient storage
143pub struct DeltaEncoder;
144
145impl DeltaEncoder {
146    /// Encode delta between base and target gradients
147    ///
148    /// Returns compressed delta representation
149    pub fn encode_delta(base: &[f32], target: &[f32]) -> Result<Vec<u8>> {
150        if base.len() != target.len() {
151            return Err(Error::Storage(
152                "Base and target must have same length".to_string(),
153            ));
154        }
155
156        // Compute delta
157        let delta: Vec<f32> = target.iter().zip(base.iter()).map(|(t, b)| t - b).collect();
158
159        // Sparse encoding: store only non-zero deltas
160        let mut sparse_delta = Vec::new();
161
162        for (idx, &value) in delta.iter().enumerate() {
163            if value.abs() > 1e-10 {
164                // Threshold for sparsity
165                // Store index and value
166                sparse_delta.extend_from_slice(&(idx as u32).to_le_bytes());
167                sparse_delta.extend_from_slice(&value.to_le_bytes());
168            }
169        }
170
171        Ok(sparse_delta)
172    }
173
174    /// Decode delta and apply to base gradient
175    pub fn decode_delta(base: &[f32], delta_bytes: &[u8]) -> Result<Vec<f32>> {
176        let mut result = base.to_vec();
177
178        // Read sparse delta entries
179        let mut offset = 0;
180        while offset + 8 <= delta_bytes.len() {
181            let idx_bytes = &delta_bytes[offset..offset + 4];
182            let value_bytes = &delta_bytes[offset + 4..offset + 8];
183
184            let idx = u32::from_le_bytes([idx_bytes[0], idx_bytes[1], idx_bytes[2], idx_bytes[3]])
185                as usize;
186            let value = f32::from_le_bytes([
187                value_bytes[0],
188                value_bytes[1],
189                value_bytes[2],
190                value_bytes[3],
191            ]);
192
193            if idx < result.len() {
194                result[idx] += value;
195            }
196
197            offset += 8;
198        }
199
200        Ok(result)
201    }
202
203    /// Compute compression ratio
204    pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f64 {
205        if compressed_size == 0 {
206            return 0.0;
207        }
208        original_size as f64 / compressed_size as f64
209    }
210}
211
212/// Gradient store for managing gradients with delta encoding
213pub struct GradientStore<S: BlockStore> {
214    /// Underlying block store
215    store: Arc<S>,
216    /// Optional version control integration
217    vcs: Option<Arc<VersionControl<S>>>,
218}
219
220impl<S: BlockStore> GradientStore<S> {
221    /// Create a new gradient store
222    pub fn new(store: Arc<S>) -> Self {
223        Self { store, vcs: None }
224    }
225
226    /// Create with version control integration
227    pub fn with_vcs(store: Arc<S>, vcs: Arc<VersionControl<S>>) -> Self {
228        Self {
229            store,
230            vcs: Some(vcs),
231        }
232    }
233
234    /// Get the version control system, if available
235    pub fn vcs(&self) -> Option<&Arc<VersionControl<S>>> {
236        self.vcs.as_ref()
237    }
238
239    /// Store a gradient (full)
240    pub async fn store_gradient(
241        &self,
242        data: &[f32],
243        shape: Vec<usize>,
244        provenance: Option<ProvenanceMetadata>,
245    ) -> Result<Cid> {
246        let gradient_data = GradientData {
247            shape,
248            dtype: "f32".to_string(),
249            data: Self::encode_f32_slice(data),
250            is_delta: false,
251            provenance,
252        };
253
254        self.store_gradient_data(&gradient_data).await
255    }
256
257    /// Store a gradient as delta from base
258    pub async fn store_gradient_delta(
259        &self,
260        base_cid: &Cid,
261        target: &[f32],
262        shape: Vec<usize>,
263        provenance: Option<ProvenanceMetadata>,
264    ) -> Result<Cid> {
265        // Load base gradient
266        let base_data = self.load_gradient(base_cid).await?;
267        let base = Self::decode_f32_slice(&base_data.data)?;
268
269        // Encode delta
270        let delta_bytes = DeltaEncoder::encode_delta(&base, target)?;
271
272        let mut prov = provenance
273            .unwrap_or_else(|| ProvenanceMetadata::new("unknown".to_string(), "delta".to_string()));
274        prov.parent = Some(*base_cid);
275
276        let gradient_data = GradientData {
277            shape,
278            dtype: "f32".to_string(),
279            data: delta_bytes,
280            is_delta: true,
281            provenance: Some(prov),
282        };
283
284        self.store_gradient_data(&gradient_data).await
285    }
286
287    /// Load a gradient and reconstruct if it's a delta
288    pub async fn load_gradient(&self, cid: &Cid) -> Result<GradientData> {
289        let block = self
290            .store
291            .get(cid)
292            .await?
293            .ok_or_else(|| Error::NotFound(format!("Gradient not found: {cid}")))?;
294
295        let gradient_data: GradientData =
296            oxicode::serde::decode_owned_from_slice(block.data(), oxicode::config::standard())
297                .map(|(v, _)| v)
298                .map_err(|e| {
299                    Error::Serialization(format!("Failed to deserialize gradient: {e}"))
300                })?;
301
302        Ok(gradient_data)
303    }
304
305    /// Reconstruct full gradient from delta chain
306    pub fn reconstruct_gradient<'a>(
307        &'a self,
308        cid: &'a Cid,
309    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>>> + Send + 'a>> {
310        Box::pin(async move {
311            let gradient_data = self.load_gradient(cid).await?;
312
313            if !gradient_data.is_delta {
314                // Already full gradient
315                return Self::decode_f32_slice(&gradient_data.data);
316            }
317
318            // Recursively reconstruct from base
319            let parent_cid = gradient_data
320                .provenance
321                .as_ref()
322                .and_then(|p| p.parent)
323                .ok_or_else(|| Error::Storage("Delta gradient missing parent CID".to_string()))?;
324
325            let base = self.reconstruct_gradient(&parent_cid).await?;
326            DeltaEncoder::decode_delta(&base, &gradient_data.data)
327        })
328    }
329
330    /// Store gradient data as a block
331    async fn store_gradient_data(&self, gradient_data: &GradientData) -> Result<Cid> {
332        let bytes = oxicode::serde::encode_to_vec(gradient_data, oxicode::config::standard())
333            .map_err(|e| Error::Serialization(format!("Failed to serialize gradient: {e}")))?;
334
335        let block = Block::new(Bytes::from(bytes))?;
336        let cid = *block.cid();
337        self.store.put(&block).await?;
338
339        Ok(cid)
340    }
341
342    /// Encode f32 slice to bytes
343    fn encode_f32_slice(data: &[f32]) -> Vec<u8> {
344        let mut bytes = Vec::with_capacity(data.len() * 4);
345        for &value in data {
346            bytes.extend_from_slice(&value.to_le_bytes());
347        }
348        bytes
349    }
350
351    /// Decode bytes to f32 slice
352    fn decode_f32_slice(bytes: &[u8]) -> Result<Vec<f32>> {
353        if !bytes.len().is_multiple_of(4) {
354            return Err(Error::Storage("Invalid f32 data length".to_string()));
355        }
356
357        let mut data = Vec::with_capacity(bytes.len() / 4);
358        for chunk in bytes.chunks_exact(4) {
359            let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
360            data.push(value);
361        }
362
363        Ok(data)
364    }
365
366    /// Get compression statistics for a gradient chain
367    pub async fn compute_compression_stats(&self, cid: &Cid) -> Result<CompressionStats> {
368        let gradient_data = self.load_gradient(cid).await?;
369
370        let original_size = gradient_data.shape.iter().product::<usize>() * 4; // f32 = 4 bytes
371        let compressed_size = gradient_data.data.len();
372
373        let ratio = DeltaEncoder::compression_ratio(original_size, compressed_size);
374
375        Ok(CompressionStats {
376            original_size,
377            compressed_size,
378            compression_ratio: ratio,
379            is_delta: gradient_data.is_delta,
380        })
381    }
382}
383
384/// Compression statistics
385#[derive(Debug, Clone, PartialEq)]
386pub struct CompressionStats {
387    /// Original uncompressed size in bytes
388    pub original_size: usize,
389    /// Compressed size in bytes
390    pub compressed_size: usize,
391    /// Compression ratio (original / compressed)
392    pub compression_ratio: f64,
393    /// Whether this is a delta encoding
394    pub is_delta: bool,
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::blockstore::{BlockStoreConfig, SledBlockStore};
401    use std::path::PathBuf;
402
403    #[test]
404    fn test_delta_encoding() {
405        let base = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
406        let target = vec![1.1f32, 2.0, 3.2, 4.0, 5.0];
407
408        let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
409        let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
410
411        for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
412            assert!(
413                (orig - recon).abs() < 1e-5,
414                "Mismatch at index {}: {} vs {}",
415                i,
416                orig,
417                recon
418            );
419        }
420    }
421
422    #[test]
423    fn test_sparse_delta() {
424        // Sparse gradient: only a few elements change
425        let base = vec![0.0f32; 1000];
426        let mut target = vec![0.0f32; 1000];
427        target[10] = 1.5;
428        target[500] = 2.3;
429        target[999] = -0.7;
430
431        let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
432
433        // Delta should be much smaller than full gradient
434        let full_size = 1000 * 4; // 4000 bytes
435        let delta_size = delta_bytes.len(); // Only 24 bytes (3 * 8)
436        assert!(delta_size < full_size / 10, "Delta not sparse enough");
437
438        let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
439        for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
440            assert!(
441                (orig - recon).abs() < 1e-5,
442                "Mismatch at index {}: {} vs {}",
443                i,
444                orig,
445                recon
446            );
447        }
448    }
449
450    #[tokio::test]
451    async fn test_gradient_store() {
452        let config = BlockStoreConfig {
453            path: PathBuf::from("/tmp/ipfrs-gradient-test"),
454            cache_size: 10 * 1024 * 1024,
455        };
456        let _ = std::fs::remove_dir_all(&config.path);
457
458        let store = Arc::new(SledBlockStore::new(config).unwrap());
459        let gradient_store = GradientStore::new(store);
460
461        let gradient = vec![1.0f32, 2.0, 3.0, 4.0];
462        let shape = vec![2, 2];
463
464        let cid = gradient_store
465            .store_gradient(&gradient, shape, None)
466            .await
467            .unwrap();
468
469        let loaded = gradient_store.load_gradient(&cid).await.unwrap();
470        assert_eq!(loaded.shape, vec![2, 2]);
471        assert!(!loaded.is_delta);
472    }
473
474    #[tokio::test]
475    async fn test_gradient_delta_chain() {
476        let config = BlockStoreConfig {
477            path: PathBuf::from("/tmp/ipfrs-gradient-delta-test"),
478            cache_size: 10 * 1024 * 1024,
479        };
480        let _ = std::fs::remove_dir_all(&config.path);
481
482        let store = Arc::new(SledBlockStore::new(config).unwrap());
483        let gradient_store = GradientStore::new(store);
484
485        // Store base gradient
486        let base_grad = vec![1.0f32, 2.0, 3.0, 4.0];
487        let base_cid = gradient_store
488            .store_gradient(&base_grad, vec![2, 2], None)
489            .await
490            .unwrap();
491
492        // Store delta
493        let target_grad = vec![1.1f32, 2.0, 3.2, 4.0];
494        let delta_cid = gradient_store
495            .store_gradient_delta(&base_cid, &target_grad, vec![2, 2], None)
496            .await
497            .unwrap();
498
499        // Reconstruct
500        let reconstructed = gradient_store
501            .reconstruct_gradient(&delta_cid)
502            .await
503            .unwrap();
504
505        for (i, (&orig, &recon)) in target_grad.iter().zip(reconstructed.iter()).enumerate() {
506            assert!(
507                (orig - recon).abs() < 1e-5,
508                "Mismatch at index {}: {} vs {}",
509                i,
510                orig,
511                recon
512            );
513        }
514    }
515
516    #[test]
517    fn test_provenance_metadata() {
518        let metadata = ProvenanceMetadata::new("layer1".to_string(), "lr=0.001".to_string())
519            .with_step(100)
520            .with_metadata("optimizer".to_string(), "adam".to_string());
521
522        assert_eq!(metadata.layer, "layer1");
523        assert_eq!(metadata.step, Some(100));
524        assert_eq!(
525            metadata.metadata.get("optimizer").unwrap(),
526            &"adam".to_string()
527        );
528    }
529}