Skip to main content

forja_memory/
longterm.rs

1use crate::compressor::CompressedEntry;
2use chrono::{DateTime, Local};
3use forja_core::error::{ForjaError, Result};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tokio::fs;
7use tokio::io::AsyncWriteExt;
8
9#[derive(Debug, Clone)]
10pub struct LongTermStore {
11    path: PathBuf,
12}
13
14impl LongTermStore {
15    pub async fn new(path: impl AsRef<Path>) -> Result<Self> {
16        let path = path.as_ref().to_path_buf();
17        if let Some(parent) = path.parent() {
18            fs::create_dir_all(parent)
19                .await
20                .map_err(|error| ForjaError::Storage(format!("Failed to create memory dir: {error}")))?;
21        }
22        if !fs::try_exists(&path)
23            .await
24            .map_err(|error| ForjaError::Storage(format!("Failed to inspect long-term store: {error}")))?
25        {
26            fs::write(&path, "")
27                .await
28                .map_err(|error| ForjaError::Storage(format!("Failed to create long-term store: {error}")))?;
29        }
30
31        Ok(Self { path })
32    }
33
34    pub async fn add(&self, entry: &CompressedEntry) -> Result<()> {
35        let mut file = fs::OpenOptions::new()
36            .create(true)
37            .append(true)
38            .open(&self.path)
39            .await
40            .map_err(|error| ForjaError::Storage(format!("Failed to open long-term store: {error}")))?;
41
42        let block = render_entry(entry);
43        file.write_all(block.as_bytes())
44            .await
45            .map_err(|error| ForjaError::Storage(format!("Failed to append long-term entry: {error}")))?;
46        Ok(())
47    }
48
49    pub async fn load(&self) -> Result<Vec<CompressedEntry>> {
50        let raw = fs::read_to_string(&self.path)
51            .await
52            .map_err(|error| ForjaError::Storage(format!("Failed to read long-term store: {error}")))?;
53        Ok(parse_entries(&raw))
54    }
55
56    pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<CompressedEntry>> {
57        let entries = self.load().await?;
58        Ok(rank_entries(&entries, query, limit))
59    }
60
61    pub async fn entry_count(&self) -> Result<usize> {
62        Ok(self.load().await?.len())
63    }
64
65    pub fn path(&self) -> &Path {
66        &self.path
67    }
68}
69
70pub fn longterm_path(base_dir: &Path, agent_name: Option<&str>) -> PathBuf {
71    match agent_name {
72        Some(agent_name) => base_dir
73            .join("agents")
74            .join(agent_name)
75            .join("memory")
76            .join("longterm.md"),
77        None => base_dir.join("longterm.md"),
78    }
79}
80
81fn render_entry(entry: &CompressedEntry) -> String {
82    let timestamp = entry.timestamp.to_rfc3339();
83    let tags = if entry.keywords.is_empty() {
84        String::new()
85    } else {
86        entry.keywords
87            .iter()
88            .map(|keyword| format!("#{keyword}"))
89            .collect::<Vec<_>>()
90            .join(" ")
91    };
92    let code_section = if entry.code_snippets.is_empty() {
93        String::new()
94    } else {
95        format!(
96            "\n```text\n{}\n```\n",
97            entry.code_snippets.join("\n---\n")
98        )
99    };
100
101    format!(
102        "## [{timestamp}]\nsummary: {}\nkeywords: {tags}\noriginal_count: {}\n{}\n",
103        entry.summary.trim(),
104        entry.original_count,
105        code_section,
106    )
107}
108
109fn parse_entries(raw: &str) -> Vec<CompressedEntry> {
110    raw.split("## [")
111        .filter_map(|section| {
112            let section = section.trim();
113            if section.is_empty() {
114                return None;
115            }
116            parse_entry_block(section)
117        })
118        .collect()
119}
120
121fn parse_entry_block(section: &str) -> Option<CompressedEntry> {
122    let (timestamp_raw, rest) = section.split_once("]\n")?;
123    let timestamp = DateTime::parse_from_rfc3339(timestamp_raw)
124        .ok()?
125        .with_timezone(&Local);
126
127    let mut summary = String::new();
128    let mut keywords = Vec::new();
129    let mut original_count = 0usize;
130    let mut code_snippets = Vec::new();
131    let mut in_code_block = false;
132    let mut current_code = Vec::new();
133
134    for line in rest.lines() {
135        if line == "```text" {
136            in_code_block = true;
137            current_code.clear();
138            continue;
139        }
140        if line == "```" {
141            in_code_block = false;
142            if !current_code.is_empty() {
143                code_snippets.push(current_code.join("\n"));
144            }
145            current_code.clear();
146            continue;
147        }
148        if in_code_block {
149            current_code.push(line.to_string());
150            continue;
151        }
152
153        if let Some(value) = line.strip_prefix("summary: ") {
154            summary = value.trim().to_string();
155        } else if let Some(value) = line.strip_prefix("keywords: ") {
156            keywords = value
157                .split_whitespace()
158                .map(str::trim)
159                .map(|keyword| keyword.trim_start_matches('#').to_string())
160                .filter(|keyword| !keyword.is_empty())
161                .collect();
162        } else if let Some(value) = line.strip_prefix("original_count: ") {
163            original_count = value.trim().parse().ok()?;
164        }
165    }
166
167    Some(CompressedEntry {
168        timestamp,
169        summary,
170        keywords,
171        original_count,
172        code_snippets,
173    })
174}
175
176fn rank_entries(entries: &[CompressedEntry], query: &str, limit: usize) -> Vec<CompressedEntry> {
177    let query_terms = tokenize(query);
178    if query_terms.is_empty() {
179        return entries.iter().rev().take(limit).cloned().collect();
180    }
181
182    let documents = entries
183        .iter()
184        .map(entry_text)
185        .collect::<Vec<_>>();
186    let average_length = documents
187        .iter()
188        .map(|document| document.len() as f64)
189        .sum::<f64>()
190        / documents.len().max(1) as f64;
191
192    let document_frequency = query_terms
193        .iter()
194        .map(|term| {
195            let count = documents
196                .iter()
197                .filter(|document| document.contains_key(term))
198                .count();
199            (term.clone(), count)
200        })
201        .collect::<HashMap<_, _>>();
202
203    let mut scored = entries
204        .iter()
205        .zip(documents.iter())
206        .filter_map(|(entry, document)| {
207            let score = query_terms.iter().fold(0.0, |acc, term| {
208                let tf = *document.get(term).unwrap_or(&0) as f64;
209                if tf == 0.0 {
210                    return acc;
211                }
212                let df = *document_frequency.get(term).unwrap_or(&0) as f64;
213                let idf = ((documents.len() as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
214                let length = document.values().sum::<usize>() as f64;
215                let k1 = 1.2;
216                let b = 0.75;
217                let denominator = tf + k1 * (1.0 - b + b * (length / average_length.max(1.0)));
218                acc + idf * (tf * (k1 + 1.0) / denominator)
219            });
220
221            if score > 0.0 {
222                Some((score, entry.clone()))
223            } else {
224                None
225            }
226        })
227        .collect::<Vec<_>>();
228
229    scored.sort_by(|left, right| right.0.total_cmp(&left.0));
230    scored.into_iter().take(limit).map(|(_, entry)| entry).collect()
231}
232
233fn entry_text(entry: &CompressedEntry) -> HashMap<String, usize> {
234    let mut counts = HashMap::new();
235    let mut parts = vec![entry.summary.clone()];
236    parts.push(entry.keywords.join(" "));
237    parts.extend(entry.code_snippets.clone());
238
239    for token in tokenize(&parts.join(" ")) {
240        *counts.entry(token).or_insert(0) += 1;
241    }
242
243    counts
244}
245
246fn tokenize(text: &str) -> Vec<String> {
247    text.split(|character: char| !character.is_alphanumeric() && character != '_' && character != '.')
248        .map(str::trim)
249        .filter(|token| !token.is_empty())
250        .map(|token| token.to_lowercase())
251        .collect()
252}
253
254#[cfg(test)]
255mod tests {
256    use super::{LongTermStore, longterm_path};
257    use crate::compressor::CompressedEntry;
258    use chrono::Local;
259    use std::path::PathBuf;
260    use std::time::{SystemTime, UNIX_EPOCH};
261
262    fn unique_temp_dir(name: &str) -> PathBuf {
263        let nanos = SystemTime::now()
264            .duration_since(UNIX_EPOCH)
265            .unwrap_or_default()
266            .as_nanos();
267        std::env::temp_dir().join(format!("forja_longterm_{name}_{nanos}"))
268    }
269
270    #[tokio::test]
271    async fn long_term_store_adds_and_searches_entries() {
272        let base_dir = unique_temp_dir("search");
273        let store = LongTermStore::new(longterm_path(&base_dir, None)).await.unwrap();
274        let deploy_entry = CompressedEntry {
275            timestamp: Local::now(),
276            summary: "Deploy completed with vercel".to_string(),
277            keywords: vec!["deploy".to_string(), "vercel".to_string()],
278            original_count: 4,
279            code_snippets: vec!["deploy.sh".to_string()],
280        };
281        let review_entry = CompressedEntry {
282            timestamp: Local::now(),
283            summary: "Code review covered auth.rs".to_string(),
284            keywords: vec!["review".to_string(), "auth.rs".to_string()],
285            original_count: 3,
286            code_snippets: vec!["auth.rs".to_string()],
287        };
288
289        store.add(&deploy_entry).await.unwrap();
290        store.add(&review_entry).await.unwrap();
291
292        let results = store.search("deploy vercel", 1).await.unwrap();
293
294        assert_eq!(results.len(), 1);
295        assert_eq!(results[0].summary, deploy_entry.summary);
296        assert_eq!(store.entry_count().await.unwrap(), 2);
297    }
298}