1use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::IndexMap;
10use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{RuntimeStore, RuntimeStoreError, SessionDelta, authoritative_receipt};
18use crate::identifiers::LogicalRuntimeId;
19use crate::input_state::InputState;
20use crate::runtime_state::RuntimeState;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24struct ReceiptKey {
25 runtime_id: String,
26 run_id: RunId,
27 sequence: u64,
28}
29
30#[derive(Debug, Default)]
32struct Inner {
33 input_states: HashMap<String, IndexMap<InputId, InputState>>,
35 receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
37 sessions: HashMap<String, Vec<u8>>,
39 runtime_states: HashMap<String, RuntimeState>,
41}
42
43#[derive(Debug, Clone)]
45pub struct InMemoryRuntimeStore {
46 inner: Arc<Mutex<Inner>>,
47}
48
49impl InMemoryRuntimeStore {
50 pub fn new() -> Self {
51 Self {
52 inner: Arc::new(Mutex::new(Inner::default())),
53 }
54 }
55}
56
57impl Default for InMemoryRuntimeStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
64#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
65impl RuntimeStore for InMemoryRuntimeStore {
66 async fn commit_session_boundary(
67 &self,
68 runtime_id: &LogicalRuntimeId,
69 session_delta: SessionDelta,
70 run_id: RunId,
71 boundary: RunApplyBoundary,
72 contributing_input_ids: Vec<InputId>,
73 input_updates: Vec<InputState>,
74 ) -> Result<RunBoundaryReceipt, RuntimeStoreError> {
75 let mut inner = self.inner.lock().await;
76 let rid = runtime_id.0.clone();
77 let sequence = inner
78 .receipts
79 .keys()
80 .filter(|key| key.runtime_id == rid && key.run_id == run_id)
81 .map(|key| key.sequence)
82 .max()
83 .map(|seq| seq + 1)
84 .unwrap_or(0);
85 let receipt = authoritative_receipt(
86 Some(&session_delta),
87 run_id,
88 boundary,
89 contributing_input_ids,
90 sequence,
91 )?;
92 let mut input_updates = input_updates;
93 for state in &mut input_updates {
94 state
95 .authority_mut()
96 .stamp_receipt_metadata(receipt.run_id.clone(), receipt.sequence);
97 }
98
99 inner
100 .sessions
101 .insert(rid.clone(), session_delta.session_snapshot);
102 let key = ReceiptKey {
103 runtime_id: rid.clone(),
104 run_id: receipt.run_id.clone(),
105 sequence: receipt.sequence,
106 };
107 inner.receipts.insert(key, receipt.clone());
108
109 let states = inner.input_states.entry(rid).or_default();
110 for state in input_updates {
111 states.insert(state.input_id.clone(), state);
112 }
113
114 Ok(receipt)
115 }
116
117 async fn atomic_apply(
118 &self,
119 runtime_id: &LogicalRuntimeId,
120 session_delta: Option<SessionDelta>,
121 receipt: RunBoundaryReceipt,
122 input_updates: Vec<InputState>,
123 _session_store_key: Option<meerkat_core::types::SessionId>,
124 ) -> Result<(), RuntimeStoreError> {
125 let mut inner = self.inner.lock().await;
129
130 let rid = runtime_id.0.clone();
132
133 if let Some(delta) = session_delta {
135 inner.sessions.insert(rid.clone(), delta.session_snapshot);
136 }
137
138 let key = ReceiptKey {
140 runtime_id: rid.clone(),
141 run_id: receipt.run_id.clone(),
142 sequence: receipt.sequence,
143 };
144 inner.receipts.insert(key, receipt);
145
146 let states = inner.input_states.entry(rid).or_default();
148 for state in input_updates {
149 states.insert(state.input_id.clone(), state);
150 }
151
152 Ok(())
153 }
154
155 async fn load_input_states(
156 &self,
157 runtime_id: &LogicalRuntimeId,
158 ) -> Result<Vec<InputState>, RuntimeStoreError> {
159 let inner = self.inner.lock().await;
160 let states = inner
161 .input_states
162 .get(&runtime_id.0)
163 .map(|m| m.values().cloned().collect())
164 .unwrap_or_default();
165 Ok(states)
166 }
167
168 async fn load_boundary_receipt(
169 &self,
170 runtime_id: &LogicalRuntimeId,
171 run_id: &RunId,
172 sequence: u64,
173 ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
174 let inner = self.inner.lock().await;
175 let key = ReceiptKey {
176 runtime_id: runtime_id.0.clone(),
177 run_id: run_id.clone(),
178 sequence,
179 };
180 Ok(inner.receipts.get(&key).cloned())
181 }
182
183 async fn load_session_snapshot(
184 &self,
185 runtime_id: &LogicalRuntimeId,
186 ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
187 let inner = self.inner.lock().await;
188 Ok(inner.sessions.get(&runtime_id.0).cloned())
189 }
190
191 async fn persist_input_state(
192 &self,
193 runtime_id: &LogicalRuntimeId,
194 state: &InputState,
195 ) -> Result<(), RuntimeStoreError> {
196 let mut inner = self.inner.lock().await;
197 let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
198 states.insert(state.input_id.clone(), state.clone());
199 Ok(())
200 }
201
202 async fn load_input_state(
203 &self,
204 runtime_id: &LogicalRuntimeId,
205 input_id: &InputId,
206 ) -> Result<Option<InputState>, RuntimeStoreError> {
207 let inner = self.inner.lock().await;
208 let state = inner
209 .input_states
210 .get(&runtime_id.0)
211 .and_then(|m| m.get(input_id).cloned());
212 Ok(state)
213 }
214
215 async fn persist_runtime_state(
216 &self,
217 runtime_id: &LogicalRuntimeId,
218 state: RuntimeState,
219 ) -> Result<(), RuntimeStoreError> {
220 let mut inner = self.inner.lock().await;
221 inner.runtime_states.insert(runtime_id.0.clone(), state);
222 Ok(())
223 }
224
225 async fn load_runtime_state(
226 &self,
227 runtime_id: &LogicalRuntimeId,
228 ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
229 let inner = self.inner.lock().await;
230 Ok(inner.runtime_states.get(&runtime_id.0).copied())
231 }
232
233 async fn atomic_lifecycle_commit(
234 &self,
235 runtime_id: &LogicalRuntimeId,
236 runtime_state: RuntimeState,
237 input_states: &[InputState],
238 ) -> Result<(), RuntimeStoreError> {
239 let mut inner = self.inner.lock().await;
240 let rid = runtime_id.0.clone();
241
242 inner.runtime_states.insert(rid.clone(), runtime_state);
244 let states = inner.input_states.entry(rid).or_default();
245 for state in input_states {
246 states.insert(state.input_id.clone(), state.clone());
247 }
248
249 Ok(())
250 }
251}
252
253#[cfg(test)]
254#[allow(clippy::unwrap_used)]
255mod tests {
256 use super::*;
257 use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
258
259 fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
260 RunBoundaryReceipt {
261 run_id,
262 boundary: RunApplyBoundary::RunStart,
263 contributing_input_ids: vec![],
264 conversation_digest: None,
265 message_count: 0,
266 sequence: seq,
267 }
268 }
269
270 #[tokio::test]
271 async fn atomic_apply_roundtrip() {
272 let store = InMemoryRuntimeStore::new();
273 let rid = LogicalRuntimeId::new("test-runtime");
274 let run_id = RunId::new();
275 let input_id = InputId::new();
276
277 let state = InputState::new_accepted(input_id.clone());
278 let receipt = make_receipt(run_id.clone(), 0);
279
280 store
281 .atomic_apply(
282 &rid,
283 Some(SessionDelta {
284 session_snapshot: b"session-data".to_vec(),
285 }),
286 receipt.clone(),
287 vec![state],
288 None,
289 )
290 .await
291 .unwrap();
292
293 let states = store.load_input_states(&rid).await.unwrap();
295 assert_eq!(states.len(), 1);
296 assert_eq!(states[0].input_id, input_id);
297
298 let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
300 assert!(loaded.is_some());
301 }
302
303 #[tokio::test]
304 async fn persist_and_load_single_state() {
305 let store = InMemoryRuntimeStore::new();
306 let rid = LogicalRuntimeId::new("test");
307 let input_id = InputId::new();
308 let state = InputState::new_accepted(input_id.clone());
309
310 store.persist_input_state(&rid, &state).await.unwrap();
311
312 let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
313 assert!(loaded.is_some());
314 assert_eq!(loaded.unwrap().input_id, input_id);
315 }
316
317 #[tokio::test]
318 async fn load_nonexistent_returns_none() {
319 let store = InMemoryRuntimeStore::new();
320 let rid = LogicalRuntimeId::new("test");
321
322 let states = store.load_input_states(&rid).await.unwrap();
323 assert!(states.is_empty());
324
325 let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
326 assert!(state.is_none());
327
328 let receipt = store
329 .load_boundary_receipt(&rid, &RunId::new(), 0)
330 .await
331 .unwrap();
332 assert!(receipt.is_none());
333 }
334
335 #[tokio::test]
336 async fn atomic_apply_updates_existing() {
337 let store = InMemoryRuntimeStore::new();
338 let rid = LogicalRuntimeId::new("test");
339 let input_id = InputId::new();
340
341 let state1 = InputState::new_accepted(input_id.clone());
343 store
344 .atomic_apply(
345 &rid,
346 None,
347 make_receipt(RunId::new(), 0),
348 vec![state1],
349 None,
350 )
351 .await
352 .unwrap();
353
354 let mut state2 = InputState::new_accepted(input_id.clone());
356 let _ = state2.apply(crate::input_lifecycle_authority::InputLifecycleInput::QueueAccepted);
357 store
358 .atomic_apply(
359 &rid,
360 None,
361 make_receipt(RunId::new(), 1),
362 vec![state2],
363 None,
364 )
365 .await
366 .unwrap();
367
368 let states = store.load_input_states(&rid).await.unwrap();
369 assert_eq!(states.len(), 1);
370 assert_eq!(
371 states[0].current_state(),
372 crate::input_state::InputLifecycleState::Queued
373 );
374 }
375
376 #[tokio::test]
377 async fn commit_session_boundary_returns_authoritative_receipt() {
378 let store = InMemoryRuntimeStore::new();
379 let rid = LogicalRuntimeId::new("test");
380 let run_id = RunId::new();
381 let input_id = InputId::new();
382 let session = meerkat_core::Session::new();
383 let snapshot = serde_json::to_vec(&session).unwrap();
384
385 let receipt = store
386 .commit_session_boundary(
387 &rid,
388 SessionDelta {
389 session_snapshot: snapshot,
390 },
391 run_id.clone(),
392 RunApplyBoundary::Immediate,
393 vec![input_id.clone()],
394 vec![InputState::new_accepted(input_id)],
395 )
396 .await
397 .unwrap();
398
399 assert_eq!(receipt.sequence, 0);
400 assert_eq!(receipt.run_id, run_id);
401 assert!(receipt.conversation_digest.is_some());
402 let loaded = store
403 .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
404 .await
405 .unwrap();
406 assert!(loaded.is_some(), "receipt should be persisted");
407 let Some(loaded) = loaded else {
408 unreachable!("asserted above");
409 };
410 assert_eq!(loaded, receipt);
411 }
412
413 #[tokio::test]
414 async fn multiple_runtimes_isolated() {
415 let store = InMemoryRuntimeStore::new();
416 let rid1 = LogicalRuntimeId::new("runtime-1");
417 let rid2 = LogicalRuntimeId::new("runtime-2");
418
419 store
420 .persist_input_state(&rid1, &InputState::new_accepted(InputId::new()))
421 .await
422 .unwrap();
423 store
424 .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
425 .await
426 .unwrap();
427 store
428 .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
429 .await
430 .unwrap();
431
432 let s1 = store.load_input_states(&rid1).await.unwrap();
433 let s2 = store.load_input_states(&rid2).await.unwrap();
434 assert_eq!(s1.len(), 1);
435 assert_eq!(s2.len(), 2);
436 }
437
438 #[tokio::test]
439 async fn load_session_snapshot_roundtrip() {
440 let store = InMemoryRuntimeStore::new();
441 let rid = LogicalRuntimeId::new("runtime");
442 let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
443
444 store
445 .atomic_apply(
446 &rid,
447 Some(SessionDelta {
448 session_snapshot: snapshot.clone(),
449 }),
450 make_receipt(RunId::new(), 0),
451 vec![],
452 None,
453 )
454 .await
455 .unwrap();
456
457 let loaded = store.load_session_snapshot(&rid).await.unwrap();
458 assert_eq!(loaded, Some(snapshot));
459 }
460}