Skip to main content

camel_core/lifecycle/adapters/
in_memory.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::sync::{Mutex, RwLock};
6
7use crate::lifecycle::domain::DomainError;
8
9use crate::lifecycle::domain::{RouteRuntimeAggregate, RouteRuntimeState, RuntimeEvent};
10use crate::lifecycle::ports::{
11    CommandDedupPort, EventPublisherPort, ProjectionStorePort, RouteRepositoryPort,
12    RouteStatusProjection, RuntimeEventJournalPort, RuntimeUnitOfWorkPort,
13};
14
15#[derive(Default, Clone)]
16pub struct InMemoryRouteRepository {
17    routes: Arc<RwLock<HashMap<String, RouteRuntimeAggregate>>>,
18}
19
20#[async_trait]
21impl RouteRepositoryPort for InMemoryRouteRepository {
22    async fn load(&self, route_id: &str) -> Result<Option<RouteRuntimeAggregate>, DomainError> {
23        let routes = self.routes.read().await;
24        Ok(routes.get(route_id).cloned())
25    }
26
27    async fn save(&self, aggregate: RouteRuntimeAggregate) -> Result<(), DomainError> {
28        let mut routes = self.routes.write().await;
29        routes.insert(aggregate.route_id().to_string(), aggregate);
30        Ok(())
31    }
32
33    async fn save_if_version(
34        &self,
35        aggregate: RouteRuntimeAggregate,
36        expected_version: u64,
37    ) -> Result<(), DomainError> {
38        let mut routes = self.routes.write().await;
39        let route_id = aggregate.route_id().to_string();
40        let current = routes.get(&route_id).ok_or_else(|| {
41            DomainError::InvalidState(format!(
42                "optimistic lock conflict for route '{route_id}': route not found"
43            ))
44        })?;
45
46        if current.version() != expected_version {
47            return Err(DomainError::InvalidState(format!(
48                "optimistic lock conflict for route '{route_id}': expected version {expected_version}, actual {}",
49                current.version()
50            )));
51        }
52
53        routes.insert(route_id, aggregate);
54        Ok(())
55    }
56
57    async fn delete(&self, route_id: &str) -> Result<(), DomainError> {
58        let mut routes = self.routes.write().await;
59        routes.remove(route_id);
60        Ok(())
61    }
62}
63
64#[derive(Default, Clone)]
65pub struct InMemoryProjectionStore {
66    statuses: Arc<RwLock<HashMap<String, RouteStatusProjection>>>,
67}
68
69#[async_trait]
70impl ProjectionStorePort for InMemoryProjectionStore {
71    async fn upsert_status(&self, status: RouteStatusProjection) -> Result<(), DomainError> {
72        let mut statuses = self.statuses.write().await;
73        statuses.insert(status.route_id.clone(), status);
74        Ok(())
75    }
76
77    async fn get_status(
78        &self,
79        route_id: &str,
80    ) -> Result<Option<RouteStatusProjection>, DomainError> {
81        let statuses = self.statuses.read().await;
82        Ok(statuses.get(route_id).cloned())
83    }
84
85    async fn list_statuses(&self) -> Result<Vec<RouteStatusProjection>, DomainError> {
86        let statuses = self.statuses.read().await;
87        Ok(statuses.values().cloned().collect())
88    }
89
90    async fn remove_status(&self, route_id: &str) -> Result<(), DomainError> {
91        let mut statuses = self.statuses.write().await;
92        statuses.remove(route_id);
93        Ok(())
94    }
95}
96
97#[derive(Default, Clone)]
98pub struct InMemoryEventPublisher {
99    events: Arc<RwLock<Vec<RuntimeEvent>>>,
100}
101
102impl InMemoryEventPublisher {
103    pub async fn snapshot(&self) -> Vec<RuntimeEvent> {
104        self.events.read().await.clone()
105    }
106}
107
108#[async_trait]
109impl EventPublisherPort for InMemoryEventPublisher {
110    async fn publish(&self, events: &[RuntimeEvent]) -> Result<(), DomainError> {
111        let mut stored = self.events.write().await;
112        stored.extend(events.iter().cloned());
113        Ok(())
114    }
115}
116
117#[derive(Default, Clone)]
118pub struct InMemoryCommandDedup {
119    seen: Arc<RwLock<HashSet<String>>>,
120}
121
122#[async_trait]
123impl CommandDedupPort for InMemoryCommandDedup {
124    async fn first_seen(&self, command_id: &str) -> Result<bool, DomainError> {
125        let mut seen = self.seen.write().await;
126        Ok(seen.insert(command_id.to_string()))
127    }
128
129    async fn forget_seen(&self, command_id: &str) -> Result<(), DomainError> {
130        let mut seen = self.seen.write().await;
131        seen.remove(command_id);
132        Ok(())
133    }
134}
135
136#[derive(Clone)]
137pub struct InMemoryRuntimeStore {
138    inner: Arc<Mutex<RuntimeStoreState>>,
139    journal: Option<Arc<dyn RuntimeEventJournalPort>>,
140}
141
142#[derive(Default)]
143struct RuntimeStoreState {
144    routes: HashMap<String, RouteRuntimeAggregate>,
145    statuses: HashMap<String, RouteStatusProjection>,
146    events: Vec<RuntimeEvent>,
147    seen: HashSet<String>,
148}
149
150impl InMemoryRuntimeStore {
151    pub fn with_journal(mut self, journal: Arc<dyn RuntimeEventJournalPort>) -> Self {
152        self.journal = Some(journal);
153        self
154    }
155
156    pub async fn snapshot_events(&self) -> Vec<RuntimeEvent> {
157        self.inner.lock().await.events.clone()
158    }
159}
160
161fn upsert_replayed_route(
162    state: &mut RuntimeStoreState,
163    route_id: &str,
164    next_state: RouteRuntimeState,
165    status: &str,
166    increment_version: bool,
167) {
168    let current_version = state
169        .routes
170        .get(route_id)
171        .map(|agg| agg.version())
172        .unwrap_or(0);
173    let next_version = if increment_version {
174        current_version.saturating_add(1)
175    } else {
176        current_version
177    };
178    state.routes.insert(
179        route_id.to_string(),
180        RouteRuntimeAggregate::from_snapshot(route_id, next_state, next_version),
181    );
182    state.statuses.insert(
183        route_id.to_string(),
184        RouteStatusProjection {
185            route_id: route_id.to_string(),
186            status: status.to_string(),
187        },
188    );
189}
190
191fn state_label(state: &RouteRuntimeState) -> &'static str {
192    match state {
193        RouteRuntimeState::Registered => "Registered",
194        RouteRuntimeState::Starting => "Starting",
195        RouteRuntimeState::Started => "Started",
196        RouteRuntimeState::Suspended => "Suspended",
197        RouteRuntimeState::Stopping => "Stopping",
198        RouteRuntimeState::Stopped => "Stopped",
199        RouteRuntimeState::Failed(_) => "Failed",
200    }
201}
202
203fn apply_replayed_event(state: &mut RuntimeStoreState, event: &RuntimeEvent) {
204    match event {
205        RuntimeEvent::RouteRegistered { route_id } => {
206            state.routes.insert(
207                route_id.clone(),
208                RouteRuntimeAggregate::new(route_id.clone()),
209            );
210            state.statuses.insert(
211                route_id.clone(),
212                RouteStatusProjection {
213                    route_id: route_id.clone(),
214                    status: "Registered".to_string(),
215                },
216            );
217        }
218        RuntimeEvent::RouteRemoved { route_id } => {
219            state.routes.remove(route_id);
220            state.statuses.remove(route_id);
221        }
222        _ => {
223            // Use the domain's state machine to derive the next state from the event.
224            let Some(next_state) = RouteRuntimeAggregate::state_from_event(event) else {
225                return;
226            };
227            let route_id = match event {
228                RuntimeEvent::RouteStartRequested { route_id }
229                | RuntimeEvent::RouteStarted { route_id }
230                | RuntimeEvent::RouteFailed { route_id, .. }
231                | RuntimeEvent::RouteStopped { route_id }
232                | RuntimeEvent::RouteSuspended { route_id }
233                | RuntimeEvent::RouteResumed { route_id }
234                | RuntimeEvent::RouteReloaded { route_id } => route_id,
235                _ => return,
236            };
237            let status = state_label(&next_state);
238            let increment_version = !matches!(
239                (event, state.routes.get(route_id).map(|agg| agg.state())),
240                (
241                    RuntimeEvent::RouteStarted { .. },
242                    Some(RouteRuntimeState::Starting)
243                )
244            );
245            upsert_replayed_route(state, route_id, next_state, status, increment_version);
246        }
247    }
248}
249
250impl Default for InMemoryRuntimeStore {
251    fn default() -> Self {
252        Self {
253            inner: Arc::new(Mutex::new(RuntimeStoreState::default())),
254            journal: None,
255        }
256    }
257}
258
259#[async_trait]
260impl RouteRepositoryPort for InMemoryRuntimeStore {
261    async fn load(&self, route_id: &str) -> Result<Option<RouteRuntimeAggregate>, DomainError> {
262        let guard = self.inner.lock().await;
263        Ok(guard.routes.get(route_id).cloned())
264    }
265
266    async fn save(&self, aggregate: RouteRuntimeAggregate) -> Result<(), DomainError> {
267        let mut guard = self.inner.lock().await;
268        guard
269            .routes
270            .insert(aggregate.route_id().to_string(), aggregate);
271        Ok(())
272    }
273
274    async fn save_if_version(
275        &self,
276        aggregate: RouteRuntimeAggregate,
277        expected_version: u64,
278    ) -> Result<(), DomainError> {
279        let mut guard = self.inner.lock().await;
280        let route_id = aggregate.route_id().to_string();
281        let current = guard.routes.get(&route_id).ok_or_else(|| {
282            DomainError::InvalidState(format!(
283                "optimistic lock conflict for route '{route_id}': route not found"
284            ))
285        })?;
286
287        if current.version() != expected_version {
288            return Err(DomainError::InvalidState(format!(
289                "optimistic lock conflict for route '{route_id}': expected version {expected_version}, actual {}",
290                current.version()
291            )));
292        }
293
294        guard.routes.insert(route_id, aggregate);
295        Ok(())
296    }
297
298    async fn delete(&self, route_id: &str) -> Result<(), DomainError> {
299        let mut guard = self.inner.lock().await;
300        guard.routes.remove(route_id);
301        Ok(())
302    }
303}
304
305#[async_trait]
306impl ProjectionStorePort for InMemoryRuntimeStore {
307    async fn upsert_status(&self, status: RouteStatusProjection) -> Result<(), DomainError> {
308        let mut guard = self.inner.lock().await;
309        guard.statuses.insert(status.route_id.clone(), status);
310        Ok(())
311    }
312
313    async fn get_status(
314        &self,
315        route_id: &str,
316    ) -> Result<Option<RouteStatusProjection>, DomainError> {
317        let guard = self.inner.lock().await;
318        Ok(guard.statuses.get(route_id).cloned())
319    }
320
321    async fn list_statuses(&self) -> Result<Vec<RouteStatusProjection>, DomainError> {
322        let guard = self.inner.lock().await;
323        Ok(guard.statuses.values().cloned().collect())
324    }
325
326    async fn remove_status(&self, route_id: &str) -> Result<(), DomainError> {
327        let mut guard = self.inner.lock().await;
328        guard.statuses.remove(route_id);
329        Ok(())
330    }
331}
332
333#[async_trait]
334impl EventPublisherPort for InMemoryRuntimeStore {
335    async fn publish(&self, events: &[RuntimeEvent]) -> Result<(), DomainError> {
336        let mut guard = self.inner.lock().await;
337        if let Some(journal) = &self.journal {
338            journal.append_batch(events).await?;
339        }
340        guard.events.extend(events.iter().cloned());
341        Ok(())
342    }
343}
344
345#[async_trait]
346impl CommandDedupPort for InMemoryRuntimeStore {
347    async fn first_seen(&self, command_id: &str) -> Result<bool, DomainError> {
348        let mut guard = self.inner.lock().await;
349        if !guard.seen.insert(command_id.to_string()) {
350            return Ok(false);
351        }
352
353        if let Some(journal) = &self.journal
354            && let Err(err) = journal.append_command_id(command_id).await
355        {
356            guard.seen.remove(command_id);
357            return Err(err);
358        }
359
360        Ok(true)
361    }
362
363    async fn forget_seen(&self, command_id: &str) -> Result<(), DomainError> {
364        let mut guard = self.inner.lock().await;
365        let removed = guard.seen.remove(command_id);
366        if removed && let Some(journal) = &self.journal {
367            journal.remove_command_id(command_id).await?;
368        }
369        Ok(())
370    }
371}
372
373#[async_trait]
374impl RuntimeUnitOfWorkPort for InMemoryRuntimeStore {
375    async fn persist_upsert(
376        &self,
377        aggregate: RouteRuntimeAggregate,
378        expected_version: Option<u64>,
379        projection: RouteStatusProjection,
380        events: &[RuntimeEvent],
381    ) -> Result<(), DomainError> {
382        let mut guard = self.inner.lock().await;
383        if let Some(expected) = expected_version {
384            let route_id = aggregate.route_id().to_string();
385            let current = guard.routes.get(&route_id).ok_or_else(|| {
386                DomainError::InvalidState(format!(
387                    "optimistic lock conflict for route '{route_id}': route not found"
388                ))
389            })?;
390            if current.version() != expected {
391                return Err(DomainError::InvalidState(format!(
392                    "optimistic lock conflict for route '{route_id}': expected version {expected}, actual {}",
393                    current.version()
394                )));
395            }
396        }
397
398        if let Some(journal) = &self.journal {
399            journal.append_batch(events).await?;
400        }
401
402        guard
403            .routes
404            .insert(aggregate.route_id().to_string(), aggregate);
405        guard
406            .statuses
407            .insert(projection.route_id.clone(), projection);
408        guard.events.extend(events.iter().cloned());
409        Ok(())
410    }
411
412    async fn persist_delete(
413        &self,
414        route_id: &str,
415        events: &[RuntimeEvent],
416    ) -> Result<(), DomainError> {
417        let mut guard = self.inner.lock().await;
418        if let Some(journal) = &self.journal {
419            journal.append_batch(events).await?;
420        }
421        guard.routes.remove(route_id);
422        guard.statuses.remove(route_id);
423        guard.events.extend(events.iter().cloned());
424        Ok(())
425    }
426
427    async fn recover_from_journal(&self) -> Result<(), DomainError> {
428        let Some(journal) = &self.journal else {
429            return Ok(());
430        };
431
432        let replayed_events = journal.load_all().await?;
433        let replayed_command_ids = journal.load_command_ids().await?;
434
435        let mut guard = self.inner.lock().await;
436        guard.routes.clear();
437        guard.statuses.clear();
438        guard.events.clear();
439        guard.seen.clear();
440
441        for event in &replayed_events {
442            apply_replayed_event(&mut guard, event);
443        }
444        guard.events = replayed_events;
445        for command_id in replayed_command_ids {
446            guard.seen.insert(command_id);
447        }
448        Ok(())
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use std::sync::Arc;
456
457    #[derive(Clone)]
458    struct ReplayJournal {
459        events: Vec<RuntimeEvent>,
460    }
461
462    #[async_trait]
463    impl RuntimeEventJournalPort for ReplayJournal {
464        async fn append_batch(&self, _events: &[RuntimeEvent]) -> Result<(), DomainError> {
465            Ok(())
466        }
467
468        async fn load_all(&self) -> Result<Vec<RuntimeEvent>, DomainError> {
469            Ok(self.events.clone())
470        }
471    }
472
473    #[tokio::test]
474    async fn repo_roundtrip_works() {
475        let repo = InMemoryRouteRepository::default();
476        repo.save(RouteRuntimeAggregate::new("r1")).await.unwrap();
477        assert!(repo.load("r1").await.unwrap().is_some());
478
479        let updated = RouteRuntimeAggregate::from_snapshot(
480            "r1",
481            crate::lifecycle::domain::RouteRuntimeState::Started,
482            1,
483        );
484        repo.save_if_version(updated.clone(), 0).await.unwrap();
485        let loaded = repo.load("r1").await.unwrap().unwrap();
486        assert_eq!(loaded.version(), 1);
487
488        let conflict = repo.save_if_version(updated, 0).await.unwrap_err();
489        assert!(
490            conflict.to_string().contains("optimistic lock conflict"),
491            "unexpected conflict error: {conflict}"
492        );
493
494        repo.delete("r1").await.unwrap();
495        assert!(repo.load("r1").await.unwrap().is_none());
496    }
497
498    #[tokio::test]
499    async fn projection_roundtrip_works() {
500        let store = InMemoryProjectionStore::default();
501        store
502            .upsert_status(RouteStatusProjection {
503                route_id: "r1".into(),
504                status: "Started".into(),
505            })
506            .await
507            .unwrap();
508
509        let status = store.get_status("r1").await.unwrap();
510        assert!(status.is_some());
511        assert_eq!(status.unwrap().status, "Started");
512        store.remove_status("r1").await.unwrap();
513        assert!(store.get_status("r1").await.unwrap().is_none());
514    }
515
516    #[tokio::test]
517    async fn event_publisher_stores_events() {
518        let publisher = InMemoryEventPublisher::default();
519        publisher
520            .publish(&[RuntimeEvent::RouteStarted {
521                route_id: "r1".into(),
522            }])
523            .await
524            .unwrap();
525
526        let events = publisher.snapshot().await;
527        assert_eq!(events.len(), 1);
528    }
529
530    #[tokio::test]
531    async fn command_dedup_detects_duplicates() {
532        let dedup = InMemoryCommandDedup::default();
533        assert!(dedup.first_seen("c1").await.unwrap());
534        assert!(!dedup.first_seen("c1").await.unwrap());
535        dedup.forget_seen("c1").await.unwrap();
536        assert!(dedup.first_seen("c1").await.unwrap());
537        assert!(dedup.first_seen("c2").await.unwrap());
538    }
539
540    #[tokio::test]
541    async fn runtime_store_uow_persists_all_three_writes() {
542        let store = InMemoryRuntimeStore::default();
543        let aggregate = RouteRuntimeAggregate::new("uow-r1");
544        let projection = RouteStatusProjection {
545            route_id: "uow-r1".to_string(),
546            status: "Registered".to_string(),
547        };
548        let events = vec![RuntimeEvent::RouteRegistered {
549            route_id: "uow-r1".to_string(),
550        }];
551
552        store
553            .persist_upsert(aggregate, None, projection.clone(), &events)
554            .await
555            .unwrap();
556
557        assert!(store.load("uow-r1").await.unwrap().is_some());
558        assert_eq!(
559            store.get_status("uow-r1").await.unwrap().unwrap(),
560            projection
561        );
562        assert_eq!(store.snapshot_events().await, events);
563    }
564
565    #[tokio::test]
566    async fn runtime_store_uow_enforces_expected_version() {
567        let store = InMemoryRuntimeStore::default();
568        let initial = RouteRuntimeAggregate::new("uow-r2");
569        let initial_projection = RouteStatusProjection {
570            route_id: "uow-r2".to_string(),
571            status: "Registered".to_string(),
572        };
573        store
574            .persist_upsert(
575                initial,
576                None,
577                initial_projection,
578                &[RuntimeEvent::RouteRegistered {
579                    route_id: "uow-r2".to_string(),
580                }],
581            )
582            .await
583            .unwrap();
584
585        let started = RouteRuntimeAggregate::from_snapshot(
586            "uow-r2",
587            crate::lifecycle::domain::RouteRuntimeState::Started,
588            1,
589        );
590        let err = store
591            .persist_upsert(
592                started,
593                Some(99),
594                RouteStatusProjection {
595                    route_id: "uow-r2".to_string(),
596                    status: "Started".to_string(),
597                },
598                &[RuntimeEvent::RouteStarted {
599                    route_id: "uow-r2".to_string(),
600                }],
601            )
602            .await
603            .unwrap_err()
604            .to_string();
605        assert!(
606            err.contains("optimistic lock conflict"),
607            "unexpected error: {err}"
608        );
609    }
610
611    #[tokio::test]
612    async fn replay_start_requested_only_advances_version_once() {
613        let store = InMemoryRuntimeStore::default().with_journal(Arc::new(ReplayJournal {
614            events: vec![
615                RuntimeEvent::RouteRegistered {
616                    route_id: "replay-r1".to_string(),
617                },
618                RuntimeEvent::RouteStartRequested {
619                    route_id: "replay-r1".to_string(),
620                },
621            ],
622        }));
623
624        store.recover_from_journal().await.unwrap();
625        let aggregate = store.load("replay-r1").await.unwrap().unwrap();
626
627        assert_eq!(aggregate.state(), &RouteRuntimeState::Starting);
628        assert_eq!(aggregate.version(), 1);
629    }
630
631    #[tokio::test]
632    async fn replay_start_requested_then_started_keeps_single_command_version() {
633        let store = InMemoryRuntimeStore::default().with_journal(Arc::new(ReplayJournal {
634            events: vec![
635                RuntimeEvent::RouteRegistered {
636                    route_id: "replay-r2".to_string(),
637                },
638                RuntimeEvent::RouteStartRequested {
639                    route_id: "replay-r2".to_string(),
640                },
641                RuntimeEvent::RouteStarted {
642                    route_id: "replay-r2".to_string(),
643                },
644            ],
645        }));
646
647        store.recover_from_journal().await.unwrap();
648        let aggregate = store.load("replay-r2").await.unwrap().unwrap();
649
650        assert_eq!(aggregate.state(), &RouteRuntimeState::Started);
651        assert_eq!(aggregate.version(), 1);
652    }
653}