1use agent_sdk_foundation::events::AgentEventEnvelope;
17use agent_sdk_foundation::llm;
18use agent_sdk_foundation::types::{AgentState, ThreadId, ToolExecution};
19use anyhow::{Context, Result};
20use async_trait::async_trait;
21use std::collections::{BTreeMap, HashMap};
22use std::sync::Arc;
23use std::sync::RwLock;
24use tokio::sync::RwLock as AsyncRwLock;
25
26#[async_trait]
29pub trait MessageStore: Send + Sync {
30 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
35
36 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
41
42 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
47
48 async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
53 Ok(self.get_history(thread_id).await?.len())
54 }
55
56 async fn replace_history(
62 &self,
63 thread_id: &ThreadId,
64 messages: Vec<llm::Message>,
65 ) -> Result<()>;
66}
67
68#[async_trait]
71pub trait StateStore: Send + Sync {
72 async fn save(&self, state: &AgentState) -> Result<()>;
77
78 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
83
84 async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
89}
90
91#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
93pub struct StoredTurnEvents {
94 pub turn: usize,
96 pub events: Vec<AgentEventEnvelope>,
98 pub finished: bool,
100}
101
102#[async_trait]
108pub trait EventStore: Send + Sync {
109 async fn append(
114 &self,
115 thread_id: &ThreadId,
116 turn: usize,
117 envelope: AgentEventEnvelope,
118 ) -> Result<()>;
119
120 async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()>;
125
126 async fn get_turn(&self, thread_id: &ThreadId, turn: usize)
131 -> Result<Option<StoredTurnEvents>>;
132
133 async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>>;
138
139 async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
144 let turns = self.get_turns(thread_id).await?;
145 Ok(turns
146 .into_iter()
147 .flat_map(|turn| turn.events.into_iter())
148 .collect())
149 }
150
151 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
156}
157
158#[async_trait]
166pub trait ToolExecutionStore: Send + Sync {
167 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
172
173 async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
178
179 async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
184
185 async fn get_execution_by_operation_id(
190 &self,
191 operation_id: &str,
192 ) -> Result<Option<ToolExecution>>;
193}
194
195#[derive(Default)]
198pub struct InMemoryStore {
199 messages: RwLock<HashMap<String, Vec<llm::Message>>>,
200 states: RwLock<HashMap<String, AgentState>>,
201}
202
203impl InMemoryStore {
204 #[must_use]
205 pub fn new() -> Self {
206 Self::default()
207 }
208}
209
210#[derive(Default)]
211struct InMemoryEventStoreInner {
212 turns: AsyncRwLock<HashMap<String, BTreeMap<usize, StoredTurnEvents>>>,
213}
214
215#[derive(Clone, Default)]
219pub struct InMemoryEventStore {
220 inner: Arc<InMemoryEventStoreInner>,
221}
222
223impl InMemoryEventStore {
224 #[must_use]
225 pub fn new() -> Self {
226 Self::default()
227 }
228
229 async fn update_turn(
230 &self,
231 thread_id: &ThreadId,
232 turn: usize,
233 update: impl FnOnce(&mut StoredTurnEvents) -> Result<()>,
234 ) -> Result<()> {
235 let mut turns = self.inner.turns.write().await;
236 let stored_turn = turns
237 .entry(thread_id.0.clone())
238 .or_default()
239 .entry(turn)
240 .or_insert_with(|| StoredTurnEvents {
241 turn,
242 events: Vec::new(),
243 finished: false,
244 });
245 let result = update(stored_turn);
246 drop(turns);
247 result
248 }
249}
250
251#[async_trait]
252impl MessageStore for InMemoryStore {
253 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
254 self.messages
255 .write()
256 .ok()
257 .context("lock poisoned")?
258 .entry(thread_id.0.clone())
259 .or_default()
260 .push(message);
261 Ok(())
262 }
263
264 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
265 let messages = self.messages.read().ok().context("lock poisoned")?;
266 Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
267 }
268
269 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
270 self.messages
271 .write()
272 .ok()
273 .context("lock poisoned")?
274 .remove(&thread_id.0);
275 Ok(())
276 }
277
278 async fn replace_history(
279 &self,
280 thread_id: &ThreadId,
281 messages: Vec<llm::Message>,
282 ) -> Result<()> {
283 self.messages
284 .write()
285 .ok()
286 .context("lock poisoned")?
287 .insert(thread_id.0.clone(), messages);
288 Ok(())
289 }
290}
291
292#[async_trait]
293impl StateStore for InMemoryStore {
294 async fn save(&self, state: &AgentState) -> Result<()> {
295 self.states
296 .write()
297 .ok()
298 .context("lock poisoned")?
299 .insert(state.thread_id.0.clone(), state.clone());
300 Ok(())
301 }
302
303 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
304 let states = self.states.read().ok().context("lock poisoned")?;
305 Ok(states.get(&thread_id.0).cloned())
306 }
307
308 async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
309 self.states
310 .write()
311 .ok()
312 .context("lock poisoned")?
313 .remove(&thread_id.0);
314 Ok(())
315 }
316}
317
318#[async_trait]
319impl EventStore for InMemoryEventStore {
320 async fn append(
321 &self,
322 thread_id: &ThreadId,
323 turn: usize,
324 envelope: AgentEventEnvelope,
325 ) -> Result<()> {
326 self.update_turn(thread_id, turn, |stored_turn| {
327 anyhow::ensure!(
328 !stored_turn.finished,
329 "cannot append to finished turn {turn}"
330 );
331 stored_turn.events.push(envelope);
332 Ok(())
333 })
334 .await
335 }
336
337 async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
338 self.update_turn(thread_id, turn, |stored_turn| {
339 anyhow::ensure!(!stored_turn.finished, "turn {turn} is already finished");
340 stored_turn.finished = true;
341 Ok(())
342 })
343 .await
344 }
345
346 async fn get_turn(
347 &self,
348 thread_id: &ThreadId,
349 turn: usize,
350 ) -> Result<Option<StoredTurnEvents>> {
351 let turns = self.inner.turns.read().await;
352 Ok(turns
353 .get(&thread_id.0)
354 .and_then(|thread_turns| thread_turns.get(&turn).cloned()))
355 }
356
357 async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
358 let turns = self.inner.turns.read().await;
359 Ok(turns
360 .get(&thread_id.0)
361 .map(|thread_turns| thread_turns.values().cloned().collect())
362 .unwrap_or_default())
363 }
364
365 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
366 {
367 let mut turns = self.inner.turns.write().await;
368 turns.remove(&thread_id.0);
369 }
370 Ok(())
371 }
372}
373
374#[derive(Default)]
379pub struct InMemoryExecutionStore {
380 executions: RwLock<HashMap<String, ToolExecution>>,
382 operation_index: RwLock<HashMap<String, String>>,
384}
385
386impl InMemoryExecutionStore {
387 #[must_use]
388 pub fn new() -> Self {
389 Self::default()
390 }
391}
392
393#[async_trait]
394impl ToolExecutionStore for InMemoryExecutionStore {
395 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
396 let executions = self.executions.read().ok().context("lock poisoned")?;
397 Ok(executions.get(tool_call_id).cloned())
398 }
399
400 async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
401 let tool_call_id = execution.tool_call_id.clone();
402 self.executions
403 .write()
404 .ok()
405 .context("lock poisoned")?
406 .insert(tool_call_id, execution);
407 Ok(())
408 }
409
410 async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
411 let tool_call_id = execution.tool_call_id.clone();
412
413 if let Some(ref op_id) = execution.operation_id {
415 self.operation_index
416 .write()
417 .ok()
418 .context("lock poisoned")?
419 .insert(op_id.clone(), tool_call_id.clone());
420 }
421
422 self.executions
423 .write()
424 .ok()
425 .context("lock poisoned")?
426 .insert(tool_call_id, execution);
427 Ok(())
428 }
429
430 async fn get_execution_by_operation_id(
431 &self,
432 operation_id: &str,
433 ) -> Result<Option<ToolExecution>> {
434 let tool_call_id = {
436 let op_index = self.operation_index.read().ok().context("lock poisoned")?;
437 op_index.get(operation_id).cloned()
438 };
439
440 let Some(tool_call_id) = tool_call_id else {
441 return Ok(None);
442 };
443
444 let executions = self.executions.read().ok().context("lock poisoned")?;
445 Ok(executions.get(&tool_call_id).cloned())
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use agent_sdk_foundation::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
453 use agent_sdk_foundation::llm::Message;
454 use agent_sdk_foundation::types::ToolResult;
455
456 #[tokio::test]
457 async fn test_in_memory_message_store() -> Result<()> {
458 let store = InMemoryStore::new();
459 let thread_id = ThreadId::new();
460
461 let history = store.get_history(&thread_id).await?;
463 assert!(history.is_empty());
464
465 store.append(&thread_id, Message::user("Hello")).await?;
467 store
468 .append(&thread_id, Message::assistant("Hi there!"))
469 .await?;
470
471 let history = store.get_history(&thread_id).await?;
473 assert_eq!(history.len(), 2);
474
475 let count = store.count(&thread_id).await?;
477 assert_eq!(count, 2);
478
479 store.clear(&thread_id).await?;
481 let history = store.get_history(&thread_id).await?;
482 assert!(history.is_empty());
483
484 Ok(())
485 }
486
487 #[tokio::test]
488 async fn test_replace_history() -> Result<()> {
489 let store = InMemoryStore::new();
490 let thread_id = ThreadId::new();
491
492 store.append(&thread_id, Message::user("Hello")).await?;
494 store
495 .append(&thread_id, Message::assistant("Hi there!"))
496 .await?;
497 store
498 .append(&thread_id, Message::user("How are you?"))
499 .await?;
500
501 let history = store.get_history(&thread_id).await?;
503 assert_eq!(history.len(), 3);
504
505 let new_history = vec![
507 Message::user("[Summary] Previous conversation about greetings"),
508 Message::assistant("I understand the context. Continuing..."),
509 ];
510 store.replace_history(&thread_id, new_history).await?;
511
512 let history = store.get_history(&thread_id).await?;
514 assert_eq!(history.len(), 2);
515
516 Ok(())
517 }
518
519 #[tokio::test]
520 async fn test_in_memory_state_store() -> Result<()> {
521 let store = InMemoryStore::new();
522 let thread_id = ThreadId::new();
523
524 let state = store.load(&thread_id).await?;
526 assert!(state.is_none());
527
528 let state = AgentState::new(thread_id.clone());
530 store.save(&state).await?;
531
532 let loaded = store.load(&thread_id).await?;
534 assert!(loaded.is_some());
535 if let Some(loaded_state) = loaded {
536 assert_eq!(loaded_state.thread_id, thread_id);
537 }
538
539 store.delete(&thread_id).await?;
541 let state = store.load(&thread_id).await?;
542 assert!(state.is_none());
543
544 Ok(())
545 }
546
547 #[tokio::test]
548 async fn test_in_memory_event_store_tracks_turns_and_finish_barrier() -> Result<()> {
549 let store = InMemoryEventStore::new();
550 let thread_id = ThreadId::new();
551 let seq = SequenceCounter::new();
552
553 store
554 .append(
555 &thread_id,
556 1,
557 AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "hello"), &seq),
558 )
559 .await?;
560 store
561 .append(
562 &thread_id,
563 2,
564 AgentEventEnvelope::wrap(AgentEvent::text("msg_2", "world"), &seq),
565 )
566 .await?;
567
568 let turn_1 = store
569 .get_turn(&thread_id, 1)
570 .await?
571 .context("missing turn 1")?;
572 assert_eq!(turn_1.turn, 1);
573 assert_eq!(turn_1.events.len(), 1);
574 assert!(!turn_1.finished);
575
576 store.finish_turn(&thread_id, 1).await?;
577 store.finish_turn(&thread_id, 2).await?;
578
579 let turn_1 = store
580 .get_turn(&thread_id, 1)
581 .await?
582 .context("missing finished turn 1")?;
583 let turn_2 = store
584 .get_turn(&thread_id, 2)
585 .await?
586 .context("missing finished turn 2")?;
587 assert!(turn_1.finished);
588 assert!(turn_2.finished);
589
590 let turns = store.get_turns(&thread_id).await?;
591 assert_eq!(turns.len(), 2);
592 assert_eq!(turns[0].turn, 1);
593 assert_eq!(turns[1].turn, 2);
594
595 Ok(())
596 }
597
598 #[tokio::test]
599 async fn test_in_memory_event_store_finish_turn_without_events_creates_finished_turn()
600 -> Result<()> {
601 let store = InMemoryEventStore::new();
602 let thread_id = ThreadId::new();
603
604 store.finish_turn(&thread_id, 3).await?;
605
606 let turn = store
607 .get_turn(&thread_id, 3)
608 .await?
609 .context("missing empty finished turn")?;
610 assert_eq!(turn.turn, 3);
611 assert!(turn.events.is_empty());
612 assert!(turn.finished);
613
614 store.clear(&thread_id).await?;
615 assert!(store.get_turns(&thread_id).await?.is_empty());
616
617 Ok(())
618 }
619
620 #[tokio::test]
621 async fn test_in_memory_event_store_rejects_append_after_finish() -> Result<()> {
622 let store = InMemoryEventStore::new();
623 let thread_id = ThreadId::new();
624 let seq = SequenceCounter::new();
625
626 store.finish_turn(&thread_id, 1).await?;
627
628 let error = store
629 .append(
630 &thread_id,
631 1,
632 AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "late"), &seq),
633 )
634 .await
635 .expect_err("append after finish should fail");
636
637 assert!(error.to_string().contains("cannot append to finished turn"));
638 Ok(())
639 }
640
641 #[tokio::test]
642 async fn test_in_memory_event_store_rejects_duplicate_finish() -> Result<()> {
643 let store = InMemoryEventStore::new();
644 let thread_id = ThreadId::new();
645
646 store.finish_turn(&thread_id, 1).await?;
647
648 let error = store
649 .finish_turn(&thread_id, 1)
650 .await
651 .expect_err("duplicate finish should fail");
652
653 assert!(error.to_string().contains("already finished"));
654 Ok(())
655 }
656
657 #[tokio::test]
658 async fn test_execution_store_basic_operations() -> Result<()> {
659 let store = InMemoryExecutionStore::new();
660 let thread_id = ThreadId::new();
661
662 let execution = store.get_execution("tool_call_123").await?;
664 assert!(execution.is_none());
665
666 let execution = ToolExecution::new_in_flight(
668 "tool_call_123",
669 thread_id.clone(),
670 "my_tool",
671 "My Tool",
672 serde_json::json!({"param": "value"}),
673 time::OffsetDateTime::now_utc(),
674 );
675 store.record_execution(execution).await?;
676
677 let loaded = store.get_execution("tool_call_123").await?;
679 assert!(loaded.is_some());
680 let loaded = loaded.expect("execution should exist");
681 assert_eq!(loaded.tool_call_id, "tool_call_123");
682 assert_eq!(loaded.tool_name, "my_tool");
683 assert!(loaded.is_in_flight());
684
685 Ok(())
686 }
687
688 #[tokio::test]
689 async fn test_execution_store_complete_execution() -> Result<()> {
690 let store = InMemoryExecutionStore::new();
691 let thread_id = ThreadId::new();
692
693 let mut execution = ToolExecution::new_in_flight(
695 "tool_call_456",
696 thread_id.clone(),
697 "my_tool",
698 "My Tool",
699 serde_json::json!({}),
700 time::OffsetDateTime::now_utc(),
701 );
702 store.record_execution(execution.clone()).await?;
703
704 execution.complete(ToolResult::success("Done!"));
706 store.update_execution(execution).await?;
707
708 let loaded = store.get_execution("tool_call_456").await?;
710 let loaded = loaded.expect("execution should exist");
711 assert!(loaded.is_completed());
712 assert!(loaded.result.is_some());
713 assert!(loaded.result.as_ref().is_some_and(|r| r.success));
714
715 Ok(())
716 }
717
718 #[tokio::test]
719 async fn test_execution_store_operation_id_lookup() -> Result<()> {
720 let store = InMemoryExecutionStore::new();
721 let thread_id = ThreadId::new();
722
723 let mut execution = ToolExecution::new_in_flight(
725 "tool_call_789",
726 thread_id.clone(),
727 "async_tool",
728 "Async Tool",
729 serde_json::json!({}),
730 time::OffsetDateTime::now_utc(),
731 );
732 execution.set_operation_id("op_abc123");
733 store.record_execution(execution.clone()).await?;
734 store.update_execution(execution).await?;
735
736 let loaded = store.get_execution_by_operation_id("op_abc123").await?;
738 assert!(loaded.is_some());
739 let loaded = loaded.expect("execution should exist");
740 assert_eq!(loaded.tool_call_id, "tool_call_789");
741 assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
742
743 let not_found = store.get_execution_by_operation_id("nonexistent").await?;
745 assert!(not_found.is_none());
746
747 Ok(())
748 }
749}