Skip to main content

rain_engine_core/
memory.rs

1use crate::{
2    ApprovalResolutionRecord, CoordinationClaimRecord, DelegationRecord, DeliberationRecord,
3    EngineOutcome, ExecutionPlanRecord, KernelEventRecord, ModelDecisionRecord, NewSessionRecord,
4    OutcomeRecord, PendingApprovalRecord, PolicyTuningRecord, ProfilePatchRecord,
5    ProviderCacheRecord, ProviderUsageRecord, RecordPage, RecordPageQuery, ReflectionRecord,
6    SessionListQuery, SessionRecord, SessionSnapshot, SessionSummary, SkillInputValidationRecord,
7    StoredSessionRecord, StrategyPreferenceRecord, SummaryRecord, ToolCallRecord,
8    ToolExecutionGraph, ToolNodeCheckpointRecord, ToolPerformanceRecord, ToolResultRecord,
9    TriggerIntentRecord, TriggerRecord,
10};
11use async_trait::async_trait;
12use std::collections::HashMap;
13use std::sync::Arc;
14use thiserror::Error;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Error, Clone, PartialEq, Eq)]
18#[error("{message}")]
19pub struct MemoryError {
20    pub message: String,
21}
22
23impl MemoryError {
24    pub fn new(message: impl Into<String>) -> Self {
25        Self {
26            message: message.into(),
27        }
28    }
29}
30
31#[async_trait]
32pub trait MemoryStore: Send + Sync {
33    async fn append_record(
34        &self,
35        record: NewSessionRecord,
36    ) -> Result<StoredSessionRecord, MemoryError>;
37
38    async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError>;
39
40    async fn list_sessions(
41        &self,
42        query: SessionListQuery,
43    ) -> Result<Vec<SessionSummary>, MemoryError>;
44
45    async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError>;
46
47    async fn find_outcome_by_idempotency_key(
48        &self,
49        session_id: &str,
50        idempotency_key: &str,
51    ) -> Result<Option<EngineOutcome>, MemoryError>;
52
53    async fn find_pending_approval_by_resume_token(
54        &self,
55        session_id: &str,
56        resume_token: &str,
57    ) -> Result<Option<PendingApprovalRecord>, MemoryError>;
58}
59
60#[async_trait]
61pub trait MemoryStoreExt: MemoryStore {
62    async fn append_trigger(
63        &self,
64        record: TriggerRecord,
65    ) -> Result<StoredSessionRecord, MemoryError> {
66        self.append_record(NewSessionRecord::from_record(
67            record.session_id.clone(),
68            SessionRecord::Trigger(record),
69        ))
70        .await
71    }
72
73    async fn append_model_decision(
74        &self,
75        session_id: &str,
76        record: ModelDecisionRecord,
77    ) -> Result<StoredSessionRecord, MemoryError> {
78        self.append_record(NewSessionRecord::from_record(
79            session_id.to_string(),
80            SessionRecord::ModelDecision(record),
81        ))
82        .await
83    }
84
85    async fn append_trigger_intent(
86        &self,
87        session_id: &str,
88        record: TriggerIntentRecord,
89    ) -> Result<StoredSessionRecord, MemoryError> {
90        self.append_record(NewSessionRecord::from_record(
91            session_id.to_string(),
92            SessionRecord::TriggerIntent(record),
93        ))
94        .await
95    }
96
97    async fn append_deliberation(
98        &self,
99        session_id: &str,
100        record: DeliberationRecord,
101    ) -> Result<StoredSessionRecord, MemoryError> {
102        self.append_record(NewSessionRecord::from_record(
103            session_id.to_string(),
104            SessionRecord::Deliberation(record),
105        ))
106        .await
107    }
108
109    async fn append_tool_execution_graph(
110        &self,
111        session_id: &str,
112        record: ToolExecutionGraph,
113    ) -> Result<StoredSessionRecord, MemoryError> {
114        self.append_record(NewSessionRecord::from_record(
115            session_id.to_string(),
116            SessionRecord::ToolExecutionGraph(record),
117        ))
118        .await
119    }
120
121    async fn append_execution_plan(
122        &self,
123        session_id: &str,
124        record: ExecutionPlanRecord,
125    ) -> Result<StoredSessionRecord, MemoryError> {
126        self.append_record(NewSessionRecord::from_record(
127            session_id.to_string(),
128            SessionRecord::ExecutionPlan(record),
129        ))
130        .await
131    }
132
133    async fn append_summary(
134        &self,
135        session_id: &str,
136        record: SummaryRecord,
137    ) -> Result<StoredSessionRecord, MemoryError> {
138        self.append_record(NewSessionRecord::from_record(
139            session_id.to_string(),
140            SessionRecord::Summary(record),
141        ))
142        .await
143    }
144
145    async fn append_tool_node_checkpoint(
146        &self,
147        session_id: &str,
148        record: ToolNodeCheckpointRecord,
149    ) -> Result<StoredSessionRecord, MemoryError> {
150        self.append_record(NewSessionRecord::from_record(
151            session_id.to_string(),
152            SessionRecord::ToolNodeCheckpoint(record),
153        ))
154        .await
155    }
156
157    async fn append_skill_input_validation(
158        &self,
159        session_id: &str,
160        record: SkillInputValidationRecord,
161    ) -> Result<StoredSessionRecord, MemoryError> {
162        self.append_record(NewSessionRecord::from_record(
163            session_id.to_string(),
164            SessionRecord::SkillInputValidation(record),
165        ))
166        .await
167    }
168
169    async fn append_kernel_event(
170        &self,
171        session_id: &str,
172        record: KernelEventRecord,
173    ) -> Result<StoredSessionRecord, MemoryError> {
174        self.append_record(NewSessionRecord::from_record(
175            session_id.to_string(),
176            SessionRecord::KernelEvent(record),
177        ))
178        .await
179    }
180
181    async fn append_tool_call(
182        &self,
183        session_id: &str,
184        record: ToolCallRecord,
185    ) -> Result<StoredSessionRecord, MemoryError> {
186        self.append_record(NewSessionRecord::from_record(
187            session_id.to_string(),
188            SessionRecord::ToolCall(record),
189        ))
190        .await
191    }
192
193    async fn append_tool_result(
194        &self,
195        session_id: &str,
196        record: ToolResultRecord,
197    ) -> Result<StoredSessionRecord, MemoryError> {
198        self.append_record(NewSessionRecord::from_record(
199            session_id.to_string(),
200            SessionRecord::ToolResult(record),
201        ))
202        .await
203    }
204
205    async fn append_pending_approval(
206        &self,
207        session_id: &str,
208        record: PendingApprovalRecord,
209    ) -> Result<StoredSessionRecord, MemoryError> {
210        self.append_record(NewSessionRecord::from_record(
211            session_id.to_string(),
212            SessionRecord::PendingApproval(record),
213        ))
214        .await
215    }
216
217    async fn append_approval_resolution(
218        &self,
219        session_id: &str,
220        record: ApprovalResolutionRecord,
221    ) -> Result<StoredSessionRecord, MemoryError> {
222        self.append_record(NewSessionRecord::from_record(
223            session_id.to_string(),
224            SessionRecord::ApprovalResolution(record),
225        ))
226        .await
227    }
228
229    async fn append_delegation(
230        &self,
231        session_id: &str,
232        record: DelegationRecord,
233    ) -> Result<StoredSessionRecord, MemoryError> {
234        self.append_record(NewSessionRecord::from_record(
235            session_id.to_string(),
236            SessionRecord::Delegation(record),
237        ))
238        .await
239    }
240
241    async fn append_coordination_claim(
242        &self,
243        session_id: &str,
244        record: CoordinationClaimRecord,
245    ) -> Result<StoredSessionRecord, MemoryError> {
246        self.append_record(NewSessionRecord::from_record(
247            session_id.to_string(),
248            SessionRecord::CoordinationClaim(record),
249        ))
250        .await
251    }
252
253    async fn append_provider_usage(
254        &self,
255        session_id: &str,
256        record: ProviderUsageRecord,
257    ) -> Result<StoredSessionRecord, MemoryError> {
258        self.append_record(NewSessionRecord::from_record(
259            session_id.to_string(),
260            SessionRecord::ProviderUsage(record),
261        ))
262        .await
263    }
264
265    async fn append_provider_cache(
266        &self,
267        session_id: &str,
268        record: ProviderCacheRecord,
269    ) -> Result<StoredSessionRecord, MemoryError> {
270        self.append_record(NewSessionRecord::from_record(
271            session_id.to_string(),
272            SessionRecord::ProviderCache(record),
273        ))
274        .await
275    }
276
277    async fn append_reflection(
278        &self,
279        session_id: &str,
280        record: ReflectionRecord,
281    ) -> Result<StoredSessionRecord, MemoryError> {
282        self.append_record(NewSessionRecord::from_record(
283            session_id.to_string(),
284            SessionRecord::Reflection(record),
285        ))
286        .await
287    }
288
289    async fn append_policy_tuning(
290        &self,
291        session_id: &str,
292        record: PolicyTuningRecord,
293    ) -> Result<StoredSessionRecord, MemoryError> {
294        self.append_record(NewSessionRecord::from_record(
295            session_id.to_string(),
296            SessionRecord::PolicyTuning(record),
297        ))
298        .await
299    }
300
301    async fn append_strategy_preference(
302        &self,
303        session_id: &str,
304        record: StrategyPreferenceRecord,
305    ) -> Result<StoredSessionRecord, MemoryError> {
306        self.append_record(NewSessionRecord::from_record(
307            session_id.to_string(),
308            SessionRecord::StrategyPreference(record),
309        ))
310        .await
311    }
312
313    async fn append_tool_performance(
314        &self,
315        session_id: &str,
316        record: ToolPerformanceRecord,
317    ) -> Result<StoredSessionRecord, MemoryError> {
318        self.append_record(NewSessionRecord::from_record(
319            session_id.to_string(),
320            SessionRecord::ToolPerformance(record),
321        ))
322        .await
323    }
324
325    async fn append_profile_patch(
326        &self,
327        session_id: &str,
328        record: ProfilePatchRecord,
329    ) -> Result<StoredSessionRecord, MemoryError> {
330        self.append_record(NewSessionRecord::from_record(
331            session_id.to_string(),
332            SessionRecord::ProfilePatch(record),
333        ))
334        .await
335    }
336
337    async fn append_outcome(
338        &self,
339        session_id: &str,
340        record: OutcomeRecord,
341    ) -> Result<StoredSessionRecord, MemoryError> {
342        self.append_record(NewSessionRecord::from_record(
343            session_id.to_string(),
344            SessionRecord::Outcome(record),
345        ))
346        .await
347    }
348}
349
350impl<T> MemoryStoreExt for T where T: MemoryStore + ?Sized {}
351
352#[derive(Debug, Default, Clone)]
353pub struct InMemoryMemoryStore {
354    inner: Arc<RwLock<HashMap<String, Vec<StoredSessionRecord>>>>,
355    next_sequence_no: Arc<RwLock<i64>>,
356}
357
358impl InMemoryMemoryStore {
359    pub fn new() -> Self {
360        Self::default()
361    }
362}
363
364#[async_trait]
365impl MemoryStore for InMemoryMemoryStore {
366    async fn append_record(
367        &self,
368        record: NewSessionRecord,
369    ) -> Result<StoredSessionRecord, MemoryError> {
370        let mut sequence_guard = self.next_sequence_no.write().await;
371        *sequence_guard += 1;
372        let stored = StoredSessionRecord {
373            session_id: record.session_id.clone(),
374            sequence_no: *sequence_guard,
375            occurred_at_ms: record.occurred_at_ms,
376            record_kind: record.record_kind,
377            trigger_id: record.trigger_id,
378            idempotency_key: record.idempotency_key,
379            record: record.record,
380        };
381        drop(sequence_guard);
382
383        let mut guard = self.inner.write().await;
384        guard
385            .entry(stored.session_id.clone())
386            .or_default()
387            .push(stored.clone());
388        Ok(stored)
389    }
390
391    async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
392        let guard = self.inner.read().await;
393        let records = guard.get(session_id).cloned().unwrap_or_default();
394        let latest_outcome = records
395            .iter()
396            .rev()
397            .find_map(|stored| match &stored.record {
398                SessionRecord::Outcome(outcome) => Some(outcome.clone()),
399                _ => None,
400            });
401
402        Ok(SessionSnapshot {
403            session_id: session_id.to_string(),
404            last_sequence_no: records.last().map(|record| record.sequence_no),
405            latest_outcome,
406            records: records.into_iter().map(|record| record.record).collect(),
407        })
408    }
409
410    async fn list_sessions(
411        &self,
412        query: SessionListQuery,
413    ) -> Result<Vec<SessionSummary>, MemoryError> {
414        let guard = self.inner.read().await;
415        let mut sessions = guard
416            .iter()
417            .filter_map(|(session_id, records)| {
418                let mut filtered = records.iter().filter(|record| {
419                    query
420                        .since_ms
421                        .is_none_or(|since_ms| record.occurred_at_ms >= since_ms)
422                        && query
423                            .until_ms
424                            .is_none_or(|until_ms| record.occurred_at_ms <= until_ms)
425                });
426                let first = filtered.next()?;
427                let mut last = first;
428                let mut count = 1usize;
429                for record in filtered {
430                    last = record;
431                    count += 1;
432                }
433                Some(SessionSummary {
434                    session_id: session_id.clone(),
435                    first_recorded_at_ms: first.occurred_at_ms,
436                    last_recorded_at_ms: last.occurred_at_ms,
437                    record_count: count,
438                })
439            })
440            .collect::<Vec<_>>();
441        sessions.sort_by(|left, right| left.session_id.cmp(&right.session_id));
442        Ok(sessions
443            .into_iter()
444            .skip(query.offset)
445            .take(query.limit)
446            .collect())
447    }
448
449    async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
450        let guard = self.inner.read().await;
451        let all = guard.get(&query.session_id).cloned().unwrap_or_default();
452        let filtered = all
453            .into_iter()
454            .filter(|record| {
455                query
456                    .since_ms
457                    .is_none_or(|since_ms| record.occurred_at_ms >= since_ms)
458                    && query
459                        .until_ms
460                        .is_none_or(|until_ms| record.occurred_at_ms <= until_ms)
461            })
462            .collect::<Vec<_>>();
463        let total = filtered.len();
464        let records = filtered
465            .into_iter()
466            .skip(query.offset)
467            .take(query.limit)
468            .collect::<Vec<_>>();
469
470        Ok(RecordPage {
471            session_id: query.session_id,
472            next_offset: (query.offset + records.len() < total)
473                .then_some(query.offset + records.len()),
474            records,
475        })
476    }
477
478    async fn find_outcome_by_idempotency_key(
479        &self,
480        session_id: &str,
481        idempotency_key: &str,
482    ) -> Result<Option<EngineOutcome>, MemoryError> {
483        let guard = self.inner.read().await;
484        Ok(guard
485            .get(session_id)
486            .into_iter()
487            .flat_map(|records| records.iter().rev())
488            .find_map(|stored| match &stored.record {
489                SessionRecord::Outcome(outcome)
490                    if outcome.idempotency_key.as_deref() == Some(idempotency_key) =>
491                {
492                    Some(EngineOutcome::from_record(outcome.clone()))
493                }
494                _ => None,
495            }))
496    }
497
498    async fn find_pending_approval_by_resume_token(
499        &self,
500        session_id: &str,
501        resume_token: &str,
502    ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
503        let guard = self.inner.read().await;
504        let records = guard.get(session_id).cloned().unwrap_or_default();
505        let mut pending = None::<PendingApprovalRecord>;
506        for stored in records {
507            match stored.record {
508                SessionRecord::PendingApproval(record)
509                    if record.resume_token.as_str() == resume_token =>
510                {
511                    pending = Some(record);
512                }
513                SessionRecord::ApprovalResolution(record)
514                    if record.resume_token.as_str() == resume_token =>
515                {
516                    pending = None;
517                }
518                _ => {}
519            }
520        }
521        Ok(pending)
522    }
523}