1use async_trait::async_trait;
7use rain_engine_core::{
8 MemoryStore, RelationshipEdge, RetrievalError, RetrievalStore, RetrievedItem,
9 RetrievedItemKind, SessionRecord, WorkingSet,
10};
11use std::collections::{HashSet, VecDeque};
12use std::sync::Arc;
13
14#[derive(Clone)]
15pub struct SessionRetrievalStore {
16 memory: Arc<dyn MemoryStore>,
17}
18
19impl SessionRetrievalStore {
20 pub fn new(memory: Arc<dyn MemoryStore>) -> Self {
21 Self { memory }
22 }
23}
24
25#[async_trait]
26impl RetrievalStore for SessionRetrievalStore {
27 async fn exact_replay(
28 &self,
29 session_id: &str,
30 limit: usize,
31 ) -> Result<Vec<SessionRecord>, RetrievalError> {
32 let snapshot = self
33 .memory
34 .load_session(session_id)
35 .await
36 .map_err(|err| RetrievalError::new(err.message))?;
37 let take = limit.max(1);
38 let len = snapshot.records.len();
39 Ok(snapshot
40 .records
41 .into_iter()
42 .skip(len.saturating_sub(take))
43 .collect())
44 }
45
46 async fn semantic_search(
47 &self,
48 session_id: &str,
49 query: &str,
50 limit: usize,
51 ) -> Result<Vec<RetrievedItem>, RetrievalError> {
52 let snapshot = self
53 .memory
54 .load_session(session_id)
55 .await
56 .map_err(|err| RetrievalError::new(err.message))?;
57 let query = query.to_lowercase();
58 let state = snapshot.agent_state();
59 let mut hits = Vec::new();
60
61 for observation in state.observations {
62 let serialized = serde_json::to_string(&observation.content).unwrap_or_default();
63 if serialized.to_lowercase().contains(&query)
64 || observation.source.to_lowercase().contains(&query)
65 {
66 hits.push(RetrievedItem {
67 kind: RetrievedItemKind::Observation,
68 key: observation.observation_id.0,
69 score: 1.0,
70 snippet: serialized,
71 });
72 }
73 }
74 for task in state.tasks {
75 if text_match(&query, &task.title, task.detail.as_deref()) {
76 hits.push(RetrievedItem {
77 kind: RetrievedItemKind::Task,
78 key: task.task_id.0,
79 score: 0.9,
80 snippet: task.detail.unwrap_or(task.title),
81 });
82 }
83 }
84 for goal in state.goals {
85 if text_match(&query, &goal.title, goal.detail.as_deref()) {
86 hits.push(RetrievedItem {
87 kind: RetrievedItemKind::Goal,
88 key: goal.goal_id.0,
89 score: 0.8,
90 snippet: goal.detail.unwrap_or(goal.title),
91 });
92 }
93 }
94
95 hits.truncate(limit.max(1));
96 Ok(hits)
97 }
98
99 async fn graph_neighbors(
100 &self,
101 session_id: &str,
102 resource_id: &str,
103 max_hops: usize,
104 ) -> Result<Vec<RelationshipEdge>, RetrievalError> {
105 let snapshot = self
106 .memory
107 .load_session(session_id)
108 .await
109 .map_err(|err| RetrievalError::new(err.message))?;
110 let state = snapshot.agent_state();
111 let mut frontier = VecDeque::from([(resource_id.to_string(), 0usize)]);
112 let mut seen = HashSet::from([resource_id.to_string()]);
113 let mut seen_edges = HashSet::<(String, String, String)>::new();
114 let mut edges = Vec::new();
115
116 while let Some((node, depth)) = frontier.pop_front() {
117 if depth >= max_hops {
118 continue;
119 }
120 for edge in &state.relationships {
121 if edge.from_resource_id == node || edge.to_resource_id == node {
122 let edge_key = (
123 edge.from_resource_id.clone(),
124 edge.to_resource_id.clone(),
125 edge.relation.clone(),
126 );
127 if seen_edges.insert(edge_key) {
128 edges.push(edge.clone());
129 }
130 for next in [&edge.from_resource_id, &edge.to_resource_id] {
131 if seen.insert(next.clone()) {
132 frontier.push_back((next.clone(), depth + 1));
133 }
134 }
135 }
136 }
137 }
138
139 Ok(edges)
140 }
141
142 async fn recent_working_set(
143 &self,
144 session_id: &str,
145 limit: usize,
146 ) -> Result<WorkingSet, RetrievalError> {
147 let snapshot = self
148 .memory
149 .load_session(session_id)
150 .await
151 .map_err(|err| RetrievalError::new(err.message))?;
152 let state = snapshot.agent_state();
153 Ok(WorkingSet {
154 observations: take_tail(state.observations, limit),
155 tasks: take_tail(state.tasks, limit),
156 goals: take_tail(state.goals, limit),
157 })
158 }
159}
160
161fn text_match(query: &str, title: &str, detail: Option<&str>) -> bool {
162 title.to_lowercase().contains(query)
163 || detail
164 .map(|detail| detail.to_lowercase().contains(query))
165 .unwrap_or(false)
166}
167
168fn take_tail<T>(items: Vec<T>, limit: usize) -> Vec<T> {
169 let take = limit.max(1);
170 let len = items.len();
171 items.into_iter().skip(len.saturating_sub(take)).collect()
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use rain_engine_core::{
178 AgentTrigger, InMemoryMemoryStore, KernelEvent, KernelEventRecord, MemoryStoreExt,
179 ObservationId, ObservationRecord, RelationshipEdge, ResourceRef,
180 };
181
182 #[tokio::test]
183 async fn semantic_search_finds_observations() {
184 let memory = Arc::new(InMemoryMemoryStore::new());
185 memory
186 .append_trigger(rain_engine_core::TriggerRecord {
187 trigger_id: "t1".to_string(),
188 session_id: "s1".to_string(),
189 idempotency_key: None,
190 recorded_at: std::time::SystemTime::now(),
191 trigger: AgentTrigger::Message {
192 user_id: "u1".to_string(),
193 content: "hello".to_string(),
194 attachments: Vec::new(),
195 },
196 intent: None,
197 })
198 .await
199 .expect("trigger");
200 memory
201 .append_kernel_event(
202 "s1",
203 KernelEventRecord {
204 event_id: "e1".to_string(),
205 occurred_at: std::time::SystemTime::now(),
206 event: KernelEvent::ObservationAppended(ObservationRecord {
207 observation_id: ObservationId("obs-1".to_string()),
208 recorded_at: std::time::SystemTime::now(),
209 source: "webhook".to_string(),
210 content: serde_json::json!({"text": "database schema mismatch"}),
211 attachment_ids: Vec::new(),
212 related_resources: Vec::new(),
213 }),
214 },
215 )
216 .await
217 .expect("event");
218
219 let store = SessionRetrievalStore::new(memory);
220 let hits = store
221 .semantic_search("s1", "schema", 5)
222 .await
223 .expect("hits");
224 assert_eq!(hits.len(), 1);
225 }
226
227 #[tokio::test]
228 async fn graph_neighbors_follow_relationships() {
229 let memory = Arc::new(InMemoryMemoryStore::new());
230 memory
231 .append_kernel_event(
232 "s2",
233 KernelEventRecord {
234 event_id: "e2".to_string(),
235 occurred_at: std::time::SystemTime::now(),
236 event: KernelEvent::ResourceRegistered(ResourceRef {
237 resource_id: "repo".to_string(),
238 resource_type: "repo".to_string(),
239 label: "repo".to_string(),
240 external_ref: None,
241 }),
242 },
243 )
244 .await
245 .expect("resource");
246 memory
247 .append_kernel_event(
248 "s2",
249 KernelEventRecord {
250 event_id: "e3".to_string(),
251 occurred_at: std::time::SystemTime::now(),
252 event: KernelEvent::RelationshipObserved(RelationshipEdge {
253 from_resource_id: "repo".to_string(),
254 to_resource_id: "ticket".to_string(),
255 relation: "tracks".to_string(),
256 observed_at: std::time::SystemTime::now(),
257 }),
258 },
259 )
260 .await
261 .expect("relationship");
262
263 let store = SessionRetrievalStore::new(memory);
264 let edges = store.graph_neighbors("s2", "repo", 2).await.expect("edges");
265 assert_eq!(edges.len(), 1);
266 }
267}