Skip to main content

rs_adk/memory/
mod.rs

1//! Memory service — session-scoped memory for agents.
2//!
3//! Mirrors ADK-JS's `BaseMemoryService`. Provides a trait for storing and
4//! searching memory entries (key-value) with an in-memory default.
5
6mod in_memory;
7mod vertex_ai_memory_bank;
8mod vertex_ai_rag;
9
10pub use in_memory::InMemoryMemoryService;
11pub use vertex_ai_memory_bank::{VertexAiMemoryBankConfig, VertexAiMemoryBankService};
12pub use vertex_ai_rag::{VertexAiRagMemoryConfig, VertexAiRagMemoryService};
13
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16
17/// A memory entry — a named piece of information stored by an agent.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct MemoryEntry {
20    /// Unique key for this memory.
21    pub key: String,
22    /// The stored value.
23    pub value: serde_json::Value,
24    /// When this entry was created (Unix timestamp seconds).
25    pub created_at: u64,
26    /// When this entry was last updated (Unix timestamp seconds).
27    pub updated_at: u64,
28}
29
30impl MemoryEntry {
31    /// Create a new memory entry.
32    pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
33        let now = now_secs();
34        Self {
35            key: key.into(),
36            value,
37            created_at: now,
38            updated_at: now,
39        }
40    }
41}
42
43/// Errors from memory service operations.
44#[derive(Debug, thiserror::Error)]
45pub enum MemoryError {
46    /// The requested memory key was not found.
47    #[error("Memory key not found: {0}")]
48    NotFound(String),
49    /// A storage backend error.
50    #[error("Storage error: {0}")]
51    Storage(String),
52}
53
54/// Trait for session-scoped memory persistence.
55///
56/// Memory is scoped to a session ID. Implementations must be `Send + Sync`.
57#[async_trait]
58pub trait MemoryService: Send + Sync {
59    /// Store a memory entry for a session.
60    async fn store(&self, session_id: &str, entry: MemoryEntry) -> Result<(), MemoryError>;
61
62    /// Retrieve a memory entry by key.
63    async fn get(&self, session_id: &str, key: &str) -> Result<Option<MemoryEntry>, MemoryError>;
64
65    /// List all memory entries for a session.
66    async fn list(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
67
68    /// Search memory entries by a query string (simple substring match in default impl).
69    async fn search(&self, session_id: &str, query: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
70
71    /// Delete a memory entry.
72    async fn delete(&self, session_id: &str, key: &str) -> Result<(), MemoryError>;
73
74    /// Clear all memory for a session.
75    async fn clear(&self, session_id: &str) -> Result<(), MemoryError>;
76}
77
78fn now_secs() -> u64 {
79    std::time::SystemTime::now()
80        .duration_since(std::time::UNIX_EPOCH)
81        .unwrap_or_default()
82        .as_secs()
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn memory_entry_new() {
91        let entry = MemoryEntry::new("topic", serde_json::json!("Rust"));
92        assert_eq!(entry.key, "topic");
93        assert_eq!(entry.value, serde_json::json!("Rust"));
94        assert!(entry.created_at > 0);
95    }
96
97    #[test]
98    fn memory_service_is_object_safe() {
99        fn _assert(_: &dyn MemoryService) {}
100    }
101
102    #[tokio::test]
103    async fn store_and_get() {
104        let svc = InMemoryMemoryService::new();
105        let entry = MemoryEntry::new("topic", serde_json::json!("AI"));
106        svc.store("s1", entry).await.unwrap();
107
108        let fetched = svc.get("s1", "topic").await.unwrap();
109        assert!(fetched.is_some());
110        assert_eq!(fetched.unwrap().value, serde_json::json!("AI"));
111    }
112
113    #[tokio::test]
114    async fn get_nonexistent_returns_none() {
115        let svc = InMemoryMemoryService::new();
116        let fetched = svc.get("s1", "missing").await.unwrap();
117        assert!(fetched.is_none());
118    }
119
120    #[tokio::test]
121    async fn list_entries() {
122        let svc = InMemoryMemoryService::new();
123        svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
124            .await
125            .unwrap();
126        svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
127            .await
128            .unwrap();
129        svc.store("s2", MemoryEntry::new("c", serde_json::json!(3)))
130            .await
131            .unwrap();
132
133        let entries = svc.list("s1").await.unwrap();
134        assert_eq!(entries.len(), 2);
135    }
136
137    #[tokio::test]
138    async fn search_entries() {
139        let svc = InMemoryMemoryService::new();
140        svc.store(
141            "s1",
142            MemoryEntry::new("rust_topic", serde_json::json!("Rust programming")),
143        )
144        .await
145        .unwrap();
146        svc.store(
147            "s1",
148            MemoryEntry::new("python_topic", serde_json::json!("Python scripting")),
149        )
150        .await
151        .unwrap();
152
153        let results = svc.search("s1", "rust").await.unwrap();
154        assert_eq!(results.len(), 1);
155        assert_eq!(results[0].key, "rust_topic");
156    }
157
158    #[tokio::test]
159    async fn delete_entry() {
160        let svc = InMemoryMemoryService::new();
161        svc.store("s1", MemoryEntry::new("k", serde_json::json!(1)))
162            .await
163            .unwrap();
164        svc.delete("s1", "k").await.unwrap();
165        let fetched = svc.get("s1", "k").await.unwrap();
166        assert!(fetched.is_none());
167    }
168
169    #[tokio::test]
170    async fn clear_session() {
171        let svc = InMemoryMemoryService::new();
172        svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
173            .await
174            .unwrap();
175        svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
176            .await
177            .unwrap();
178        svc.clear("s1").await.unwrap();
179        let entries = svc.list("s1").await.unwrap();
180        assert!(entries.is_empty());
181    }
182}