Skip to main content

shuttle_rs/
memory.rs

1use std::collections::HashSet;
2
3use crate::core::{Event, EventFilter, EventStore, EventType, NewEvent, Result};
4use serde_json::json;
5
6#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
7pub struct RecallResult {
8    pub event: Event,
9    pub score: i64,
10    pub reasons: Vec<String>,
11}
12
13pub fn new_memory(
14    workspace_id: String,
15    agent: String,
16    session_id: String,
17    content: String,
18) -> Event {
19    new_typed_memory(EventType::Memory, workspace_id, agent, session_id, content)
20}
21
22pub fn new_typed_memory(
23    event_type: EventType,
24    workspace_id: String,
25    agent: String,
26    session_id: String,
27    content: String,
28) -> Event {
29    Event::new(NewEvent {
30        event_type,
31        workspace_id,
32        repo_id: None,
33        repo_path: None,
34        git_remote: None,
35        bit_repo_id: None,
36        branch: None,
37        commit: None,
38        repo_dirty: None,
39        agent,
40        session_id,
41        title: title_for(event_type),
42        content,
43        tags: Vec::new(),
44        metadata_json: json!({ "kind": event_type.as_str() }),
45    })
46}
47
48pub async fn memories(store: &impl EventStore) -> Result<Vec<Event>> {
49    store
50        .list(EventFilter {
51            event_type: Some(EventType::Memory),
52            ..EventFilter::default()
53        })
54        .await
55}
56
57pub async fn recall(store: &impl EventStore, query: &str) -> Result<Vec<Event>> {
58    recall_by_type(store, query, Some(EventType::Memory), None).await
59}
60
61pub async fn recall_by_type(
62    store: &impl EventStore,
63    query: &str,
64    event_type: Option<EventType>,
65    workspace_id: Option<&str>,
66) -> Result<Vec<Event>> {
67    if let Some(event_type) = event_type {
68        return recall_candidates_for_type(store, query, event_type, workspace_id).await;
69    }
70
71    let mut events = Vec::new();
72    for event_type in memory_event_types() {
73        events.extend(recall_candidates_for_type(store, query, event_type, workspace_id).await?);
74    }
75    dedup_events(&mut events);
76    events.sort_by(|left, right| {
77        right
78            .created_at
79            .cmp(&left.created_at)
80            .then(left.id.cmp(&right.id))
81    });
82    events.truncate(50);
83    Ok(events)
84}
85
86async fn recall_candidates_for_type(
87    store: &impl EventStore,
88    query: &str,
89    event_type: EventType,
90    workspace_id: Option<&str>,
91) -> Result<Vec<Event>> {
92    let mut events = store
93        .list(EventFilter {
94            event_type: Some(event_type),
95            workspace_id: workspace_id.map(ToOwned::to_owned),
96            query: Some(query.to_owned()),
97            limit: Some(50),
98            ..EventFilter::default()
99        })
100        .await?;
101
102    let mut tokens = query
103        .split_whitespace()
104        .filter(|token| !token.is_empty())
105        .collect::<Vec<_>>();
106    tokens.sort_unstable();
107    tokens.dedup();
108    for token in tokens.into_iter().take(8) {
109        events.extend(
110            store
111                .list(EventFilter {
112                    event_type: Some(event_type),
113                    workspace_id: workspace_id.map(ToOwned::to_owned),
114                    query: Some(token.to_owned()),
115                    limit: Some(50),
116                    ..EventFilter::default()
117                })
118                .await?,
119        );
120    }
121
122    dedup_events(&mut events);
123    events.sort_by(|left, right| {
124        right
125            .created_at
126            .cmp(&left.created_at)
127            .then(left.id.cmp(&right.id))
128    });
129    events.truncate(50);
130    Ok(events)
131}
132
133fn dedup_events(events: &mut Vec<Event>) {
134    let mut seen = HashSet::new();
135    events.retain(|event| seen.insert(event.id));
136}
137
138pub async fn ranked_recall(
139    store: &impl EventStore,
140    query: &str,
141    event_type: Option<EventType>,
142    workspace_id: Option<&str>,
143    repo_id: Option<&str>,
144    branch: Option<&str>,
145) -> Result<Vec<RecallResult>> {
146    let events = recall_by_type(store, query, event_type, workspace_id).await?;
147    let mut results = events
148        .into_iter()
149        .map(|event| score_event(event, query, repo_id, branch))
150        .collect::<Vec<_>>();
151    results.sort_by(|left, right| {
152        right
153            .score
154            .cmp(&left.score)
155            .then(right.event.created_at.cmp(&left.event.created_at))
156            .then(left.event.id.cmp(&right.event.id))
157    });
158    Ok(results)
159}
160
161pub fn memory_event_types() -> Vec<EventType> {
162    vec![
163        EventType::Memory,
164        EventType::Decision,
165        EventType::Observation,
166        EventType::Pattern,
167        EventType::Fact,
168        EventType::Bug,
169        EventType::Handoff,
170    ]
171}
172
173fn title_for(event_type: EventType) -> Option<String> {
174    match event_type {
175        EventType::Memory => None,
176        EventType::Decision => Some("decision".to_owned()),
177        EventType::Observation => Some("observation".to_owned()),
178        EventType::Pattern => Some("pattern".to_owned()),
179        EventType::Fact => Some("fact".to_owned()),
180        EventType::Bug => Some("bug".to_owned()),
181        EventType::Handoff => Some("handoff".to_owned()),
182        _ => Some(event_type.as_str().to_owned()),
183    }
184}
185
186fn score_event(
187    event: Event,
188    query: &str,
189    repo_id: Option<&str>,
190    branch: Option<&str>,
191) -> RecallResult {
192    let query = query.to_lowercase();
193    let tokens = query
194        .split_whitespace()
195        .filter(|token| !token.is_empty())
196        .collect::<Vec<_>>();
197    let searchable = format!(
198        "{}\n{}\n{}\n{}",
199        event.title.as_deref().unwrap_or_default(),
200        event.content,
201        event.tags.join(" "),
202        event.metadata_json
203    )
204    .to_lowercase();
205    let mut score = 0;
206    let mut reasons = Vec::new();
207
208    let exact_match = !query.is_empty() && searchable.contains(&query);
209    if exact_match {
210        score += 50;
211        reasons.push("exact text match".to_owned());
212    }
213    if !exact_match {
214        let token_matches = tokens
215            .iter()
216            .filter(|token| searchable.contains(**token))
217            .count();
218        if token_matches > 0 {
219            score += (token_matches as i64) * 10;
220            reasons.push(format!("{token_matches} token match(es)"));
221        }
222    }
223    if matches!(event.event_type, EventType::Decision) {
224        score += 8;
225        reasons.push("decision event".to_owned());
226    } else if event.event_type != EventType::Memory {
227        score += 4;
228        reasons.push(format!("typed {} event", event.event_type.as_str()));
229    }
230    if let (Some(current), Some(event_repo)) = (repo_id, event.repo_id.as_deref()) {
231        if current == event_repo {
232            score += 12;
233            reasons.push("same repo".to_owned());
234        }
235    }
236    if let (Some(current), Some(event_branch)) = (branch, event.branch.as_deref()) {
237        if current == event_branch {
238            score += 6;
239            reasons.push("same branch".to_owned());
240        }
241    }
242
243    RecallResult {
244        event,
245        score,
246        reasons,
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn memory_is_an_event() {
256        let event = new_memory(
257            "workspace".into(),
258            "codex".into(),
259            "session".into(),
260            "SQLite chosen".into(),
261        );
262        assert_eq!(event.event_type, EventType::Memory);
263        assert_eq!(event.content, "SQLite chosen");
264    }
265
266    #[test]
267    fn typed_memory_uses_event_type_and_kind_metadata() {
268        let event = new_typed_memory(
269            EventType::Decision,
270            "workspace".into(),
271            "codex".into(),
272            "session".into(),
273            "SQLite chosen".into(),
274        );
275
276        assert_eq!(event.event_type, EventType::Decision);
277        assert_eq!(event.title.as_deref(), Some("decision"));
278        assert_eq!(event.metadata_json["kind"], "decision");
279    }
280
281    #[test]
282    fn ranking_prefers_same_repo_and_decisions() {
283        let mut decision = new_typed_memory(
284            EventType::Decision,
285            "workspace".into(),
286            "codex".into(),
287            "session".into(),
288            "SQLite storage decision".into(),
289        );
290        decision.repo_id = Some("repo".into());
291        decision.branch = Some("main".into());
292
293        let result = score_event(decision, "SQLite", Some("repo"), Some("main"));
294
295        assert!(result.score >= 76);
296        assert!(result.reasons.contains(&"decision event".to_owned()));
297        assert!(result.reasons.contains(&"same repo".to_owned()));
298        assert!(result.reasons.contains(&"same branch".to_owned()));
299    }
300}