Skip to main content

rain_engine_memory/
lib.rs

1//! Ledger-backed retrieval utilities for RainEngine.
2//!
3//! This crate provides exact replay, recent working sets, graph-neighbor
4//! traversal, and simple semantic search over projected state.
5
6use async_trait::async_trait;
7use rain_engine_core::{
8    MemoryStore, RelationshipEdge, RetrievalError, RetrievalStore, RetrievedItem,
9    RetrievedItemKind, SessionRecord, WorkingSet,
10};
11use std::collections::{HashSet, VecDeque};
12use std::sync::Arc;
13
14#[derive(Clone)]
15pub struct SessionRetrievalStore {
16    memory: Arc<dyn MemoryStore>,
17}
18
19impl SessionRetrievalStore {
20    pub fn new(memory: Arc<dyn MemoryStore>) -> Self {
21        Self { memory }
22    }
23}
24
25#[async_trait]
26impl RetrievalStore for SessionRetrievalStore {
27    async fn exact_replay(
28        &self,
29        session_id: &str,
30        limit: usize,
31    ) -> Result<Vec<SessionRecord>, RetrievalError> {
32        let snapshot = self
33            .memory
34            .load_session(session_id)
35            .await
36            .map_err(|err| RetrievalError::new(err.message))?;
37        let take = limit.max(1);
38        let len = snapshot.records.len();
39        Ok(snapshot
40            .records
41            .into_iter()
42            .skip(len.saturating_sub(take))
43            .collect())
44    }
45
46    async fn semantic_search(
47        &self,
48        session_id: &str,
49        query: &str,
50        limit: usize,
51    ) -> Result<Vec<RetrievedItem>, RetrievalError> {
52        let snapshot = self
53            .memory
54            .load_session(session_id)
55            .await
56            .map_err(|err| RetrievalError::new(err.message))?;
57        let query = query.to_lowercase();
58        let state = snapshot.agent_state();
59        let mut hits = Vec::new();
60
61        for observation in state.observations {
62            let serialized = serde_json::to_string(&observation.content).unwrap_or_default();
63            if serialized.to_lowercase().contains(&query)
64                || observation.source.to_lowercase().contains(&query)
65            {
66                hits.push(RetrievedItem {
67                    kind: RetrievedItemKind::Observation,
68                    key: observation.observation_id.0,
69                    score: 1.0,
70                    snippet: serialized,
71                });
72            }
73        }
74        for task in state.tasks {
75            if text_match(&query, &task.title, task.detail.as_deref()) {
76                hits.push(RetrievedItem {
77                    kind: RetrievedItemKind::Task,
78                    key: task.task_id.0,
79                    score: 0.9,
80                    snippet: task.detail.unwrap_or(task.title),
81                });
82            }
83        }
84        for goal in state.goals {
85            if text_match(&query, &goal.title, goal.detail.as_deref()) {
86                hits.push(RetrievedItem {
87                    kind: RetrievedItemKind::Goal,
88                    key: goal.goal_id.0,
89                    score: 0.8,
90                    snippet: goal.detail.unwrap_or(goal.title),
91                });
92            }
93        }
94
95        hits.truncate(limit.max(1));
96        Ok(hits)
97    }
98
99    async fn graph_neighbors(
100        &self,
101        session_id: &str,
102        resource_id: &str,
103        max_hops: usize,
104    ) -> Result<Vec<RelationshipEdge>, RetrievalError> {
105        let snapshot = self
106            .memory
107            .load_session(session_id)
108            .await
109            .map_err(|err| RetrievalError::new(err.message))?;
110        let state = snapshot.agent_state();
111        let mut frontier = VecDeque::from([(resource_id.to_string(), 0usize)]);
112        let mut seen = HashSet::from([resource_id.to_string()]);
113        let mut seen_edges = HashSet::<(String, String, String)>::new();
114        let mut edges = Vec::new();
115
116        while let Some((node, depth)) = frontier.pop_front() {
117            if depth >= max_hops {
118                continue;
119            }
120            for edge in &state.relationships {
121                if edge.from_resource_id == node || edge.to_resource_id == node {
122                    let edge_key = (
123                        edge.from_resource_id.clone(),
124                        edge.to_resource_id.clone(),
125                        edge.relation.clone(),
126                    );
127                    if seen_edges.insert(edge_key) {
128                        edges.push(edge.clone());
129                    }
130                    for next in [&edge.from_resource_id, &edge.to_resource_id] {
131                        if seen.insert(next.clone()) {
132                            frontier.push_back((next.clone(), depth + 1));
133                        }
134                    }
135                }
136            }
137        }
138
139        Ok(edges)
140    }
141
142    async fn recent_working_set(
143        &self,
144        session_id: &str,
145        limit: usize,
146    ) -> Result<WorkingSet, RetrievalError> {
147        let snapshot = self
148            .memory
149            .load_session(session_id)
150            .await
151            .map_err(|err| RetrievalError::new(err.message))?;
152        let state = snapshot.agent_state();
153        Ok(WorkingSet {
154            observations: take_tail(state.observations, limit),
155            tasks: take_tail(state.tasks, limit),
156            goals: take_tail(state.goals, limit),
157        })
158    }
159}
160
161fn text_match(query: &str, title: &str, detail: Option<&str>) -> bool {
162    title.to_lowercase().contains(query)
163        || detail
164            .map(|detail| detail.to_lowercase().contains(query))
165            .unwrap_or(false)
166}
167
168fn take_tail<T>(items: Vec<T>, limit: usize) -> Vec<T> {
169    let take = limit.max(1);
170    let len = items.len();
171    items.into_iter().skip(len.saturating_sub(take)).collect()
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use rain_engine_core::{
178        AgentTrigger, InMemoryMemoryStore, KernelEvent, KernelEventRecord, MemoryStoreExt,
179        ObservationId, ObservationRecord, RelationshipEdge, ResourceRef,
180    };
181
182    #[tokio::test]
183    async fn semantic_search_finds_observations() {
184        let memory = Arc::new(InMemoryMemoryStore::new());
185        memory
186            .append_trigger(rain_engine_core::TriggerRecord {
187                trigger_id: "t1".to_string(),
188                session_id: "s1".to_string(),
189                idempotency_key: None,
190                recorded_at: std::time::SystemTime::now(),
191                trigger: AgentTrigger::Message {
192                    user_id: "u1".to_string(),
193                    content: "hello".to_string(),
194                    attachments: Vec::new(),
195                },
196                intent: None,
197            })
198            .await
199            .expect("trigger");
200        memory
201            .append_kernel_event(
202                "s1",
203                KernelEventRecord {
204                    event_id: "e1".to_string(),
205                    occurred_at: std::time::SystemTime::now(),
206                    event: KernelEvent::ObservationAppended(ObservationRecord {
207                        observation_id: ObservationId("obs-1".to_string()),
208                        recorded_at: std::time::SystemTime::now(),
209                        source: "webhook".to_string(),
210                        content: serde_json::json!({"text": "database schema mismatch"}),
211                        attachment_ids: Vec::new(),
212                        related_resources: Vec::new(),
213                    }),
214                },
215            )
216            .await
217            .expect("event");
218
219        let store = SessionRetrievalStore::new(memory);
220        let hits = store
221            .semantic_search("s1", "schema", 5)
222            .await
223            .expect("hits");
224        assert_eq!(hits.len(), 1);
225    }
226
227    #[tokio::test]
228    async fn graph_neighbors_follow_relationships() {
229        let memory = Arc::new(InMemoryMemoryStore::new());
230        memory
231            .append_kernel_event(
232                "s2",
233                KernelEventRecord {
234                    event_id: "e2".to_string(),
235                    occurred_at: std::time::SystemTime::now(),
236                    event: KernelEvent::ResourceRegistered(ResourceRef {
237                        resource_id: "repo".to_string(),
238                        resource_type: "repo".to_string(),
239                        label: "repo".to_string(),
240                        external_ref: None,
241                    }),
242                },
243            )
244            .await
245            .expect("resource");
246        memory
247            .append_kernel_event(
248                "s2",
249                KernelEventRecord {
250                    event_id: "e3".to_string(),
251                    occurred_at: std::time::SystemTime::now(),
252                    event: KernelEvent::RelationshipObserved(RelationshipEdge {
253                        from_resource_id: "repo".to_string(),
254                        to_resource_id: "ticket".to_string(),
255                        relation: "tracks".to_string(),
256                        observed_at: std::time::SystemTime::now(),
257                    }),
258                },
259            )
260            .await
261            .expect("relationship");
262
263        let store = SessionRetrievalStore::new(memory);
264        let edges = store.graph_neighbors("s2", "repo", 2).await.expect("edges");
265        assert_eq!(edges.len(), 1);
266    }
267}