1use agent_sdk_foundation::events::AgentEventEnvelope;
17use agent_sdk_foundation::llm;
18use agent_sdk_foundation::types::{AgentState, ThreadId, ToolExecution};
19use anyhow::{Context, Result};
20use async_trait::async_trait;
21use std::collections::{BTreeMap, HashMap};
22use std::sync::Arc;
23use std::sync::RwLock;
24use tokio::sync::RwLock as AsyncRwLock;
25
26#[async_trait]
29pub trait MessageStore: Send + Sync {
30 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
35
36 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
41
42 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
47
48 async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
53 Ok(self.get_history(thread_id).await?.len())
54 }
55
56 async fn replace_history(
62 &self,
63 thread_id: &ThreadId,
64 messages: Vec<llm::Message>,
65 ) -> Result<()>;
66}
67
68#[async_trait]
71pub trait StateStore: Send + Sync {
72 async fn save(&self, state: &AgentState) -> Result<()>;
77
78 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
83
84 async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
89}
90
91#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
93pub struct StoredTurnEvents {
94 pub turn: usize,
96 pub events: Vec<AgentEventEnvelope>,
98 pub finished: bool,
100}
101
102#[async_trait]
108pub trait EventStore: Send + Sync {
109 async fn append(
114 &self,
115 thread_id: &ThreadId,
116 turn: usize,
117 envelope: AgentEventEnvelope,
118 ) -> Result<()>;
119
120 async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()>;
125
126 async fn get_turn(&self, thread_id: &ThreadId, turn: usize)
131 -> Result<Option<StoredTurnEvents>>;
132
133 async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>>;
138
139 async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
144 let turns = self.get_turns(thread_id).await?;
145 Ok(turns
146 .into_iter()
147 .flat_map(|turn| turn.events.into_iter())
148 .collect())
149 }
150
151 async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
162 Ok(self.get_events(thread_id).await?.len())
163 }
164
165 async fn get_events_since(
175 &self,
176 thread_id: &ThreadId,
177 offset: usize,
178 ) -> Result<Vec<AgentEventEnvelope>> {
179 Ok(self
180 .get_events(thread_id)
181 .await?
182 .into_iter()
183 .skip(offset)
184 .collect())
185 }
186
187 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
192}
193
194#[async_trait]
202pub trait ToolExecutionStore: Send + Sync {
203 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
208
209 async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
214
215 async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
220
221 async fn get_execution_by_operation_id(
226 &self,
227 operation_id: &str,
228 ) -> Result<Option<ToolExecution>>;
229}
230
231#[derive(Default)]
232struct InMemoryStoreInner {
233 messages: RwLock<HashMap<String, Vec<llm::Message>>>,
234 states: RwLock<HashMap<String, AgentState>>,
235}
236
237#[derive(Clone, Default)]
247pub struct InMemoryStore {
248 inner: Arc<InMemoryStoreInner>,
249}
250
251impl InMemoryStore {
252 #[must_use]
253 pub fn new() -> Self {
254 Self::default()
255 }
256}
257
258#[derive(Default)]
259struct InMemoryEventStoreInner {
260 turns: AsyncRwLock<HashMap<String, BTreeMap<usize, StoredTurnEvents>>>,
261}
262
263#[derive(Clone, Default)]
267pub struct InMemoryEventStore {
268 inner: Arc<InMemoryEventStoreInner>,
269}
270
271impl InMemoryEventStore {
272 #[must_use]
273 pub fn new() -> Self {
274 Self::default()
275 }
276
277 async fn update_turn(
278 &self,
279 thread_id: &ThreadId,
280 turn: usize,
281 update: impl FnOnce(&mut StoredTurnEvents) -> Result<()>,
282 ) -> Result<()> {
283 let mut turns = self.inner.turns.write().await;
284 let stored_turn = turns
285 .entry(thread_id.0.clone())
286 .or_default()
287 .entry(turn)
288 .or_insert_with(|| StoredTurnEvents {
289 turn,
290 events: Vec::new(),
291 finished: false,
292 });
293 let result = update(stored_turn);
294 drop(turns);
295 result
296 }
297}
298
299#[async_trait]
300impl MessageStore for InMemoryStore {
301 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
302 self.inner
303 .messages
304 .write()
305 .ok()
306 .context("lock poisoned")?
307 .entry(thread_id.0.clone())
308 .or_default()
309 .push(message);
310 Ok(())
311 }
312
313 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
314 let messages = self.inner.messages.read().ok().context("lock poisoned")?;
315 Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
316 }
317
318 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
319 self.inner
320 .messages
321 .write()
322 .ok()
323 .context("lock poisoned")?
324 .remove(&thread_id.0);
325 Ok(())
326 }
327
328 async fn replace_history(
329 &self,
330 thread_id: &ThreadId,
331 messages: Vec<llm::Message>,
332 ) -> Result<()> {
333 self.inner
334 .messages
335 .write()
336 .ok()
337 .context("lock poisoned")?
338 .insert(thread_id.0.clone(), messages);
339 Ok(())
340 }
341}
342
343#[async_trait]
344impl StateStore for InMemoryStore {
345 async fn save(&self, state: &AgentState) -> Result<()> {
346 self.inner
347 .states
348 .write()
349 .ok()
350 .context("lock poisoned")?
351 .insert(state.thread_id.0.clone(), state.clone());
352 Ok(())
353 }
354
355 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
356 let states = self.inner.states.read().ok().context("lock poisoned")?;
357 Ok(states.get(&thread_id.0).cloned())
358 }
359
360 async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
361 self.inner
362 .states
363 .write()
364 .ok()
365 .context("lock poisoned")?
366 .remove(&thread_id.0);
367 Ok(())
368 }
369}
370
371#[async_trait]
376impl<T: MessageStore + ?Sized> MessageStore for Arc<T> {
377 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
378 (**self).append(thread_id, message).await
379 }
380
381 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
382 (**self).get_history(thread_id).await
383 }
384
385 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
386 (**self).clear(thread_id).await
387 }
388
389 async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
390 (**self).count(thread_id).await
391 }
392
393 async fn replace_history(
394 &self,
395 thread_id: &ThreadId,
396 messages: Vec<llm::Message>,
397 ) -> Result<()> {
398 (**self).replace_history(thread_id, messages).await
399 }
400}
401
402#[async_trait]
403impl<T: StateStore + ?Sized> StateStore for Arc<T> {
404 async fn save(&self, state: &AgentState) -> Result<()> {
405 (**self).save(state).await
406 }
407
408 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
409 (**self).load(thread_id).await
410 }
411
412 async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
413 (**self).delete(thread_id).await
414 }
415}
416
417#[async_trait]
418impl EventStore for InMemoryEventStore {
419 async fn append(
420 &self,
421 thread_id: &ThreadId,
422 turn: usize,
423 envelope: AgentEventEnvelope,
424 ) -> Result<()> {
425 self.update_turn(thread_id, turn, |stored_turn| {
426 anyhow::ensure!(
427 !stored_turn.finished,
428 "cannot append to finished turn {turn}"
429 );
430 stored_turn.events.push(envelope);
431 Ok(())
432 })
433 .await
434 }
435
436 async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
437 self.update_turn(thread_id, turn, |stored_turn| {
438 anyhow::ensure!(!stored_turn.finished, "turn {turn} is already finished");
439 stored_turn.finished = true;
440 Ok(())
441 })
442 .await
443 }
444
445 async fn get_turn(
446 &self,
447 thread_id: &ThreadId,
448 turn: usize,
449 ) -> Result<Option<StoredTurnEvents>> {
450 let turns = self.inner.turns.read().await;
451 Ok(turns
452 .get(&thread_id.0)
453 .and_then(|thread_turns| thread_turns.get(&turn).cloned()))
454 }
455
456 async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
457 let turns = self.inner.turns.read().await;
458 Ok(turns
459 .get(&thread_id.0)
460 .map(|thread_turns| thread_turns.values().cloned().collect())
461 .unwrap_or_default())
462 }
463
464 async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
465 let turns = self.inner.turns.read().await;
467 Ok(turns.get(&thread_id.0).map_or(0, |thread_turns| {
468 thread_turns.values().map(|turn| turn.events.len()).sum()
469 }))
470 }
471
472 async fn get_events_since(
473 &self,
474 thread_id: &ThreadId,
475 offset: usize,
476 ) -> Result<Vec<AgentEventEnvelope>> {
477 let turns = self.inner.turns.read().await;
479 Ok(turns
480 .get(&thread_id.0)
481 .map(|thread_turns| {
482 thread_turns
483 .values()
484 .flat_map(|turn| turn.events.iter())
485 .skip(offset)
486 .cloned()
487 .collect()
488 })
489 .unwrap_or_default())
490 }
491
492 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
493 {
494 let mut turns = self.inner.turns.write().await;
495 turns.remove(&thread_id.0);
496 }
497 Ok(())
498 }
499}
500
501pub struct ObservingEventStore<S, F> {
525 inner: S,
526 observer: F,
527}
528
529impl<S, F> ObservingEventStore<S, F>
530where
531 S: EventStore,
532 F: Fn(&AgentEventEnvelope) + Send + Sync,
533{
534 #[must_use]
537 pub const fn new(inner: S, observer: F) -> Self {
538 Self { inner, observer }
539 }
540
541 #[must_use]
543 pub const fn inner(&self) -> &S {
544 &self.inner
545 }
546}
547
548#[async_trait]
549impl<S, F> EventStore for ObservingEventStore<S, F>
550where
551 S: EventStore,
552 F: Fn(&AgentEventEnvelope) + Send + Sync,
553{
554 async fn append(
555 &self,
556 thread_id: &ThreadId,
557 turn: usize,
558 envelope: AgentEventEnvelope,
559 ) -> Result<()> {
560 (self.observer)(&envelope);
561 self.inner.append(thread_id, turn, envelope).await
562 }
563
564 async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
565 self.inner.finish_turn(thread_id, turn).await
566 }
567
568 async fn get_turn(
569 &self,
570 thread_id: &ThreadId,
571 turn: usize,
572 ) -> Result<Option<StoredTurnEvents>> {
573 self.inner.get_turn(thread_id, turn).await
574 }
575
576 async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
577 self.inner.get_turns(thread_id).await
578 }
579
580 async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
581 self.inner.get_events(thread_id).await
582 }
583
584 async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
585 self.inner.event_count(thread_id).await
586 }
587
588 async fn get_events_since(
589 &self,
590 thread_id: &ThreadId,
591 offset: usize,
592 ) -> Result<Vec<AgentEventEnvelope>> {
593 self.inner.get_events_since(thread_id, offset).await
594 }
595
596 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
597 self.inner.clear(thread_id).await
598 }
599}
600
601#[derive(Default)]
606pub struct InMemoryExecutionStore {
607 executions: RwLock<HashMap<String, ToolExecution>>,
609 operation_index: RwLock<HashMap<String, String>>,
611}
612
613impl InMemoryExecutionStore {
614 #[must_use]
615 pub fn new() -> Self {
616 Self::default()
617 }
618}
619
620#[async_trait]
621impl ToolExecutionStore for InMemoryExecutionStore {
622 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
623 let executions = self.executions.read().ok().context("lock poisoned")?;
624 Ok(executions.get(tool_call_id).cloned())
625 }
626
627 async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
628 let tool_call_id = execution.tool_call_id.clone();
629 let operation_id = execution.operation_id.clone();
630
631 let mut executions = self.executions.write().ok().context("lock poisoned")?;
637 if let Some(op_id) = operation_id {
638 self.operation_index
639 .write()
640 .ok()
641 .context("lock poisoned")?
642 .insert(op_id, tool_call_id.clone());
643 }
644 executions.insert(tool_call_id, execution);
645 drop(executions);
646 Ok(())
647 }
648
649 async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
650 let tool_call_id = execution.tool_call_id.clone();
651 let new_operation_id = execution.operation_id.clone();
652
653 let mut executions = self.executions.write().ok().context("lock poisoned")?;
657
658 let stale_op_id = executions
661 .get(&tool_call_id)
662 .and_then(|prev| prev.operation_id.clone())
663 .filter(|prev| Some(prev) != new_operation_id.as_ref());
664 if stale_op_id.is_some() || new_operation_id.is_some() {
665 let mut op_index = self.operation_index.write().ok().context("lock poisoned")?;
666 if let Some(stale) = stale_op_id {
667 op_index.remove(&stale);
668 }
669 if let Some(op_id) = new_operation_id {
670 op_index.insert(op_id, tool_call_id.clone());
671 }
672 }
673 executions.insert(tool_call_id, execution);
674 drop(executions);
675 Ok(())
676 }
677
678 async fn get_execution_by_operation_id(
679 &self,
680 operation_id: &str,
681 ) -> Result<Option<ToolExecution>> {
682 let executions = self.executions.read().ok().context("lock poisoned")?;
687 let tool_call_id = {
688 let op_index = self.operation_index.read().ok().context("lock poisoned")?;
689 op_index.get(operation_id).cloned()
690 };
691 let Some(tool_call_id) = tool_call_id else {
692 return Ok(None);
693 };
694 Ok(executions.get(&tool_call_id).cloned())
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use agent_sdk_foundation::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
702 use agent_sdk_foundation::llm::Message;
703 use agent_sdk_foundation::types::ToolResult;
704
705 #[tokio::test]
706 async fn test_in_memory_message_store() -> Result<()> {
707 let store = InMemoryStore::new();
708 let thread_id = ThreadId::new();
709
710 let history = store.get_history(&thread_id).await?;
712 assert!(history.is_empty());
713
714 store.append(&thread_id, Message::user("Hello")).await?;
716 store
717 .append(&thread_id, Message::assistant("Hi there!"))
718 .await?;
719
720 let history = store.get_history(&thread_id).await?;
722 assert_eq!(history.len(), 2);
723
724 let count = store.count(&thread_id).await?;
726 assert_eq!(count, 2);
727
728 store.clear(&thread_id).await?;
730 let history = store.get_history(&thread_id).await?;
731 assert!(history.is_empty());
732
733 Ok(())
734 }
735
736 #[tokio::test]
737 async fn test_replace_history() -> Result<()> {
738 let store = InMemoryStore::new();
739 let thread_id = ThreadId::new();
740
741 store.append(&thread_id, Message::user("Hello")).await?;
743 store
744 .append(&thread_id, Message::assistant("Hi there!"))
745 .await?;
746 store
747 .append(&thread_id, Message::user("How are you?"))
748 .await?;
749
750 let history = store.get_history(&thread_id).await?;
752 assert_eq!(history.len(), 3);
753
754 let new_history = vec![
756 Message::user("[Summary] Previous conversation about greetings"),
757 Message::assistant("I understand the context. Continuing..."),
758 ];
759 store.replace_history(&thread_id, new_history).await?;
760
761 let history = store.get_history(&thread_id).await?;
763 assert_eq!(history.len(), 2);
764
765 Ok(())
766 }
767
768 #[tokio::test]
769 async fn test_in_memory_state_store() -> Result<()> {
770 let store = InMemoryStore::new();
771 let thread_id = ThreadId::new();
772
773 let state = store.load(&thread_id).await?;
775 assert!(state.is_none());
776
777 let state = AgentState::new(thread_id.clone());
779 store.save(&state).await?;
780
781 let loaded = store.load(&thread_id).await?;
783 assert!(loaded.is_some());
784 if let Some(loaded_state) = loaded {
785 assert_eq!(loaded_state.thread_id, thread_id);
786 }
787
788 store.delete(&thread_id).await?;
790 let state = store.load(&thread_id).await?;
791 assert!(state.is_none());
792
793 Ok(())
794 }
795
796 #[tokio::test]
797 async fn test_in_memory_event_store_tracks_turns_and_finish_barrier() -> Result<()> {
798 let store = InMemoryEventStore::new();
799 let thread_id = ThreadId::new();
800 let seq = SequenceCounter::new();
801
802 store
803 .append(
804 &thread_id,
805 1,
806 AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "hello"), &seq),
807 )
808 .await?;
809 store
810 .append(
811 &thread_id,
812 2,
813 AgentEventEnvelope::wrap(AgentEvent::text("msg_2", "world"), &seq),
814 )
815 .await?;
816
817 let turn_1 = store
818 .get_turn(&thread_id, 1)
819 .await?
820 .context("missing turn 1")?;
821 assert_eq!(turn_1.turn, 1);
822 assert_eq!(turn_1.events.len(), 1);
823 assert!(!turn_1.finished);
824
825 store.finish_turn(&thread_id, 1).await?;
826 store.finish_turn(&thread_id, 2).await?;
827
828 let turn_1 = store
829 .get_turn(&thread_id, 1)
830 .await?
831 .context("missing finished turn 1")?;
832 let turn_2 = store
833 .get_turn(&thread_id, 2)
834 .await?
835 .context("missing finished turn 2")?;
836 assert!(turn_1.finished);
837 assert!(turn_2.finished);
838
839 let turns = store.get_turns(&thread_id).await?;
840 assert_eq!(turns.len(), 2);
841 assert_eq!(turns[0].turn, 1);
842 assert_eq!(turns[1].turn, 2);
843
844 Ok(())
845 }
846
847 #[tokio::test]
848 async fn test_in_memory_event_store_finish_turn_without_events_creates_finished_turn()
849 -> Result<()> {
850 let store = InMemoryEventStore::new();
851 let thread_id = ThreadId::new();
852
853 store.finish_turn(&thread_id, 3).await?;
854
855 let turn = store
856 .get_turn(&thread_id, 3)
857 .await?
858 .context("missing empty finished turn")?;
859 assert_eq!(turn.turn, 3);
860 assert!(turn.events.is_empty());
861 assert!(turn.finished);
862
863 store.clear(&thread_id).await?;
864 assert!(store.get_turns(&thread_id).await?.is_empty());
865
866 Ok(())
867 }
868
869 #[tokio::test]
870 async fn test_in_memory_event_store_rejects_append_after_finish() -> Result<()> {
871 let store = InMemoryEventStore::new();
872 let thread_id = ThreadId::new();
873 let seq = SequenceCounter::new();
874
875 store.finish_turn(&thread_id, 1).await?;
876
877 let error = store
878 .append(
879 &thread_id,
880 1,
881 AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "late"), &seq),
882 )
883 .await
884 .expect_err("append after finish should fail");
885
886 assert!(error.to_string().contains("cannot append to finished turn"));
887 Ok(())
888 }
889
890 #[tokio::test]
891 async fn test_in_memory_event_store_rejects_duplicate_finish() -> Result<()> {
892 let store = InMemoryEventStore::new();
893 let thread_id = ThreadId::new();
894
895 store.finish_turn(&thread_id, 1).await?;
896
897 let error = store
898 .finish_turn(&thread_id, 1)
899 .await
900 .expect_err("duplicate finish should fail");
901
902 assert!(error.to_string().contains("already finished"));
903 Ok(())
904 }
905
906 #[tokio::test]
907 async fn test_execution_store_basic_operations() -> Result<()> {
908 let store = InMemoryExecutionStore::new();
909 let thread_id = ThreadId::new();
910
911 let execution = store.get_execution("tool_call_123").await?;
913 assert!(execution.is_none());
914
915 let execution = ToolExecution::new_in_flight(
917 "tool_call_123",
918 thread_id.clone(),
919 "my_tool",
920 "My Tool",
921 serde_json::json!({"param": "value"}),
922 time::OffsetDateTime::now_utc(),
923 );
924 store.record_execution(execution).await?;
925
926 let loaded = store.get_execution("tool_call_123").await?;
928 assert!(loaded.is_some());
929 let loaded = loaded.expect("execution should exist");
930 assert_eq!(loaded.tool_call_id, "tool_call_123");
931 assert_eq!(loaded.tool_name, "my_tool");
932 assert!(loaded.is_in_flight());
933
934 Ok(())
935 }
936
937 #[tokio::test]
938 async fn test_execution_store_complete_execution() -> Result<()> {
939 let store = InMemoryExecutionStore::new();
940 let thread_id = ThreadId::new();
941
942 let mut execution = ToolExecution::new_in_flight(
944 "tool_call_456",
945 thread_id.clone(),
946 "my_tool",
947 "My Tool",
948 serde_json::json!({}),
949 time::OffsetDateTime::now_utc(),
950 );
951 store.record_execution(execution.clone()).await?;
952
953 execution.complete(ToolResult::success("Done!"));
955 store.update_execution(execution).await?;
956
957 let loaded = store.get_execution("tool_call_456").await?;
959 let loaded = loaded.expect("execution should exist");
960 assert!(loaded.is_completed());
961 assert!(loaded.result.is_some());
962 assert!(loaded.result.as_ref().is_some_and(|r| r.success));
963
964 Ok(())
965 }
966
967 #[tokio::test]
968 async fn test_execution_store_operation_id_lookup() -> Result<()> {
969 let store = InMemoryExecutionStore::new();
970 let thread_id = ThreadId::new();
971
972 let mut execution = ToolExecution::new_in_flight(
974 "tool_call_789",
975 thread_id.clone(),
976 "async_tool",
977 "Async Tool",
978 serde_json::json!({}),
979 time::OffsetDateTime::now_utc(),
980 );
981 execution.set_operation_id("op_abc123");
982 store.record_execution(execution.clone()).await?;
983 store.update_execution(execution).await?;
984
985 let loaded = store.get_execution_by_operation_id("op_abc123").await?;
987 assert!(loaded.is_some());
988 let loaded = loaded.expect("execution should exist");
989 assert_eq!(loaded.tool_call_id, "tool_call_789");
990 assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
991
992 let not_found = store.get_execution_by_operation_id("nonexistent").await?;
994 assert!(not_found.is_none());
995
996 Ok(())
997 }
998
999 #[tokio::test]
1000 async fn in_memory_store_clone_shares_history() -> Result<()> {
1001 let store = InMemoryStore::new();
1004 let handle = store.clone();
1005 let thread_id = ThreadId::new();
1006
1007 store.append(&thread_id, Message::user("hello")).await?;
1008
1009 let history = handle.get_history(&thread_id).await?;
1010 assert_eq!(
1011 history.len(),
1012 1,
1013 "clone must observe appends via the original"
1014 );
1015 Ok(())
1016 }
1017
1018 #[tokio::test]
1019 async fn arc_store_blanket_impls_forward() -> Result<()> {
1020 let store: Arc<InMemoryStore> = Arc::new(InMemoryStore::new());
1021 let thread_id = ThreadId::new();
1022
1023 MessageStore::append(&store, &thread_id, Message::user("hi")).await?;
1025 assert_eq!(MessageStore::count(&store, &thread_id).await?, 1);
1026
1027 let state = AgentState::new(thread_id.clone());
1028 StateStore::save(&store, &state).await?;
1029 assert!(StateStore::load(&store, &thread_id).await?.is_some());
1030
1031 assert_eq!(store.get_history(&thread_id).await?.len(), 1);
1033 Ok(())
1034 }
1035
1036 #[tokio::test]
1037 async fn event_count_and_get_events_since_are_incremental() -> Result<()> {
1038 let store = InMemoryEventStore::new();
1039 let thread_id = ThreadId::new();
1040 let seq = SequenceCounter::new();
1041
1042 assert_eq!(store.event_count(&thread_id).await?, 0);
1043
1044 for (turn, (id, text)) in [(1, ("m1", "a")), (1, ("m2", "b")), (2, ("m3", "c"))] {
1045 store
1046 .append(
1047 &thread_id,
1048 turn,
1049 AgentEventEnvelope::wrap(AgentEvent::text(id, text), &seq),
1050 )
1051 .await?;
1052 }
1053
1054 assert_eq!(store.event_count(&thread_id).await?, 3);
1055
1056 let tail = store.get_events_since(&thread_id, 1).await?;
1057 assert_eq!(tail.len(), 2, "should skip the first event");
1058 let all = store.get_events(&thread_id).await?;
1060 assert_eq!(all.len(), 3);
1061 Ok(())
1062 }
1063
1064 #[tokio::test]
1065 async fn record_execution_indexes_operation_id_immediately() -> Result<()> {
1066 let store = InMemoryExecutionStore::new();
1067 let thread_id = ThreadId::new();
1068
1069 let mut execution = ToolExecution::new_in_flight(
1070 "call_1",
1071 thread_id,
1072 "async_tool",
1073 "Async Tool",
1074 serde_json::json!({}),
1075 time::OffsetDateTime::now_utc(),
1076 );
1077 execution.set_operation_id("op_1");
1078 store.record_execution(execution).await?;
1080
1081 let loaded = store.get_execution_by_operation_id("op_1").await?;
1082 assert_eq!(
1083 loaded
1084 .context("write-ahead operation_id must resolve")?
1085 .tool_call_id,
1086 "call_1"
1087 );
1088 Ok(())
1089 }
1090
1091 #[tokio::test]
1092 async fn update_execution_removes_stale_operation_id() -> Result<()> {
1093 let store = InMemoryExecutionStore::new();
1094 let thread_id = ThreadId::new();
1095
1096 let mut execution = ToolExecution::new_in_flight(
1097 "call_2",
1098 thread_id,
1099 "async_tool",
1100 "Async Tool",
1101 serde_json::json!({}),
1102 time::OffsetDateTime::now_utc(),
1103 );
1104 execution.set_operation_id("op_old");
1105 store.record_execution(execution.clone()).await?;
1106
1107 execution.set_operation_id("op_new");
1109 store.update_execution(execution).await?;
1110
1111 assert!(
1112 store
1113 .get_execution_by_operation_id("op_old")
1114 .await?
1115 .is_none(),
1116 "superseded operation_id must stop resolving"
1117 );
1118 let loaded = store.get_execution_by_operation_id("op_new").await?;
1119 assert_eq!(
1120 loaded
1121 .context("new operation_id must resolve")?
1122 .tool_call_id,
1123 "call_2"
1124 );
1125 Ok(())
1126 }
1127
1128 #[tokio::test]
1129 async fn observing_event_store_invokes_callback_and_delegates() -> Result<()> {
1130 use std::sync::atomic::{AtomicUsize, Ordering};
1131
1132 let seen = Arc::new(AtomicUsize::new(0));
1133 let seen_for_cb = Arc::clone(&seen);
1134 let store = ObservingEventStore::new(InMemoryEventStore::new(), move |_envelope| {
1135 seen_for_cb.fetch_add(1, Ordering::SeqCst);
1136 });
1137 let thread_id = ThreadId::new();
1138 let seq = SequenceCounter::new();
1139
1140 store
1141 .append(
1142 &thread_id,
1143 1,
1144 AgentEventEnvelope::wrap(AgentEvent::text("m1", "hi"), &seq),
1145 )
1146 .await?;
1147 store
1148 .append(
1149 &thread_id,
1150 1,
1151 AgentEventEnvelope::wrap(AgentEvent::text("m2", "yo"), &seq),
1152 )
1153 .await?;
1154
1155 assert_eq!(seen.load(Ordering::SeqCst), 2, "observer runs per append");
1156 assert_eq!(store.get_events(&thread_id).await?.len(), 2);
1158 assert_eq!(store.inner().get_events(&thread_id).await?.len(), 2);
1159 Ok(())
1160 }
1161}