use crate::traits::BlockStore;
use crate::vcs::VersionControl;
use bytes::Bytes;
use ipfrs_core::{Block, Cid, Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProvenanceMetadata {
pub layer: String,
pub timestamp: u64,
pub training_config: String,
#[serde(
serialize_with = "serialize_option_cid",
deserialize_with = "deserialize_option_cid"
)]
pub parent: Option<Cid>,
pub step: Option<u64>,
pub metadata: HashMap<String, String>,
}
fn serialize_option_cid<S>(cid: &Option<Cid>, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match cid {
Some(c) => serializer.serialize_some(&c.to_bytes()),
None => serializer.serialize_none(),
}
}
fn deserialize_option_cid<'de, D>(deserializer: D) -> std::result::Result<Option<Cid>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<Vec<u8>> = Deserialize::deserialize(deserializer)?;
match opt {
Some(bytes) => Cid::try_from(bytes)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
impl ProvenanceMetadata {
pub fn new(layer: String, training_config: String) -> Self {
Self {
layer,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
training_config,
parent: None,
step: None,
metadata: HashMap::new(),
}
}
pub fn with_parent(mut self, parent: Cid) -> Self {
self.parent = Some(parent);
self
}
pub fn with_step(mut self, step: u64) -> Self {
self.step = Some(step);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GradientData {
pub shape: Vec<usize>,
pub dtype: String,
pub data: Vec<u8>,
pub is_delta: bool,
pub provenance: Option<ProvenanceMetadata>,
}
pub struct DeltaEncoder;
impl DeltaEncoder {
pub fn encode_delta(base: &[f32], target: &[f32]) -> Result<Vec<u8>> {
if base.len() != target.len() {
return Err(Error::Storage(
"Base and target must have same length".to_string(),
));
}
let delta: Vec<f32> = target.iter().zip(base.iter()).map(|(t, b)| t - b).collect();
let mut sparse_delta = Vec::new();
for (idx, &value) in delta.iter().enumerate() {
if value.abs() > 1e-10 {
sparse_delta.extend_from_slice(&(idx as u32).to_le_bytes());
sparse_delta.extend_from_slice(&value.to_le_bytes());
}
}
Ok(sparse_delta)
}
pub fn decode_delta(base: &[f32], delta_bytes: &[u8]) -> Result<Vec<f32>> {
let mut result = base.to_vec();
let mut offset = 0;
while offset + 8 <= delta_bytes.len() {
let idx_bytes = &delta_bytes[offset..offset + 4];
let value_bytes = &delta_bytes[offset + 4..offset + 8];
let idx = u32::from_le_bytes([idx_bytes[0], idx_bytes[1], idx_bytes[2], idx_bytes[3]])
as usize;
let value = f32::from_le_bytes([
value_bytes[0],
value_bytes[1],
value_bytes[2],
value_bytes[3],
]);
if idx < result.len() {
result[idx] += value;
}
offset += 8;
}
Ok(result)
}
pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f64 {
if compressed_size == 0 {
return 0.0;
}
original_size as f64 / compressed_size as f64
}
}
pub struct GradientStore<S: BlockStore> {
store: Arc<S>,
vcs: Option<Arc<VersionControl<S>>>,
}
impl<S: BlockStore> GradientStore<S> {
pub fn new(store: Arc<S>) -> Self {
Self { store, vcs: None }
}
pub fn with_vcs(store: Arc<S>, vcs: Arc<VersionControl<S>>) -> Self {
Self {
store,
vcs: Some(vcs),
}
}
pub fn vcs(&self) -> Option<&Arc<VersionControl<S>>> {
self.vcs.as_ref()
}
pub async fn store_gradient(
&self,
data: &[f32],
shape: Vec<usize>,
provenance: Option<ProvenanceMetadata>,
) -> Result<Cid> {
let gradient_data = GradientData {
shape,
dtype: "f32".to_string(),
data: Self::encode_f32_slice(data),
is_delta: false,
provenance,
};
self.store_gradient_data(&gradient_data).await
}
pub async fn store_gradient_delta(
&self,
base_cid: &Cid,
target: &[f32],
shape: Vec<usize>,
provenance: Option<ProvenanceMetadata>,
) -> Result<Cid> {
let base_data = self.load_gradient(base_cid).await?;
let base = Self::decode_f32_slice(&base_data.data)?;
let delta_bytes = DeltaEncoder::encode_delta(&base, target)?;
let mut prov = provenance
.unwrap_or_else(|| ProvenanceMetadata::new("unknown".to_string(), "delta".to_string()));
prov.parent = Some(*base_cid);
let gradient_data = GradientData {
shape,
dtype: "f32".to_string(),
data: delta_bytes,
is_delta: true,
provenance: Some(prov),
};
self.store_gradient_data(&gradient_data).await
}
pub async fn load_gradient(&self, cid: &Cid) -> Result<GradientData> {
let block = self
.store
.get(cid)
.await?
.ok_or_else(|| Error::NotFound(format!("Gradient not found: {cid}")))?;
let gradient_data: GradientData =
oxicode::serde::decode_owned_from_slice(block.data(), oxicode::config::standard())
.map(|(v, _)| v)
.map_err(|e| {
Error::Serialization(format!("Failed to deserialize gradient: {e}"))
})?;
Ok(gradient_data)
}
pub fn reconstruct_gradient<'a>(
&'a self,
cid: &'a Cid,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>>> + Send + 'a>> {
Box::pin(async move {
let gradient_data = self.load_gradient(cid).await?;
if !gradient_data.is_delta {
return Self::decode_f32_slice(&gradient_data.data);
}
let parent_cid = gradient_data
.provenance
.as_ref()
.and_then(|p| p.parent)
.ok_or_else(|| Error::Storage("Delta gradient missing parent CID".to_string()))?;
let base = self.reconstruct_gradient(&parent_cid).await?;
DeltaEncoder::decode_delta(&base, &gradient_data.data)
})
}
async fn store_gradient_data(&self, gradient_data: &GradientData) -> Result<Cid> {
let bytes = oxicode::serde::encode_to_vec(gradient_data, oxicode::config::standard())
.map_err(|e| Error::Serialization(format!("Failed to serialize gradient: {e}")))?;
let block = Block::new(Bytes::from(bytes))?;
let cid = *block.cid();
self.store.put(&block).await?;
Ok(cid)
}
fn encode_f32_slice(data: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(data.len() * 4);
for &value in data {
bytes.extend_from_slice(&value.to_le_bytes());
}
bytes
}
fn decode_f32_slice(bytes: &[u8]) -> Result<Vec<f32>> {
if !bytes.len().is_multiple_of(4) {
return Err(Error::Storage("Invalid f32 data length".to_string()));
}
let mut data = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
data.push(value);
}
Ok(data)
}
pub async fn compute_compression_stats(&self, cid: &Cid) -> Result<CompressionStats> {
let gradient_data = self.load_gradient(cid).await?;
let original_size = gradient_data.shape.iter().product::<usize>() * 4; let compressed_size = gradient_data.data.len();
let ratio = DeltaEncoder::compression_ratio(original_size, compressed_size);
Ok(CompressionStats {
original_size,
compressed_size,
compression_ratio: ratio,
is_delta: gradient_data.is_delta,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CompressionStats {
pub original_size: usize,
pub compressed_size: usize,
pub compression_ratio: f64,
pub is_delta: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::blockstore::{BlockStoreConfig, SledBlockStore};
use std::path::PathBuf;
#[test]
fn test_delta_encoding() {
let base = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let target = vec![1.1f32, 2.0, 3.2, 4.0, 5.0];
let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
assert!(
(orig - recon).abs() < 1e-5,
"Mismatch at index {}: {} vs {}",
i,
orig,
recon
);
}
}
#[test]
fn test_sparse_delta() {
let base = vec![0.0f32; 1000];
let mut target = vec![0.0f32; 1000];
target[10] = 1.5;
target[500] = 2.3;
target[999] = -0.7;
let delta_bytes = DeltaEncoder::encode_delta(&base, &target).unwrap();
let full_size = 1000 * 4; let delta_size = delta_bytes.len(); assert!(delta_size < full_size / 10, "Delta not sparse enough");
let reconstructed = DeltaEncoder::decode_delta(&base, &delta_bytes).unwrap();
for (i, (&orig, &recon)) in target.iter().zip(reconstructed.iter()).enumerate() {
assert!(
(orig - recon).abs() < 1e-5,
"Mismatch at index {}: {} vs {}",
i,
orig,
recon
);
}
}
#[tokio::test]
async fn test_gradient_store() {
let config = BlockStoreConfig {
path: PathBuf::from("/tmp/ipfrs-gradient-test"),
cache_size: 10 * 1024 * 1024,
};
let _ = std::fs::remove_dir_all(&config.path);
let store = Arc::new(SledBlockStore::new(config).unwrap());
let gradient_store = GradientStore::new(store);
let gradient = vec![1.0f32, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let cid = gradient_store
.store_gradient(&gradient, shape, None)
.await
.unwrap();
let loaded = gradient_store.load_gradient(&cid).await.unwrap();
assert_eq!(loaded.shape, vec![2, 2]);
assert!(!loaded.is_delta);
}
#[tokio::test]
async fn test_gradient_delta_chain() {
let config = BlockStoreConfig {
path: PathBuf::from("/tmp/ipfrs-gradient-delta-test"),
cache_size: 10 * 1024 * 1024,
};
let _ = std::fs::remove_dir_all(&config.path);
let store = Arc::new(SledBlockStore::new(config).unwrap());
let gradient_store = GradientStore::new(store);
let base_grad = vec![1.0f32, 2.0, 3.0, 4.0];
let base_cid = gradient_store
.store_gradient(&base_grad, vec![2, 2], None)
.await
.unwrap();
let target_grad = vec![1.1f32, 2.0, 3.2, 4.0];
let delta_cid = gradient_store
.store_gradient_delta(&base_cid, &target_grad, vec![2, 2], None)
.await
.unwrap();
let reconstructed = gradient_store
.reconstruct_gradient(&delta_cid)
.await
.unwrap();
for (i, (&orig, &recon)) in target_grad.iter().zip(reconstructed.iter()).enumerate() {
assert!(
(orig - recon).abs() < 1e-5,
"Mismatch at index {}: {} vs {}",
i,
orig,
recon
);
}
}
#[test]
fn test_provenance_metadata() {
let metadata = ProvenanceMetadata::new("layer1".to_string(), "lr=0.001".to_string())
.with_step(100)
.with_metadata("optimizer".to_string(), "adam".to_string());
assert_eq!(metadata.layer, "layer1");
assert_eq!(metadata.step, Some(100));
assert_eq!(
metadata.metadata.get("optimizer").unwrap(),
&"adam".to_string()
);
}
}