Skip to main content

agent_sdk_tools/
stores.rs

1//! Storage traits for message history, agent state, and event persistence.
2//!
3//! The SDK uses three storage abstractions:
4//!
5//! - [`MessageStore`] - Stores conversation message history per thread
6//! - [`StateStore`] - Stores agent state checkpoints for recovery
7//! - [`EventStore`] - Stores turn-scoped event envelopes for retrieval
8//!
9//! # Built-in Implementation
10//!
11//! [`InMemoryStore`] implements the message/state traits and is suitable for
12//! testing and single-process deployments. [`InMemoryEventStore`] provides the
13//! corresponding in-memory event journal. For production, implement custom
14//! stores backed by your database (e.g., Postgres, Redis).
15
16use agent_sdk_foundation::events::AgentEventEnvelope;
17use agent_sdk_foundation::llm;
18use agent_sdk_foundation::types::{AgentState, ThreadId, ToolExecution};
19use anyhow::{Context, Result};
20use async_trait::async_trait;
21use std::collections::{BTreeMap, HashMap};
22use std::sync::Arc;
23use std::sync::RwLock;
24use tokio::sync::RwLock as AsyncRwLock;
25
26/// Trait for storing and retrieving conversation messages.
27/// Implement this trait to persist messages to your storage backend.
28#[async_trait]
29pub trait MessageStore: Send + Sync {
30    /// Append a message to the thread's history
31    ///
32    /// # Errors
33    /// Returns an error if the message cannot be stored.
34    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
35
36    /// Get all messages for a thread
37    ///
38    /// # Errors
39    /// Returns an error if the history cannot be retrieved.
40    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
41
42    /// Clear all messages for a thread
43    ///
44    /// # Errors
45    /// Returns an error if the messages cannot be cleared.
46    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
47
48    /// Get the message count for a thread
49    ///
50    /// # Errors
51    /// Returns an error if the count cannot be retrieved.
52    async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
53        Ok(self.get_history(thread_id).await?.len())
54    }
55
56    /// Replace the entire message history for a thread.
57    /// Used for context compaction to replace old messages with a summary.
58    ///
59    /// # Errors
60    /// Returns an error if the history cannot be replaced.
61    async fn replace_history(
62        &self,
63        thread_id: &ThreadId,
64        messages: Vec<llm::Message>,
65    ) -> Result<()>;
66}
67
68/// Trait for storing agent state checkpoints.
69/// Implement this to enable conversation recovery and resume.
70#[async_trait]
71pub trait StateStore: Send + Sync {
72    /// Save the current agent state
73    ///
74    /// # Errors
75    /// Returns an error if the state cannot be saved.
76    async fn save(&self, state: &AgentState) -> Result<()>;
77
78    /// Load the most recent state for a thread
79    ///
80    /// # Errors
81    /// Returns an error if the state cannot be loaded.
82    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
83
84    /// Delete state for a thread
85    ///
86    /// # Errors
87    /// Returns an error if the state cannot be deleted.
88    async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
89}
90
91/// Stored event data for a single turn.
92#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
93pub struct StoredTurnEvents {
94    /// Turn number (1-based once execution starts).
95    pub turn: usize,
96    /// Events emitted for this turn.
97    pub events: Vec<AgentEventEnvelope>,
98    /// Whether `finish_turn()` has completed for this turn.
99    pub finished: bool,
100}
101
102/// Trait for storing and retrieving turn-scoped event streams.
103///
104/// Event writes are split into two phases:
105/// 1. [`append`](EventStore::append) records individual envelopes
106/// 2. [`finish_turn`](EventStore::finish_turn) marks the authoritative close barrier
107#[async_trait]
108pub trait EventStore: Send + Sync {
109    /// Append an event envelope for the given thread and turn.
110    ///
111    /// # Errors
112    /// Returns an error if the event cannot be persisted.
113    async fn append(
114        &self,
115        thread_id: &ThreadId,
116        turn: usize,
117        envelope: AgentEventEnvelope,
118    ) -> Result<()>;
119
120    /// Mark the given turn as finished and flush any buffered writes.
121    ///
122    /// # Errors
123    /// Returns an error if the store cannot durably close the turn.
124    async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()>;
125
126    /// Retrieve the stored data for a single turn.
127    ///
128    /// # Errors
129    /// Returns an error if the turn cannot be retrieved.
130    async fn get_turn(&self, thread_id: &ThreadId, turn: usize)
131    -> Result<Option<StoredTurnEvents>>;
132
133    /// Retrieve all stored turns for the given thread in ascending turn order.
134    ///
135    /// # Errors
136    /// Returns an error if the thread history cannot be retrieved.
137    async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>>;
138
139    /// Retrieve all event envelopes for the given thread across every stored turn.
140    ///
141    /// # Errors
142    /// Returns an error if the thread history cannot be retrieved.
143    async fn get_events(&self, thread_id: &ThreadId) -> Result<Vec<AgentEventEnvelope>> {
144        let turns = self.get_turns(thread_id).await?;
145        Ok(turns
146            .into_iter()
147            .flat_map(|turn| turn.events.into_iter())
148            .collect())
149    }
150
151    /// Clear all events for the given thread.
152    ///
153    /// # Errors
154    /// Returns an error if the thread cannot be cleared.
155    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
156}
157
158/// Store for tracking tool executions (idempotency).
159///
160/// This trait enables write-ahead execution tracking to ensure tool idempotency.
161/// The pattern is:
162/// 1. Record execution intent BEFORE calling the tool (`record_execution`)
163/// 2. Update with result AFTER completion (`update_execution`)
164/// 3. On retry, check if execution exists and return cached result
165#[async_trait]
166pub trait ToolExecutionStore: Send + Sync {
167    /// Get an execution by `tool_call_id`.
168    ///
169    /// # Errors
170    /// Returns an error if the execution cannot be retrieved.
171    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
172
173    /// Record a new execution (write-ahead, before calling tool).
174    ///
175    /// # Errors
176    /// Returns an error if the execution cannot be recorded.
177    async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
178
179    /// Update an existing execution (after completion or to set `operation_id`).
180    ///
181    /// # Errors
182    /// Returns an error if the execution cannot be updated.
183    async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
184
185    /// Get execution by `operation_id` (for async tool resume).
186    ///
187    /// # Errors
188    /// Returns an error if the execution cannot be retrieved.
189    async fn get_execution_by_operation_id(
190        &self,
191        operation_id: &str,
192    ) -> Result<Option<ToolExecution>>;
193}
194
195/// In-memory implementation of `MessageStore` and `StateStore`.
196/// Useful for testing and simple use cases.
197#[derive(Default)]
198pub struct InMemoryStore {
199    messages: RwLock<HashMap<String, Vec<llm::Message>>>,
200    states: RwLock<HashMap<String, AgentState>>,
201}
202
203impl InMemoryStore {
204    #[must_use]
205    pub fn new() -> Self {
206        Self::default()
207    }
208}
209
210#[derive(Default)]
211struct InMemoryEventStoreInner {
212    turns: AsyncRwLock<HashMap<String, BTreeMap<usize, StoredTurnEvents>>>,
213}
214
215/// In-memory implementation of [`EventStore`].
216///
217/// Cloning this type shares the same underlying event journal.
218#[derive(Clone, Default)]
219pub struct InMemoryEventStore {
220    inner: Arc<InMemoryEventStoreInner>,
221}
222
223impl InMemoryEventStore {
224    #[must_use]
225    pub fn new() -> Self {
226        Self::default()
227    }
228
229    async fn update_turn(
230        &self,
231        thread_id: &ThreadId,
232        turn: usize,
233        update: impl FnOnce(&mut StoredTurnEvents) -> Result<()>,
234    ) -> Result<()> {
235        let mut turns = self.inner.turns.write().await;
236        let stored_turn = turns
237            .entry(thread_id.0.clone())
238            .or_default()
239            .entry(turn)
240            .or_insert_with(|| StoredTurnEvents {
241                turn,
242                events: Vec::new(),
243                finished: false,
244            });
245        let result = update(stored_turn);
246        drop(turns);
247        result
248    }
249}
250
251#[async_trait]
252impl MessageStore for InMemoryStore {
253    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
254        self.messages
255            .write()
256            .ok()
257            .context("lock poisoned")?
258            .entry(thread_id.0.clone())
259            .or_default()
260            .push(message);
261        Ok(())
262    }
263
264    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
265        let messages = self.messages.read().ok().context("lock poisoned")?;
266        Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
267    }
268
269    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
270        self.messages
271            .write()
272            .ok()
273            .context("lock poisoned")?
274            .remove(&thread_id.0);
275        Ok(())
276    }
277
278    async fn replace_history(
279        &self,
280        thread_id: &ThreadId,
281        messages: Vec<llm::Message>,
282    ) -> Result<()> {
283        self.messages
284            .write()
285            .ok()
286            .context("lock poisoned")?
287            .insert(thread_id.0.clone(), messages);
288        Ok(())
289    }
290}
291
292#[async_trait]
293impl StateStore for InMemoryStore {
294    async fn save(&self, state: &AgentState) -> Result<()> {
295        self.states
296            .write()
297            .ok()
298            .context("lock poisoned")?
299            .insert(state.thread_id.0.clone(), state.clone());
300        Ok(())
301    }
302
303    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
304        let states = self.states.read().ok().context("lock poisoned")?;
305        Ok(states.get(&thread_id.0).cloned())
306    }
307
308    async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
309        self.states
310            .write()
311            .ok()
312            .context("lock poisoned")?
313            .remove(&thread_id.0);
314        Ok(())
315    }
316}
317
318#[async_trait]
319impl EventStore for InMemoryEventStore {
320    async fn append(
321        &self,
322        thread_id: &ThreadId,
323        turn: usize,
324        envelope: AgentEventEnvelope,
325    ) -> Result<()> {
326        self.update_turn(thread_id, turn, |stored_turn| {
327            anyhow::ensure!(
328                !stored_turn.finished,
329                "cannot append to finished turn {turn}"
330            );
331            stored_turn.events.push(envelope);
332            Ok(())
333        })
334        .await
335    }
336
337    async fn finish_turn(&self, thread_id: &ThreadId, turn: usize) -> Result<()> {
338        self.update_turn(thread_id, turn, |stored_turn| {
339            anyhow::ensure!(!stored_turn.finished, "turn {turn} is already finished");
340            stored_turn.finished = true;
341            Ok(())
342        })
343        .await
344    }
345
346    async fn get_turn(
347        &self,
348        thread_id: &ThreadId,
349        turn: usize,
350    ) -> Result<Option<StoredTurnEvents>> {
351        let turns = self.inner.turns.read().await;
352        Ok(turns
353            .get(&thread_id.0)
354            .and_then(|thread_turns| thread_turns.get(&turn).cloned()))
355    }
356
357    async fn get_turns(&self, thread_id: &ThreadId) -> Result<Vec<StoredTurnEvents>> {
358        let turns = self.inner.turns.read().await;
359        Ok(turns
360            .get(&thread_id.0)
361            .map(|thread_turns| thread_turns.values().cloned().collect())
362            .unwrap_or_default())
363    }
364
365    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
366        {
367            let mut turns = self.inner.turns.write().await;
368            turns.remove(&thread_id.0);
369        }
370        Ok(())
371    }
372}
373
374/// In-memory implementation of `ToolExecutionStore`.
375///
376/// Useful for testing and simple use cases where durability is not required.
377/// For production, implement a custom store backed by a database.
378#[derive(Default)]
379pub struct InMemoryExecutionStore {
380    /// Executions indexed by `tool_call_id`
381    executions: RwLock<HashMap<String, ToolExecution>>,
382    /// Index from `operation_id` to `tool_call_id` for async tool lookup
383    operation_index: RwLock<HashMap<String, String>>,
384}
385
386impl InMemoryExecutionStore {
387    #[must_use]
388    pub fn new() -> Self {
389        Self::default()
390    }
391}
392
393#[async_trait]
394impl ToolExecutionStore for InMemoryExecutionStore {
395    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
396        let executions = self.executions.read().ok().context("lock poisoned")?;
397        Ok(executions.get(tool_call_id).cloned())
398    }
399
400    async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
401        let tool_call_id = execution.tool_call_id.clone();
402        self.executions
403            .write()
404            .ok()
405            .context("lock poisoned")?
406            .insert(tool_call_id, execution);
407        Ok(())
408    }
409
410    async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
411        let tool_call_id = execution.tool_call_id.clone();
412
413        // Update operation_id index if present
414        if let Some(ref op_id) = execution.operation_id {
415            self.operation_index
416                .write()
417                .ok()
418                .context("lock poisoned")?
419                .insert(op_id.clone(), tool_call_id.clone());
420        }
421
422        self.executions
423            .write()
424            .ok()
425            .context("lock poisoned")?
426            .insert(tool_call_id, execution);
427        Ok(())
428    }
429
430    async fn get_execution_by_operation_id(
431        &self,
432        operation_id: &str,
433    ) -> Result<Option<ToolExecution>> {
434        // Get tool_call_id and drop lock before acquiring another
435        let tool_call_id = {
436            let op_index = self.operation_index.read().ok().context("lock poisoned")?;
437            op_index.get(operation_id).cloned()
438        };
439
440        let Some(tool_call_id) = tool_call_id else {
441            return Ok(None);
442        };
443
444        let executions = self.executions.read().ok().context("lock poisoned")?;
445        Ok(executions.get(&tool_call_id).cloned())
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use agent_sdk_foundation::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
453    use agent_sdk_foundation::llm::Message;
454    use agent_sdk_foundation::types::ToolResult;
455
456    #[tokio::test]
457    async fn test_in_memory_message_store() -> Result<()> {
458        let store = InMemoryStore::new();
459        let thread_id = ThreadId::new();
460
461        // Initially empty
462        let history = store.get_history(&thread_id).await?;
463        assert!(history.is_empty());
464
465        // Add messages
466        store.append(&thread_id, Message::user("Hello")).await?;
467        store
468            .append(&thread_id, Message::assistant("Hi there!"))
469            .await?;
470
471        // Retrieve messages
472        let history = store.get_history(&thread_id).await?;
473        assert_eq!(history.len(), 2);
474
475        // Count
476        let count = store.count(&thread_id).await?;
477        assert_eq!(count, 2);
478
479        // Clear
480        store.clear(&thread_id).await?;
481        let history = store.get_history(&thread_id).await?;
482        assert!(history.is_empty());
483
484        Ok(())
485    }
486
487    #[tokio::test]
488    async fn test_replace_history() -> Result<()> {
489        let store = InMemoryStore::new();
490        let thread_id = ThreadId::new();
491
492        // Add some messages
493        store.append(&thread_id, Message::user("Hello")).await?;
494        store
495            .append(&thread_id, Message::assistant("Hi there!"))
496            .await?;
497        store
498            .append(&thread_id, Message::user("How are you?"))
499            .await?;
500
501        // Verify original messages
502        let history = store.get_history(&thread_id).await?;
503        assert_eq!(history.len(), 3);
504
505        // Replace with compacted history
506        let new_history = vec![
507            Message::user("[Summary] Previous conversation about greetings"),
508            Message::assistant("I understand the context. Continuing..."),
509        ];
510        store.replace_history(&thread_id, new_history).await?;
511
512        // Verify replaced history
513        let history = store.get_history(&thread_id).await?;
514        assert_eq!(history.len(), 2);
515
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn test_in_memory_state_store() -> Result<()> {
521        let store = InMemoryStore::new();
522        let thread_id = ThreadId::new();
523
524        // Initially none
525        let state = store.load(&thread_id).await?;
526        assert!(state.is_none());
527
528        // Save state
529        let state = AgentState::new(thread_id.clone());
530        store.save(&state).await?;
531
532        // Load state
533        let loaded = store.load(&thread_id).await?;
534        assert!(loaded.is_some());
535        if let Some(loaded_state) = loaded {
536            assert_eq!(loaded_state.thread_id, thread_id);
537        }
538
539        // Delete state
540        store.delete(&thread_id).await?;
541        let state = store.load(&thread_id).await?;
542        assert!(state.is_none());
543
544        Ok(())
545    }
546
547    #[tokio::test]
548    async fn test_in_memory_event_store_tracks_turns_and_finish_barrier() -> Result<()> {
549        let store = InMemoryEventStore::new();
550        let thread_id = ThreadId::new();
551        let seq = SequenceCounter::new();
552
553        store
554            .append(
555                &thread_id,
556                1,
557                AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "hello"), &seq),
558            )
559            .await?;
560        store
561            .append(
562                &thread_id,
563                2,
564                AgentEventEnvelope::wrap(AgentEvent::text("msg_2", "world"), &seq),
565            )
566            .await?;
567
568        let turn_1 = store
569            .get_turn(&thread_id, 1)
570            .await?
571            .context("missing turn 1")?;
572        assert_eq!(turn_1.turn, 1);
573        assert_eq!(turn_1.events.len(), 1);
574        assert!(!turn_1.finished);
575
576        store.finish_turn(&thread_id, 1).await?;
577        store.finish_turn(&thread_id, 2).await?;
578
579        let turn_1 = store
580            .get_turn(&thread_id, 1)
581            .await?
582            .context("missing finished turn 1")?;
583        let turn_2 = store
584            .get_turn(&thread_id, 2)
585            .await?
586            .context("missing finished turn 2")?;
587        assert!(turn_1.finished);
588        assert!(turn_2.finished);
589
590        let turns = store.get_turns(&thread_id).await?;
591        assert_eq!(turns.len(), 2);
592        assert_eq!(turns[0].turn, 1);
593        assert_eq!(turns[1].turn, 2);
594
595        Ok(())
596    }
597
598    #[tokio::test]
599    async fn test_in_memory_event_store_finish_turn_without_events_creates_finished_turn()
600    -> Result<()> {
601        let store = InMemoryEventStore::new();
602        let thread_id = ThreadId::new();
603
604        store.finish_turn(&thread_id, 3).await?;
605
606        let turn = store
607            .get_turn(&thread_id, 3)
608            .await?
609            .context("missing empty finished turn")?;
610        assert_eq!(turn.turn, 3);
611        assert!(turn.events.is_empty());
612        assert!(turn.finished);
613
614        store.clear(&thread_id).await?;
615        assert!(store.get_turns(&thread_id).await?.is_empty());
616
617        Ok(())
618    }
619
620    #[tokio::test]
621    async fn test_in_memory_event_store_rejects_append_after_finish() -> Result<()> {
622        let store = InMemoryEventStore::new();
623        let thread_id = ThreadId::new();
624        let seq = SequenceCounter::new();
625
626        store.finish_turn(&thread_id, 1).await?;
627
628        let error = store
629            .append(
630                &thread_id,
631                1,
632                AgentEventEnvelope::wrap(AgentEvent::text("msg_1", "late"), &seq),
633            )
634            .await
635            .expect_err("append after finish should fail");
636
637        assert!(error.to_string().contains("cannot append to finished turn"));
638        Ok(())
639    }
640
641    #[tokio::test]
642    async fn test_in_memory_event_store_rejects_duplicate_finish() -> Result<()> {
643        let store = InMemoryEventStore::new();
644        let thread_id = ThreadId::new();
645
646        store.finish_turn(&thread_id, 1).await?;
647
648        let error = store
649            .finish_turn(&thread_id, 1)
650            .await
651            .expect_err("duplicate finish should fail");
652
653        assert!(error.to_string().contains("already finished"));
654        Ok(())
655    }
656
657    #[tokio::test]
658    async fn test_execution_store_basic_operations() -> Result<()> {
659        let store = InMemoryExecutionStore::new();
660        let thread_id = ThreadId::new();
661
662        // Initially none
663        let execution = store.get_execution("tool_call_123").await?;
664        assert!(execution.is_none());
665
666        // Record execution
667        let execution = ToolExecution::new_in_flight(
668            "tool_call_123",
669            thread_id.clone(),
670            "my_tool",
671            "My Tool",
672            serde_json::json!({"param": "value"}),
673            time::OffsetDateTime::now_utc(),
674        );
675        store.record_execution(execution).await?;
676
677        // Retrieve execution
678        let loaded = store.get_execution("tool_call_123").await?;
679        assert!(loaded.is_some());
680        let loaded = loaded.expect("execution should exist");
681        assert_eq!(loaded.tool_call_id, "tool_call_123");
682        assert_eq!(loaded.tool_name, "my_tool");
683        assert!(loaded.is_in_flight());
684
685        Ok(())
686    }
687
688    #[tokio::test]
689    async fn test_execution_store_complete_execution() -> Result<()> {
690        let store = InMemoryExecutionStore::new();
691        let thread_id = ThreadId::new();
692
693        // Record in-flight execution
694        let mut execution = ToolExecution::new_in_flight(
695            "tool_call_456",
696            thread_id.clone(),
697            "my_tool",
698            "My Tool",
699            serde_json::json!({}),
700            time::OffsetDateTime::now_utc(),
701        );
702        store.record_execution(execution.clone()).await?;
703
704        // Complete the execution
705        execution.complete(ToolResult::success("Done!"));
706        store.update_execution(execution).await?;
707
708        // Verify it's completed
709        let loaded = store.get_execution("tool_call_456").await?;
710        let loaded = loaded.expect("execution should exist");
711        assert!(loaded.is_completed());
712        assert!(loaded.result.is_some());
713        assert!(loaded.result.as_ref().is_some_and(|r| r.success));
714
715        Ok(())
716    }
717
718    #[tokio::test]
719    async fn test_execution_store_operation_id_lookup() -> Result<()> {
720        let store = InMemoryExecutionStore::new();
721        let thread_id = ThreadId::new();
722
723        // Record execution with operation_id
724        let mut execution = ToolExecution::new_in_flight(
725            "tool_call_789",
726            thread_id.clone(),
727            "async_tool",
728            "Async Tool",
729            serde_json::json!({}),
730            time::OffsetDateTime::now_utc(),
731        );
732        execution.set_operation_id("op_abc123");
733        store.record_execution(execution.clone()).await?;
734        store.update_execution(execution).await?;
735
736        // Lookup by operation_id
737        let loaded = store.get_execution_by_operation_id("op_abc123").await?;
738        assert!(loaded.is_some());
739        let loaded = loaded.expect("execution should exist");
740        assert_eq!(loaded.tool_call_id, "tool_call_789");
741        assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
742
743        // Non-existent operation_id
744        let not_found = store.get_execution_by_operation_id("nonexistent").await?;
745        assert!(not_found.is_none());
746
747        Ok(())
748    }
749}