1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9
10use indexmap::IndexMap;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{
18 AuthOAuthFlowSnapshotUpdate, MachineLifecycleCommit, RuntimeStore, RuntimeStoreError,
19 SessionDelta,
20};
21use crate::identifiers::LogicalRuntimeId;
22use crate::input_state::StoredInputState;
23use crate::ops_lifecycle::PersistedOpsSnapshot;
24use crate::runtime_state::RuntimeState;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28struct ReceiptKey {
29 runtime_id: String,
30 run_id: RunId,
31 sequence: u64,
32}
33
34#[derive(Debug, Default)]
36struct Inner {
37 input_states: HashMap<String, IndexMap<InputId, StoredInputState>>,
39 receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
41 sessions: HashMap<String, Vec<u8>>,
43 runtime_states: HashMap<String, RuntimeState>,
45 ops_lifecycle_snapshots: HashMap<String, PersistedOpsSnapshot>,
47}
48
49#[derive(Debug, Clone)]
51pub struct InMemoryRuntimeStore {
52 inner: Arc<Mutex<Inner>>,
53 auth_oauth_flow_snapshot: Arc<StdMutex<Option<Vec<u8>>>>,
54}
55
56impl InMemoryRuntimeStore {
57 pub fn new() -> Self {
58 Self {
59 inner: Arc::new(Mutex::new(Inner::default())),
60 auth_oauth_flow_snapshot: Arc::new(StdMutex::new(None)),
61 }
62 }
63}
64
65impl Default for InMemoryRuntimeStore {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
72#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
73impl RuntimeStore for InMemoryRuntimeStore {
74 fn persist_auth_oauth_flow_snapshot(
75 &self,
76 snapshot_json: &[u8],
77 ) -> Result<(), RuntimeStoreError> {
78 *self
79 .auth_oauth_flow_snapshot
80 .lock()
81 .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))? =
82 Some(snapshot_json.to_vec());
83 Ok(())
84 }
85
86 fn load_auth_oauth_flow_snapshot(&self) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
87 self.auth_oauth_flow_snapshot
88 .lock()
89 .map(|snapshot| snapshot.clone())
90 .map_err(|err| RuntimeStoreError::ReadFailed(err.to_string()))
91 }
92
93 fn update_auth_oauth_flow_snapshot(
94 &self,
95 update: &mut AuthOAuthFlowSnapshotUpdate<'_>,
96 ) -> Result<(), RuntimeStoreError> {
97 let mut snapshot = self
98 .auth_oauth_flow_snapshot
99 .lock()
100 .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
101 let next = update(snapshot.as_deref())?;
102 *snapshot = Some(next);
103 Ok(())
104 }
105
106 async fn commit_session_snapshot(
107 &self,
108 runtime_id: &LogicalRuntimeId,
109 session_delta: SessionDelta,
110 ) -> Result<(), RuntimeStoreError> {
111 let _: meerkat_core::Session = serde_json::from_slice(&session_delta.session_snapshot)
112 .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
113 let mut inner = self.inner.lock().await;
114 inner
115 .sessions
116 .insert(runtime_id.0.clone(), session_delta.session_snapshot);
117 Ok(())
118 }
119
120 async fn atomic_apply(
121 &self,
122 runtime_id: &LogicalRuntimeId,
123 session_delta: Option<SessionDelta>,
124 receipt: RunBoundaryReceipt,
125 input_updates: Vec<StoredInputState>,
126 session_store_key: Option<meerkat_core::types::SessionId>,
127 ) -> Result<(), RuntimeStoreError> {
128 let mut inner = self.inner.lock().await;
129
130 let rid = runtime_id.0.clone();
132
133 if let Some(delta) = session_delta {
135 if let Some(session_store_key) = session_store_key {
136 let session: meerkat_core::Session =
137 serde_json::from_slice(&delta.session_snapshot)
138 .map_err(|err| RuntimeStoreError::WriteFailed(err.to_string()))?;
139 if session.id() != &session_store_key {
140 return Err(RuntimeStoreError::SessionKeyMismatch {
141 expected: session_store_key,
142 actual: session.id().clone(),
143 });
144 }
145 }
146 inner.sessions.insert(rid.clone(), delta.session_snapshot);
147 }
148
149 let key = ReceiptKey {
151 runtime_id: rid.clone(),
152 run_id: receipt.run_id.clone(),
153 sequence: receipt.sequence,
154 };
155 inner.receipts.insert(key, receipt);
156
157 let states = inner.input_states.entry(rid).or_default();
159 for bundle in input_updates {
160 states.insert(bundle.state.input_id.clone(), bundle);
161 }
162
163 Ok(())
164 }
165
166 async fn load_input_states(
167 &self,
168 runtime_id: &LogicalRuntimeId,
169 ) -> Result<Vec<StoredInputState>, RuntimeStoreError> {
170 let inner = self.inner.lock().await;
171 let states = inner
172 .input_states
173 .get(&runtime_id.0)
174 .map(|m| m.values().cloned().collect())
175 .unwrap_or_default();
176 Ok(states)
177 }
178
179 async fn load_boundary_receipt(
180 &self,
181 runtime_id: &LogicalRuntimeId,
182 run_id: &RunId,
183 sequence: u64,
184 ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
185 let inner = self.inner.lock().await;
186 let key = ReceiptKey {
187 runtime_id: runtime_id.0.clone(),
188 run_id: run_id.clone(),
189 sequence,
190 };
191 Ok(inner.receipts.get(&key).cloned())
192 }
193
194 async fn load_session_snapshot(
195 &self,
196 runtime_id: &LogicalRuntimeId,
197 ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
198 let inner = self.inner.lock().await;
199 Ok(inner.sessions.get(&runtime_id.0).cloned())
200 }
201
202 async fn persist_input_state(
203 &self,
204 runtime_id: &LogicalRuntimeId,
205 state: &StoredInputState,
206 ) -> Result<(), RuntimeStoreError> {
207 let mut inner = self.inner.lock().await;
208 let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
209 states.insert(state.state.input_id.clone(), state.clone());
210 Ok(())
211 }
212
213 async fn load_input_state(
214 &self,
215 runtime_id: &LogicalRuntimeId,
216 input_id: &InputId,
217 ) -> Result<Option<StoredInputState>, RuntimeStoreError> {
218 let inner = self.inner.lock().await;
219 let state = inner
220 .input_states
221 .get(&runtime_id.0)
222 .and_then(|m| m.get(input_id).cloned());
223 Ok(state)
224 }
225
226 async fn load_runtime_state(
227 &self,
228 runtime_id: &LogicalRuntimeId,
229 ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
230 let inner = self.inner.lock().await;
231 Ok(inner.runtime_states.get(&runtime_id.0).copied())
232 }
233
234 async fn commit_machine_lifecycle(
235 &self,
236 runtime_id: &LogicalRuntimeId,
237 commit: MachineLifecycleCommit,
238 input_states: &[StoredInputState],
239 ) -> Result<(), RuntimeStoreError> {
240 let mut inner = self.inner.lock().await;
241 let rid = runtime_id.0.clone();
242
243 inner
245 .runtime_states
246 .insert(rid.clone(), commit.runtime_state());
247 let states = inner.input_states.entry(rid).or_default();
248 for bundle in input_states {
249 states.insert(bundle.state.input_id.clone(), bundle.clone());
250 }
251
252 Ok(())
253 }
254
255 async fn persist_ops_lifecycle(
256 &self,
257 runtime_id: &LogicalRuntimeId,
258 snapshot: &PersistedOpsSnapshot,
259 ) -> Result<(), RuntimeStoreError> {
260 let mut inner = self.inner.lock().await;
261 inner
262 .ops_lifecycle_snapshots
263 .insert(runtime_id.0.clone(), snapshot.clone());
264 Ok(())
265 }
266
267 async fn load_ops_lifecycle(
268 &self,
269 runtime_id: &LogicalRuntimeId,
270 ) -> Result<Option<PersistedOpsSnapshot>, RuntimeStoreError> {
271 let inner = self.inner.lock().await;
272 Ok(inner.ops_lifecycle_snapshots.get(&runtime_id.0).cloned())
273 }
274}
275
276#[cfg(test)]
277#[allow(clippy::unwrap_used)]
278mod tests {
279 use super::*;
280 use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
281
282 fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
283 RunBoundaryReceipt {
284 run_id,
285 boundary: RunApplyBoundary::RunStart,
286 contributing_input_ids: vec![],
287 conversation_digest: None,
288 message_count: 0,
289 sequence: seq,
290 }
291 }
292
293 #[tokio::test]
294 async fn atomic_apply_roundtrip() {
295 let store = InMemoryRuntimeStore::new();
296 let rid = LogicalRuntimeId::new("test-runtime");
297 let run_id = RunId::new();
298 let input_id = InputId::new();
299
300 let bundle = StoredInputState::new_accepted(input_id.clone());
301 let receipt = make_receipt(run_id.clone(), 0);
302
303 store
304 .atomic_apply(
305 &rid,
306 Some(SessionDelta {
307 session_snapshot: b"session-data".to_vec(),
308 }),
309 receipt.clone(),
310 vec![bundle],
311 None,
312 )
313 .await
314 .unwrap();
315
316 let states = store.load_input_states(&rid).await.unwrap();
318 assert_eq!(states.len(), 1);
319 assert_eq!(states[0].state.input_id, input_id);
320
321 let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
323 assert!(loaded.is_some());
324 }
325
326 #[tokio::test]
327 async fn persist_and_load_single_state() {
328 let store = InMemoryRuntimeStore::new();
329 let rid = LogicalRuntimeId::new("test");
330 let input_id = InputId::new();
331 let bundle = StoredInputState::new_accepted(input_id.clone());
332
333 store.persist_input_state(&rid, &bundle).await.unwrap();
334
335 let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
336 assert!(loaded.is_some());
337 assert_eq!(loaded.unwrap().state.input_id, input_id);
338 }
339
340 #[tokio::test]
341 async fn load_nonexistent_returns_none() {
342 let store = InMemoryRuntimeStore::new();
343 let rid = LogicalRuntimeId::new("test");
344
345 let states = store.load_input_states(&rid).await.unwrap();
346 assert!(states.is_empty());
347
348 let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
349 assert!(state.is_none());
350
351 let receipt = store
352 .load_boundary_receipt(&rid, &RunId::new(), 0)
353 .await
354 .unwrap();
355 assert!(receipt.is_none());
356 }
357
358 #[tokio::test]
359 async fn atomic_apply_updates_existing() {
360 let store = InMemoryRuntimeStore::new();
361 let rid = LogicalRuntimeId::new("test");
362 let input_id = InputId::new();
363
364 let bundle1 = StoredInputState::new_accepted(input_id.clone());
366 store
367 .atomic_apply(
368 &rid,
369 None,
370 make_receipt(RunId::new(), 0),
371 vec![bundle1],
372 None,
373 )
374 .await
375 .unwrap();
376
377 let mut bundle2 = StoredInputState::new_accepted(input_id.clone());
379 bundle2.seed.phase = crate::input_state::InputLifecycleState::Queued;
380 store
381 .atomic_apply(
382 &rid,
383 None,
384 make_receipt(RunId::new(), 1),
385 vec![bundle2],
386 None,
387 )
388 .await
389 .unwrap();
390
391 let states = store.load_input_states(&rid).await.unwrap();
392 assert_eq!(states.len(), 1);
393 assert_eq!(
394 states[0].seed.phase,
395 crate::input_state::InputLifecycleState::Queued
396 );
397 }
398
399 #[tokio::test]
400 async fn atomic_apply_validates_session_store_key_without_aliasing_snapshot() {
401 let store = InMemoryRuntimeStore::new();
402 let rid = LogicalRuntimeId::new("runtime-key");
403 let session = meerkat_core::Session::new();
404 let session_id = session.id().clone();
405 let snapshot = serde_json::to_vec(&session).unwrap();
406
407 store
408 .atomic_apply(
409 &rid,
410 Some(SessionDelta {
411 session_snapshot: snapshot.clone(),
412 }),
413 make_receipt(RunId::new(), 0),
414 vec![],
415 Some(session_id.clone()),
416 )
417 .await
418 .unwrap();
419
420 assert_eq!(
421 store.load_session_snapshot(&rid).await.unwrap(),
422 Some(snapshot)
423 );
424 assert!(
425 store
426 .load_session_snapshot(&LogicalRuntimeId::legacy_session_uuid_alias(&session_id))
427 .await
428 .unwrap()
429 .is_none(),
430 "session_store_key must validate the snapshot identity, not create a raw UUID runtime alias"
431 );
432 }
433
434 #[tokio::test]
435 async fn atomic_apply_rejects_mismatched_session_store_key() {
436 let store = InMemoryRuntimeStore::new();
437 let rid = LogicalRuntimeId::new("runtime-key");
438 let session = meerkat_core::Session::new();
439 let wrong_session_id = meerkat_core::Session::new().id().clone();
440 let snapshot = serde_json::to_vec(&session).unwrap();
441
442 let err = store
443 .atomic_apply(
444 &rid,
445 Some(SessionDelta {
446 session_snapshot: snapshot,
447 }),
448 make_receipt(RunId::new(), 0),
449 vec![],
450 Some(wrong_session_id),
451 )
452 .await
453 .expect_err("mismatched session_store_key should fail");
454
455 assert!(matches!(err, RuntimeStoreError::SessionKeyMismatch { .. }));
456 assert!(store.load_session_snapshot(&rid).await.unwrap().is_none());
457 }
458
459 #[tokio::test]
460 async fn atomic_apply_persists_machine_owned_receipt() {
461 let store = InMemoryRuntimeStore::new();
462 let rid = LogicalRuntimeId::new("test");
463 let run_id = RunId::new();
464 let input_id = InputId::new();
465 let session = meerkat_core::Session::new();
466 let snapshot = serde_json::to_vec(&session).unwrap();
467 let receipt = RunBoundaryReceipt {
468 run_id: run_id.clone(),
469 boundary: RunApplyBoundary::Immediate,
470 contributing_input_ids: vec![input_id.clone()],
471 conversation_digest: Some("machine-owned-digest".to_string()),
472 message_count: 42,
473 sequence: 7,
474 };
475
476 store
477 .atomic_apply(
478 &rid,
479 Some(SessionDelta {
480 session_snapshot: snapshot,
481 }),
482 receipt.clone(),
483 vec![StoredInputState::new_accepted(input_id)],
484 None,
485 )
486 .await
487 .unwrap();
488
489 assert_eq!(receipt.run_id, run_id);
490 assert!(receipt.conversation_digest.is_some());
491 let loaded = store
492 .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
493 .await
494 .unwrap();
495 assert!(loaded.is_some(), "receipt should be persisted");
496 let Some(loaded) = loaded else {
497 unreachable!("asserted above");
498 };
499 assert_eq!(loaded, receipt);
500 }
501
502 #[tokio::test]
503 async fn multiple_runtimes_isolated() {
504 let store = InMemoryRuntimeStore::new();
505 let rid1 = LogicalRuntimeId::new("runtime-1");
506 let rid2 = LogicalRuntimeId::new("runtime-2");
507
508 store
509 .persist_input_state(&rid1, &StoredInputState::new_accepted(InputId::new()))
510 .await
511 .unwrap();
512 store
513 .persist_input_state(&rid2, &StoredInputState::new_accepted(InputId::new()))
514 .await
515 .unwrap();
516 store
517 .persist_input_state(&rid2, &StoredInputState::new_accepted(InputId::new()))
518 .await
519 .unwrap();
520
521 let s1 = store.load_input_states(&rid1).await.unwrap();
522 let s2 = store.load_input_states(&rid2).await.unwrap();
523 assert_eq!(s1.len(), 1);
524 assert_eq!(s2.len(), 2);
525 }
526
527 #[tokio::test]
528 async fn load_session_snapshot_roundtrip() {
529 let store = InMemoryRuntimeStore::new();
530 let rid = LogicalRuntimeId::new("runtime");
531 let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
532
533 store
534 .atomic_apply(
535 &rid,
536 Some(SessionDelta {
537 session_snapshot: snapshot.clone(),
538 }),
539 make_receipt(RunId::new(), 0),
540 vec![],
541 None,
542 )
543 .await
544 .unwrap();
545
546 let loaded = store.load_session_snapshot(&rid).await.unwrap();
547 assert_eq!(loaded, Some(snapshot));
548 }
549}