1use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::IndexMap;
10use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
11use meerkat_core::lifecycle::{InputId, RunBoundaryReceipt, RunId};
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::Mutex;
14#[cfg(target_arch = "wasm32")]
15use tokio_with_wasm::alias::sync::Mutex;
16
17use super::{RuntimeStore, RuntimeStoreError, SessionDelta, authoritative_receipt};
18use crate::identifiers::LogicalRuntimeId;
19use crate::input_state::InputState;
20use crate::runtime_state::RuntimeState;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24struct ReceiptKey {
25 runtime_id: String,
26 run_id: RunId,
27 sequence: u64,
28}
29
30#[derive(Debug, Default)]
32struct Inner {
33 input_states: HashMap<String, IndexMap<InputId, InputState>>,
35 receipts: HashMap<ReceiptKey, RunBoundaryReceipt>,
37 sessions: HashMap<String, Vec<u8>>,
39 runtime_states: HashMap<String, RuntimeState>,
41}
42
43#[derive(Debug, Clone)]
45pub struct InMemoryRuntimeStore {
46 inner: Arc<Mutex<Inner>>,
47}
48
49impl InMemoryRuntimeStore {
50 pub fn new() -> Self {
51 Self {
52 inner: Arc::new(Mutex::new(Inner::default())),
53 }
54 }
55}
56
57impl Default for InMemoryRuntimeStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
64#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
65impl RuntimeStore for InMemoryRuntimeStore {
66 async fn commit_session_boundary(
67 &self,
68 runtime_id: &LogicalRuntimeId,
69 session_delta: SessionDelta,
70 run_id: RunId,
71 boundary: RunApplyBoundary,
72 contributing_input_ids: Vec<InputId>,
73 input_updates: Vec<InputState>,
74 ) -> Result<RunBoundaryReceipt, RuntimeStoreError> {
75 let mut inner = self.inner.lock().await;
76 let rid = runtime_id.0.clone();
77 let sequence = inner
78 .receipts
79 .keys()
80 .filter(|key| key.runtime_id == rid && key.run_id == run_id)
81 .map(|key| key.sequence)
82 .max()
83 .map(|seq| seq + 1)
84 .unwrap_or(0);
85 let receipt = authoritative_receipt(
86 Some(&session_delta),
87 run_id,
88 boundary,
89 contributing_input_ids,
90 sequence,
91 )?;
92 let mut input_updates = input_updates;
93 for state in &mut input_updates {
94 state.last_run_id = Some(receipt.run_id.clone());
95 state.last_boundary_sequence = Some(receipt.sequence);
96 }
97
98 inner
99 .sessions
100 .insert(rid.clone(), session_delta.session_snapshot);
101 let key = ReceiptKey {
102 runtime_id: rid.clone(),
103 run_id: receipt.run_id.clone(),
104 sequence: receipt.sequence,
105 };
106 inner.receipts.insert(key, receipt.clone());
107
108 let states = inner.input_states.entry(rid).or_default();
109 for state in input_updates {
110 states.insert(state.input_id.clone(), state);
111 }
112
113 Ok(receipt)
114 }
115
116 async fn atomic_apply(
117 &self,
118 runtime_id: &LogicalRuntimeId,
119 session_delta: Option<SessionDelta>,
120 receipt: RunBoundaryReceipt,
121 input_updates: Vec<InputState>,
122 _session_store_key: Option<meerkat_core::types::SessionId>,
123 ) -> Result<(), RuntimeStoreError> {
124 let mut inner = self.inner.lock().await;
128
129 let rid = runtime_id.0.clone();
131
132 if let Some(delta) = session_delta {
134 inner.sessions.insert(rid.clone(), delta.session_snapshot);
135 }
136
137 let key = ReceiptKey {
139 runtime_id: rid.clone(),
140 run_id: receipt.run_id.clone(),
141 sequence: receipt.sequence,
142 };
143 inner.receipts.insert(key, receipt);
144
145 let states = inner.input_states.entry(rid).or_default();
147 for state in input_updates {
148 states.insert(state.input_id.clone(), state);
149 }
150
151 Ok(())
152 }
153
154 async fn load_input_states(
155 &self,
156 runtime_id: &LogicalRuntimeId,
157 ) -> Result<Vec<InputState>, RuntimeStoreError> {
158 let inner = self.inner.lock().await;
159 let states = inner
160 .input_states
161 .get(&runtime_id.0)
162 .map(|m| m.values().cloned().collect())
163 .unwrap_or_default();
164 Ok(states)
165 }
166
167 async fn load_boundary_receipt(
168 &self,
169 runtime_id: &LogicalRuntimeId,
170 run_id: &RunId,
171 sequence: u64,
172 ) -> Result<Option<RunBoundaryReceipt>, RuntimeStoreError> {
173 let inner = self.inner.lock().await;
174 let key = ReceiptKey {
175 runtime_id: runtime_id.0.clone(),
176 run_id: run_id.clone(),
177 sequence,
178 };
179 Ok(inner.receipts.get(&key).cloned())
180 }
181
182 async fn load_session_snapshot(
183 &self,
184 runtime_id: &LogicalRuntimeId,
185 ) -> Result<Option<Vec<u8>>, RuntimeStoreError> {
186 let inner = self.inner.lock().await;
187 Ok(inner.sessions.get(&runtime_id.0).cloned())
188 }
189
190 async fn persist_input_state(
191 &self,
192 runtime_id: &LogicalRuntimeId,
193 state: &InputState,
194 ) -> Result<(), RuntimeStoreError> {
195 let mut inner = self.inner.lock().await;
196 let states = inner.input_states.entry(runtime_id.0.clone()).or_default();
197 states.insert(state.input_id.clone(), state.clone());
198 Ok(())
199 }
200
201 async fn load_input_state(
202 &self,
203 runtime_id: &LogicalRuntimeId,
204 input_id: &InputId,
205 ) -> Result<Option<InputState>, RuntimeStoreError> {
206 let inner = self.inner.lock().await;
207 let state = inner
208 .input_states
209 .get(&runtime_id.0)
210 .and_then(|m| m.get(input_id).cloned());
211 Ok(state)
212 }
213
214 async fn persist_runtime_state(
215 &self,
216 runtime_id: &LogicalRuntimeId,
217 state: RuntimeState,
218 ) -> Result<(), RuntimeStoreError> {
219 let mut inner = self.inner.lock().await;
220 inner.runtime_states.insert(runtime_id.0.clone(), state);
221 Ok(())
222 }
223
224 async fn load_runtime_state(
225 &self,
226 runtime_id: &LogicalRuntimeId,
227 ) -> Result<Option<RuntimeState>, RuntimeStoreError> {
228 let inner = self.inner.lock().await;
229 Ok(inner.runtime_states.get(&runtime_id.0).copied())
230 }
231
232 async fn atomic_lifecycle_commit(
233 &self,
234 runtime_id: &LogicalRuntimeId,
235 runtime_state: RuntimeState,
236 input_states: &[InputState],
237 ) -> Result<(), RuntimeStoreError> {
238 let mut inner = self.inner.lock().await;
239 let rid = runtime_id.0.clone();
240
241 inner.runtime_states.insert(rid.clone(), runtime_state);
243 let states = inner.input_states.entry(rid).or_default();
244 for state in input_states {
245 states.insert(state.input_id.clone(), state.clone());
246 }
247
248 Ok(())
249 }
250}
251
252#[cfg(test)]
253#[allow(clippy::unwrap_used)]
254mod tests {
255 use super::*;
256 use meerkat_core::lifecycle::run_primitive::RunApplyBoundary;
257
258 fn make_receipt(run_id: RunId, seq: u64) -> RunBoundaryReceipt {
259 RunBoundaryReceipt {
260 run_id,
261 boundary: RunApplyBoundary::RunStart,
262 contributing_input_ids: vec![],
263 conversation_digest: None,
264 message_count: 0,
265 sequence: seq,
266 }
267 }
268
269 #[tokio::test]
270 async fn atomic_apply_roundtrip() {
271 let store = InMemoryRuntimeStore::new();
272 let rid = LogicalRuntimeId::new("test-runtime");
273 let run_id = RunId::new();
274 let input_id = InputId::new();
275
276 let state = InputState::new_accepted(input_id.clone());
277 let receipt = make_receipt(run_id.clone(), 0);
278
279 store
280 .atomic_apply(
281 &rid,
282 Some(SessionDelta {
283 session_snapshot: b"session-data".to_vec(),
284 }),
285 receipt.clone(),
286 vec![state],
287 None,
288 )
289 .await
290 .unwrap();
291
292 let states = store.load_input_states(&rid).await.unwrap();
294 assert_eq!(states.len(), 1);
295 assert_eq!(states[0].input_id, input_id);
296
297 let loaded = store.load_boundary_receipt(&rid, &run_id, 0).await.unwrap();
299 assert!(loaded.is_some());
300 }
301
302 #[tokio::test]
303 async fn persist_and_load_single_state() {
304 let store = InMemoryRuntimeStore::new();
305 let rid = LogicalRuntimeId::new("test");
306 let input_id = InputId::new();
307 let state = InputState::new_accepted(input_id.clone());
308
309 store.persist_input_state(&rid, &state).await.unwrap();
310
311 let loaded = store.load_input_state(&rid, &input_id).await.unwrap();
312 assert!(loaded.is_some());
313 assert_eq!(loaded.unwrap().input_id, input_id);
314 }
315
316 #[tokio::test]
317 async fn load_nonexistent_returns_none() {
318 let store = InMemoryRuntimeStore::new();
319 let rid = LogicalRuntimeId::new("test");
320
321 let states = store.load_input_states(&rid).await.unwrap();
322 assert!(states.is_empty());
323
324 let state = store.load_input_state(&rid, &InputId::new()).await.unwrap();
325 assert!(state.is_none());
326
327 let receipt = store
328 .load_boundary_receipt(&rid, &RunId::new(), 0)
329 .await
330 .unwrap();
331 assert!(receipt.is_none());
332 }
333
334 #[tokio::test]
335 async fn atomic_apply_updates_existing() {
336 let store = InMemoryRuntimeStore::new();
337 let rid = LogicalRuntimeId::new("test");
338 let input_id = InputId::new();
339
340 let state1 = InputState::new_accepted(input_id.clone());
342 store
343 .atomic_apply(
344 &rid,
345 None,
346 make_receipt(RunId::new(), 0),
347 vec![state1],
348 None,
349 )
350 .await
351 .unwrap();
352
353 let mut state2 = InputState::new_accepted(input_id.clone());
355 state2.current_state = crate::input_state::InputLifecycleState::Queued;
356 store
357 .atomic_apply(
358 &rid,
359 None,
360 make_receipt(RunId::new(), 1),
361 vec![state2],
362 None,
363 )
364 .await
365 .unwrap();
366
367 let states = store.load_input_states(&rid).await.unwrap();
368 assert_eq!(states.len(), 1);
369 assert_eq!(
370 states[0].current_state,
371 crate::input_state::InputLifecycleState::Queued
372 );
373 }
374
375 #[tokio::test]
376 async fn commit_session_boundary_returns_authoritative_receipt() {
377 let store = InMemoryRuntimeStore::new();
378 let rid = LogicalRuntimeId::new("test");
379 let run_id = RunId::new();
380 let input_id = InputId::new();
381 let session = meerkat_core::Session::new();
382 let snapshot = serde_json::to_vec(&session).unwrap();
383
384 let receipt = store
385 .commit_session_boundary(
386 &rid,
387 SessionDelta {
388 session_snapshot: snapshot,
389 },
390 run_id.clone(),
391 RunApplyBoundary::Immediate,
392 vec![input_id.clone()],
393 vec![InputState::new_accepted(input_id)],
394 )
395 .await
396 .unwrap();
397
398 assert_eq!(receipt.sequence, 0);
399 assert_eq!(receipt.run_id, run_id);
400 assert!(receipt.conversation_digest.is_some());
401 let loaded = store
402 .load_boundary_receipt(&rid, &receipt.run_id, receipt.sequence)
403 .await
404 .unwrap();
405 assert!(loaded.is_some(), "receipt should be persisted");
406 let Some(loaded) = loaded else {
407 unreachable!("asserted above");
408 };
409 assert_eq!(loaded, receipt);
410 }
411
412 #[tokio::test]
413 async fn multiple_runtimes_isolated() {
414 let store = InMemoryRuntimeStore::new();
415 let rid1 = LogicalRuntimeId::new("runtime-1");
416 let rid2 = LogicalRuntimeId::new("runtime-2");
417
418 store
419 .persist_input_state(&rid1, &InputState::new_accepted(InputId::new()))
420 .await
421 .unwrap();
422 store
423 .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
424 .await
425 .unwrap();
426 store
427 .persist_input_state(&rid2, &InputState::new_accepted(InputId::new()))
428 .await
429 .unwrap();
430
431 let s1 = store.load_input_states(&rid1).await.unwrap();
432 let s2 = store.load_input_states(&rid2).await.unwrap();
433 assert_eq!(s1.len(), 1);
434 assert_eq!(s2.len(), 2);
435 }
436
437 #[tokio::test]
438 async fn load_session_snapshot_roundtrip() {
439 let store = InMemoryRuntimeStore::new();
440 let rid = LogicalRuntimeId::new("runtime");
441 let snapshot = serde_json::to_vec(&meerkat_core::Session::new()).unwrap();
442
443 store
444 .atomic_apply(
445 &rid,
446 Some(SessionDelta {
447 session_snapshot: snapshot.clone(),
448 }),
449 make_receipt(RunId::new(), 0),
450 vec![],
451 None,
452 )
453 .await
454 .unwrap();
455
456 let loaded = store.load_session_snapshot(&rid).await.unwrap();
457 assert_eq!(loaded, Some(snapshot));
458 }
459}