1use std::{any::TypeId, collections::HashMap, marker::PhantomData, sync::Arc};
7
8use tokio::sync::RwLock;
9
10use super::Event;
11
12pub trait VersionedEvent: Event {
14 fn version(&self) -> u32;
16
17 fn event_type(&self) -> &'static str;
19}
20
21pub trait Upcaster<From: Event, To: Event>: Send + Sync {
23 fn upcast(&self, from: From) -> To;
25}
26
27pub struct AutoUpcaster<From: Event, To: Event> {
29 _phantom: PhantomData<(From, To)>,
30}
31
32impl<From: Event, To: Event> AutoUpcaster<From, To>
33where
34 To: std::convert::From<From>,
35{
36 pub fn new() -> Self {
38 Self {
39 _phantom: PhantomData,
40 }
41 }
42}
43
44impl<From: Event, To: Event> Default for AutoUpcaster<From, To>
45where
46 To: std::convert::From<From>,
47{
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl<From: Event, To: Event> Upcaster<From, To> for AutoUpcaster<From, To>
54where
55 To: std::convert::From<From>,
56{
57 fn upcast(&self, from: From) -> To {
58 from.into()
59 }
60}
61
62trait ErasedUpcaster<E: Event>: Send + Sync {
64 #[allow(dead_code)]
66 fn upcast_erased(&self, event: Box<dyn std::any::Any>) -> Option<E>;
67}
68
69struct UpcasterWrapper<From: Event, To: Event, U: Upcaster<From, To>> {
71 #[allow(dead_code)]
72 upcaster: Arc<U>,
73 _phantom: PhantomData<(From, To)>,
74}
75
76impl<From: Event, To: Event, U: Upcaster<From, To>> ErasedUpcaster<To>
77 for UpcasterWrapper<From, To, U>
78{
79 fn upcast_erased(&self, event: Box<dyn std::any::Any>) -> Option<To> {
80 match event.downcast::<From>() {
81 Ok(from_event) => Some(self.upcaster.upcast(*from_event)),
82 Err(_) => None,
83 }
84 }
85}
86
87type UpcasterMap<E> = HashMap<(TypeId, TypeId), Box<dyn ErasedUpcaster<E>>>;
89
90#[derive(Debug, Clone)]
92pub struct MigrationPath {
93 pub from_version: u32,
95 pub to_version: u32,
97 pub event_type: String,
99}
100
101impl MigrationPath {
102 pub fn new(from: u32, to: u32, event_type: impl Into<String>) -> Self {
104 Self {
105 from_version: from,
106 to_version: to,
107 event_type: event_type.into(),
108 }
109 }
110}
111
112pub struct VersionRegistry<E: Event> {
114 upcasters: Arc<RwLock<UpcasterMap<E>>>,
116 migrations: Arc<RwLock<HashMap<String, Vec<MigrationPath>>>>,
118 _phantom: PhantomData<E>,
119}
120
121impl<E: Event> VersionRegistry<E> {
122 pub fn new() -> Self {
124 Self {
125 upcasters: Arc::new(RwLock::new(HashMap::new())),
126 migrations: Arc::new(RwLock::new(HashMap::new())),
127 _phantom: PhantomData,
128 }
129 }
130
131 pub async fn register_upcaster<F: Event + 'static, U: Upcaster<F, E> + 'static>(
133 &self,
134 upcaster: U,
135 ) {
136 let from_type = TypeId::of::<F>();
137 let to_type = TypeId::of::<E>();
138
139 let wrapper = UpcasterWrapper {
140 upcaster: Arc::new(upcaster),
141 _phantom: PhantomData,
142 };
143
144 let mut upcasters = self.upcasters.write().await;
145 upcasters.insert((from_type, to_type), Box::new(wrapper));
146 }
147
148 pub async fn register_migration(&self, path: MigrationPath) {
150 let mut migrations = self.migrations.write().await;
151 migrations
152 .entry(path.event_type.clone())
153 .or_insert_with(Vec::new)
154 .push(path);
155 }
156
157 pub async fn get_migrations(&self) -> Vec<MigrationPath> {
159 let migrations = self.migrations.read().await;
160 migrations.values().flatten().cloned().collect()
161 }
162
163 pub async fn get_migrations_for(&self, event_type: &str) -> Vec<MigrationPath> {
165 let migrations = self.migrations.read().await;
166 migrations.get(event_type).cloned().unwrap_or_default()
167 }
168
169 pub async fn has_upcaster<F: Event + 'static, T: Event + 'static>(&self) -> bool {
171 let from_type = TypeId::of::<F>();
172 let to_type = TypeId::of::<T>();
173 let upcasters = self.upcasters.read().await;
174 upcasters.contains_key(&(from_type, to_type))
175 }
176
177 pub async fn upcaster_count(&self) -> usize {
179 self.upcasters.read().await.len()
180 }
181
182 pub async fn migration_count(&self) -> usize {
184 self.migrations.read().await.values().map(|v| v.len()).sum()
185 }
186}
187
188impl<E: Event> Default for VersionRegistry<E> {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl<E: Event> Clone for VersionRegistry<E> {
195 fn clone(&self) -> Self {
196 Self {
197 upcasters: Arc::clone(&self.upcasters),
198 migrations: Arc::clone(&self.migrations),
199 _phantom: PhantomData,
200 }
201 }
202}
203
204#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
226 struct UserCreatedV1 {
227 user_id: String,
228 email: String,
229 }
230
231 impl Event for UserCreatedV1 {}
232
233 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
234 struct UserCreatedV2 {
235 user_id: String,
236 email: String,
237 name: String,
238 }
239
240 impl Event for UserCreatedV2 {}
241
242 impl From<UserCreatedV1> for UserCreatedV2 {
243 fn from(v1: UserCreatedV1) -> Self {
244 Self {
245 user_id: v1.user_id,
246 email: v1.email,
247 name: "Unknown".to_string(),
248 }
249 }
250 }
251
252 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
253 enum TestEvent {
254 #[allow(dead_code)]
255 V1(UserCreatedV1),
256 V2(UserCreatedV2),
257 }
258
259 impl Event for TestEvent {}
260
261 impl From<UserCreatedV2> for TestEvent {
262 fn from(v2: UserCreatedV2) -> Self {
263 TestEvent::V2(v2)
264 }
265 }
266
267 #[tokio::test]
268 async fn test_registry_creation() {
269 let registry: VersionRegistry<TestEvent> = VersionRegistry::new();
270 assert_eq!(registry.upcaster_count().await, 0);
271 assert_eq!(registry.migration_count().await, 0);
272 }
273
274 #[tokio::test]
275 async fn test_upcaster_registration() {
276 let registry: VersionRegistry<UserCreatedV2> = VersionRegistry::new();
277
278 registry
280 .register_upcaster(AutoUpcaster::<UserCreatedV1, UserCreatedV2>::new())
281 .await;
282
283 assert_eq!(registry.upcaster_count().await, 1);
284 assert!(
285 registry
286 .has_upcaster::<UserCreatedV1, UserCreatedV2>()
287 .await
288 );
289 }
290
291 #[tokio::test]
292 async fn test_migration_path_registration() {
293 let registry: VersionRegistry<TestEvent> = VersionRegistry::new();
294
295 let path = MigrationPath::new(1, 2, "UserCreated");
296 registry.register_migration(path).await;
297
298 assert_eq!(registry.migration_count().await, 1);
299
300 let migrations = registry.get_migrations_for("UserCreated").await;
301 assert_eq!(migrations.len(), 1);
302 assert_eq!(migrations[0].from_version, 1);
303 assert_eq!(migrations[0].to_version, 2);
304 }
305
306 #[tokio::test]
307 async fn test_multiple_migrations() {
308 let registry: VersionRegistry<TestEvent> = VersionRegistry::new();
309
310 registry
312 .register_migration(MigrationPath::new(1, 2, "UserCreated"))
313 .await;
314 registry
315 .register_migration(MigrationPath::new(2, 3, "UserCreated"))
316 .await;
317
318 assert_eq!(registry.migration_count().await, 2);
319
320 let migrations = registry.get_migrations_for("UserCreated").await;
321 assert_eq!(migrations.len(), 2);
322 }
323
324 #[tokio::test]
325 async fn test_auto_upcaster() {
326 let upcaster = AutoUpcaster::<UserCreatedV1, UserCreatedV2>::new();
327
328 let v1 = UserCreatedV1 {
329 user_id: "123".to_string(),
330 email: "test@example.com".to_string(),
331 };
332
333 let v2 = upcaster.upcast(v1.clone());
334
335 assert_eq!(v2.user_id, v1.user_id);
336 assert_eq!(v2.email, v1.email);
337 assert_eq!(v2.name, "Unknown");
338 }
339
340 #[test]
341 fn test_migration_path_creation() {
342 let path = MigrationPath::new(1, 2, "UserCreated");
343
344 assert_eq!(path.from_version, 1);
345 assert_eq!(path.to_version, 2);
346 assert_eq!(path.event_type, "UserCreated");
347 }
348}