Skip to main content

meerkat_runtime/store/
memory.rs

1//! InMemoryRuntimeStore — in-memory implementation for testing/ephemeral.
2//!
3//! Uses `tokio::sync::Mutex` per the in-memory concurrency rule.
4//! All mutations complete inside one lock acquisition (no lock held across .await).
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::IndexMap;
10use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{RuntimeStore, RuntimeStoreError, SessionDelta, authoritative_receipt};
18use crate::identifiers::LogicalRuntimeId;
19use crate::input_state::InputState;
20use crate::runtime_state::RuntimeState;
21
22/// Receipt key: (runtime_id, run_id, sequence).
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24struct ReceiptKey {
25    runtime_id: String,
26    run_id: RunId,
27    sequence: u64,
28}
29
30/// Inner state protected by the mutex.
31#[derive(Debug, Default)]
32struct Inner {
33    /// runtime_id → (input_id → InputState). IndexMap for deterministic iteration order.
34    input_states: HashMap<String, IndexMap<InputId, InputState>>,
35    /// Receipt storage.
36    receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
37    /// Session snapshots (opaque bytes).
38    sessions: HashMap<String, Vec<u8>>,
39    /// Persisted runtime state.
40    runtime_states: HashMap<String, RuntimeState>,
41}
42
43/// In-memory runtime store. Thread-safe via `tokio::sync::Mutex`.
44#[derive(Debug, Clone)]
45pub struct InMemoryRuntimeStore {
46    inner: Arc<Mutex<Inner>>,
47}
48
49impl InMemoryRuntimeStore {
50    pub fn new() -> Self {
51        Self {
52            inner: Arc::new(Mutex::new(Inner::default())),
53        }
54    }
55}
56
57impl Default for InMemoryRuntimeStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
64#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
65impl RuntimeStore for InMemoryRuntimeStore {
66    async fn commit_session_boundary(
67        &self,
68        runtime_id: &LogicalRuntimeId,
69        session_delta: SessionDelta,
70        run_id: RunId,
71        boundary: RunApplyBoundary,
72        contributing_input_ids: Vec<InputId>,
73        input_updates: Vec<InputState>,
74    ) -> Result<RunBoundaryReceipt, RuntimeStoreError> {
75        let mut inner = self.inner.lock().await;
76        let rid = runtime_id.0.clone();
77        let sequence = inner
78            .receipts
79            .keys()
80            .filter(|key| key.runtime_id == rid && key.run_id == run_id)
81            .map(|key| key.sequence)
82            .max()
83            .map(|seq| seq + 1)
84            .unwrap_or(0);
85        let receipt = authoritative_receipt(
86            Some(&session_delta),
87            run_id,
88            boundary,
89            contributing_input_ids,
90            sequence,
91        )?;
92        let mut input_updates = input_updates;
93        for state in &mut input_updates {
94            state
95                .authority_mut()
96                .stamp_receipt_metadata(receipt.run_id.clone(), receipt.sequence);
97        }
98
99        inner
100            .sessions
101            .insert(rid.clone(), session_delta.session_snapshot);
102        let key = ReceiptKey {
103            runtime_id: rid.clone(),
104            run_id: receipt.run_id.clone(),
105            sequence: receipt.sequence,
106        };
107        inner.receipts.insert(key, receipt.clone());
108
109        let states = inner.input_states.entry(rid).or_default();
110        for state in input_updates {
111            states.insert(state.input_id.clone(), state);
112        }
113
114        Ok(receipt)
115    }
116
117    async fn atomic_apply(
118        &self,
119        runtime_id: &LogicalRuntimeId,
120        session_delta: Option<SessionDelta>,
121        receipt: RunBoundaryReceipt,
122        input_updates: Vec<InputState>,
123        _session_store_key: Option<meerkat_core::types::SessionId>,
124    ) -> Result<(), RuntimeStoreError> {
125        // InMemoryRuntimeStore ignores session_store_key — there's no shared
126        // sessions table in memory. The session snapshot is stored in the
127        // runtime's own snapshot map.
128        let mut inner = self.inner.lock().await;
129
130        // All writes in one lock acquisition (atomic for in-memory)
131        let rid = runtime_id.0.clone();
132
133        // Session delta
134        if let Some(delta) = session_delta {
135            inner.sessions.insert(rid.clone(), delta.session_snapshot);
136        }
137
138        // Receipt
139        let key = ReceiptKey {
140            runtime_id: rid.clone(),
141            run_id: receipt.run_id.clone(),
142            sequence: receipt.sequence,
143        };
144        inner.receipts.insert(key, receipt);
145
146        // Input states
147        let states = inner.input_states.entry(rid).or_default();
148        for state in input_updates {
149            states.insert(state.input_id.clone(), state);
150        }
151
152        Ok(())
153    }
154
155    async fn load_input_states(
156        &self,
157        runtime_id: &LogicalRuntimeId,
158    ) -> Result<Vec<InputState>, RuntimeStoreError> {
159        let inner = self.inner.lock().await;
160        let states = inner
161            .input_states
162            .get(&runtime_id.0)
163            .map(|m| m.values().cloned().collect())
164            .unwrap_or_default();
165        Ok(states)
166    }
167
168    async fn load_boundary_receipt(
169        &self,
170        runtime_id: &LogicalRuntimeId,
171        run_id: &RunId,
172        sequence: u64,
173    ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
174        let inner = self.inner.lock().await;
175        let key = ReceiptKey {
176            runtime_id: runtime_id.0.clone(),
177            run_id: run_id.clone(),
178            sequence,
179        };
180        Ok(inner.receipts.get(&key).cloned())
181    }
182
183    async fn load_session_snapshot(
184        &self,
185        runtime_id: &LogicalRuntimeId,
186    ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
187        let inner = self.inner.lock().await;
188        Ok(inner.sessions.get(&runtime_id.0).cloned())
189    }
190
191    async fn persist_input_state(
192        &self,
193        runtime_id: &LogicalRuntimeId,
194        state: &InputState,
195    ) -> Result<(), RuntimeStoreError> {
196        let mut inner = self.inner.lock().await;
197        let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
198        states.insert(state.input_id.clone(), state.clone());
199        Ok(())
200    }
201
202    async fn load_input_state(
203        &self,
204        runtime_id: &LogicalRuntimeId,
205        input_id: &InputId,
206    ) -> Result<Option<InputState>, RuntimeStoreError> {
207        let inner = self.inner.lock().await;
208        let state = inner
209            .input_states
210            .get(&runtime_id.0)
211            .and_then(|m| m.get(input_id).cloned());
212        Ok(state)
213    }
214
215    async fn persist_runtime_state(
216        &self,
217        runtime_id: &LogicalRuntimeId,
218        state: RuntimeState,
219    ) -> Result<(), RuntimeStoreError> {
220        let mut inner = self.inner.lock().await;
221        inner.runtime_states.insert(runtime_id.0.clone(), state);
222        Ok(())
223    }
224
225    async fn load_runtime_state(
226        &self,
227        runtime_id: &LogicalRuntimeId,
228    ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
229        let inner = self.inner.lock().await;
230        Ok(inner.runtime_states.get(&runtime_id.0).copied())
231    }
232
233    async fn atomic_lifecycle_commit(
234        &self,
235        runtime_id: &LogicalRuntimeId,
236        runtime_state: RuntimeState,
237        input_states: &[InputState],
238    ) -> Result<(), RuntimeStoreError> {
239        let mut inner = self.inner.lock().await;
240        let rid = runtime_id.0.clone();
241
242        // Single lock acquisition — atomic for in-memory
243        inner.runtime_states.insert(rid.clone(), runtime_state);
244        let states = inner.input_states.entry(rid).or_default();
245        for state in input_states {
246            states.insert(state.input_id.clone(), state.clone());
247        }
248
249        Ok(())
250    }
251}
252
253#[cfg(test)]
254#[allow(clippy::unwrap_used)]
255mod tests {
256    use super::*;
257    use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
258
259    fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
260        RunBoundaryReceipt {
261            run_id,
262            boundary: RunApplyBoundary::RunStart,
263            contributing_input_ids: vec![],
264            conversation_digest: None,
265            message_count: 0,
266            sequence: seq,
267        }
268    }
269
270    #[tokio::test]
271    async fn atomic_apply_roundtrip() {
272        let store = InMemoryRuntimeStore::new();
273        let rid = LogicalRuntimeId::new("test-runtime");
274        let run_id = RunId::new();
275        let input_id = InputId::new();
276
277        let state = InputState::new_accepted(input_id.clone());
278        let receipt = make_receipt(run_id.clone(), 0);
279
280        store
281            .atomic_apply(
282                &rid,
283                Some(SessionDelta {
284                    session_snapshot: b"session-data".to_vec(),
285                }),
286                receipt.clone(),
287                vec![state],
288                None,
289            )
290            .await
291            .unwrap();
292
293        // Load input states
294        let states = store.load_input_states(&rid).await.unwrap();
295        assert_eq!(states.len(), 1);
296        assert_eq!(states[0].input_id, input_id);
297
298        // Load receipt
299        let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
300        assert!(loaded.is_some());
301    }
302
303    #[tokio::test]
304    async fn persist_and_load_single_state() {
305        let store = InMemoryRuntimeStore::new();
306        let rid = LogicalRuntimeId::new("test");
307        let input_id = InputId::new();
308        let state = InputState::new_accepted(input_id.clone());
309
310        store.persist_input_state(&rid, &state).await.unwrap();
311
312        let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
313        assert!(loaded.is_some());
314        assert_eq!(loaded.unwrap().input_id, input_id);
315    }
316
317    #[tokio::test]
318    async fn load_nonexistent_returns_none() {
319        let store = InMemoryRuntimeStore::new();
320        let rid = LogicalRuntimeId::new("test");
321
322        let states = store.load_input_states(&rid).await.unwrap();
323        assert!(states.is_empty());
324
325        let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
326        assert!(state.is_none());
327
328        let receipt = store
329            .load_boundary_receipt(&rid, &RunId::new(), 0)
330            .await
331            .unwrap();
332        assert!(receipt.is_none());
333    }
334
335    #[tokio::test]
336    async fn atomic_apply_updates_existing() {
337        let store = InMemoryRuntimeStore::new();
338        let rid = LogicalRuntimeId::new("test");
339        let input_id = InputId::new();
340
341        // First write
342        let state1 = InputState::new_accepted(input_id.clone());
343        store
344            .atomic_apply(
345                &rid,
346                None,
347                make_receipt(RunId::new(), 0),
348                vec![state1],
349                None,
350            )
351            .await
352            .unwrap();
353
354        // Second write with updated state
355        let mut state2 = InputState::new_accepted(input_id.clone());
356        let _ = state2.apply(crate::input_lifecycle_authority::InputLifecycleInput::QueueAccepted);
357        store
358            .atomic_apply(
359                &rid,
360                None,
361                make_receipt(RunId::new(), 1),
362                vec![state2],
363                None,
364            )
365            .await
366            .unwrap();
367
368        let states = store.load_input_states(&rid).await.unwrap();
369        assert_eq!(states.len(), 1);
370        assert_eq!(
371            states[0].current_state(),
372            crate::input_state::InputLifecycleState::Queued
373        );
374    }
375
376    #[tokio::test]
377    async fn commit_session_boundary_returns_authoritative_receipt() {
378        let store = InMemoryRuntimeStore::new();
379        let rid = LogicalRuntimeId::new("test");
380        let run_id = RunId::new();
381        let input_id = InputId::new();
382        let session = meerkat_core::Session::new();
383        let snapshot = serde_json::to_vec(&session).unwrap();
384
385        let receipt = store
386            .commit_session_boundary(
387                &rid,
388                SessionDelta {
389                    session_snapshot: snapshot,
390                },
391                run_id.clone(),
392                RunApplyBoundary::Immediate,
393                vec![input_id.clone()],
394                vec![InputState::new_accepted(input_id)],
395            )
396            .await
397            .unwrap();
398
399        assert_eq!(receipt.sequence, 0);
400        assert_eq!(receipt.run_id, run_id);
401        assert!(receipt.conversation_digest.is_some());
402        let loaded = store
403            .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
404            .await
405            .unwrap();
406        assert!(loaded.is_some(), "receipt should be persisted");
407        let Some(loaded) = loaded else {
408            unreachable!("asserted above");
409        };
410        assert_eq!(loaded, receipt);
411    }
412
413    #[tokio::test]
414    async fn multiple_runtimes_isolated() {
415        let store = InMemoryRuntimeStore::new();
416        let rid1 = LogicalRuntimeId::new("runtime-1");
417        let rid2 = LogicalRuntimeId::new("runtime-2");
418
419        store
420            .persist_input_state(&rid1, &InputState::new_accepted(InputId::new()))
421            .await
422            .unwrap();
423        store
424            .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
425            .await
426            .unwrap();
427        store
428            .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
429            .await
430            .unwrap();
431
432        let s1 = store.load_input_states(&rid1).await.unwrap();
433        let s2 = store.load_input_states(&rid2).await.unwrap();
434        assert_eq!(s1.len(), 1);
435        assert_eq!(s2.len(), 2);
436    }
437
438    #[tokio::test]
439    async fn load_session_snapshot_roundtrip() {
440        let store = InMemoryRuntimeStore::new();
441        let rid = LogicalRuntimeId::new("runtime");
442        let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
443
444        store
445            .atomic_apply(
446                &rid,
447                Some(SessionDelta {
448                    session_snapshot: snapshot.clone(),
449                }),
450                make_receipt(RunId::new(), 0),
451                vec![],
452                None,
453            )
454            .await
455            .unwrap();
456
457        let loaded = store.load_session_snapshot(&rid).await.unwrap();
458        assert_eq!(loaded, Some(snapshot));
459    }
460}