1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::sync::{Mutex, RwLock};
6
7use crate::lifecycle::domain::DomainError;
8
9use crate::lifecycle::domain::{RouteRuntimeAggregate, RouteRuntimeState, RuntimeEvent};
10use crate::lifecycle::ports::{
11 CommandDedupPort, EventPublisherPort, ProjectionStorePort, RouteRepositoryPort,
12 RouteStatusProjection, RuntimeEventJournalPort, RuntimeUnitOfWorkPort,
13};
14
15#[derive(Default, Clone)]
16pub struct InMemoryRouteRepository {
17 routes: Arc<RwLock<HashMap<String, RouteRuntimeAggregate>>>,
18}
19
20#[async_trait]
21impl RouteRepositoryPort for InMemoryRouteRepository {
22 async fn load(&self, route_id: &str) -> Result<Option<RouteRuntimeAggregate>, DomainError> {
23 let routes = self.routes.read().await;
24 Ok(routes.get(route_id).cloned())
25 }
26
27 async fn save(&self, aggregate: RouteRuntimeAggregate) -> Result<(), DomainError> {
28 let mut routes = self.routes.write().await;
29 routes.insert(aggregate.route_id().to_string(), aggregate);
30 Ok(())
31 }
32
33 async fn save_if_version(
34 &self,
35 aggregate: RouteRuntimeAggregate,
36 expected_version: u64,
37 ) -> Result<(), DomainError> {
38 let mut routes = self.routes.write().await;
39 let route_id = aggregate.route_id().to_string();
40 let current = routes.get(&route_id).ok_or_else(|| {
41 DomainError::InvalidState(format!(
42 "optimistic lock conflict for route '{route_id}': route not found"
43 ))
44 })?;
45
46 if current.version() != expected_version {
47 return Err(DomainError::InvalidState(format!(
48 "optimistic lock conflict for route '{route_id}': expected version {expected_version}, actual {}",
49 current.version()
50 )));
51 }
52
53 routes.insert(route_id, aggregate);
54 Ok(())
55 }
56
57 async fn delete(&self, route_id: &str) -> Result<(), DomainError> {
58 let mut routes = self.routes.write().await;
59 routes.remove(route_id);
60 Ok(())
61 }
62}
63
64#[derive(Default, Clone)]
65pub struct InMemoryProjectionStore {
66 statuses: Arc<RwLock<HashMap<String, RouteStatusProjection>>>,
67}
68
69#[async_trait]
70impl ProjectionStorePort for InMemoryProjectionStore {
71 async fn upsert_status(&self, status: RouteStatusProjection) -> Result<(), DomainError> {
72 let mut statuses = self.statuses.write().await;
73 statuses.insert(status.route_id.clone(), status);
74 Ok(())
75 }
76
77 async fn get_status(
78 &self,
79 route_id: &str,
80 ) -> Result<Option<RouteStatusProjection>, DomainError> {
81 let statuses = self.statuses.read().await;
82 Ok(statuses.get(route_id).cloned())
83 }
84
85 async fn list_statuses(&self) -> Result<Vec<RouteStatusProjection>, DomainError> {
86 let statuses = self.statuses.read().await;
87 Ok(statuses.values().cloned().collect())
88 }
89
90 async fn remove_status(&self, route_id: &str) -> Result<(), DomainError> {
91 let mut statuses = self.statuses.write().await;
92 statuses.remove(route_id);
93 Ok(())
94 }
95}
96
97#[derive(Default, Clone)]
98pub struct InMemoryEventPublisher {
99 events: Arc<RwLock<Vec<RuntimeEvent>>>,
100}
101
102impl InMemoryEventPublisher {
103 pub async fn snapshot(&self) -> Vec<RuntimeEvent> {
104 self.events.read().await.clone()
105 }
106}
107
108#[async_trait]
109impl EventPublisherPort for InMemoryEventPublisher {
110 async fn publish(&self, events: &[RuntimeEvent]) -> Result<(), DomainError> {
111 let mut stored = self.events.write().await;
112 stored.extend(events.iter().cloned());
113 Ok(())
114 }
115}
116
117#[derive(Default, Clone)]
118pub struct InMemoryCommandDedup {
119 seen: Arc<RwLock<HashSet<String>>>,
120}
121
122#[async_trait]
123impl CommandDedupPort for InMemoryCommandDedup {
124 async fn first_seen(&self, command_id: &str) -> Result<bool, DomainError> {
125 let mut seen = self.seen.write().await;
126 Ok(seen.insert(command_id.to_string()))
127 }
128
129 async fn forget_seen(&self, command_id: &str) -> Result<(), DomainError> {
130 let mut seen = self.seen.write().await;
131 seen.remove(command_id);
132 Ok(())
133 }
134}
135
136#[derive(Clone)]
137pub struct InMemoryRuntimeStore {
138 inner: Arc<Mutex<RuntimeStoreState>>,
139 journal: Option<Arc<dyn RuntimeEventJournalPort>>,
140}
141
142#[derive(Default)]
143struct RuntimeStoreState {
144 routes: HashMap<String, RouteRuntimeAggregate>,
145 statuses: HashMap<String, RouteStatusProjection>,
146 events: Vec<RuntimeEvent>,
147 seen: HashSet<String>,
148}
149
150impl InMemoryRuntimeStore {
151 pub fn with_journal(mut self, journal: Arc<dyn RuntimeEventJournalPort>) -> Self {
152 self.journal = Some(journal);
153 self
154 }
155
156 pub async fn snapshot_events(&self) -> Vec<RuntimeEvent> {
157 self.inner.lock().await.events.clone()
158 }
159}
160
161fn upsert_replayed_route(
162 state: &mut RuntimeStoreState,
163 route_id: &str,
164 next_state: RouteRuntimeState,
165 status: &str,
166 increment_version: bool,
167) {
168 let current_version = state
169 .routes
170 .get(route_id)
171 .map(|agg| agg.version())
172 .unwrap_or(0);
173 let next_version = if increment_version {
174 current_version.saturating_add(1)
175 } else {
176 current_version
177 };
178 state.routes.insert(
179 route_id.to_string(),
180 RouteRuntimeAggregate::from_snapshot(route_id, next_state, next_version),
181 );
182 state.statuses.insert(
183 route_id.to_string(),
184 RouteStatusProjection {
185 route_id: route_id.to_string(),
186 status: status.to_string(),
187 },
188 );
189}
190
191fn state_label(state: &RouteRuntimeState) -> &'static str {
192 match state {
193 RouteRuntimeState::Registered => "Registered",
194 RouteRuntimeState::Starting => "Starting",
195 RouteRuntimeState::Started => "Started",
196 RouteRuntimeState::Suspended => "Suspended",
197 RouteRuntimeState::Stopping => "Stopping",
198 RouteRuntimeState::Stopped => "Stopped",
199 RouteRuntimeState::Failed(_) => "Failed",
200 }
201}
202
203fn apply_replayed_event(state: &mut RuntimeStoreState, event: &RuntimeEvent) {
204 match event {
205 RuntimeEvent::RouteRegistered { route_id } => {
206 state.routes.insert(
207 route_id.clone(),
208 RouteRuntimeAggregate::new(route_id.clone()),
209 );
210 state.statuses.insert(
211 route_id.clone(),
212 RouteStatusProjection {
213 route_id: route_id.clone(),
214 status: "Registered".to_string(),
215 },
216 );
217 }
218 RuntimeEvent::RouteRemoved { route_id } => {
219 state.routes.remove(route_id);
220 state.statuses.remove(route_id);
221 }
222 _ => {
223 let Some(next_state) = RouteRuntimeAggregate::state_from_event(event) else {
225 return;
226 };
227 let route_id = match event {
228 RuntimeEvent::RouteStartRequested { route_id }
229 | RuntimeEvent::RouteStarted { route_id }
230 | RuntimeEvent::RouteFailed { route_id, .. }
231 | RuntimeEvent::RouteStopped { route_id }
232 | RuntimeEvent::RouteSuspended { route_id }
233 | RuntimeEvent::RouteResumed { route_id }
234 | RuntimeEvent::RouteReloaded { route_id } => route_id,
235 _ => return,
236 };
237 let status = state_label(&next_state);
238 let increment_version = !matches!(
239 (event, state.routes.get(route_id).map(|agg| agg.state())),
240 (
241 RuntimeEvent::RouteStarted { .. },
242 Some(RouteRuntimeState::Starting)
243 )
244 );
245 upsert_replayed_route(state, route_id, next_state, status, increment_version);
246 }
247 }
248}
249
250impl Default for InMemoryRuntimeStore {
251 fn default() -> Self {
252 Self {
253 inner: Arc::new(Mutex::new(RuntimeStoreState::default())),
254 journal: None,
255 }
256 }
257}
258
259#[async_trait]
260impl RouteRepositoryPort for InMemoryRuntimeStore {
261 async fn load(&self, route_id: &str) -> Result<Option<RouteRuntimeAggregate>, DomainError> {
262 let guard = self.inner.lock().await;
263 Ok(guard.routes.get(route_id).cloned())
264 }
265
266 async fn save(&self, aggregate: RouteRuntimeAggregate) -> Result<(), DomainError> {
267 let mut guard = self.inner.lock().await;
268 guard
269 .routes
270 .insert(aggregate.route_id().to_string(), aggregate);
271 Ok(())
272 }
273
274 async fn save_if_version(
275 &self,
276 aggregate: RouteRuntimeAggregate,
277 expected_version: u64,
278 ) -> Result<(), DomainError> {
279 let mut guard = self.inner.lock().await;
280 let route_id = aggregate.route_id().to_string();
281 let current = guard.routes.get(&route_id).ok_or_else(|| {
282 DomainError::InvalidState(format!(
283 "optimistic lock conflict for route '{route_id}': route not found"
284 ))
285 })?;
286
287 if current.version() != expected_version {
288 return Err(DomainError::InvalidState(format!(
289 "optimistic lock conflict for route '{route_id}': expected version {expected_version}, actual {}",
290 current.version()
291 )));
292 }
293
294 guard.routes.insert(route_id, aggregate);
295 Ok(())
296 }
297
298 async fn delete(&self, route_id: &str) -> Result<(), DomainError> {
299 let mut guard = self.inner.lock().await;
300 guard.routes.remove(route_id);
301 Ok(())
302 }
303}
304
305#[async_trait]
306impl ProjectionStorePort for InMemoryRuntimeStore {
307 async fn upsert_status(&self, status: RouteStatusProjection) -> Result<(), DomainError> {
308 let mut guard = self.inner.lock().await;
309 guard.statuses.insert(status.route_id.clone(), status);
310 Ok(())
311 }
312
313 async fn get_status(
314 &self,
315 route_id: &str,
316 ) -> Result<Option<RouteStatusProjection>, DomainError> {
317 let guard = self.inner.lock().await;
318 Ok(guard.statuses.get(route_id).cloned())
319 }
320
321 async fn list_statuses(&self) -> Result<Vec<RouteStatusProjection>, DomainError> {
322 let guard = self.inner.lock().await;
323 Ok(guard.statuses.values().cloned().collect())
324 }
325
326 async fn remove_status(&self, route_id: &str) -> Result<(), DomainError> {
327 let mut guard = self.inner.lock().await;
328 guard.statuses.remove(route_id);
329 Ok(())
330 }
331}
332
333#[async_trait]
334impl EventPublisherPort for InMemoryRuntimeStore {
335 async fn publish(&self, events: &[RuntimeEvent]) -> Result<(), DomainError> {
336 let mut guard = self.inner.lock().await;
337 if let Some(journal) = &self.journal {
338 journal.append_batch(events).await?;
339 }
340 guard.events.extend(events.iter().cloned());
341 Ok(())
342 }
343}
344
345#[async_trait]
346impl CommandDedupPort for InMemoryRuntimeStore {
347 async fn first_seen(&self, command_id: &str) -> Result<bool, DomainError> {
348 let mut guard = self.inner.lock().await;
349 if !guard.seen.insert(command_id.to_string()) {
350 return Ok(false);
351 }
352
353 if let Some(journal) = &self.journal
354 && let Err(err) = journal.append_command_id(command_id).await
355 {
356 guard.seen.remove(command_id);
357 return Err(err);
358 }
359
360 Ok(true)
361 }
362
363 async fn forget_seen(&self, command_id: &str) -> Result<(), DomainError> {
364 let mut guard = self.inner.lock().await;
365 let removed = guard.seen.remove(command_id);
366 if removed && let Some(journal) = &self.journal {
367 journal.remove_command_id(command_id).await?;
368 }
369 Ok(())
370 }
371}
372
373#[async_trait]
374impl RuntimeUnitOfWorkPort for InMemoryRuntimeStore {
375 async fn persist_upsert(
376 &self,
377 aggregate: RouteRuntimeAggregate,
378 expected_version: Option<u64>,
379 projection: RouteStatusProjection,
380 events: &[RuntimeEvent],
381 ) -> Result<(), DomainError> {
382 let mut guard = self.inner.lock().await;
383 if let Some(expected) = expected_version {
384 let route_id = aggregate.route_id().to_string();
385 let current = guard.routes.get(&route_id).ok_or_else(|| {
386 DomainError::InvalidState(format!(
387 "optimistic lock conflict for route '{route_id}': route not found"
388 ))
389 })?;
390 if current.version() != expected {
391 return Err(DomainError::InvalidState(format!(
392 "optimistic lock conflict for route '{route_id}': expected version {expected}, actual {}",
393 current.version()
394 )));
395 }
396 }
397
398 if let Some(journal) = &self.journal {
399 journal.append_batch(events).await?;
400 }
401
402 guard
403 .routes
404 .insert(aggregate.route_id().to_string(), aggregate);
405 guard
406 .statuses
407 .insert(projection.route_id.clone(), projection);
408 guard.events.extend(events.iter().cloned());
409 Ok(())
410 }
411
412 async fn persist_delete(
413 &self,
414 route_id: &str,
415 events: &[RuntimeEvent],
416 ) -> Result<(), DomainError> {
417 let mut guard = self.inner.lock().await;
418 if let Some(journal) = &self.journal {
419 journal.append_batch(events).await?;
420 }
421 guard.routes.remove(route_id);
422 guard.statuses.remove(route_id);
423 guard.events.extend(events.iter().cloned());
424 Ok(())
425 }
426
427 async fn recover_from_journal(&self) -> Result<(), DomainError> {
428 let Some(journal) = &self.journal else {
429 return Ok(());
430 };
431
432 let replayed_events = journal.load_all().await?;
433 let replayed_command_ids = journal.load_command_ids().await?;
434
435 let mut guard = self.inner.lock().await;
436 guard.routes.clear();
437 guard.statuses.clear();
438 guard.events.clear();
439 guard.seen.clear();
440
441 for event in &replayed_events {
442 apply_replayed_event(&mut guard, event);
443 }
444 guard.events = replayed_events;
445 for command_id in replayed_command_ids {
446 guard.seen.insert(command_id);
447 }
448 Ok(())
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use std::sync::Arc;
456
457 #[derive(Clone)]
458 struct ReplayJournal {
459 events: Vec<RuntimeEvent>,
460 }
461
462 #[async_trait]
463 impl RuntimeEventJournalPort for ReplayJournal {
464 async fn append_batch(&self, _events: &[RuntimeEvent]) -> Result<(), DomainError> {
465 Ok(())
466 }
467
468 async fn load_all(&self) -> Result<Vec<RuntimeEvent>, DomainError> {
469 Ok(self.events.clone())
470 }
471 }
472
473 #[tokio::test]
474 async fn repo_roundtrip_works() {
475 let repo = InMemoryRouteRepository::default();
476 repo.save(RouteRuntimeAggregate::new("r1")).await.unwrap();
477 assert!(repo.load("r1").await.unwrap().is_some());
478
479 let updated = RouteRuntimeAggregate::from_snapshot(
480 "r1",
481 crate::lifecycle::domain::RouteRuntimeState::Started,
482 1,
483 );
484 repo.save_if_version(updated.clone(), 0).await.unwrap();
485 let loaded = repo.load("r1").await.unwrap().unwrap();
486 assert_eq!(loaded.version(), 1);
487
488 let conflict = repo.save_if_version(updated, 0).await.unwrap_err();
489 assert!(
490 conflict.to_string().contains("optimistic lock conflict"),
491 "unexpected conflict error: {conflict}"
492 );
493
494 repo.delete("r1").await.unwrap();
495 assert!(repo.load("r1").await.unwrap().is_none());
496 }
497
498 #[tokio::test]
499 async fn projection_roundtrip_works() {
500 let store = InMemoryProjectionStore::default();
501 store
502 .upsert_status(RouteStatusProjection {
503 route_id: "r1".into(),
504 status: "Started".into(),
505 })
506 .await
507 .unwrap();
508
509 let status = store.get_status("r1").await.unwrap();
510 assert!(status.is_some());
511 assert_eq!(status.unwrap().status, "Started");
512 store.remove_status("r1").await.unwrap();
513 assert!(store.get_status("r1").await.unwrap().is_none());
514 }
515
516 #[tokio::test]
517 async fn event_publisher_stores_events() {
518 let publisher = InMemoryEventPublisher::default();
519 publisher
520 .publish(&[RuntimeEvent::RouteStarted {
521 route_id: "r1".into(),
522 }])
523 .await
524 .unwrap();
525
526 let events = publisher.snapshot().await;
527 assert_eq!(events.len(), 1);
528 }
529
530 #[tokio::test]
531 async fn command_dedup_detects_duplicates() {
532 let dedup = InMemoryCommandDedup::default();
533 assert!(dedup.first_seen("c1").await.unwrap());
534 assert!(!dedup.first_seen("c1").await.unwrap());
535 dedup.forget_seen("c1").await.unwrap();
536 assert!(dedup.first_seen("c1").await.unwrap());
537 assert!(dedup.first_seen("c2").await.unwrap());
538 }
539
540 #[tokio::test]
541 async fn runtime_store_uow_persists_all_three_writes() {
542 let store = InMemoryRuntimeStore::default();
543 let aggregate = RouteRuntimeAggregate::new("uow-r1");
544 let projection = RouteStatusProjection {
545 route_id: "uow-r1".to_string(),
546 status: "Registered".to_string(),
547 };
548 let events = vec![RuntimeEvent::RouteRegistered {
549 route_id: "uow-r1".to_string(),
550 }];
551
552 store
553 .persist_upsert(aggregate, None, projection.clone(), &events)
554 .await
555 .unwrap();
556
557 assert!(store.load("uow-r1").await.unwrap().is_some());
558 assert_eq!(
559 store.get_status("uow-r1").await.unwrap().unwrap(),
560 projection
561 );
562 assert_eq!(store.snapshot_events().await, events);
563 }
564
565 #[tokio::test]
566 async fn runtime_store_uow_enforces_expected_version() {
567 let store = InMemoryRuntimeStore::default();
568 let initial = RouteRuntimeAggregate::new("uow-r2");
569 let initial_projection = RouteStatusProjection {
570 route_id: "uow-r2".to_string(),
571 status: "Registered".to_string(),
572 };
573 store
574 .persist_upsert(
575 initial,
576 None,
577 initial_projection,
578 &[RuntimeEvent::RouteRegistered {
579 route_id: "uow-r2".to_string(),
580 }],
581 )
582 .await
583 .unwrap();
584
585 let started = RouteRuntimeAggregate::from_snapshot(
586 "uow-r2",
587 crate::lifecycle::domain::RouteRuntimeState::Started,
588 1,
589 );
590 let err = store
591 .persist_upsert(
592 started,
593 Some(99),
594 RouteStatusProjection {
595 route_id: "uow-r2".to_string(),
596 status: "Started".to_string(),
597 },
598 &[RuntimeEvent::RouteStarted {
599 route_id: "uow-r2".to_string(),
600 }],
601 )
602 .await
603 .unwrap_err()
604 .to_string();
605 assert!(
606 err.contains("optimistic lock conflict"),
607 "unexpected error: {err}"
608 );
609 }
610
611 #[tokio::test]
612 async fn replay_start_requested_only_advances_version_once() {
613 let store = InMemoryRuntimeStore::default().with_journal(Arc::new(ReplayJournal {
614 events: vec![
615 RuntimeEvent::RouteRegistered {
616 route_id: "replay-r1".to_string(),
617 },
618 RuntimeEvent::RouteStartRequested {
619 route_id: "replay-r1".to_string(),
620 },
621 ],
622 }));
623
624 store.recover_from_journal().await.unwrap();
625 let aggregate = store.load("replay-r1").await.unwrap().unwrap();
626
627 assert_eq!(aggregate.state(), &RouteRuntimeState::Starting);
628 assert_eq!(aggregate.version(), 1);
629 }
630
631 #[tokio::test]
632 async fn replay_start_requested_then_started_keeps_single_command_version() {
633 let store = InMemoryRuntimeStore::default().with_journal(Arc::new(ReplayJournal {
634 events: vec![
635 RuntimeEvent::RouteRegistered {
636 route_id: "replay-r2".to_string(),
637 },
638 RuntimeEvent::RouteStartRequested {
639 route_id: "replay-r2".to_string(),
640 },
641 RuntimeEvent::RouteStarted {
642 route_id: "replay-r2".to_string(),
643 },
644 ],
645 }));
646
647 store.recover_from_journal().await.unwrap();
648 let aggregate = store.load("replay-r2").await.unwrap().unwrap();
649
650 assert_eq!(aggregate.state(), &RouteRuntimeState::Started);
651 assert_eq!(aggregate.version(), 1);
652 }
653}