1use crate::error::{DbxError, DbxResult};
8use arrow::datatypes::Schema;
9use dashmap::DashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct SchemaVersion {
15 pub version: u64,
17 pub schema: Arc<Schema>,
19 pub created_at: u64,
21 pub description: String,
23}
24
25pub struct SchemaVersionManager {
32 versions: DashMap<String, Vec<SchemaVersion>>,
34 current_versions: DashMap<String, u64>,
36 current_cache: DashMap<String, Arc<Schema>>,
38}
39
40impl SchemaVersionManager {
41 pub fn new() -> Self {
43 Self {
44 versions: DashMap::new(),
45 current_versions: DashMap::new(),
46 current_cache: DashMap::new(),
47 }
48 }
49
50 pub fn register_table(&self, table: &str, schema: Arc<Schema>) -> DbxResult<u64> {
52 let version = SchemaVersion {
53 version: 1,
54 schema: schema.clone(),
55 created_at: Self::now(),
56 description: "Initial schema".to_string(),
57 };
58
59 self.versions.insert(table.to_string(), vec![version]);
60 self.current_versions.insert(table.to_string(), 1);
61 self.current_cache.insert(table.to_string(), schema);
62
63 Ok(1)
64 }
65
66 pub fn alter_table(
68 &self,
69 table: &str,
70 new_schema: Arc<Schema>,
71 description: &str,
72 ) -> DbxResult<u64> {
73 let mut history = self
74 .versions
75 .get_mut(table)
76 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))?;
77
78 let new_version = history.last().map(|v| v.version + 1).unwrap_or(1);
79
80 history.push(SchemaVersion {
81 version: new_version,
82 schema: new_schema.clone(),
83 created_at: Self::now(),
84 description: description.to_string(),
85 });
86
87 self.current_versions.insert(table.to_string(), new_version);
88 self.current_cache.insert(table.to_string(), new_schema);
89
90 Ok(new_version)
91 }
92
93 pub fn get_current(&self, table: &str) -> DbxResult<Arc<Schema>> {
95 self.current_cache
96 .get(table)
97 .map(|r| r.value().clone())
98 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))
99 }
100
101 pub fn get_at_version(&self, table: &str, version: u64) -> DbxResult<Arc<Schema>> {
103 let history = self
104 .versions
105 .get(table)
106 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))?;
107
108 history
109 .iter()
110 .find(|v| v.version == version)
111 .map(|v| v.schema.clone())
112 .ok_or_else(|| {
113 DbxError::Serialization(format!("Version {version} not found for {table}"))
114 })
115 }
116
117 pub fn version_history(&self, table: &str) -> DbxResult<Vec<SchemaVersion>> {
119 self.versions
120 .get(table)
121 .map(|r| r.value().clone())
122 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))
123 }
124
125 pub fn current_version(&self, table: &str) -> DbxResult<u64> {
127 self.current_versions
128 .get(table)
129 .map(|r| *r.value())
130 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))
131 }
132
133 pub fn rollback(&self, table: &str, target_version: u64) -> DbxResult<()> {
135 let schema = {
137 let history = self
138 .versions
139 .get(table)
140 .ok_or_else(|| DbxError::TableNotFound(table.to_string()))?;
141
142 history
143 .iter()
144 .find(|v| v.version == target_version)
145 .map(|v| v.schema.clone())
146 .ok_or_else(|| {
147 DbxError::Serialization(format!(
148 "Version {target_version} not found for {table}"
149 ))
150 })?
151 };
152
153 self.current_versions
154 .insert(table.to_string(), target_version);
155 self.current_cache.insert(table.to_string(), schema);
156
157 Ok(())
158 }
159
160 fn now() -> u64 {
161 std::time::SystemTime::now()
162 .duration_since(std::time::UNIX_EPOCH)
163 .unwrap_or_default()
164 .as_millis() as u64
165 }
166}
167
168impl Default for SchemaVersionManager {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use arrow::datatypes::{DataType, Field};
178
179 fn make_schema(fields: &[(&str, DataType)]) -> Arc<Schema> {
180 Arc::new(Schema::new(
181 fields
182 .iter()
183 .map(|(n, t)| Field::new(*n, t.clone(), true))
184 .collect::<Vec<_>>(),
185 ))
186 }
187
188 #[test]
189 fn test_register_and_get() {
190 let mgr = SchemaVersionManager::new();
191 let schema = make_schema(&[("id", DataType::Int64), ("name", DataType::Utf8)]);
192 mgr.register_table("users", schema.clone()).unwrap();
193
194 let current = mgr.get_current("users").unwrap();
195 assert_eq!(current.fields().len(), 2);
196 assert_eq!(mgr.current_version("users").unwrap(), 1);
197 }
198
199 #[test]
200 fn test_alter_table() {
201 let mgr = SchemaVersionManager::new();
202 let v1 = make_schema(&[("id", DataType::Int64), ("name", DataType::Utf8)]);
203 mgr.register_table("users", v1).unwrap();
204
205 let v2 = make_schema(&[
206 ("id", DataType::Int64),
207 ("name", DataType::Utf8),
208 ("email", DataType::Utf8),
209 ]);
210 let ver = mgr.alter_table("users", v2, "Add email column").unwrap();
211 assert_eq!(ver, 2);
212
213 let current = mgr.get_current("users").unwrap();
214 assert_eq!(current.fields().len(), 3);
215 }
216
217 #[test]
218 fn test_version_history() {
219 let mgr = SchemaVersionManager::new();
220 let v1 = make_schema(&[("id", DataType::Int64)]);
221 mgr.register_table("users", v1).unwrap();
222
223 let v2 = make_schema(&[("id", DataType::Int64), ("name", DataType::Utf8)]);
224 mgr.alter_table("users", v2, "Add name").unwrap();
225
226 let history = mgr.version_history("users").unwrap();
227 assert_eq!(history.len(), 2);
228 assert_eq!(history[0].version, 1);
229 assert_eq!(history[1].version, 2);
230 }
231
232 #[test]
233 fn test_get_at_version() {
234 let mgr = SchemaVersionManager::new();
235 let v1 = make_schema(&[("id", DataType::Int64)]);
236 mgr.register_table("users", v1).unwrap();
237
238 let v2 = make_schema(&[("id", DataType::Int64), ("name", DataType::Utf8)]);
239 mgr.alter_table("users", v2, "Add name").unwrap();
240
241 let old = mgr.get_at_version("users", 1).unwrap();
243 assert_eq!(old.fields().len(), 1);
244
245 let new = mgr.get_at_version("users", 2).unwrap();
246 assert_eq!(new.fields().len(), 2);
247 }
248
249 #[test]
250 fn test_rollback() {
251 let mgr = SchemaVersionManager::new();
252 let v1 = make_schema(&[("id", DataType::Int64)]);
253 mgr.register_table("users", v1).unwrap();
254
255 let v2 = make_schema(&[("id", DataType::Int64), ("name", DataType::Utf8)]);
256 mgr.alter_table("users", v2, "Add name").unwrap();
257
258 assert_eq!(mgr.current_version("users").unwrap(), 2);
259
260 mgr.rollback("users", 1).unwrap();
261 assert_eq!(mgr.current_version("users").unwrap(), 1);
262
263 let current = mgr.get_current("users").unwrap();
264 assert_eq!(current.fields().len(), 1);
265 }
266}