1use std::collections::BTreeMap;
12
13use hirn_core::HirnError;
14use hirn_storage::PhysicalStore;
15use hirn_storage::store::VersionTag;
16
17#[derive(Debug, Clone)]
19pub struct Snapshot {
20 pub name: String,
22 pub versions: BTreeMap<String, u64>,
24}
25
26#[derive(Debug, Clone)]
28pub struct SnapshotReport {
29 pub tag: String,
31 pub datasets_tagged: usize,
33}
34
35#[derive(Debug, Clone)]
37pub struct RollbackReport {
38 pub tag: String,
40 pub datasets_rolled_back: usize,
42}
43
44pub async fn create_snapshot(
46 storage: &dyn PhysicalStore,
47 tag: &str,
48) -> Result<SnapshotReport, HirnError> {
49 let datasets = storage
50 .list_datasets()
51 .await
52 .map_err(|e| HirnError::storage(e))?;
53
54 let mut tagged = 0usize;
55
56 for ds in &datasets {
57 storage
58 .tag(&ds.name, tag)
59 .await
60 .map_err(|e| HirnError::storage(e))?;
61 tagged += 1;
62 }
63
64 Ok(SnapshotReport {
65 tag: tag.to_string(),
66 datasets_tagged: tagged,
67 })
68}
69
70pub async fn list_snapshots(storage: &dyn PhysicalStore) -> Result<Vec<Snapshot>, HirnError> {
74 let datasets = storage
75 .list_datasets()
76 .await
77 .map_err(|e| HirnError::storage(e))?;
78
79 if datasets.is_empty() {
80 return Ok(Vec::new());
81 }
82
83 let mut tag_map: BTreeMap<String, BTreeMap<String, u64>> = BTreeMap::new();
85
86 for ds in &datasets {
87 let tags = storage
88 .list_tags(&ds.name)
89 .await
90 .map_err(|e| HirnError::storage(e))?;
91 for t in tags {
92 tag_map
93 .entry(t.name)
94 .or_default()
95 .insert(ds.name.clone(), t.version);
96 }
97 }
98
99 let num_datasets = datasets.len();
100 let snapshots = tag_map
101 .into_iter()
102 .filter(|(_, versions)| versions.len() == num_datasets)
103 .map(|(name, versions)| Snapshot { name, versions })
104 .collect();
105
106 Ok(snapshots)
107}
108
109pub async fn rollback(storage: &dyn PhysicalStore, tag: &str) -> Result<RollbackReport, HirnError> {
111 let datasets = storage
112 .list_datasets()
113 .await
114 .map_err(|e| HirnError::storage(e))?;
115
116 let mut rolled_back = 0usize;
117
118 for ds in &datasets {
119 let tags: Vec<VersionTag> = storage
120 .list_tags(&ds.name)
121 .await
122 .map_err(|e| HirnError::storage(e))?;
123
124 let target = tags.iter().find(|t| t.name == tag).ok_or_else(|| {
125 HirnError::storage(format!(
126 "snapshot tag '{}' not found on dataset '{}'",
127 tag, ds.name
128 ))
129 })?;
130
131 storage
132 .checkout(&ds.name, target.version)
133 .await
134 .map_err(|e| HirnError::storage(e))?;
135
136 rolled_back += 1;
137 }
138
139 Ok(RollbackReport {
140 tag: tag.to_string(),
141 datasets_rolled_back: rolled_back,
142 })
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use hirn_storage::memory_store::MemoryStore;
149
150 #[tokio::test]
151 async fn snapshot_empty_storage() {
152 let storage = MemoryStore::new();
153 let report = create_snapshot(&storage, "test-snap").await.unwrap();
154 assert_eq!(report.datasets_tagged, 0);
155 }
156
157 #[tokio::test]
158 async fn list_snapshots_empty_storage() {
159 let storage = MemoryStore::new();
160 let snapshots = list_snapshots(&storage).await.unwrap();
161 assert!(snapshots.is_empty());
162 }
163
164 #[tokio::test]
165 async fn rollback_empty_storage() {
166 let storage = MemoryStore::new();
167 let report = rollback(&storage, "nonexistent").await.unwrap();
169 assert_eq!(report.datasets_rolled_back, 0);
170 }
171}