#![allow(missing_docs)]
use crate::error::{IoError, Result};
use crate::lineage::DataLineage;
use crate::schema::Schema;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DataFormat {
Csv,
Tsv,
JsonLines,
Json,
Parquet,
Hdf5,
NetCdf,
ArrowIpc,
Matlab,
MatrixMarket,
Numpy,
Wav,
Custom(String),
Unknown,
}
impl DataFormat {
pub fn as_str(&self) -> &str {
match self {
DataFormat::Csv => "csv",
DataFormat::Tsv => "tsv",
DataFormat::JsonLines => "jsonlines",
DataFormat::Json => "json",
DataFormat::Parquet => "parquet",
DataFormat::Hdf5 => "hdf5",
DataFormat::NetCdf => "netcdf",
DataFormat::ArrowIpc => "arrow_ipc",
DataFormat::Matlab => "matlab",
DataFormat::MatrixMarket => "matrix_market",
DataFormat::Numpy => "numpy",
DataFormat::Wav => "wav",
DataFormat::Custom(name) => name.as_str(),
DataFormat::Unknown => "unknown",
}
}
pub fn from_extension(ext: &str) -> Self {
match ext.to_lowercase().as_str() {
"csv" => DataFormat::Csv,
"tsv" | "tab" => DataFormat::Tsv,
"jsonl" | "ndjson" => DataFormat::JsonLines,
"json" => DataFormat::Json,
"parquet" => DataFormat::Parquet,
"h5" | "hdf5" | "hdf" => DataFormat::Hdf5,
"nc" | "netcdf" | "nc4" => DataFormat::NetCdf,
"arrow" | "ipc" | "feather" => DataFormat::ArrowIpc,
"mat" => DataFormat::Matlab,
"mtx" | "mm" => DataFormat::MatrixMarket,
"npy" | "npz" => DataFormat::Numpy,
"wav" => DataFormat::Wav,
other => DataFormat::Custom(other.to_string()),
}
}
}
impl std::fmt::Display for DataFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetEntry {
pub id: Uuid,
pub name: String,
pub description: Option<String>,
pub location: String,
pub format: DataFormat,
pub schema: Option<Schema>,
pub tags: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub row_count: Option<u64>,
pub column_count: Option<u64>,
pub size_bytes: Option<u64>,
pub owner: Option<String>,
pub version: Option<String>,
pub is_active: bool,
}
impl DatasetEntry {
pub fn builder() -> DatasetEntryBuilder {
DatasetEntryBuilder::default()
}
pub fn touch(&mut self) {
self.updated_at = Utc::now();
}
pub fn add_tag(&mut self, tag: impl Into<String>) {
let tag = tag.into();
if !self.tags.contains(&tag) {
self.tags.push(tag);
}
}
pub fn remove_tag(&mut self, tag: &str) {
self.tags.retain(|t| t != tag);
}
pub fn has_tag(&self, tag: &str) -> bool {
self.tags.iter().any(|t| t == tag)
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.insert(key.into(), value);
self.updated_at = Utc::now();
}
pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
}
#[derive(Debug, Default)]
pub struct DatasetEntryBuilder {
name: Option<String>,
description: Option<String>,
location: Option<String>,
format: Option<DataFormat>,
schema: Option<Schema>,
tags: Vec<String>,
metadata: HashMap<String, serde_json::Value>,
row_count: Option<u64>,
column_count: Option<u64>,
size_bytes: Option<u64>,
owner: Option<String>,
version: Option<String>,
}
impl DatasetEntryBuilder {
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn location(mut self, loc: impl Into<String>) -> Self {
self.location = Some(loc.into());
self
}
pub fn format(mut self, fmt: DataFormat) -> Self {
self.format = Some(fmt);
self
}
pub fn schema(mut self, schema: Schema) -> Self {
self.schema = Some(schema);
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn row_count(mut self, n: u64) -> Self {
self.row_count = Some(n);
self
}
pub fn column_count(mut self, n: u64) -> Self {
self.column_count = Some(n);
self
}
pub fn size_bytes(mut self, bytes: u64) -> Self {
self.size_bytes = Some(bytes);
self
}
pub fn owner(mut self, owner: impl Into<String>) -> Self {
self.owner = Some(owner.into());
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn build(self) -> DatasetEntry {
let now = Utc::now();
DatasetEntry {
id: Uuid::new_v4(),
name: self.name.expect("DatasetEntry requires a name"),
description: self.description,
location: self.location.unwrap_or_default(),
format: self.format.unwrap_or(DataFormat::Unknown),
schema: self.schema,
tags: self.tags,
metadata: self.metadata,
created_at: now,
updated_at: now,
row_count: self.row_count,
column_count: self.column_count,
size_bytes: self.size_bytes,
owner: self.owner,
version: self.version,
is_active: true,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct DataCatalog {
entries: HashMap<Uuid, DatasetEntry>,
#[serde(skip)]
name_index: HashMap<String, Uuid>,
pub name: Option<String>,
pub description: Option<String>,
pub lineage: DataLineage,
}
impl DataCatalog {
pub fn new() -> Self {
Self::default()
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
..Default::default()
}
}
pub fn register(&mut self, entry: DatasetEntry) -> Result<Uuid> {
if self.name_index.contains_key(&entry.name) {
return Err(IoError::ValidationError(format!(
"Dataset '{}' already registered in catalog",
entry.name
)));
}
let id = entry.id;
self.name_index.insert(entry.name.clone(), id);
self.entries.insert(id, entry);
Ok(id)
}
pub fn upsert(&mut self, entry: DatasetEntry) -> Uuid {
let id = entry.id;
if let Some(old_id) = self.name_index.get(&entry.name).copied() {
self.entries.remove(&old_id);
}
self.name_index.insert(entry.name.clone(), id);
self.entries.insert(id, entry);
id
}
pub fn get_by_name(&self, name: &str) -> Option<&DatasetEntry> {
let id = self.name_index.get(name)?;
self.entries.get(id)
}
pub fn get_by_id(&self, id: Uuid) -> Option<&DatasetEntry> {
self.entries.get(&id)
}
pub fn get_by_name_mut(&mut self, name: &str) -> Option<&mut DatasetEntry> {
let id = *self.name_index.get(name)?;
self.entries.get_mut(&id)
}
pub fn deregister(&mut self, name: &str) -> Option<DatasetEntry> {
let id = self.name_index.remove(name)?;
self.entries.remove(&id)
}
pub fn all_entries(&self) -> impl Iterator<Item = &DatasetEntry> {
self.entries.values()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn rebuild_index(&mut self) {
self.name_index.clear();
for (id, entry) in &self.entries {
self.name_index.insert(entry.name.clone(), *id);
}
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
pub fn from_json(json: &str) -> Result<Self> {
let mut catalog: Self =
serde_json::from_str(json).map_err(|e| IoError::DeserializationError(e.to_string()))?;
catalog.rebuild_index();
Ok(catalog)
}
pub fn save_to_file(&self, path: &str) -> Result<()> {
let json = self.to_json()?;
std::fs::write(path, json).map_err(|e| IoError::Io(e))
}
pub fn load_from_file(path: &str) -> Result<Self> {
let json = std::fs::read_to_string(path).map_err(|e| IoError::Io(e))?;
Self::from_json(&json)
}
pub fn all_tags(&self) -> Vec<String> {
let mut tags: std::collections::HashSet<String> = std::collections::HashSet::new();
for entry in self.entries.values() {
for tag in &entry.tags {
tags.insert(tag.clone());
}
}
let mut sorted: Vec<String> = tags.into_iter().collect();
sorted.sort();
sorted
}
pub fn all_formats(&self) -> Vec<DataFormat> {
let mut seen: Vec<DataFormat> = Vec::new();
for entry in self.entries.values() {
if !seen.contains(&entry.format) {
seen.push(entry.format.clone());
}
}
seen
}
}
#[derive(Clone)]
pub enum SearchFilter {
Tag(String),
NameContains(String),
Format(DataFormat),
Owner(String),
HasSchema,
IsActive,
And(Vec<SearchFilter>),
Or(Vec<SearchFilter>),
Not(Box<SearchFilter>),
#[allow(clippy::type_complexity)]
Custom(String, std::sync::Arc<dyn Fn(&DatasetEntry) -> bool + Send + Sync>),
}
impl std::fmt::Debug for SearchFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SearchFilter::Tag(s) => write!(f, "Tag({s:?})"),
SearchFilter::NameContains(s) => write!(f, "NameContains({s:?})"),
SearchFilter::Format(fmt) => write!(f, "Format({fmt:?})"),
SearchFilter::Owner(s) => write!(f, "Owner({s:?})"),
SearchFilter::HasSchema => write!(f, "HasSchema"),
SearchFilter::IsActive => write!(f, "IsActive"),
SearchFilter::And(v) => write!(f, "And({v:?})"),
SearchFilter::Or(v) => write!(f, "Or({v:?})"),
SearchFilter::Not(b) => write!(f, "Not({b:?})"),
SearchFilter::Custom(name, _) => write!(f, "Custom({name:?}, <fn>)"),
}
}
}
impl SearchFilter {
pub fn matches(&self, entry: &DatasetEntry) -> bool {
match self {
SearchFilter::Tag(tag) => entry.has_tag(tag),
SearchFilter::NameContains(substr) => entry
.name
.to_lowercase()
.contains(&substr.to_lowercase()),
SearchFilter::Format(fmt) => &entry.format == fmt,
SearchFilter::Owner(owner) => {
entry.owner.as_deref() == Some(owner.as_str())
}
SearchFilter::HasSchema => entry.schema.is_some(),
SearchFilter::IsActive => entry.is_active,
SearchFilter::And(filters) => filters.iter().all(|f| f.matches(entry)),
SearchFilter::Or(filters) => filters.iter().any(|f| f.matches(entry)),
SearchFilter::Not(filter) => !filter.matches(entry),
SearchFilter::Custom(_, pred) => pred(entry),
}
}
}
pub struct CatalogSearcher<'a> {
catalog: &'a DataCatalog,
filters: Vec<SearchFilter>,
max_results: Option<usize>,
sort_by: SortOrder,
}
#[derive(Debug, Clone, Default)]
pub enum SortOrder {
#[default]
NameAsc,
NameDesc,
NewestFirst,
OldestFirst,
LargestFirst,
}
impl<'a> CatalogSearcher<'a> {
pub fn new(catalog: &'a DataCatalog) -> Self {
Self {
catalog,
filters: Vec::new(),
max_results: None,
sort_by: SortOrder::NameAsc,
}
}
pub fn filter(mut self, f: SearchFilter) -> Self {
self.filters.push(f);
self
}
pub fn limit(mut self, n: usize) -> Self {
self.max_results = Some(n);
self
}
pub fn sort_by(mut self, order: SortOrder) -> Self {
self.sort_by = order;
self
}
pub fn search(&self) -> Vec<&DatasetEntry> {
let mut results: Vec<&DatasetEntry> = self
.catalog
.all_entries()
.filter(|entry| self.filters.iter().all(|f| f.matches(entry)))
.collect();
match &self.sort_by {
SortOrder::NameAsc => results.sort_by(|a, b| a.name.cmp(&b.name)),
SortOrder::NameDesc => results.sort_by(|a, b| b.name.cmp(&a.name)),
SortOrder::NewestFirst => results.sort_by(|a, b| b.created_at.cmp(&a.created_at)),
SortOrder::OldestFirst => results.sort_by(|a, b| a.created_at.cmp(&b.created_at)),
SortOrder::LargestFirst => results.sort_by(|a, b| {
b.row_count
.unwrap_or(0)
.cmp(&a.row_count.unwrap_or(0))
}),
}
if let Some(n) = self.max_results {
results.truncate(n);
}
results
}
}
pub struct CatalogSerializer;
impl CatalogSerializer {
pub fn to_json_bytes(catalog: &DataCatalog) -> Result<Vec<u8>> {
serde_json::to_vec_pretty(catalog)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<DataCatalog> {
let mut catalog: DataCatalog =
serde_json::from_slice(bytes).map_err(|e| IoError::DeserializationError(e.to_string()))?;
catalog.rebuild_index();
Ok(catalog)
}
pub fn save(catalog: &DataCatalog, path: &str) -> Result<()> {
let bytes = Self::to_json_bytes(catalog)?;
let tmp = format!("{path}.tmp");
std::fs::write(&tmp, &bytes).map_err(IoError::Io)?;
std::fs::rename(&tmp, path).map_err(IoError::Io)?;
Ok(())
}
pub fn load(path: &str) -> Result<DataCatalog> {
let bytes = std::fs::read(path).map_err(IoError::Io)?;
Self::from_json_bytes(&bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{FieldType, Schema, SchemaField};
fn sample_entry(name: &str, format: DataFormat, tags: &[&str]) -> DatasetEntry {
let mut b = DatasetEntry::builder()
.name(name)
.location(format!("/data/{name}.csv"))
.format(format);
for t in tags {
b = b.tag(*t);
}
b.build()
}
#[test]
fn test_catalog_register_and_retrieve() {
let mut cat = DataCatalog::new();
let entry = sample_entry("iris", DataFormat::Csv, &["biology", "ml"]);
let id = cat.register(entry).unwrap();
let found = cat.get_by_id(id).unwrap();
assert_eq!(found.name, "iris");
assert_eq!(found.format, DataFormat::Csv);
}
#[test]
fn test_catalog_duplicate_name_error() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("iris", DataFormat::Csv, &[])).unwrap();
let result = cat.register(sample_entry("iris", DataFormat::Csv, &[]));
assert!(result.is_err());
}
#[test]
fn test_catalog_search_by_tag() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("iris", DataFormat::Csv, &["biology"])).unwrap();
cat.register(sample_entry("mnist", DataFormat::Numpy, &["image", "ml"])).unwrap();
cat.register(sample_entry("titanic", DataFormat::Csv, &["tabular", "ml"])).unwrap();
let searcher = CatalogSearcher::new(&cat)
.filter(SearchFilter::Tag("ml".to_string()));
let results = searcher.search();
assert_eq!(results.len(), 2);
}
#[test]
fn test_catalog_search_by_format() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("a", DataFormat::Csv, &[])).unwrap();
cat.register(sample_entry("b", DataFormat::Csv, &[])).unwrap();
cat.register(sample_entry("c", DataFormat::Parquet, &[])).unwrap();
let searcher = CatalogSearcher::new(&cat)
.filter(SearchFilter::Format(DataFormat::Csv));
let results = searcher.search();
assert_eq!(results.len(), 2);
}
#[test]
fn test_catalog_search_name_contains() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("train_data", DataFormat::Csv, &[])).unwrap();
cat.register(sample_entry("test_data", DataFormat::Csv, &[])).unwrap();
cat.register(sample_entry("validation_set", DataFormat::Csv, &[])).unwrap();
let searcher = CatalogSearcher::new(&cat)
.filter(SearchFilter::NameContains("data".to_string()));
let results = searcher.search();
assert_eq!(results.len(), 2);
}
#[test]
fn test_catalog_json_roundtrip() {
let mut cat = DataCatalog::with_name("my_catalog");
let entry = DatasetEntry::builder()
.name("ds1")
.location("/data/ds1.parquet")
.format(DataFormat::Parquet)
.tag("finance")
.row_count(10000)
.schema(
Schema::builder()
.field(SchemaField::new("price", FieldType::Float64))
.build(),
)
.build();
cat.register(entry).unwrap();
let json = cat.to_json().unwrap();
let restored = DataCatalog::from_json(&json).unwrap();
assert_eq!(restored.len(), 1);
let ds = restored.get_by_name("ds1").unwrap();
assert_eq!(ds.row_count, Some(10000));
assert!(ds.schema.is_some());
}
#[test]
fn test_catalog_save_load_file() {
let tmp = std::env::temp_dir().join("test_catalog_save.json");
let path = tmp.to_str().expect("temp path");
let mut cat = DataCatalog::new();
cat.register(sample_entry("test", DataFormat::Csv, &["temp"])).unwrap();
CatalogSerializer::save(&cat, path).unwrap();
let loaded = CatalogSerializer::load(path).unwrap();
assert_eq!(loaded.len(), 1);
assert!(loaded.get_by_name("test").is_some());
let _ = std::fs::remove_file(tmp);
}
#[test]
fn test_catalog_deregister() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("del_me", DataFormat::Csv, &[])).unwrap();
assert!(cat.get_by_name("del_me").is_some());
let removed = cat.deregister("del_me");
assert!(removed.is_some());
assert!(cat.get_by_name("del_me").is_none());
assert_eq!(cat.len(), 0);
}
#[test]
fn test_search_filter_and() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("train_csv", DataFormat::Csv, &["ml"])).unwrap();
cat.register(sample_entry("train_parquet", DataFormat::Parquet, &["ml"])).unwrap();
cat.register(sample_entry("raw_csv", DataFormat::Csv, &["raw"])).unwrap();
let searcher = CatalogSearcher::new(&cat)
.filter(SearchFilter::And(vec![
SearchFilter::Tag("ml".to_string()),
SearchFilter::Format(DataFormat::Csv),
]));
let results = searcher.search();
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "train_csv");
}
#[test]
fn test_all_tags() {
let mut cat = DataCatalog::new();
cat.register(sample_entry("a", DataFormat::Csv, &["x", "y"])).unwrap();
cat.register(sample_entry("b", DataFormat::Csv, &["y", "z"])).unwrap();
let tags = cat.all_tags();
assert_eq!(tags.len(), 3);
assert!(tags.contains(&"x".to_string()));
assert!(tags.contains(&"y".to_string()));
assert!(tags.contains(&"z".to_string()));
}
#[test]
fn test_data_format_from_extension() {
assert_eq!(DataFormat::from_extension("csv"), DataFormat::Csv);
assert_eq!(DataFormat::from_extension("parquet"), DataFormat::Parquet);
assert_eq!(DataFormat::from_extension("jsonl"), DataFormat::JsonLines);
assert_eq!(DataFormat::from_extension("h5"), DataFormat::Hdf5);
}
}