Skip to main content

meerkat_runtime/store/
memory.rs

1//! InMemoryRuntimeStore — in-memory implementation for testing/ephemeral.
2//!
3//! Uses `tokio::sync::Mutex` per the in-memory concurrency rule.
4//! All mutations complete inside one lock acquisition (no lock held across .await).
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9
10use indexmap::IndexMap;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{
18    AuthOAuthFlowSnapshotUpdate, MachineLifecycleCommit, RuntimeStore, RuntimeStoreError,
19    SessionDelta,
20};
21use crate::identifiers::LogicalRuntimeId;
22use crate::input_state::StoredInputState;
23use crate::ops_lifecycle::PersistedOpsSnapshot;
24use crate::runtime_state::RuntimeState;
25
26/// Receipt key: (runtime_id, run_id, sequence).
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28struct ReceiptKey {
29    runtime_id: String,
30    run_id: RunId,
31    sequence: u64,
32}
33
34/// Inner state protected by the mutex.
35#[derive(Debug, Default)]
36struct Inner {
37    /// runtime_id → (input_id → StoredInputState). IndexMap for deterministic iteration order.
38    input_states: HashMap<String, IndexMap<InputId, StoredInputState>>,
39    /// Receipt storage.
40    receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
41    /// Runtime session snapshots keyed by canonical runtime id.
42    sessions: HashMap<String, Vec<u8>>,
43    /// Persisted runtime state.
44    runtime_states: HashMap<String, RuntimeState>,
45    /// Persisted ops lifecycle snapshots.
46    ops_lifecycle_snapshots: HashMap<String, PersistedOpsSnapshot>,
47}
48
49/// In-memory runtime store. Thread-safe via `tokio::sync::Mutex`.
50#[derive(Debug, Clone)]
51pub struct InMemoryRuntimeStore {
52    inner: Arc<Mutex<Inner>>,
53    auth_oauth_flow_snapshot: Arc<StdMutex<Option<Vec<u8>>>>,
54}
55
56impl InMemoryRuntimeStore {
57    pub fn new() -> Self {
58        Self {
59            inner: Arc::new(Mutex::new(Inner::default())),
60            auth_oauth_flow_snapshot: Arc::new(StdMutex::new(None)),
61        }
62    }
63}
64
65impl Default for InMemoryRuntimeStore {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
72#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
73impl RuntimeStore for InMemoryRuntimeStore {
74    fn persist_auth_oauth_flow_snapshot(
75        &self,
76        snapshot_json: &[u8],
77    ) -> Result<(), RuntimeStoreError> {
78        *self
79            .auth_oauth_flow_snapshot
80            .lock()
81            .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))? =
82            Some(snapshot_json.to_vec());
83        Ok(())
84    }
85
86    fn load_auth_oauth_flow_snapshot(&self) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
87        self.auth_oauth_flow_snapshot
88            .lock()
89            .map(|snapshot| snapshot.clone())
90            .map_err(|err| RuntimeStoreError::ReadFailed(err.to_string()))
91    }
92
93    fn update_auth_oauth_flow_snapshot(
94        &self,
95        update: &mut AuthOAuthFlowSnapshotUpdate<'_>,
96    ) -> Result<(), RuntimeStoreError> {
97        let mut snapshot = self
98            .auth_oauth_flow_snapshot
99            .lock()
100            .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
101        let next = update(snapshot.as_deref())?;
102        *snapshot = Some(next);
103        Ok(())
104    }
105
106    async fn commit_session_snapshot(
107        &self,
108        runtime_id: &LogicalRuntimeId,
109        session_delta: SessionDelta,
110    ) -> Result<(), RuntimeStoreError> {
111        let _: meerkat_core::Session = serde_json::from_slice(&session_delta.session_snapshot)
112            .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
113        let mut inner = self.inner.lock().await;
114        inner
115            .sessions
116            .insert(runtime_id.0.clone(), session_delta.session_snapshot);
117        Ok(())
118    }
119
120    async fn atomic_apply(
121        &self,
122        runtime_id: &LogicalRuntimeId,
123        session_delta: Option<SessionDelta>,
124        receipt: RunBoundaryReceipt,
125        input_updates: Vec<StoredInputState>,
126        session_store_key: Option<meerkat_core::types::SessionId>,
127    ) -> Result<(), RuntimeStoreError> {
128        let mut inner = self.inner.lock().await;
129
130        // All writes in one lock acquisition (atomic for in-memory)
131        let rid = runtime_id.0.clone();
132
133        // Session delta
134        if let Some(delta) = session_delta {
135            if let Some(session_store_key) = session_store_key {
136                let session: meerkat_core::Session =
137                    serde_json::from_slice(&delta.session_snapshot)
138                        .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
139                if session.id() != &session_store_key {
140                    return Err(RuntimeStoreError::SessionKeyMismatch {
141                        expected: session_store_key,
142                        actual: session.id().clone(),
143                    });
144                }
145            }
146            inner.sessions.insert(rid.clone(), delta.session_snapshot);
147        }
148
149        // Receipt
150        let key = ReceiptKey {
151            runtime_id: rid.clone(),
152            run_id: receipt.run_id.clone(),
153            sequence: receipt.sequence,
154        };
155        inner.receipts.insert(key, receipt);
156
157        // Input states
158        let states = inner.input_states.entry(rid).or_default();
159        for bundle in input_updates {
160            states.insert(bundle.state.input_id.clone(), bundle);
161        }
162
163        Ok(())
164    }
165
166    async fn load_input_states(
167        &self,
168        runtime_id: &LogicalRuntimeId,
169    ) -> Result<Vec<StoredInputState>, RuntimeStoreError> {
170        let inner = self.inner.lock().await;
171        let states = inner
172            .input_states
173            .get(&runtime_id.0)
174            .map(|m| m.values().cloned().collect())
175            .unwrap_or_default();
176        Ok(states)
177    }
178
179    async fn load_boundary_receipt(
180        &self,
181        runtime_id: &LogicalRuntimeId,
182        run_id: &RunId,
183        sequence: u64,
184    ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
185        let inner = self.inner.lock().await;
186        let key = ReceiptKey {
187            runtime_id: runtime_id.0.clone(),
188            run_id: run_id.clone(),
189            sequence,
190        };
191        Ok(inner.receipts.get(&key).cloned())
192    }
193
194    async fn load_session_snapshot(
195        &self,
196        runtime_id: &LogicalRuntimeId,
197    ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
198        let inner = self.inner.lock().await;
199        Ok(inner.sessions.get(&runtime_id.0).cloned())
200    }
201
202    async fn persist_input_state(
203        &self,
204        runtime_id: &LogicalRuntimeId,
205        state: &StoredInputState,
206    ) -> Result<(), RuntimeStoreError> {
207        let mut inner = self.inner.lock().await;
208        let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
209        states.insert(state.state.input_id.clone(), state.clone());
210        Ok(())
211    }
212
213    async fn load_input_state(
214        &self,
215        runtime_id: &LogicalRuntimeId,
216        input_id: &InputId,
217    ) -> Result<Option<StoredInputState>, RuntimeStoreError> {
218        let inner = self.inner.lock().await;
219        let state = inner
220            .input_states
221            .get(&runtime_id.0)
222            .and_then(|m| m.get(input_id).cloned());
223        Ok(state)
224    }
225
226    async fn load_runtime_state(
227        &self,
228        runtime_id: &LogicalRuntimeId,
229    ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
230        let inner = self.inner.lock().await;
231        Ok(inner.runtime_states.get(&runtime_id.0).copied())
232    }
233
234    async fn commit_machine_lifecycle(
235        &self,
236        runtime_id: &LogicalRuntimeId,
237        commit: MachineLifecycleCommit,
238        input_states: &[StoredInputState],
239    ) -> Result<(), RuntimeStoreError> {
240        let mut inner = self.inner.lock().await;
241        let rid = runtime_id.0.clone();
242
243        // Single lock acquisition — atomic for in-memory
244        inner
245            .runtime_states
246            .insert(rid.clone(), commit.runtime_state());
247        let states = inner.input_states.entry(rid).or_default();
248        for bundle in input_states {
249            states.insert(bundle.state.input_id.clone(), bundle.clone());
250        }
251
252        Ok(())
253    }
254
255    async fn persist_ops_lifecycle(
256        &self,
257        runtime_id: &LogicalRuntimeId,
258        snapshot: &PersistedOpsSnapshot,
259    ) -> Result<(), RuntimeStoreError> {
260        let mut inner = self.inner.lock().await;
261        inner
262            .ops_lifecycle_snapshots
263            .insert(runtime_id.0.clone(), snapshot.clone());
264        Ok(())
265    }
266
267    async fn load_ops_lifecycle(
268        &self,
269        runtime_id: &LogicalRuntimeId,
270    ) -> Result<Option<PersistedOpsSnapshot>, RuntimeStoreError> {
271        let inner = self.inner.lock().await;
272        Ok(inner.ops_lifecycle_snapshots.get(&runtime_id.0).cloned())
273    }
274}
275
276#[cfg(test)]
277#[allow(clippy::unwrap_used)]
278mod tests {
279    use super::*;
280    use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
281
282    fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
283        RunBoundaryReceipt {
284            run_id,
285            boundary: RunApplyBoundary::RunStart,
286            contributing_input_ids: vec![],
287            conversation_digest: None,
288            message_count: 0,
289            sequence: seq,
290        }
291    }
292
293    #[tokio::test]
294    async fn atomic_apply_roundtrip() {
295        let store = InMemoryRuntimeStore::new();
296        let rid = LogicalRuntimeId::new("test-runtime");
297        let run_id = RunId::new();
298        let input_id = InputId::new();
299
300        let bundle = StoredInputState::new_accepted(input_id.clone());
301        let receipt = make_receipt(run_id.clone(), 0);
302
303        store
304            .atomic_apply(
305                &rid,
306                Some(SessionDelta {
307                    session_snapshot: b"session-data".to_vec(),
308                }),
309                receipt.clone(),
310                vec![bundle],
311                None,
312            )
313            .await
314            .unwrap();
315
316        // Load input states
317        let states = store.load_input_states(&rid).await.unwrap();
318        assert_eq!(states.len(), 1);
319        assert_eq!(states[0].state.input_id, input_id);
320
321        // Load receipt
322        let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
323        assert!(loaded.is_some());
324    }
325
326    #[tokio::test]
327    async fn persist_and_load_single_state() {
328        let store = InMemoryRuntimeStore::new();
329        let rid = LogicalRuntimeId::new("test");
330        let input_id = InputId::new();
331        let bundle = StoredInputState::new_accepted(input_id.clone());
332
333        store.persist_input_state(&rid, &bundle).await.unwrap();
334
335        let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
336        assert!(loaded.is_some());
337        assert_eq!(loaded.unwrap().state.input_id, input_id);
338    }
339
340    #[tokio::test]
341    async fn load_nonexistent_returns_none() {
342        let store = InMemoryRuntimeStore::new();
343        let rid = LogicalRuntimeId::new("test");
344
345        let states = store.load_input_states(&rid).await.unwrap();
346        assert!(states.is_empty());
347
348        let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
349        assert!(state.is_none());
350
351        let receipt = store
352            .load_boundary_receipt(&rid, &RunId::new(), 0)
353            .await
354            .unwrap();
355        assert!(receipt.is_none());
356    }
357
358    #[tokio::test]
359    async fn atomic_apply_updates_existing() {
360        let store = InMemoryRuntimeStore::new();
361        let rid = LogicalRuntimeId::new("test");
362        let input_id = InputId::new();
363
364        // First write
365        let bundle1 = StoredInputState::new_accepted(input_id.clone());
366        store
367            .atomic_apply(
368                &rid,
369                None,
370                make_receipt(RunId::new(), 0),
371                vec![bundle1],
372                None,
373            )
374            .await
375            .unwrap();
376
377        // Second write with updated seed phase
378        let mut bundle2 = StoredInputState::new_accepted(input_id.clone());
379        bundle2.seed.phase = crate::input_state::InputLifecycleState::Queued;
380        store
381            .atomic_apply(
382                &rid,
383                None,
384                make_receipt(RunId::new(), 1),
385                vec![bundle2],
386                None,
387            )
388            .await
389            .unwrap();
390
391        let states = store.load_input_states(&rid).await.unwrap();
392        assert_eq!(states.len(), 1);
393        assert_eq!(
394            states[0].seed.phase,
395            crate::input_state::InputLifecycleState::Queued
396        );
397    }
398
399    #[tokio::test]
400    async fn atomic_apply_validates_session_store_key_without_aliasing_snapshot() {
401        let store = InMemoryRuntimeStore::new();
402        let rid = LogicalRuntimeId::new("runtime-key");
403        let session = meerkat_core::Session::new();
404        let session_id = session.id().clone();
405        let snapshot = serde_json::to_vec(&session).unwrap();
406
407        store
408            .atomic_apply(
409                &rid,
410                Some(SessionDelta {
411                    session_snapshot: snapshot.clone(),
412                }),
413                make_receipt(RunId::new(), 0),
414                vec![],
415                Some(session_id.clone()),
416            )
417            .await
418            .unwrap();
419
420        assert_eq!(
421            store.load_session_snapshot(&rid).await.unwrap(),
422            Some(snapshot)
423        );
424        assert!(
425            store
426                .load_session_snapshot(&LogicalRuntimeId::legacy_session_uuid_alias(&session_id))
427                .await
428                .unwrap()
429                .is_none(),
430            "session_store_key must validate the snapshot identity, not create a raw UUID runtime alias"
431        );
432    }
433
434    #[tokio::test]
435    async fn atomic_apply_rejects_mismatched_session_store_key() {
436        let store = InMemoryRuntimeStore::new();
437        let rid = LogicalRuntimeId::new("runtime-key");
438        let session = meerkat_core::Session::new();
439        let wrong_session_id = meerkat_core::Session::new().id().clone();
440        let snapshot = serde_json::to_vec(&session).unwrap();
441
442        let err = store
443            .atomic_apply(
444                &rid,
445                Some(SessionDelta {
446                    session_snapshot: snapshot,
447                }),
448                make_receipt(RunId::new(), 0),
449                vec![],
450                Some(wrong_session_id),
451            )
452            .await
453            .expect_err("mismatched session_store_key should fail");
454
455        assert!(matches!(err, RuntimeStoreError::SessionKeyMismatch { .. }));
456        assert!(store.load_session_snapshot(&rid).await.unwrap().is_none());
457    }
458
459    #[tokio::test]
460    async fn atomic_apply_persists_machine_owned_receipt() {
461        let store = InMemoryRuntimeStore::new();
462        let rid = LogicalRuntimeId::new("test");
463        let run_id = RunId::new();
464        let input_id = InputId::new();
465        let session = meerkat_core::Session::new();
466        let snapshot = serde_json::to_vec(&session).unwrap();
467        let receipt = RunBoundaryReceipt {
468            run_id: run_id.clone(),
469            boundary: RunApplyBoundary::Immediate,
470            contributing_input_ids: vec![input_id.clone()],
471            conversation_digest: Some("machine-owned-digest".to_string()),
472            message_count: 42,
473            sequence: 7,
474        };
475
476        store
477            .atomic_apply(
478                &rid,
479                Some(SessionDelta {
480                    session_snapshot: snapshot,
481                }),
482                receipt.clone(),
483                vec![StoredInputState::new_accepted(input_id)],
484                None,
485            )
486            .await
487            .unwrap();
488
489        assert_eq!(receipt.run_id, run_id);
490        assert!(receipt.conversation_digest.is_some());
491        let loaded = store
492            .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
493            .await
494            .unwrap();
495        assert!(loaded.is_some(), "receipt should be persisted");
496        let Some(loaded) = loaded else {
497            unreachable!("asserted above");
498        };
499        assert_eq!(loaded, receipt);
500    }
501
502    #[tokio::test]
503    async fn multiple_runtimes_isolated() {
504        let store = InMemoryRuntimeStore::new();
505        let rid1 = LogicalRuntimeId::new("runtime-1");
506        let rid2 = LogicalRuntimeId::new("runtime-2");
507
508        store
509            .persist_input_state(&rid1, &StoredInputState::new_accepted(InputId::new()))
510            .await
511            .unwrap();
512        store
513            .persist_input_state(&rid2, &StoredInputState::new_accepted(InputId::new()))
514            .await
515            .unwrap();
516        store
517            .persist_input_state(&rid2, &StoredInputState::new_accepted(InputId::new()))
518            .await
519            .unwrap();
520
521        let s1 = store.load_input_states(&rid1).await.unwrap();
522        let s2 = store.load_input_states(&rid2).await.unwrap();
523        assert_eq!(s1.len(), 1);
524        assert_eq!(s2.len(), 2);
525    }
526
527    #[tokio::test]
528    async fn load_session_snapshot_roundtrip() {
529        let store = InMemoryRuntimeStore::new();
530        let rid = LogicalRuntimeId::new("runtime");
531        let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
532
533        store
534            .atomic_apply(
535                &rid,
536                Some(SessionDelta {
537                    session_snapshot: snapshot.clone(),
538                }),
539                make_receipt(RunId::new(), 0),
540                vec![],
541                None,
542            )
543            .await
544            .unwrap();
545
546        let loaded = store.load_session_snapshot(&rid).await.unwrap();
547        assert_eq!(loaded, Some(snapshot));
548    }
549}