Skip to main content

agent_sdk_tools/
stores.rs

1//! Storage traits for message history, agent state, and event persistence.
2//!
3//! The SDK uses three storage abstractions:
4//!
5//! - [`MessageStore`] - Stores conversation message history per thread
6//! - [`StateStore`] - Stores agent state checkpoints for recovery
7//! - [`EventStore`] - Stores turn-scoped event envelopes for retrieval
8//!
9//! # Built-in Implementation
10//!
11//! [`InMemoryStore`] implements the message/state traits and is suitable for
12//! testing and single-process deployments. [`InMemoryEventStore`] provides the
13//! corresponding in-memory event journal. For production, implement custom
14//! stores backed by your database (e.g., Postgres, Redis).
15
16use agent_sdk_foundation::events::AgentEventEnvelope;
17use agent_sdk_foundation::llm;
18use agent_sdk_foundation::types::{AgentState, ThreadId, ToolExecution};
19use anyhow::{Context, Result};
20use async_trait::async_trait;
21use std::collections::{BTreeMap, HashMap};
22use std::sync::Arc;
23use std::sync::RwLock;
24use tokio::sync::RwLock as AsyncRwLock;
25
26/// Trait for storing and retrieving conversation messages.
27/// Implement this trait to persist messages to your storage backend.
28#[async_trait]
29pub trait MessageStore: Send + Sync {
30    /// Append a message to the thread's history
31    ///
32    /// # Errors
33    /// Returns an error if the message cannot be stored.
34    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
35
36    /// Get all messages for a thread
37    ///
38    /// # Errors
39    /// Returns an error if the history cannot be retrieved.
40    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
41
42    /// Clear all messages for a thread
43    ///
44    /// # Errors
45    /// Returns an error if the messages cannot be cleared.
46    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
47
48    /// Get the message count for a thread
49    ///
50    /// # Errors
51    /// Returns an error if the count cannot be retrieved.
52    async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
53        Ok(self.get_history(thread_id).await?.len())
54    }
55
56    /// Replace the entire message history for a thread.
57    /// Used for context compaction to replace old messages with a summary.
58    ///
59    /// # Errors
60    /// Returns an error if the history cannot be replaced.
61    async fn replace_history(
62        &self,
63        thread_id: &ThreadId,
64        messages: Vec<llm::Message>,
65    ) -> Result<()>;
66}
67
68/// Trait for storing agent state checkpoints.
69/// Implement this to enable conversation recovery and resume.
70#[async_trait]
71pub trait StateStore: Send + Sync {
72    /// Save the current agent state
73    ///
74    /// # Errors
75    /// Returns an error if the state cannot be saved.
76    async fn save(&self, state: &AgentState) -> Result<()>;
77
78    /// Load the most recent state for a thread
79    ///
80    /// # Errors
81    /// Returns an error if the state cannot be loaded.
82    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
83
84    /// Delete state for a thread
85    ///
86    /// # Errors
87    /// Returns an error if the state cannot be deleted.
88    async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
89}
90
91/// Stored event data for a single turn.
92#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
93pub struct StoredTurnEvents {
94    /// Turn number (1-based once execution starts).
95    pub turn: usize,
96    /// Events emitted for this turn.
97    pub events: Vec<AgentEventEnvelope>,
98    /// Whether `finish_turn()` has completed for this turn.
99    pub finished: bool,
100}
101
102/// Trait for storing and retrieving turn-scoped event streams.
103///
104/// Event writes are split into two phases:
105/// 1. [`append`](EventStore::append) records individual envelopes
106/// 2. [`finish_turn`](EventStore::finish_turn) marks the authoritative close barrier
107#[async_trait]
108pub trait EventStore: Send + Sync {
109    /// Append an event envelope for the given thread and turn.
110    ///
111    /// # Errors
112    /// Returns an error if the event cannot be persisted.
113    async fn append(
114        &self,
115        thread_id: &ThreadId,
116        turn: usize,
117        envelope: AgentEventEnvelope,
118    ) -> Result<()>;
119
120    /// Mark the given turn as finished and flush any buffered writes.
121    ///
122    /// # Errors
123    /// Returns an error if the store cannot durably close the turn.
124    async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()>;
125
126    /// Retrieve the stored data for a single turn.
127    ///
128    /// # Errors
129    /// Returns an error if the turn cannot be retrieved.
130    async fn get_turn(&self, thread_id: &ThreadId, turn: usize)
131    -> Result<Option<StoredTurnEvents>>;
132
133    /// Retrieve all stored turns for the given thread in ascending turn order.
134    ///
135    /// # Errors
136    /// Returns an error if the thread history cannot be retrieved.
137    async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>>;
138
139    /// Retrieve all event envelopes for the given thread across every stored turn.
140    ///
141    /// # Errors
142    /// Returns an error if the thread history cannot be retrieved.
143    async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
144        let turns = self.get_turns(thread_id).await?;
145        Ok(turns
146            .into_iter()
147            .flat_map(|turn| turn.events.into_iter())
148            .collect())
149    }
150
151    /// Count the stored events for `thread_id` without materializing them.
152    ///
153    /// The default falls back to the length of
154    /// [`get_events`](EventStore::get_events) (which clones the whole history);
155    /// stores that can answer cheaply should override this. Callers that only
156    /// need a baseline count — e.g. to read just the new events after a turn —
157    /// should prefer this over `get_events(..).len()`.
158    ///
159    /// # Errors
160    /// Returns an error if the count cannot be retrieved.
161    async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
162        Ok(self.get_events(thread_id).await?.len())
163    }
164
165    /// Retrieve event envelopes for `thread_id` from `offset` onward, in overall
166    /// append order, skipping the earlier ones.
167    ///
168    /// Lets incremental readers avoid re-cloning the whole history each call.
169    /// The default slices [`get_events`](EventStore::get_events); stores with a
170    /// cheaper access path should override.
171    ///
172    /// # Errors
173    /// Returns an error if the events cannot be retrieved.
174    async fn get_events_since(
175        &self,
176        thread_id: &ThreadId,
177        offset: usize,
178    ) -> Result<Vec<AgentEventEnvelope>> {
179        Ok(self
180            .get_events(thread_id)
181            .await?
182            .into_iter()
183            .skip(offset)
184            .collect())
185    }
186
187    /// Clear all events for the given thread.
188    ///
189    /// # Errors
190    /// Returns an error if the thread cannot be cleared.
191    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
192}
193
194/// Store for tracking tool executions (idempotency).
195///
196/// This trait enables write-ahead execution tracking to ensure tool idempotency.
197/// The pattern is:
198/// 1. Record execution intent BEFORE calling the tool (`record_execution`)
199/// 2. Update with result AFTER completion (`update_execution`)
200/// 3. On retry, check if execution exists and return cached result
201#[async_trait]
202pub trait ToolExecutionStore: Send + Sync {
203    /// Get an execution by `tool_call_id`.
204    ///
205    /// # Errors
206    /// Returns an error if the execution cannot be retrieved.
207    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
208
209    /// Record a new execution (write-ahead, before calling tool).
210    ///
211    /// # Errors
212    /// Returns an error if the execution cannot be recorded.
213    async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
214
215    /// Update an existing execution (after completion or to set `operation_id`).
216    ///
217    /// # Errors
218    /// Returns an error if the execution cannot be updated.
219    async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
220
221    /// Get execution by `operation_id` (for async tool resume).
222    ///
223    /// # Errors
224    /// Returns an error if the execution cannot be retrieved.
225    async fn get_execution_by_operation_id(
226        &self,
227        operation_id: &str,
228    ) -> Result<Option<ToolExecution>>;
229}
230
231#[derive(Default)]
232struct InMemoryStoreInner {
233    messages: RwLock<HashMap<String, Vec<llm::Message>>>,
234    states: RwLock<HashMap<String, AgentState>>,
235}
236
237/// In-memory implementation of `MessageStore` and `StateStore`.
238/// Useful for testing and simple use cases.
239///
240/// Cloning shares the same underlying message/state maps (mirroring
241/// [`InMemoryEventStore`]'s shared-journal semantics). This matters because the
242/// agent builder takes its stores **by value**: hand the builder a clone and
243/// keep the original, and the kept handle still observes everything the agent
244/// records. Without shared handles, history written through the builder's copy
245/// would be permanently unreachable to the caller.
246#[derive(Clone, Default)]
247pub struct InMemoryStore {
248    inner: Arc<InMemoryStoreInner>,
249}
250
251impl InMemoryStore {
252    #[must_use]
253    pub fn new() -> Self {
254        Self::default()
255    }
256}
257
258#[derive(Default)]
259struct InMemoryEventStoreInner {
260    turns: AsyncRwLock<HashMap<String, BTreeMap<usize, StoredTurnEvents>>>,
261}
262
263/// In-memory implementation of [`EventStore`].
264///
265/// Cloning this type shares the same underlying event journal.
266#[derive(Clone, Default)]
267pub struct InMemoryEventStore {
268    inner: Arc<InMemoryEventStoreInner>,
269}
270
271impl InMemoryEventStore {
272    #[must_use]
273    pub fn new() -> Self {
274        Self::default()
275    }
276
277    async fn update_turn(
278        &self,
279        thread_id: &ThreadId,
280        turn: usize,
281        update: impl FnOnce(&mut StoredTurnEvents) -> Result<()>,
282    ) -> Result<()> {
283        let mut turns = self.inner.turns.write().await;
284        let stored_turn = turns
285            .entry(thread_id.0.clone())
286            .or_default()
287            .entry(turn)
288            .or_insert_with(|| StoredTurnEvents {
289                turn,
290                events: Vec::new(),
291                finished: false,
292            });
293        let result = update(stored_turn);
294        drop(turns);
295        result
296    }
297}
298
299#[async_trait]
300impl MessageStore for InMemoryStore {
301    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
302        self.inner
303            .messages
304            .write()
305            .ok()
306            .context("lock poisoned")?
307            .entry(thread_id.0.clone())
308            .or_default()
309            .push(message);
310        Ok(())
311    }
312
313    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
314        let messages = self.inner.messages.read().ok().context("lock poisoned")?;
315        Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
316    }
317
318    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
319        self.inner
320            .messages
321            .write()
322            .ok()
323            .context("lock poisoned")?
324            .remove(&thread_id.0);
325        Ok(())
326    }
327
328    async fn replace_history(
329        &self,
330        thread_id: &ThreadId,
331        messages: Vec<llm::Message>,
332    ) -> Result<()> {
333        self.inner
334            .messages
335            .write()
336            .ok()
337            .context("lock poisoned")?
338            .insert(thread_id.0.clone(), messages);
339        Ok(())
340    }
341}
342
343#[async_trait]
344impl StateStore for InMemoryStore {
345    async fn save(&self, state: &AgentState) -> Result<()> {
346        self.inner
347            .states
348            .write()
349            .ok()
350            .context("lock poisoned")?
351            .insert(state.thread_id.0.clone(), state.clone());
352        Ok(())
353    }
354
355    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
356        let states = self.inner.states.read().ok().context("lock poisoned")?;
357        Ok(states.get(&thread_id.0).cloned())
358    }
359
360    async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
361        self.inner
362            .states
363            .write()
364            .ok()
365            .context("lock poisoned")?
366            .remove(&thread_id.0);
367        Ok(())
368    }
369}
370
371// Blanket impls so a shared `Arc<Store>` is itself a `MessageStore` /
372// `StateStore`. This lets callers keep a readable handle after handing the
373// store to the agent builder (which takes stores by value), without forcing
374// every store type to be `Clone`.
375#[async_trait]
376impl<T: MessageStore + ?Sized> MessageStore for Arc<T> {
377    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
378        (**self).append(thread_id, message).await
379    }
380
381    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
382        (**self).get_history(thread_id).await
383    }
384
385    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
386        (**self).clear(thread_id).await
387    }
388
389    async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
390        (**self).count(thread_id).await
391    }
392
393    async fn replace_history(
394        &self,
395        thread_id: &ThreadId,
396        messages: Vec<llm::Message>,
397    ) -> Result<()> {
398        (**self).replace_history(thread_id, messages).await
399    }
400}
401
402#[async_trait]
403impl<T: StateStore + ?Sized> StateStore for Arc<T> {
404    async fn save(&self, state: &AgentState) -> Result<()> {
405        (**self).save(state).await
406    }
407
408    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
409        (**self).load(thread_id).await
410    }
411
412    async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
413        (**self).delete(thread_id).await
414    }
415}
416
417#[async_trait]
418impl EventStore for InMemoryEventStore {
419    async fn append(
420        &self,
421        thread_id: &ThreadId,
422        turn: usize,
423        envelope: AgentEventEnvelope,
424    ) -> Result<()> {
425        self.update_turn(thread_id, turn, |stored_turn| {
426            anyhow::ensure!(
427                !stored_turn.finished,
428                "cannot append to finished turn {turn}"
429            );
430            stored_turn.events.push(envelope);
431            Ok(())
432        })
433        .await
434    }
435
436    async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
437        self.update_turn(thread_id, turn, |stored_turn| {
438            anyhow::ensure!(!stored_turn.finished, "turn {turn} is already finished");
439            stored_turn.finished = true;
440            Ok(())
441        })
442        .await
443    }
444
445    async fn get_turn(
446        &self,
447        thread_id: &ThreadId,
448        turn: usize,
449    ) -> Result<Option<StoredTurnEvents>> {
450        let turns = self.inner.turns.read().await;
451        Ok(turns
452            .get(&thread_id.0)
453            .and_then(|thread_turns| thread_turns.get(&turn).cloned()))
454    }
455
456    async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
457        let turns = self.inner.turns.read().await;
458        Ok(turns
459            .get(&thread_id.0)
460            .map(|thread_turns| thread_turns.values().cloned().collect())
461            .unwrap_or_default())
462    }
463
464    async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
465        // Sum the per-turn lengths under the read lock — no envelope is cloned.
466        let turns = self.inner.turns.read().await;
467        Ok(turns.get(&thread_id.0).map_or(0, |thread_turns| {
468            thread_turns.values().map(|turn| turn.events.len()).sum()
469        }))
470    }
471
472    async fn get_events_since(
473        &self,
474        thread_id: &ThreadId,
475        offset: usize,
476    ) -> Result<Vec<AgentEventEnvelope>> {
477        // Only the requested tail is cloned, not the whole history.
478        let turns = self.inner.turns.read().await;
479        Ok(turns
480            .get(&thread_id.0)
481            .map(|thread_turns| {
482                thread_turns
483                    .values()
484                    .flat_map(|turn| turn.events.iter())
485                    .skip(offset)
486                    .cloned()
487                    .collect()
488            })
489            .unwrap_or_default())
490    }
491
492    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
493        {
494            let mut turns = self.inner.turns.write().await;
495            turns.remove(&thread_id.0);
496        }
497        Ok(())
498    }
499}
500
501/// An [`EventStore`] decorator that invokes a callback on every appended
502/// envelope, then delegates all storage to an inner store.
503///
504/// This is the reusable, blessed way to "stream to stdout" (or to any live
505/// observer) with the SDK: the agent loop writes every [`AgentEventEnvelope`]
506/// through the configured event store, so wrapping a store lets you watch
507/// events as they happen — printing `TextDelta`s, forwarding to a UI channel —
508/// without hand-rolling the full five-method [`EventStore`] surface or wiring an
509/// in-process channel. The callback runs before the inner store records the
510/// envelope.
511///
512/// # Example
513///
514/// ```
515/// use agent_sdk_tools::stores::{InMemoryEventStore, ObservingEventStore};
516/// use agent_sdk_foundation::events::AgentEvent;
517///
518/// let _store = ObservingEventStore::new(InMemoryEventStore::new(), |envelope| {
519///     if let AgentEvent::TextDelta { delta, .. } = &envelope.event {
520///         print!("{delta}");
521///     }
522/// });
523/// ```
524pub struct ObservingEventStore<S, F> {
525    inner: S,
526    observer: F,
527}
528
529impl<S, F> ObservingEventStore<S, F>
530where
531    S: EventStore,
532    F: Fn(&AgentEventEnvelope) + Send + Sync,
533{
534    /// Wrap `inner`, calling `observer` on every appended envelope before it is
535    /// persisted.
536    #[must_use]
537    pub const fn new(inner: S, observer: F) -> Self {
538        Self { inner, observer }
539    }
540
541    /// Borrow the wrapped inner store (e.g. to read back persisted history).
542    #[must_use]
543    pub const fn inner(&self) -> &S {
544        &self.inner
545    }
546}
547
548#[async_trait]
549impl<S, F> EventStore for ObservingEventStore<S, F>
550where
551    S: EventStore,
552    F: Fn(&AgentEventEnvelope) + Send + Sync,
553{
554    async fn append(
555        &self,
556        thread_id: &ThreadId,
557        turn: usize,
558        envelope: AgentEventEnvelope,
559    ) -> Result<()> {
560        (self.observer)(&envelope);
561        self.inner.append(thread_id, turn, envelope).await
562    }
563
564    async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
565        self.inner.finish_turn(thread_id, turn).await
566    }
567
568    async fn get_turn(
569        &self,
570        thread_id: &ThreadId,
571        turn: usize,
572    ) -> Result<Option<StoredTurnEvents>> {
573        self.inner.get_turn(thread_id, turn).await
574    }
575
576    async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
577        self.inner.get_turns(thread_id).await
578    }
579
580    async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
581        self.inner.get_events(thread_id).await
582    }
583
584    async fn event_count(&self, thread_id: &ThreadId) -> Result<usize> {
585        self.inner.event_count(thread_id).await
586    }
587
588    async fn get_events_since(
589        &self,
590        thread_id: &ThreadId,
591        offset: usize,
592    ) -> Result<Vec<AgentEventEnvelope>> {
593        self.inner.get_events_since(thread_id, offset).await
594    }
595
596    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
597        self.inner.clear(thread_id).await
598    }
599}
600
601/// In-memory implementation of `ToolExecutionStore`.
602///
603/// Useful for testing and simple use cases where durability is not required.
604/// For production, implement a custom store backed by a database.
605#[derive(Default)]
606pub struct InMemoryExecutionStore {
607    /// Executions indexed by `tool_call_id`
608    executions: RwLock<HashMap<String, ToolExecution>>,
609    /// Index from `operation_id` to `tool_call_id` for async tool lookup
610    operation_index: RwLock<HashMap<String, String>>,
611}
612
613impl InMemoryExecutionStore {
614    #[must_use]
615    pub fn new() -> Self {
616        Self::default()
617    }
618}
619
620#[async_trait]
621impl ToolExecutionStore for InMemoryExecutionStore {
622    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
623        let executions = self.executions.read().ok().context("lock poisoned")?;
624        Ok(executions.get(tool_call_id).cloned())
625    }
626
627    async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
628        let tool_call_id = execution.tool_call_id.clone();
629        let operation_id = execution.operation_id.clone();
630
631        // Hold the executions write lock for the whole insert. Readers acquire
632        // executions first (the global executions -> operation_index lock
633        // order), so they cannot observe a half-written record. Indexing the
634        // operation_id here (not only in `update_execution`) means a write-ahead
635        // record is resolvable by `get_execution_by_operation_id` immediately.
636        let mut executions = self.executions.write().ok().context("lock poisoned")?;
637        if let Some(op_id) = operation_id {
638            self.operation_index
639                .write()
640                .ok()
641                .context("lock poisoned")?
642                .insert(op_id, tool_call_id.clone());
643        }
644        executions.insert(tool_call_id, execution);
645        drop(executions);
646        Ok(())
647    }
648
649    async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
650        let tool_call_id = execution.tool_call_id.clone();
651        let new_operation_id = execution.operation_id.clone();
652
653        // Hold the executions write lock across the whole update so a concurrent
654        // reader (which gates on executions first) can never observe the new
655        // index entry against a stale execution.
656        let mut executions = self.executions.write().ok().context("lock poisoned")?;
657
658        // Drop a superseded operation_id index entry when the id changes, so a
659        // stale id stops resolving instead of pointing forever at this call.
660        let stale_op_id = executions
661            .get(&tool_call_id)
662            .and_then(|prev| prev.operation_id.clone())
663            .filter(|prev| Some(prev) != new_operation_id.as_ref());
664        if stale_op_id.is_some() || new_operation_id.is_some() {
665            let mut op_index = self.operation_index.write().ok().context("lock poisoned")?;
666            if let Some(stale) = stale_op_id {
667                op_index.remove(&stale);
668            }
669            if let Some(op_id) = new_operation_id {
670                op_index.insert(op_id, tool_call_id.clone());
671            }
672        }
673        executions.insert(tool_call_id, execution);
674        drop(executions);
675        Ok(())
676    }
677
678    async fn get_execution_by_operation_id(
679        &self,
680        operation_id: &str,
681    ) -> Result<Option<ToolExecution>> {
682        // Acquire executions first (the global executions -> operation_index
683        // lock order) and hold it while resolving the id, so this reader can
684        // neither deadlock against nor observe a partial write from a concurrent
685        // record/update.
686        let executions = self.executions.read().ok().context("lock poisoned")?;
687        let tool_call_id = {
688            let op_index = self.operation_index.read().ok().context("lock poisoned")?;
689            op_index.get(operation_id).cloned()
690        };
691        let Some(tool_call_id) = tool_call_id else {
692            return Ok(None);
693        };
694        Ok(executions.get(&tool_call_id).cloned())
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701    use agent_sdk_foundation::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
702    use agent_sdk_foundation::llm::Message;
703    use agent_sdk_foundation::types::ToolResult;
704
705    #[tokio::test]
706    async fn test_in_memory_message_store() -> Result<()> {
707        let store = InMemoryStore::new();
708        let thread_id = ThreadId::new();
709
710        // Initially empty
711        let history = store.get_history(&thread_id).await?;
712        assert!(history.is_empty());
713
714        // Add messages
715        store.append(&thread_id, Message::user("Hello")).await?;
716        store
717            .append(&thread_id, Message::assistant("Hi there!"))
718            .await?;
719
720        // Retrieve messages
721        let history = store.get_history(&thread_id).await?;
722        assert_eq!(history.len(), 2);
723
724        // Count
725        let count = store.count(&thread_id).await?;
726        assert_eq!(count, 2);
727
728        // Clear
729        store.clear(&thread_id).await?;
730        let history = store.get_history(&thread_id).await?;
731        assert!(history.is_empty());
732
733        Ok(())
734    }
735
736    #[tokio::test]
737    async fn test_replace_history() -> Result<()> {
738        let store = InMemoryStore::new();
739        let thread_id = ThreadId::new();
740
741        // Add some messages
742        store.append(&thread_id, Message::user("Hello")).await?;
743        store
744            .append(&thread_id, Message::assistant("Hi there!"))
745            .await?;
746        store
747            .append(&thread_id, Message::user("How are you?"))
748            .await?;
749
750        // Verify original messages
751        let history = store.get_history(&thread_id).await?;
752        assert_eq!(history.len(), 3);
753
754        // Replace with compacted history
755        let new_history = vec![
756            Message::user("[Summary] Previous conversation about greetings"),
757            Message::assistant("I understand the context. Continuing..."),
758        ];
759        store.replace_history(&thread_id, new_history).await?;
760
761        // Verify replaced history
762        let history = store.get_history(&thread_id).await?;
763        assert_eq!(history.len(), 2);
764
765        Ok(())
766    }
767
768    #[tokio::test]
769    async fn test_in_memory_state_store() -> Result<()> {
770        let store = InMemoryStore::new();
771        let thread_id = ThreadId::new();
772
773        // Initially none
774        let state = store.load(&thread_id).await?;
775        assert!(state.is_none());
776
777        // Save state
778        let state = AgentState::new(thread_id.clone());
779        store.save(&state).await?;
780
781        // Load state
782        let loaded = store.load(&thread_id).await?;
783        assert!(loaded.is_some());
784        if let Some(loaded_state) = loaded {
785            assert_eq!(loaded_state.thread_id, thread_id);
786        }
787
788        // Delete state
789        store.delete(&thread_id).await?;
790        let state = store.load(&thread_id).await?;
791        assert!(state.is_none());
792
793        Ok(())
794    }
795
796    #[tokio::test]
797    async fn test_in_memory_event_store_tracks_turns_and_finish_barrier() -> Result<()> {
798        let store = InMemoryEventStore::new();
799        let thread_id = ThreadId::new();
800        let seq = SequenceCounter::new();
801
802        store
803            .append(
804                &thread_id,
805                1,
806                AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "hello"), &seq),
807            )
808            .await?;
809        store
810            .append(
811                &thread_id,
812                2,
813                AgentEventEnvelope::wrap(AgentEvent::text("msg_2", "world"), &seq),
814            )
815            .await?;
816
817        let turn_1 = store
818            .get_turn(&thread_id, 1)
819            .await?
820            .context("missing turn 1")?;
821        assert_eq!(turn_1.turn, 1);
822        assert_eq!(turn_1.events.len(), 1);
823        assert!(!turn_1.finished);
824
825        store.finish_turn(&thread_id, 1).await?;
826        store.finish_turn(&thread_id, 2).await?;
827
828        let turn_1 = store
829            .get_turn(&thread_id, 1)
830            .await?
831            .context("missing finished turn 1")?;
832        let turn_2 = store
833            .get_turn(&thread_id, 2)
834            .await?
835            .context("missing finished turn 2")?;
836        assert!(turn_1.finished);
837        assert!(turn_2.finished);
838
839        let turns = store.get_turns(&thread_id).await?;
840        assert_eq!(turns.len(), 2);
841        assert_eq!(turns[0].turn, 1);
842        assert_eq!(turns[1].turn, 2);
843
844        Ok(())
845    }
846
847    #[tokio::test]
848    async fn test_in_memory_event_store_finish_turn_without_events_creates_finished_turn()
849    -> Result<()> {
850        let store = InMemoryEventStore::new();
851        let thread_id = ThreadId::new();
852
853        store.finish_turn(&thread_id, 3).await?;
854
855        let turn = store
856            .get_turn(&thread_id, 3)
857            .await?
858            .context("missing empty finished turn")?;
859        assert_eq!(turn.turn, 3);
860        assert!(turn.events.is_empty());
861        assert!(turn.finished);
862
863        store.clear(&thread_id).await?;
864        assert!(store.get_turns(&thread_id).await?.is_empty());
865
866        Ok(())
867    }
868
869    #[tokio::test]
870    async fn test_in_memory_event_store_rejects_append_after_finish() -> Result<()> {
871        let store = InMemoryEventStore::new();
872        let thread_id = ThreadId::new();
873        let seq = SequenceCounter::new();
874
875        store.finish_turn(&thread_id, 1).await?;
876
877        let error = store
878            .append(
879                &thread_id,
880                1,
881                AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "late"), &seq),
882            )
883            .await
884            .expect_err("append after finish should fail");
885
886        assert!(error.to_string().contains("cannot append to finished turn"));
887        Ok(())
888    }
889
890    #[tokio::test]
891    async fn test_in_memory_event_store_rejects_duplicate_finish() -> Result<()> {
892        let store = InMemoryEventStore::new();
893        let thread_id = ThreadId::new();
894
895        store.finish_turn(&thread_id, 1).await?;
896
897        let error = store
898            .finish_turn(&thread_id, 1)
899            .await
900            .expect_err("duplicate finish should fail");
901
902        assert!(error.to_string().contains("already finished"));
903        Ok(())
904    }
905
906    #[tokio::test]
907    async fn test_execution_store_basic_operations() -> Result<()> {
908        let store = InMemoryExecutionStore::new();
909        let thread_id = ThreadId::new();
910
911        // Initially none
912        let execution = store.get_execution("tool_call_123").await?;
913        assert!(execution.is_none());
914
915        // Record execution
916        let execution = ToolExecution::new_in_flight(
917            "tool_call_123",
918            thread_id.clone(),
919            "my_tool",
920            "My Tool",
921            serde_json::json!({"param": "value"}),
922            time::OffsetDateTime::now_utc(),
923        );
924        store.record_execution(execution).await?;
925
926        // Retrieve execution
927        let loaded = store.get_execution("tool_call_123").await?;
928        assert!(loaded.is_some());
929        let loaded = loaded.expect("execution should exist");
930        assert_eq!(loaded.tool_call_id, "tool_call_123");
931        assert_eq!(loaded.tool_name, "my_tool");
932        assert!(loaded.is_in_flight());
933
934        Ok(())
935    }
936
937    #[tokio::test]
938    async fn test_execution_store_complete_execution() -> Result<()> {
939        let store = InMemoryExecutionStore::new();
940        let thread_id = ThreadId::new();
941
942        // Record in-flight execution
943        let mut execution = ToolExecution::new_in_flight(
944            "tool_call_456",
945            thread_id.clone(),
946            "my_tool",
947            "My Tool",
948            serde_json::json!({}),
949            time::OffsetDateTime::now_utc(),
950        );
951        store.record_execution(execution.clone()).await?;
952
953        // Complete the execution
954        execution.complete(ToolResult::success("Done!"));
955        store.update_execution(execution).await?;
956
957        // Verify it's completed
958        let loaded = store.get_execution("tool_call_456").await?;
959        let loaded = loaded.expect("execution should exist");
960        assert!(loaded.is_completed());
961        assert!(loaded.result.is_some());
962        assert!(loaded.result.as_ref().is_some_and(|r| r.success));
963
964        Ok(())
965    }
966
967    #[tokio::test]
968    async fn test_execution_store_operation_id_lookup() -> Result<()> {
969        let store = InMemoryExecutionStore::new();
970        let thread_id = ThreadId::new();
971
972        // Record execution with operation_id
973        let mut execution = ToolExecution::new_in_flight(
974            "tool_call_789",
975            thread_id.clone(),
976            "async_tool",
977            "Async Tool",
978            serde_json::json!({}),
979            time::OffsetDateTime::now_utc(),
980        );
981        execution.set_operation_id("op_abc123");
982        store.record_execution(execution.clone()).await?;
983        store.update_execution(execution).await?;
984
985        // Lookup by operation_id
986        let loaded = store.get_execution_by_operation_id("op_abc123").await?;
987        assert!(loaded.is_some());
988        let loaded = loaded.expect("execution should exist");
989        assert_eq!(loaded.tool_call_id, "tool_call_789");
990        assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
991
992        // Non-existent operation_id
993        let not_found = store.get_execution_by_operation_id("nonexistent").await?;
994        assert!(not_found.is_none());
995
996        Ok(())
997    }
998
999    #[tokio::test]
1000    async fn in_memory_store_clone_shares_history() -> Result<()> {
1001        // A clone handed to the builder shares state with the kept handle, so
1002        // history written by the agent stays reachable to the caller.
1003        let store = InMemoryStore::new();
1004        let handle = store.clone();
1005        let thread_id = ThreadId::new();
1006
1007        store.append(&thread_id, Message::user("hello")).await?;
1008
1009        let history = handle.get_history(&thread_id).await?;
1010        assert_eq!(
1011            history.len(),
1012            1,
1013            "clone must observe appends via the original"
1014        );
1015        Ok(())
1016    }
1017
1018    #[tokio::test]
1019    async fn arc_store_blanket_impls_forward() -> Result<()> {
1020        let store: Arc<InMemoryStore> = Arc::new(InMemoryStore::new());
1021        let thread_id = ThreadId::new();
1022
1023        // `Arc<InMemoryStore>` is itself a `MessageStore` and `StateStore`.
1024        MessageStore::append(&store, &thread_id, Message::user("hi")).await?;
1025        assert_eq!(MessageStore::count(&store, &thread_id).await?, 1);
1026
1027        let state = AgentState::new(thread_id.clone());
1028        StateStore::save(&store, &state).await?;
1029        assert!(StateStore::load(&store, &thread_id).await?.is_some());
1030
1031        // The kept Arc handle still sees everything.
1032        assert_eq!(store.get_history(&thread_id).await?.len(), 1);
1033        Ok(())
1034    }
1035
1036    #[tokio::test]
1037    async fn event_count_and_get_events_since_are_incremental() -> Result<()> {
1038        let store = InMemoryEventStore::new();
1039        let thread_id = ThreadId::new();
1040        let seq = SequenceCounter::new();
1041
1042        assert_eq!(store.event_count(&thread_id).await?, 0);
1043
1044        for (turn, (id, text)) in [(1, ("m1", "a")), (1, ("m2", "b")), (2, ("m3", "c"))] {
1045            store
1046                .append(
1047                    &thread_id,
1048                    turn,
1049                    AgentEventEnvelope::wrap(AgentEvent::text(id, text), &seq),
1050                )
1051                .await?;
1052        }
1053
1054        assert_eq!(store.event_count(&thread_id).await?, 3);
1055
1056        let tail = store.get_events_since(&thread_id, 1).await?;
1057        assert_eq!(tail.len(), 2, "should skip the first event");
1058        // Consistent with the full read.
1059        let all = store.get_events(&thread_id).await?;
1060        assert_eq!(all.len(), 3);
1061        Ok(())
1062    }
1063
1064    #[tokio::test]
1065    async fn record_execution_indexes_operation_id_immediately() -> Result<()> {
1066        let store = InMemoryExecutionStore::new();
1067        let thread_id = ThreadId::new();
1068
1069        let mut execution = ToolExecution::new_in_flight(
1070            "call_1",
1071            thread_id,
1072            "async_tool",
1073            "Async Tool",
1074            serde_json::json!({}),
1075            time::OffsetDateTime::now_utc(),
1076        );
1077        execution.set_operation_id("op_1");
1078        // Write-ahead record only — no `update_execution` call.
1079        store.record_execution(execution).await?;
1080
1081        let loaded = store.get_execution_by_operation_id("op_1").await?;
1082        assert_eq!(
1083            loaded
1084                .context("write-ahead operation_id must resolve")?
1085                .tool_call_id,
1086            "call_1"
1087        );
1088        Ok(())
1089    }
1090
1091    #[tokio::test]
1092    async fn update_execution_removes_stale_operation_id() -> Result<()> {
1093        let store = InMemoryExecutionStore::new();
1094        let thread_id = ThreadId::new();
1095
1096        let mut execution = ToolExecution::new_in_flight(
1097            "call_2",
1098            thread_id,
1099            "async_tool",
1100            "Async Tool",
1101            serde_json::json!({}),
1102            time::OffsetDateTime::now_utc(),
1103        );
1104        execution.set_operation_id("op_old");
1105        store.record_execution(execution.clone()).await?;
1106
1107        // Re-point the execution at a new operation id.
1108        execution.set_operation_id("op_new");
1109        store.update_execution(execution).await?;
1110
1111        assert!(
1112            store
1113                .get_execution_by_operation_id("op_old")
1114                .await?
1115                .is_none(),
1116            "superseded operation_id must stop resolving"
1117        );
1118        let loaded = store.get_execution_by_operation_id("op_new").await?;
1119        assert_eq!(
1120            loaded
1121                .context("new operation_id must resolve")?
1122                .tool_call_id,
1123            "call_2"
1124        );
1125        Ok(())
1126    }
1127
1128    #[tokio::test]
1129    async fn observing_event_store_invokes_callback_and_delegates() -> Result<()> {
1130        use std::sync::atomic::{AtomicUsize, Ordering};
1131
1132        let seen = Arc::new(AtomicUsize::new(0));
1133        let seen_for_cb = Arc::clone(&seen);
1134        let store = ObservingEventStore::new(InMemoryEventStore::new(), move |_envelope| {
1135            seen_for_cb.fetch_add(1, Ordering::SeqCst);
1136        });
1137        let thread_id = ThreadId::new();
1138        let seq = SequenceCounter::new();
1139
1140        store
1141            .append(
1142                &thread_id,
1143                1,
1144                AgentEventEnvelope::wrap(AgentEvent::text("m1", "hi"), &seq),
1145            )
1146            .await?;
1147        store
1148            .append(
1149                &thread_id,
1150                1,
1151                AgentEventEnvelope::wrap(AgentEvent::text("m2", "yo"), &seq),
1152            )
1153            .await?;
1154
1155        assert_eq!(seen.load(Ordering::SeqCst), 2, "observer runs per append");
1156        // Delegation: the inner store actually persisted both events.
1157        assert_eq!(store.get_events(&thread_id).await?.len(), 2);
1158        assert_eq!(store.inner().get_events(&thread_id).await?.len(), 2);
1159        Ok(())
1160    }
1161}