use crate::{Dataset, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use tenflowers_core::{Tensor, TensorError};
pub type VersionId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionMetadata {
pub version_id: VersionId,
pub parent_version: Option<VersionId>,
pub timestamp: u64,
pub description: String,
pub tags: Vec<String>,
pub custom_metadata: HashMap<String, String>,
pub checksum: String,
pub size_info: DatasetSizeInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetSizeInfo {
pub sample_count: usize,
pub feature_shape: Vec<usize>,
pub label_shape: Vec<usize>,
pub size_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetLineage {
pub version: VersionMetadata,
pub transformations: Vec<TransformationRecord>,
pub source_versions: Vec<VersionId>,
pub child_versions: Vec<VersionId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformationRecord {
pub transform_type: String,
pub parameters: HashMap<String, String>,
pub timestamp: u64,
pub description: String,
}
#[derive(Debug)]
pub struct DatasetVersionManager {
base_path: PathBuf,
lineage_graph: HashMap<VersionId, DatasetLineage>,
current_version: Option<VersionId>,
}
impl DatasetVersionManager {
pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
if !base_path.exists() {
std::fs::create_dir_all(&base_path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to create version directory: {e}"))
})?;
}
let mut manager = Self {
base_path,
lineage_graph: HashMap::new(),
current_version: None,
};
manager.load_lineage_graph()?;
Ok(manager)
}
pub fn create_snapshot<T>(
&mut self,
dataset: &dyn Dataset<T>,
description: String,
tags: Vec<String>,
parent_version: Option<VersionId>,
) -> Result<VersionId>
where
T: Clone + Default + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
{
let version_id = self.generate_version_id();
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs();
let size_info = self.calculate_size_info(dataset)?;
let checksum = self.calculate_checksum(dataset)?;
let metadata = VersionMetadata {
version_id: version_id.clone(),
parent_version: parent_version.clone(),
timestamp,
description,
tags,
custom_metadata: HashMap::new(),
checksum,
size_info,
};
let version_dir = self.base_path.join(&version_id);
std::fs::create_dir_all(&version_dir).map_err(|e| {
TensorError::invalid_argument(format!("Failed to create version directory: {e}"))
})?;
self.save_dataset_samples(dataset, &version_dir)?;
self.save_metadata(&metadata, &version_dir)?;
let lineage = DatasetLineage {
version: metadata,
transformations: Vec::new(),
source_versions: if let Some(parent) = &parent_version {
vec![parent.clone()]
} else {
Vec::new()
},
child_versions: Vec::new(),
};
self.lineage_graph.insert(version_id.clone(), lineage);
if let Some(parent) = &parent_version {
if let Some(parent_lineage) = self.lineage_graph.get_mut(parent) {
parent_lineage.child_versions.push(version_id.clone());
}
}
self.current_version = Some(version_id.clone());
self.save_lineage_graph()?;
Ok(version_id)
}
pub fn load_snapshot<T>(&self, version_id: &str) -> Result<VersionedDataset<T>>
where
T: Clone + Default + serde::de::DeserializeOwned + Send + Sync + 'static,
{
let version_dir = self.base_path.join(version_id);
if !version_dir.exists() {
return Err(TensorError::invalid_argument(format!(
"Version {version_id} not found"
)));
}
let metadata = self.load_metadata(&version_dir)?;
let samples = self.load_dataset_samples(&version_dir)?;
Ok(VersionedDataset { metadata, samples })
}
pub fn get_lineage(&self, version_id: &str) -> Option<&DatasetLineage> {
self.lineage_graph.get(version_id)
}
pub fn add_transformation(
&mut self,
version_id: &str,
transform_type: String,
parameters: HashMap<String, String>,
description: String,
) -> Result<()> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs();
let transformation = TransformationRecord {
transform_type,
parameters,
timestamp,
description,
};
if let Some(lineage) = self.lineage_graph.get_mut(version_id) {
lineage.transformations.push(transformation);
self.save_lineage_graph()?;
} else {
return Err(TensorError::invalid_argument(format!(
"Version {version_id} not found"
)));
}
Ok(())
}
pub fn list_versions(&self) -> Vec<&VersionMetadata> {
self.lineage_graph
.values()
.map(|lineage| &lineage.version)
.collect()
}
pub fn get_versions_by_tag(&self, tag: &str) -> Vec<&VersionMetadata> {
self.lineage_graph
.values()
.filter(|lineage| lineage.version.tags.contains(&tag.to_string()))
.map(|lineage| &lineage.version)
.collect()
}
pub fn get_lineage_tree(&self, version_id: &str) -> Option<LineageTree> {
self.lineage_graph
.get(version_id)
.map(|lineage| self.build_lineage_tree(&lineage.version))
}
fn build_lineage_tree(&self, version: &VersionMetadata) -> LineageTree {
let children = version.version_id.clone();
let child_trees = if let Some(lineage) = self.lineage_graph.get(&children) {
lineage
.child_versions
.iter()
.filter_map(|child_id| {
self.lineage_graph
.get(child_id)
.map(|child_lineage| self.build_lineage_tree(&child_lineage.version))
})
.collect()
} else {
Vec::new()
};
LineageTree {
version: version.clone(),
children: child_trees,
}
}
fn generate_version_id(&self) -> VersionId {
format!("v_{}", uuid::Uuid::new_v4().to_string().replace('-', ""))
}
fn calculate_size_info<T>(&self, dataset: &dyn Dataset<T>) -> Result<DatasetSizeInfo>
where
T: Clone + Default + Send + Sync + 'static,
{
let sample_count = dataset.len();
if sample_count == 0 {
return Ok(DatasetSizeInfo {
sample_count: 0,
feature_shape: vec![0],
label_shape: vec![0],
size_bytes: 0,
});
}
let (features, labels) = dataset.get(0)?;
let feature_shape = features.shape().dims().to_vec();
let label_shape = labels.shape().dims().to_vec();
let feature_size = feature_shape.iter().product::<usize>();
let label_size = label_shape.iter().product::<usize>();
let estimated_bytes_per_sample = (feature_size + label_size) * std::mem::size_of::<f32>();
let size_bytes = (sample_count * estimated_bytes_per_sample) as u64;
Ok(DatasetSizeInfo {
sample_count,
feature_shape,
label_shape,
size_bytes,
})
}
fn calculate_checksum<T>(&self, dataset: &dyn Dataset<T>) -> Result<String>
where
T: Clone + Default + Send + Sync + 'static,
{
let len = dataset.len();
if len == 0 {
return Ok("empty_dataset".to_string());
}
let (first_features, first_labels) = dataset.get(0)?;
let mut checksum_value = 0u64;
checksum_value = checksum_value.wrapping_mul(31).wrapping_add(len as u64);
for &dim in first_features.shape().dims() {
checksum_value = checksum_value.wrapping_mul(31).wrapping_add(dim as u64);
}
for &dim in first_labels.shape().dims() {
checksum_value = checksum_value.wrapping_mul(31).wrapping_add(dim as u64);
}
let features_hash = format!("{:?}", first_features.shape().dims()).len() as u64;
let labels_hash = format!("{:?}", first_labels.shape().dims()).len() as u64;
checksum_value = checksum_value.wrapping_mul(31).wrapping_add(features_hash);
checksum_value = checksum_value.wrapping_mul(31).wrapping_add(labels_hash);
Ok(format!("{checksum_value:016x}"))
}
fn save_dataset_samples<T>(&self, dataset: &dyn Dataset<T>, version_dir: &Path) -> Result<()>
where
T: Clone + Default + serde::Serialize + Send + Sync + 'static,
{
let samples_file = version_dir.join("samples.json");
let mut samples = Vec::new();
for i in 0..dataset.len() {
let (features, labels) = dataset.get(i)?;
let features_data = if let Some(slice) = features.as_slice() {
slice.to_vec()
} else {
vec![features.get(&[]).unwrap_or(T::default())]
};
let labels_data = if let Some(slice) = labels.as_slice() {
slice.to_vec()
} else {
vec![labels.get(&[]).unwrap_or(T::default())]
};
samples.push(serde_json::json!({
"features": features_data,
"labels": labels_data,
"feature_shape": features.shape().dims(),
"label_shape": labels.shape().dims(),
}));
}
let json_data = serde_json::to_string_pretty(&samples).map_err(|e| {
TensorError::invalid_argument(format!("Failed to serialize samples: {e}"))
})?;
std::fs::write(samples_file, json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to write samples file: {e}"))
})?;
Ok(())
}
fn load_dataset_samples<T>(&self, version_dir: &Path) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
where
T: Clone + Default + serde::de::DeserializeOwned + Send + Sync + 'static,
{
let samples_file = version_dir.join("samples.json");
let json_data = std::fs::read_to_string(samples_file).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read samples file: {e}"))
})?;
let json_samples: Vec<serde_json::Value> =
serde_json::from_str(&json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse samples JSON: {e}"))
})?;
let mut samples = Vec::new();
for sample in json_samples {
let features_data: Vec<T> = serde_json::from_value(sample["features"].clone())
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse features: {e}"))
})?;
let labels_data: Vec<T> =
serde_json::from_value(sample["labels"].clone()).map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse labels: {e}"))
})?;
let feature_shape: Vec<usize> = serde_json::from_value(sample["feature_shape"].clone())
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse feature shape: {e}"))
})?;
let label_shape: Vec<usize> = serde_json::from_value(sample["label_shape"].clone())
.map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse label shape: {e}"))
})?;
let features_tensor = if feature_shape.is_empty() || feature_shape == vec![0] {
Tensor::from_scalar(features_data.into_iter().next().unwrap_or_default())
} else {
Tensor::from_vec(features_data, &feature_shape)?
};
let labels_tensor = if label_shape.is_empty() || label_shape == vec![0] {
Tensor::from_scalar(labels_data.into_iter().next().unwrap_or_default())
} else {
Tensor::from_vec(labels_data, &label_shape)?
};
samples.push((features_tensor, labels_tensor));
}
Ok(samples)
}
fn save_metadata(&self, metadata: &VersionMetadata, version_dir: &Path) -> Result<()> {
let metadata_file = version_dir.join("metadata.json");
let json_data = serde_json::to_string_pretty(metadata).map_err(|e| {
TensorError::invalid_argument(format!("Failed to serialize metadata: {e}"))
})?;
std::fs::write(metadata_file, json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to write metadata file: {e}"))
})?;
Ok(())
}
fn load_metadata(&self, version_dir: &Path) -> Result<VersionMetadata> {
let metadata_file = version_dir.join("metadata.json");
let json_data = std::fs::read_to_string(metadata_file).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read metadata file: {e}"))
})?;
serde_json::from_str(&json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse metadata JSON: {e}"))
})
}
fn save_lineage_graph(&self) -> Result<()> {
let lineage_file = self.base_path.join("lineage.json");
let json_data = serde_json::to_string_pretty(&self.lineage_graph).map_err(|e| {
TensorError::invalid_argument(format!("Failed to serialize lineage graph: {e}"))
})?;
std::fs::write(lineage_file, json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to write lineage file: {e}"))
})?;
Ok(())
}
fn load_lineage_graph(&mut self) -> Result<()> {
let lineage_file = self.base_path.join("lineage.json");
if !lineage_file.exists() {
return Ok(()); }
let json_data = std::fs::read_to_string(lineage_file).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read lineage file: {e}"))
})?;
self.lineage_graph = serde_json::from_str(&json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse lineage JSON: {e}"))
})?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LineageTree {
pub version: VersionMetadata,
pub children: Vec<LineageTree>,
}
#[derive(Debug)]
pub struct VersionedDataset<T> {
metadata: VersionMetadata,
samples: Vec<(Tensor<T>, Tensor<T>)>,
}
impl<T> VersionedDataset<T>
where
T: Clone + Default + Send + Sync + 'static,
{
pub fn metadata(&self) -> &VersionMetadata {
&self.metadata
}
pub fn version_id(&self) -> &str {
&self.metadata.version_id
}
}
impl<T> Dataset<T> for VersionedDataset<T>
where
T: Clone + Default + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.samples.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.samples.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.samples.len()
)));
}
Ok(self.samples[index].clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tempfile::TempDir;
#[test]
fn test_version_manager_creation() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let manager =
DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
assert!(temp_dir.path().exists());
assert_eq!(manager.list_versions().len(), 0);
}
#[test]
fn test_create_and_load_snapshot() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let mut manager =
DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
let features_data = vec![1.0, 2.0, 3.0, 4.0];
let labels_data = vec![0.0, 1.0];
let features =
Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let version_id = manager
.create_snapshot(
&dataset,
"Test snapshot".to_string(),
vec!["test".to_string()],
None,
)
.expect("test: operation should succeed");
assert!(!version_id.is_empty());
assert_eq!(manager.list_versions().len(), 1);
let loaded_dataset = manager
.load_snapshot::<f32>(&version_id)
.expect("test: operation should succeed");
assert_eq!(loaded_dataset.len(), 2);
assert_eq!(loaded_dataset.version_id(), &version_id);
let (features, labels) = loaded_dataset.get(0).expect("index should be in bounds");
let features_slice = features.as_slice().expect("tensor should be contiguous");
assert_eq!(features_slice, &[1.0, 2.0]);
assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
}
#[test]
fn test_lineage_tracking() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let mut manager =
DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
let features_data1 = vec![1.0, 2.0];
let labels_data1 = vec![0.0];
let features1 = Tensor::from_vec(features_data1, &[1, 2])
.expect("test: tensor creation should succeed");
let labels1 =
Tensor::from_vec(labels_data1, &[1]).expect("test: tensor creation should succeed");
let dataset1 = TensorDataset::new(features1, labels1);
let version1 = manager
.create_snapshot(
&dataset1,
"Initial version".to_string(),
vec!["v1".to_string()],
None,
)
.expect("test: operation should succeed");
let features_data2 = vec![2.0, 4.0];
let labels_data2 = vec![1.0];
let features2 = Tensor::from_vec(features_data2, &[1, 2])
.expect("test: tensor creation should succeed");
let labels2 =
Tensor::from_vec(labels_data2, &[1]).expect("test: tensor creation should succeed");
let dataset2 = TensorDataset::new(features2, labels2);
let version2 = manager
.create_snapshot(
&dataset2,
"Scaled version".to_string(),
vec!["v2".to_string()],
Some(version1.clone()),
)
.expect("test: operation should succeed");
let mut params = HashMap::new();
params.insert("scale_factor".to_string(), "2.0".to_string());
manager
.add_transformation(
&version2,
"scale".to_string(),
params,
"Scale features by 2".to_string(),
)
.expect("test: operation should succeed");
let lineage = manager
.get_lineage(&version2)
.expect("test: operation should succeed");
assert_eq!(lineage.source_versions, vec![version1.clone()]);
assert_eq!(lineage.transformations.len(), 1);
assert_eq!(lineage.transformations[0].transform_type, "scale");
let tree = manager
.get_lineage_tree(&version1)
.expect("test: operation should succeed");
assert_eq!(tree.version.version_id, version1);
assert_eq!(tree.children.len(), 1);
assert_eq!(tree.children[0].version.version_id, version2);
}
#[test]
fn test_version_filtering() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let mut manager =
DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
let features_data = vec![1.0];
let labels_data = vec![0.0];
let features =
Tensor::from_vec(features_data, &[1, 1]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[1]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let _version1 = manager
.create_snapshot(
&dataset,
"Version 1".to_string(),
vec!["production".to_string()],
None,
)
.expect("test: operation should succeed");
let _version2 = manager
.create_snapshot(
&dataset,
"Version 2".to_string(),
vec!["development".to_string()],
None,
)
.expect("test: operation should succeed");
let _version3 = manager
.create_snapshot(
&dataset,
"Version 3".to_string(),
vec!["production".to_string(), "validated".to_string()],
None,
)
.expect("test: operation should succeed");
let prod_versions = manager.get_versions_by_tag("production");
assert_eq!(prod_versions.len(), 2);
let dev_versions = manager.get_versions_by_tag("development");
assert_eq!(dev_versions.len(), 1);
let validated_versions = manager.get_versions_by_tag("validated");
assert_eq!(validated_versions.len(), 1);
}
}