Skip to main content

forja_memory/
lib.rs

1use async_trait::async_trait;
2use forja_core::error::Result;
3use forja_core::traits::MemoryStore;
4use forja_core::types::MemoryEntry;
5use std::fmt::Display;
6use std::path::{Path, PathBuf};
7use std::sync::{Arc, Mutex};
8
9pub mod compressor;
10pub mod longterm;
11pub mod manager;
12pub mod session;
13pub mod storage;
14
15use manager::{MemoryManager, MemoryStats, memory_entry_to_message};
16use storage::Storage;
17
18pub use compressor::{CompressedEntry, Compressor};
19pub use longterm::{LongTermStore, longterm_path};
20pub use manager::{MemoryCommand, MemoryManager as UnifiedMemoryManager, parse_memory_command};
21pub use session::SessionBuffer;
22
23pub struct MarkdownMemoryStore {
24    storage: Storage,
25}
26
27impl MarkdownMemoryStore {
28    pub async fn new(memory_path: impl AsRef<Path>) -> Result<Self> {
29        let storage = Storage::init(memory_path).await?;
30        Ok(Self { storage })
31    }
32
33    pub async fn flush_and_summarize<F, O>(&self, summarizer: F) -> Result<()>
34    where
35        F: Fn(String) -> O,
36        O: SummarizeOutput,
37    {
38        self.storage
39            .flush_and_summarize(|block| summarizer(block).into_summary_result())
40            .await
41    }
42}
43
44pub struct MemoryManagerStore {
45    manager: Arc<tokio::sync::Mutex<MemoryManager>>,
46    current_query: Arc<Mutex<String>>,
47}
48
49impl MemoryManagerStore {
50    pub async fn new(base_dir: impl AsRef<Path>, agent_name: Option<&str>) -> Result<Self> {
51        let manager = MemoryManager::new(longterm_path(base_dir.as_ref(), agent_name)).await?;
52        Ok(Self {
53            manager: Arc::new(tokio::sync::Mutex::new(manager)),
54            current_query: Arc::new(Mutex::new(String::new())),
55        })
56    }
57
58    pub fn set_current_query(&self, query: impl Into<String>) {
59        if let Ok(mut current_query) = self.current_query.lock() {
60            *current_query = query.into();
61        }
62    }
63
64    pub async fn load(&self) -> Result<()> {
65        self.manager.lock().await.load().await
66    }
67
68    pub async fn stats(&self) -> Result<MemoryStats> {
69        self.manager.lock().await.stats().await
70    }
71
72    pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<CompressedEntry>> {
73        self.manager.lock().await.recall(query, limit).await
74    }
75
76    pub async fn clear_session(&self) {
77        self.manager.lock().await.clear_session().await;
78    }
79
80    pub async fn flush_manager(&self) -> Result<()> {
81        self.manager.lock().await.flush().await
82    }
83}
84
85impl Clone for MemoryManagerStore {
86    fn clone(&self) -> Self {
87        Self {
88            manager: self.manager.clone(),
89            current_query: self.current_query.clone(),
90        }
91    }
92}
93
94pub trait SummarizeOutput {
95    fn into_summary_result(self) -> std::result::Result<String, String>;
96}
97
98impl SummarizeOutput for String {
99    fn into_summary_result(self) -> std::result::Result<String, String> {
100        Ok(self)
101    }
102}
103
104impl<E> SummarizeOutput for std::result::Result<String, E>
105where
106    E: Display,
107{
108    fn into_summary_result(self) -> std::result::Result<String, String> {
109        self.map_err(|error| error.to_string())
110    }
111}
112
113#[async_trait]
114impl MemoryStore for MarkdownMemoryStore {
115    async fn save(&self, entry: &MemoryEntry) -> Result<()> {
116        self.storage.append_entry(entry).await
117    }
118
119    async fn load_all(&self) -> Result<String> {
120        self.storage.read_all().await
121    }
122
123    async fn flush(&self) -> Result<()> {
124        Ok(())
125    }
126}
127
128#[async_trait]
129impl MemoryStore for MemoryManagerStore {
130    async fn save(&self, entry: &MemoryEntry) -> Result<()> {
131        let role_hint = entry
132            .tags
133            .iter()
134            .find(|tag| matches!(tag.as_str(), "assistant" | "system" | "tool" | "user"))
135            .map(|tag| tag.as_str())
136            .unwrap_or("user");
137        let message = memory_entry_to_message(role_hint, &entry.content);
138        self.manager.lock().await.record(message).await
139    }
140
141    async fn load_all(&self) -> Result<String> {
142        let query = self
143            .current_query
144            .lock()
145            .map(|query| query.clone())
146            .unwrap_or_default();
147        self.manager.lock().await.get_context(&query).await
148    }
149
150    async fn flush(&self) -> Result<()> {
151        self.flush_manager().await
152    }
153}
154
155impl From<tokio::sync::Mutex<MemoryManager>> for MemoryManagerStore {
156    fn from(manager: tokio::sync::Mutex<MemoryManager>) -> Self {
157        Self {
158            manager: Arc::new(manager),
159            current_query: Arc::new(Mutex::new(String::new())),
160        }
161    }
162}
163
164pub fn default_memory_base_dir() -> PathBuf {
165    std::env::var("FORJA_MEMORY_DIR")
166        .map(PathBuf::from)
167        .unwrap_or_else(|_| {
168            dirs_next::home_dir()
169                .unwrap_or_default()
170                .join(".forja")
171                .join("memory")
172        })
173}
174
175#[cfg(test)]
176mod tests {
177    use super::MemoryManagerStore;
178    use crate::longterm::longterm_path;
179    use forja_core::traits::MemoryStore;
180    use forja_core::types::MemoryEntry;
181    use std::path::PathBuf;
182    use std::time::{SystemTime, UNIX_EPOCH};
183
184    fn unique_temp_dir(name: &str) -> PathBuf {
185        let nanos = SystemTime::now()
186            .duration_since(UNIX_EPOCH)
187            .unwrap_or_default()
188            .as_nanos();
189        std::env::temp_dir().join(format!("forja_memory_store_{name}_{nanos}"))
190    }
191
192    #[tokio::test]
193    async fn memory_manager_store_flushes_to_longterm_file() {
194        let base_dir = unique_temp_dir("flush");
195        let store = MemoryManagerStore::new(&base_dir, None).await.unwrap();
196        store
197            .save(&MemoryEntry {
198                id: "user_1".to_string(),
199                timestamp: 1,
200                tags: vec!["user".to_string()],
201                content: "Remember the deploy checklist.".to_string(),
202                score: 0.0,
203                metadata: Default::default(),
204            })
205            .await
206            .unwrap();
207        store.flush().await.unwrap();
208
209        let path = longterm_path(base_dir.as_path(), None);
210        let contents = tokio::fs::read_to_string(path).await.unwrap();
211        assert!(contents.contains("summary:"));
212    }
213}