Skip to main content

meerkat_runtime/store/
memory.rs

1//! InMemoryRuntimeStore — in-memory implementation for testing/ephemeral.
2//!
3//! Uses `tokio::sync::Mutex` per the in-memory concurrency rule.
4//! All mutations complete inside one lock acquisition (no lock held across .await).
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::IndexMap;
10use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{RuntimeStore, RuntimeStoreError, SessionDelta, authoritative_receipt};
18use crate::identifiers::LogicalRuntimeId;
19use crate::input_state::InputState;
20use crate::runtime_state::RuntimeState;
21
22/// Receipt key: (runtime_id, run_id, sequence).
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24struct ReceiptKey {
25    runtime_id: String,
26    run_id: RunId,
27    sequence: u64,
28}
29
30/// Inner state protected by the mutex.
31#[derive(Debug, Default)]
32struct Inner {
33    /// runtime_id → (input_id → InputState). IndexMap for deterministic iteration order.
34    input_states: HashMap<String, IndexMap<InputId, InputState>>,
35    /// Receipt storage.
36    receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
37    /// Session snapshots (opaque bytes).
38    sessions: HashMap<String, Vec<u8>>,
39    /// Persisted runtime state.
40    runtime_states: HashMap<String, RuntimeState>,
41}
42
43/// In-memory runtime store. Thread-safe via `tokio::sync::Mutex`.
44#[derive(Debug, Clone)]
45pub struct InMemoryRuntimeStore {
46    inner: Arc<Mutex<Inner>>,
47}
48
49impl InMemoryRuntimeStore {
50    pub fn new() -> Self {
51        Self {
52            inner: Arc::new(Mutex::new(Inner::default())),
53        }
54    }
55}
56
57impl Default for InMemoryRuntimeStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
64#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
65impl RuntimeStore for InMemoryRuntimeStore {
66    async fn commit_session_boundary(
67        &self,
68        runtime_id: &LogicalRuntimeId,
69        session_delta: SessionDelta,
70        run_id: RunId,
71        boundary: RunApplyBoundary,
72        contributing_input_ids: Vec<InputId>,
73        input_updates: Vec<InputState>,
74    ) -> Result<RunBoundaryReceipt, RuntimeStoreError> {
75        let mut inner = self.inner.lock().await;
76        let rid = runtime_id.0.clone();
77        let sequence = inner
78            .receipts
79            .keys()
80            .filter(|key| key.runtime_id == rid && key.run_id == run_id)
81            .map(|key| key.sequence)
82            .max()
83            .map(|seq| seq + 1)
84            .unwrap_or(0);
85        let receipt = authoritative_receipt(
86            Some(&session_delta),
87            run_id,
88            boundary,
89            contributing_input_ids,
90            sequence,
91        )?;
92        let mut input_updates = input_updates;
93        for state in &mut input_updates {
94            state.last_run_id = Some(receipt.run_id.clone());
95            state.last_boundary_sequence = Some(receipt.sequence);
96        }
97
98        inner
99            .sessions
100            .insert(rid.clone(), session_delta.session_snapshot);
101        let key = ReceiptKey {
102            runtime_id: rid.clone(),
103            run_id: receipt.run_id.clone(),
104            sequence: receipt.sequence,
105        };
106        inner.receipts.insert(key, receipt.clone());
107
108        let states = inner.input_states.entry(rid).or_default();
109        for state in input_updates {
110            states.insert(state.input_id.clone(), state);
111        }
112
113        Ok(receipt)
114    }
115
116    async fn atomic_apply(
117        &self,
118        runtime_id: &LogicalRuntimeId,
119        session_delta: Option<SessionDelta>,
120        receipt: RunBoundaryReceipt,
121        input_updates: Vec<InputState>,
122        _session_store_key: Option<meerkat_core::types::SessionId>,
123    ) -> Result<(), RuntimeStoreError> {
124        // InMemoryRuntimeStore ignores session_store_key — there's no shared
125        // sessions table in memory. The session snapshot is stored in the
126        // runtime's own snapshot map.
127        let mut inner = self.inner.lock().await;
128
129        // All writes in one lock acquisition (atomic for in-memory)
130        let rid = runtime_id.0.clone();
131
132        // Session delta
133        if let Some(delta) = session_delta {
134            inner.sessions.insert(rid.clone(), delta.session_snapshot);
135        }
136
137        // Receipt
138        let key = ReceiptKey {
139            runtime_id: rid.clone(),
140            run_id: receipt.run_id.clone(),
141            sequence: receipt.sequence,
142        };
143        inner.receipts.insert(key, receipt);
144
145        // Input states
146        let states = inner.input_states.entry(rid).or_default();
147        for state in input_updates {
148            states.insert(state.input_id.clone(), state);
149        }
150
151        Ok(())
152    }
153
154    async fn load_input_states(
155        &self,
156        runtime_id: &LogicalRuntimeId,
157    ) -> Result<Vec<InputState>, RuntimeStoreError> {
158        let inner = self.inner.lock().await;
159        let states = inner
160            .input_states
161            .get(&runtime_id.0)
162            .map(|m| m.values().cloned().collect())
163            .unwrap_or_default();
164        Ok(states)
165    }
166
167    async fn load_boundary_receipt(
168        &self,
169        runtime_id: &LogicalRuntimeId,
170        run_id: &RunId,
171        sequence: u64,
172    ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
173        let inner = self.inner.lock().await;
174        let key = ReceiptKey {
175            runtime_id: runtime_id.0.clone(),
176            run_id: run_id.clone(),
177            sequence,
178        };
179        Ok(inner.receipts.get(&key).cloned())
180    }
181
182    async fn load_session_snapshot(
183        &self,
184        runtime_id: &LogicalRuntimeId,
185    ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
186        let inner = self.inner.lock().await;
187        Ok(inner.sessions.get(&runtime_id.0).cloned())
188    }
189
190    async fn persist_input_state(
191        &self,
192        runtime_id: &LogicalRuntimeId,
193        state: &InputState,
194    ) -> Result<(), RuntimeStoreError> {
195        let mut inner = self.inner.lock().await;
196        let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
197        states.insert(state.input_id.clone(), state.clone());
198        Ok(())
199    }
200
201    async fn load_input_state(
202        &self,
203        runtime_id: &LogicalRuntimeId,
204        input_id: &InputId,
205    ) -> Result<Option<InputState>, RuntimeStoreError> {
206        let inner = self.inner.lock().await;
207        let state = inner
208            .input_states
209            .get(&runtime_id.0)
210            .and_then(|m| m.get(input_id).cloned());
211        Ok(state)
212    }
213
214    async fn persist_runtime_state(
215        &self,
216        runtime_id: &LogicalRuntimeId,
217        state: RuntimeState,
218    ) -> Result<(), RuntimeStoreError> {
219        let mut inner = self.inner.lock().await;
220        inner.runtime_states.insert(runtime_id.0.clone(), state);
221        Ok(())
222    }
223
224    async fn load_runtime_state(
225        &self,
226        runtime_id: &LogicalRuntimeId,
227    ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
228        let inner = self.inner.lock().await;
229        Ok(inner.runtime_states.get(&runtime_id.0).copied())
230    }
231
232    async fn atomic_lifecycle_commit(
233        &self,
234        runtime_id: &LogicalRuntimeId,
235        runtime_state: RuntimeState,
236        input_states: &[InputState],
237    ) -> Result<(), RuntimeStoreError> {
238        let mut inner = self.inner.lock().await;
239        let rid = runtime_id.0.clone();
240
241        // Single lock acquisition — atomic for in-memory
242        inner.runtime_states.insert(rid.clone(), runtime_state);
243        let states = inner.input_states.entry(rid).or_default();
244        for state in input_states {
245            states.insert(state.input_id.clone(), state.clone());
246        }
247
248        Ok(())
249    }
250}
251
252#[cfg(test)]
253#[allow(clippy::unwrap_used)]
254mod tests {
255    use super::*;
256    use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
257
258    fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
259        RunBoundaryReceipt {
260            run_id,
261            boundary: RunApplyBoundary::RunStart,
262            contributing_input_ids: vec![],
263            conversation_digest: None,
264            message_count: 0,
265            sequence: seq,
266        }
267    }
268
269    #[tokio::test]
270    async fn atomic_apply_roundtrip() {
271        let store = InMemoryRuntimeStore::new();
272        let rid = LogicalRuntimeId::new("test-runtime");
273        let run_id = RunId::new();
274        let input_id = InputId::new();
275
276        let state = InputState::new_accepted(input_id.clone());
277        let receipt = make_receipt(run_id.clone(), 0);
278
279        store
280            .atomic_apply(
281                &rid,
282                Some(SessionDelta {
283                    session_snapshot: b"session-data".to_vec(),
284                }),
285                receipt.clone(),
286                vec![state],
287                None,
288            )
289            .await
290            .unwrap();
291
292        // Load input states
293        let states = store.load_input_states(&rid).await.unwrap();
294        assert_eq!(states.len(), 1);
295        assert_eq!(states[0].input_id, input_id);
296
297        // Load receipt
298        let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
299        assert!(loaded.is_some());
300    }
301
302    #[tokio::test]
303    async fn persist_and_load_single_state() {
304        let store = InMemoryRuntimeStore::new();
305        let rid = LogicalRuntimeId::new("test");
306        let input_id = InputId::new();
307        let state = InputState::new_accepted(input_id.clone());
308
309        store.persist_input_state(&rid, &state).await.unwrap();
310
311        let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
312        assert!(loaded.is_some());
313        assert_eq!(loaded.unwrap().input_id, input_id);
314    }
315
316    #[tokio::test]
317    async fn load_nonexistent_returns_none() {
318        let store = InMemoryRuntimeStore::new();
319        let rid = LogicalRuntimeId::new("test");
320
321        let states = store.load_input_states(&rid).await.unwrap();
322        assert!(states.is_empty());
323
324        let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
325        assert!(state.is_none());
326
327        let receipt = store
328            .load_boundary_receipt(&rid, &RunId::new(), 0)
329            .await
330            .unwrap();
331        assert!(receipt.is_none());
332    }
333
334    #[tokio::test]
335    async fn atomic_apply_updates_existing() {
336        let store = InMemoryRuntimeStore::new();
337        let rid = LogicalRuntimeId::new("test");
338        let input_id = InputId::new();
339
340        // First write
341        let state1 = InputState::new_accepted(input_id.clone());
342        store
343            .atomic_apply(
344                &rid,
345                None,
346                make_receipt(RunId::new(), 0),
347                vec![state1],
348                None,
349            )
350            .await
351            .unwrap();
352
353        // Second write with updated state
354        let mut state2 = InputState::new_accepted(input_id.clone());
355        state2.current_state = crate::input_state::InputLifecycleState::Queued;
356        store
357            .atomic_apply(
358                &rid,
359                None,
360                make_receipt(RunId::new(), 1),
361                vec![state2],
362                None,
363            )
364            .await
365            .unwrap();
366
367        let states = store.load_input_states(&rid).await.unwrap();
368        assert_eq!(states.len(), 1);
369        assert_eq!(
370            states[0].current_state,
371            crate::input_state::InputLifecycleState::Queued
372        );
373    }
374
375    #[tokio::test]
376    async fn commit_session_boundary_returns_authoritative_receipt() {
377        let store = InMemoryRuntimeStore::new();
378        let rid = LogicalRuntimeId::new("test");
379        let run_id = RunId::new();
380        let input_id = InputId::new();
381        let session = meerkat_core::Session::new();
382        let snapshot = serde_json::to_vec(&session).unwrap();
383
384        let receipt = store
385            .commit_session_boundary(
386                &rid,
387                SessionDelta {
388                    session_snapshot: snapshot,
389                },
390                run_id.clone(),
391                RunApplyBoundary::Immediate,
392                vec![input_id.clone()],
393                vec![InputState::new_accepted(input_id)],
394            )
395            .await
396            .unwrap();
397
398        assert_eq!(receipt.sequence, 0);
399        assert_eq!(receipt.run_id, run_id);
400        assert!(receipt.conversation_digest.is_some());
401        let loaded = store
402            .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
403            .await
404            .unwrap();
405        assert!(loaded.is_some(), "receipt should be persisted");
406        let Some(loaded) = loaded else {
407            unreachable!("asserted above");
408        };
409        assert_eq!(loaded, receipt);
410    }
411
412    #[tokio::test]
413    async fn multiple_runtimes_isolated() {
414        let store = InMemoryRuntimeStore::new();
415        let rid1 = LogicalRuntimeId::new("runtime-1");
416        let rid2 = LogicalRuntimeId::new("runtime-2");
417
418        store
419            .persist_input_state(&rid1, &InputState::new_accepted(InputId::new()))
420            .await
421            .unwrap();
422        store
423            .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
424            .await
425            .unwrap();
426        store
427            .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
428            .await
429            .unwrap();
430
431        let s1 = store.load_input_states(&rid1).await.unwrap();
432        let s2 = store.load_input_states(&rid2).await.unwrap();
433        assert_eq!(s1.len(), 1);
434        assert_eq!(s2.len(), 2);
435    }
436
437    #[tokio::test]
438    async fn load_session_snapshot_roundtrip() {
439        let store = InMemoryRuntimeStore::new();
440        let rid = LogicalRuntimeId::new("runtime");
441        let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
442
443        store
444            .atomic_apply(
445                &rid,
446                Some(SessionDelta {
447                    session_snapshot: snapshot.clone(),
448                }),
449                make_receipt(RunId::new(), 0),
450                vec![],
451                None,
452            )
453            .await
454            .unwrap();
455
456        let loaded = store.load_session_snapshot(&rid).await.unwrap();
457        assert_eq!(loaded, Some(snapshot));
458    }
459}