1use crate::agent::Agent;
9use crate::harness::Harness;
10use crate::llm_models::LlmProviderType;
11use crate::session::Session;
12use crate::tool_types::{ToolCall, ToolDefinition, ToolResult};
13use crate::traits::ModelWithProvider;
14use crate::typed_id::{AgentId, EventId, HarnessId, MessageId, ModelId, SessionId};
15use async_trait::async_trait;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use crate::error::Result;
22use crate::message::Message;
23use crate::message_filter::MessageQuery;
24use crate::message_retriever::{InputMessage, MessageRetriever};
25use crate::traits::{AgentStore, HarnessStore, LlmProviderStore, SessionStore, ToolExecutor};
26use chrono::Utc;
27
28#[derive(Debug, Default, Clone)]
40pub struct InMemoryMessageRetriever {
41 messages: Arc<RwLock<HashMap<SessionId, Vec<Message>>>>,
42}
43
44impl InMemoryMessageRetriever {
45 pub fn new() -> Self {
47 Self {
48 messages: Arc::new(RwLock::new(HashMap::new())),
49 }
50 }
51
52 pub async fn sessions(&self) -> Vec<SessionId> {
54 self.messages.read().await.keys().copied().collect()
55 }
56
57 pub async fn clear(&self) {
59 self.messages.write().await.clear();
60 }
61
62 pub async fn clear_session(&self, session_id: SessionId) {
64 self.messages.write().await.remove(&session_id);
65 }
66
67 pub async fn seed(&self, session_id: SessionId, messages: Vec<Message>) {
69 self.messages.write().await.insert(session_id, messages);
70 }
71
72 pub async fn add(&self, session_id: SessionId, input: InputMessage) -> Result<Message> {
77 let message = Message {
78 id: MessageId::new(),
79 role: input.role,
80 content: input.content,
81 phase: None,
82 thinking: None, thinking_signature: None,
84 controls: input.controls,
85 metadata: input.metadata,
86 external_actor: None,
87 created_at: Utc::now(),
88 };
89
90 self.messages
91 .write()
92 .await
93 .entry(session_id)
94 .or_default()
95 .push(message.clone());
96
97 Ok(message)
98 }
99
100 pub async fn store(&self, session_id: SessionId, message: Message) -> Result<()> {
105 self.messages
106 .write()
107 .await
108 .entry(session_id)
109 .or_default()
110 .push(message);
111 Ok(())
112 }
113}
114
115#[async_trait]
116impl MessageRetriever for InMemoryMessageRetriever {
117 async fn get(&self, session_id: SessionId, message_id: MessageId) -> Result<Option<Message>> {
118 Ok(self
119 .messages
120 .read()
121 .await
122 .get(&session_id)
123 .and_then(|messages| messages.iter().find(|m| m.id == message_id).cloned()))
124 }
125
126 async fn load(&self, session_id: SessionId) -> Result<Vec<Message>> {
127 Ok(self
128 .messages
129 .read()
130 .await
131 .get(&session_id)
132 .cloned()
133 .unwrap_or_default())
134 }
135
136 async fn load_filtered(&self, query: MessageQuery) -> Result<Vec<Message>> {
137 use crate::message_filter::MessageFilter;
138
139 let mut messages = self.load(query.session_id).await?;
140
141 for filter in &query.filters {
143 match filter {
144 MessageFilter::TimeRange { from, to } => {
145 messages.retain(|m| {
146 let after_from = from.is_none_or(|t| m.created_at >= t);
147 let before_to = to.is_none_or(|t| m.created_at <= t);
148 after_from && before_to
149 });
150 }
151 MessageFilter::Search(q) => {
152 let q_lower = q.to_lowercase();
153 messages.retain(|m| {
154 m.text()
155 .is_some_and(|t| t.to_lowercase().contains(&q_lower))
156 });
157 }
158 MessageFilter::Custom(predicate) => {
159 messages.retain(|m| predicate(m));
160 }
161 _ => {}
163 }
164 }
165
166 query.apply_windowing(&mut messages);
167
168 if query.has_injections() {
170 query.apply_injections(&mut messages);
171 }
172
173 Ok(messages)
174 }
175
176 async fn count(&self, session_id: SessionId) -> Result<usize> {
177 Ok(self
178 .messages
179 .read()
180 .await
181 .get(&session_id)
182 .map(|m| m.len())
183 .unwrap_or(0))
184 }
185}
186
187#[derive(Debug, Default, Clone)]
196pub struct InMemoryAgentStore {
197 agents: Arc<RwLock<HashMap<AgentId, Agent>>>,
198}
199
200impl InMemoryAgentStore {
201 pub fn new() -> Self {
203 Self {
204 agents: Arc::new(RwLock::new(HashMap::new())),
205 }
206 }
207
208 pub async fn add_agent(&self, agent: Agent) {
210 self.agents.write().await.insert(agent.public_id, agent);
211 }
212
213 pub async fn agent_ids(&self) -> Vec<AgentId> {
215 self.agents.read().await.keys().copied().collect()
216 }
217
218 pub async fn clear(&self) {
220 self.agents.write().await.clear();
221 }
222}
223
224#[async_trait]
225impl AgentStore for InMemoryAgentStore {
226 async fn get_agent(&self, agent_id: AgentId) -> Result<Option<Agent>> {
227 Ok(self.agents.read().await.get(&agent_id).cloned())
228 }
229}
230
231#[derive(Debug, Default, Clone)]
240pub struct InMemoryHarnessStore {
241 harnesses: Arc<RwLock<HashMap<HarnessId, Harness>>>,
242}
243
244impl InMemoryHarnessStore {
245 pub fn new() -> Self {
247 Self {
248 harnesses: Arc::new(RwLock::new(HashMap::new())),
249 }
250 }
251
252 pub async fn add_harness(&self, harness: Harness) {
254 self.harnesses.write().await.insert(harness.id, harness);
255 }
256}
257
258#[async_trait]
259impl HarnessStore for InMemoryHarnessStore {
260 async fn get_harness_chain(&self, harness_id: HarnessId) -> Result<Vec<Harness>> {
261 Ok(self
262 .harnesses
263 .read()
264 .await
265 .get(&harness_id)
266 .cloned()
267 .into_iter()
268 .collect())
269 }
270}
271
272#[derive(Debug, Default, Clone)]
281pub struct InMemorySessionStore {
282 sessions: Arc<RwLock<HashMap<SessionId, Session>>>,
283}
284
285impl InMemorySessionStore {
286 pub fn new() -> Self {
288 Self {
289 sessions: Arc::new(RwLock::new(HashMap::new())),
290 }
291 }
292
293 pub async fn add_session(&self, session: Session) {
295 self.sessions.write().await.insert(session.id, session);
296 }
297
298 pub async fn session_ids(&self) -> Vec<SessionId> {
300 self.sessions.read().await.keys().copied().collect()
301 }
302
303 pub async fn clear(&self) {
305 self.sessions.write().await.clear();
306 }
307}
308
309#[async_trait]
310impl SessionStore for InMemorySessionStore {
311 async fn get_session(&self, session_id: SessionId) -> Result<Option<Session>> {
312 Ok(self.sessions.read().await.get(&session_id).cloned())
313 }
314}
315
316#[derive(Debug, Default, Clone)]
335pub struct InMemoryLlmProviderStore {
336 models: Arc<RwLock<HashMap<ModelId, ModelWithProvider>>>,
337 default_model: Arc<RwLock<Option<ModelWithProvider>>>,
338}
339
340impl InMemoryLlmProviderStore {
341 pub fn new() -> Self {
343 Self {
344 models: Arc::new(RwLock::new(HashMap::new())),
345 default_model: Arc::new(RwLock::new(None)),
346 }
347 }
348
349 pub async fn from_env() -> Self {
354 let store = Self::new();
355
356 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
358 let model = ModelWithProvider {
359 model: "gpt-5.4".to_string(),
360 provider_type: LlmProviderType::Openai,
361 api_key: Some(api_key),
362 base_url: std::env::var("OPENAI_BASE_URL").ok(),
363 };
364 store.set_default_model(model).await;
365 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
366 let model = ModelWithProvider {
367 model: "claude-sonnet-4-20250514".to_string(),
368 provider_type: LlmProviderType::Anthropic,
369 api_key: Some(api_key),
370 base_url: std::env::var("ANTHROPIC_BASE_URL").ok(),
371 };
372 store.set_default_model(model).await;
373 }
374
375 store
376 }
377
378 pub async fn with_default(model: ModelWithProvider) -> Self {
380 let store = Self::new();
381 store.set_default_model(model).await;
382 store
383 }
384
385 pub async fn add_model(&self, model_id: ModelId, model: ModelWithProvider) {
387 self.models.write().await.insert(model_id, model);
388 }
389
390 pub async fn set_default_model(&self, model: ModelWithProvider) {
392 *self.default_model.write().await = Some(model);
393 }
394
395 pub async fn clear(&self) {
397 self.models.write().await.clear();
398 *self.default_model.write().await = None;
399 }
400}
401
402#[async_trait]
403impl LlmProviderStore for InMemoryLlmProviderStore {
404 async fn get_model_with_provider(
405 &self,
406 model_id: ModelId,
407 ) -> Result<Option<ModelWithProvider>> {
408 Ok(self.models.read().await.get(&model_id).cloned())
409 }
410
411 async fn get_default_model(&self) -> Result<Option<ModelWithProvider>> {
412 Ok(self.default_model.read().await.clone())
413 }
414}
415
416#[derive(Debug, Default)]
424pub struct MockToolExecutor {
425 results: Arc<RwLock<HashMap<String, serde_json::Value>>>,
426 call_log: Arc<RwLock<Vec<ToolCall>>>,
427}
428
429impl MockToolExecutor {
430 pub fn new() -> Self {
432 Self {
433 results: Arc::new(RwLock::new(HashMap::new())),
434 call_log: Arc::new(RwLock::new(Vec::new())),
435 }
436 }
437
438 pub async fn set_result(&self, tool_name: impl Into<String>, result: serde_json::Value) {
440 self.results.write().await.insert(tool_name.into(), result);
441 }
442
443 pub async fn calls(&self) -> Vec<ToolCall> {
445 self.call_log.read().await.clone()
446 }
447
448 pub async fn clear_calls(&self) {
450 self.call_log.write().await.clear();
451 }
452}
453
454#[async_trait]
455impl ToolExecutor for MockToolExecutor {
456 async fn execute(
457 &self,
458 tool_call: &ToolCall,
459 _tool_def: &ToolDefinition,
460 ) -> Result<ToolResult> {
461 self.call_log.write().await.push(tool_call.clone());
463
464 let result = self
466 .results
467 .read()
468 .await
469 .get(&tool_call.name)
470 .cloned()
471 .unwrap_or_else(|| serde_json::json!({"status": "ok"}));
472
473 Ok(ToolResult {
474 tool_call_id: tool_call.id.clone(),
475 result: Some(result),
476 images: None,
477 error: None,
478 connection_required: None,
479 raw_output: None,
480 })
481 }
482}
483
484#[derive(Debug, Default, Clone, Copy)]
492pub struct EchoToolExecutor;
493
494impl EchoToolExecutor {
495 pub fn new() -> Self {
496 Self
497 }
498}
499
500#[async_trait]
501impl ToolExecutor for EchoToolExecutor {
502 async fn execute(
503 &self,
504 tool_call: &ToolCall,
505 _tool_def: &ToolDefinition,
506 ) -> Result<ToolResult> {
507 Ok(ToolResult {
508 tool_call_id: tool_call.id.clone(),
509 result: Some(serde_json::json!({
510 "echoed_tool": tool_call.name,
511 "echoed_arguments": tool_call.arguments
512 })),
513 images: None,
514 error: None,
515 connection_required: None,
516 raw_output: None,
517 })
518 }
519}
520
521#[derive(Debug, Clone)]
529pub struct FailingToolExecutor {
530 error_message: String,
531}
532
533impl FailingToolExecutor {
534 pub fn new(error_message: impl Into<String>) -> Self {
535 Self {
536 error_message: error_message.into(),
537 }
538 }
539}
540
541impl Default for FailingToolExecutor {
542 fn default() -> Self {
543 Self::new("Tool execution failed")
544 }
545}
546
547#[async_trait]
548impl ToolExecutor for FailingToolExecutor {
549 async fn execute(
550 &self,
551 tool_call: &ToolCall,
552 _tool_def: &ToolDefinition,
553 ) -> Result<ToolResult> {
554 Ok(ToolResult {
555 tool_call_id: tool_call.id.clone(),
556 result: None,
557 images: None,
558 error: Some(self.error_message.clone()),
559 connection_required: None,
560 raw_output: None,
561 })
562 }
563}
564
565use crate::events::{Event, EventRequest};
570use crate::llm_driver_registry::{
571 LlmCallConfig, LlmDriver, LlmMessage, LlmResponseStream, LlmStreamEvent,
572};
573use crate::traits::EventEmitter;
574use futures::stream;
575
576#[derive(Debug, Default)]
580pub struct MockLlmProvider {
581 responses: Arc<RwLock<Vec<MockLlmResponse>>>,
582 call_index: Arc<RwLock<usize>>,
583 call_log: Arc<RwLock<Vec<Vec<LlmMessage>>>>,
584}
585
586#[derive(Debug, Clone)]
588pub struct MockLlmResponse {
589 pub text: String,
590 pub tool_calls: Option<Vec<ToolCall>>,
591}
592
593impl MockLlmResponse {
594 pub fn text(text: impl Into<String>) -> Self {
596 Self {
597 text: text.into(),
598 tool_calls: None,
599 }
600 }
601
602 pub fn with_tools(text: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
604 Self {
605 text: text.into(),
606 tool_calls: Some(tool_calls),
607 }
608 }
609}
610
611impl MockLlmProvider {
612 pub fn new() -> Self {
614 Self {
615 responses: Arc::new(RwLock::new(Vec::new())),
616 call_index: Arc::new(RwLock::new(0)),
617 call_log: Arc::new(RwLock::new(Vec::new())),
618 }
619 }
620
621 pub async fn add_response(&self, response: MockLlmResponse) {
623 self.responses.write().await.push(response);
624 }
625
626 pub async fn set_responses(&self, responses: Vec<MockLlmResponse>) {
628 *self.responses.write().await = responses;
629 *self.call_index.write().await = 0;
630 }
631
632 pub async fn calls(&self) -> Vec<Vec<LlmMessage>> {
634 self.call_log.read().await.clone()
635 }
636
637 pub async fn reset(&self) {
639 self.responses.write().await.clear();
640 *self.call_index.write().await = 0;
641 self.call_log.write().await.clear();
642 }
643}
644
645#[async_trait]
646impl LlmDriver for MockLlmProvider {
647 async fn chat_completion_stream(
648 &self,
649 messages: Vec<LlmMessage>,
650 _config: &LlmCallConfig,
651 ) -> Result<LlmResponseStream> {
652 self.call_log.write().await.push(messages);
654
655 let mut index = self.call_index.write().await;
657 let responses = self.responses.read().await;
658
659 let response = responses.get(*index).cloned().unwrap_or_else(|| {
660 MockLlmResponse::text("Mock response (no more responses configured)")
661 });
662
663 *index += 1;
664 drop(index);
665 drop(responses);
666
667 let events = vec![
669 Ok(LlmStreamEvent::TextDelta(response.text.clone())),
670 if let Some(tool_calls) = response.tool_calls {
671 Ok(LlmStreamEvent::ToolCalls(tool_calls))
672 } else {
673 Ok(LlmStreamEvent::Done(Box::default()))
674 },
675 Ok(LlmStreamEvent::Done(Box::default())),
676 ];
677
678 Ok(Box::pin(stream::iter(events)))
679 }
680}
681
682#[derive(Debug, Default, Clone)]
705pub struct InMemoryEventEmitter {
706 events: Arc<RwLock<Vec<Event>>>,
707 sequence: Arc<RwLock<i32>>,
708}
709
710impl InMemoryEventEmitter {
711 pub fn new() -> Self {
713 Self {
714 events: Arc::new(RwLock::new(Vec::new())),
715 sequence: Arc::new(RwLock::new(0)),
716 }
717 }
718
719 pub async fn events(&self) -> Vec<Event> {
721 self.events.read().await.clone()
722 }
723
724 pub async fn event_count(&self) -> usize {
726 self.events.read().await.len()
727 }
728
729 pub async fn clear(&self) {
731 self.events.write().await.clear();
732 *self.sequence.write().await = 0;
733 }
734
735 pub async fn events_by_type(&self, event_type: &str) -> Vec<Event> {
737 self.events
738 .read()
739 .await
740 .iter()
741 .filter(|e| e.event_type == event_type)
742 .cloned()
743 .collect()
744 }
745
746 pub async fn events_for_session(&self, session_id: Uuid) -> Vec<Event> {
748 self.events
749 .read()
750 .await
751 .iter()
752 .filter(|e| e.session_uuid() == session_id)
753 .cloned()
754 .collect()
755 }
756}
757
758#[async_trait]
759impl EventEmitter for InMemoryEventEmitter {
760 async fn emit(&self, request: EventRequest) -> Result<Event> {
761 let mut sequence = self.sequence.write().await;
762 *sequence += 1;
763 let seq = *sequence;
764 drop(sequence);
765
766 let event = request.into_event(EventId::new(), seq);
768 self.events.write().await.push(event.clone());
769 Ok(event)
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use uuid::Uuid;
777
778 #[tokio::test]
779 async fn test_in_memory_message_retriever() {
780 let store = InMemoryMessageRetriever::new();
781 let session_id: SessionId = Uuid::now_v7().into();
782
783 store
784 .store(session_id, Message::user("Hello"))
785 .await
786 .unwrap();
787
788 let messages = store.load(session_id).await.unwrap();
789 assert_eq!(messages.len(), 1);
790 assert_eq!(messages[0].text(), Some("Hello"));
791 }
792
793 #[tokio::test]
794 async fn test_in_memory_message_retriever_add_and_get() {
795 let store = InMemoryMessageRetriever::new();
796 let session_id: SessionId = Uuid::now_v7().into();
797
798 let message = store
800 .add(session_id, InputMessage::user("Hello via add"))
801 .await
802 .unwrap();
803
804 let retrieved = store.get(session_id, message.id).await.unwrap();
806 assert!(retrieved.is_some());
807 assert_eq!(retrieved.unwrap().text(), Some("Hello via add"));
808
809 let missing = store.get(session_id, MessageId::new()).await.unwrap();
811 assert!(missing.is_none());
812 }
813
814 #[tokio::test]
819 async fn test_message_retriever_add_returns_consistent_id() {
820 let store = InMemoryMessageRetriever::new();
821 let session_id: SessionId = Uuid::now_v7().into();
822
823 let added = store
825 .add(session_id, InputMessage::user("Test consistency"))
826 .await
827 .unwrap();
828
829 let retrieved = store.get(session_id, added.id).await.unwrap();
831 assert!(
832 retrieved.is_some(),
833 "Message must be retrievable by the ID returned from add()"
834 );
835
836 let retrieved = retrieved.unwrap();
838 assert_eq!(
839 retrieved.id, added.id,
840 "Retrieved message ID must match the ID returned from add()"
841 );
842
843 let all_messages = store.load(session_id).await.unwrap();
845 let found = all_messages.iter().find(|m| m.id == added.id);
846 assert!(
847 found.is_some(),
848 "Message with returned ID must appear in load() results"
849 );
850 }
851
852 #[tokio::test]
853 async fn test_mock_tool_executor() {
854 let executor = MockToolExecutor::new();
855 executor
856 .set_result("get_weather", serde_json::json!({"temp": 72}))
857 .await;
858
859 let tool_call = ToolCall {
860 id: "call_1".to_string(),
861 name: "get_weather".to_string(),
862 arguments: serde_json::json!({"city": "NYC"}),
863 };
864
865 let tool_def = ToolDefinition::Builtin(crate::tool_types::BuiltinTool {
866 name: "get_weather".to_string(),
867 display_name: None,
868 description: "Get weather".to_string(),
869 parameters: serde_json::json!({}),
870 policy: crate::tool_types::ToolPolicy::Auto,
871 category: None,
872 deferrable: crate::tool_types::DeferrablePolicy::default(),
873 hints: crate::tool_types::ToolHints::default(),
874 });
875
876 let result = executor.execute(&tool_call, &tool_def).await.unwrap();
877
878 assert!(result.error.is_none());
879 assert_eq!(result.result, Some(serde_json::json!({"temp": 72})));
880 }
881
882 #[tokio::test]
883 async fn test_in_memory_event_emitter() {
884 use crate::events::{EventContext, EventRequest, InputMessageData};
885
886 let emitter = InMemoryEventEmitter::new();
887 let session_id: SessionId = Uuid::now_v7().into();
888 let event_context = EventContext::empty();
889
890 let event1 = emitter
892 .emit(EventRequest::new(
893 session_id,
894 event_context.clone(),
895 InputMessageData::new(Message::user("test1")),
896 ))
897 .await
898 .unwrap();
899 assert_eq!(event1.sequence, Some(1));
900
901 let event2 = emitter
903 .emit(EventRequest::new(
904 session_id,
905 event_context,
906 InputMessageData::new(Message::user("test2")),
907 ))
908 .await
909 .unwrap();
910 assert_eq!(event2.sequence, Some(2));
911
912 let events = emitter.events().await;
914 assert_eq!(events.len(), 2);
915 assert_eq!(emitter.event_count().await, 2);
916 }
917
918 #[tokio::test]
919 async fn test_in_memory_event_emitter_filter_by_type() {
920 use crate::events::{
921 EventContext, EventRequest, INPUT_MESSAGE, InputMessageData, REASON_STARTED,
922 ReasonStartedData,
923 };
924
925 let emitter = InMemoryEventEmitter::new();
926 let session_id: SessionId = Uuid::now_v7().into();
927 let event_context = EventContext::empty();
928
929 emitter
931 .emit(EventRequest::new(
932 session_id,
933 event_context.clone(),
934 InputMessageData::new(Message::user("test")),
935 ))
936 .await
937 .unwrap();
938
939 emitter
940 .emit(EventRequest::new(
941 session_id,
942 event_context,
943 ReasonStartedData {
944 harness_id: HarnessId::from_seed(1),
945 agent_id: Some(AgentId::new()),
946 metadata: None,
947 },
948 ))
949 .await
950 .unwrap();
951
952 let received_events = emitter.events_by_type(INPUT_MESSAGE).await;
954 assert_eq!(received_events.len(), 1);
955
956 let started_events = emitter.events_by_type(REASON_STARTED).await;
957 assert_eq!(started_events.len(), 1);
958 }
959
960 #[tokio::test]
961 async fn test_in_memory_event_emitter_filter_by_session() {
962 use crate::events::{EventContext, EventRequest, InputMessageData};
963
964 let emitter = InMemoryEventEmitter::new();
965 let session1: SessionId = Uuid::now_v7().into();
966 let session2: SessionId = Uuid::now_v7().into();
967
968 let context = EventContext::empty();
970
971 emitter
972 .emit(EventRequest::new(
973 session1,
974 context.clone(),
975 InputMessageData::new(Message::user("session1")),
976 ))
977 .await
978 .unwrap();
979 emitter
980 .emit(EventRequest::new(
981 session2,
982 context,
983 InputMessageData::new(Message::user("session2")),
984 ))
985 .await
986 .unwrap();
987
988 let session1_events = emitter.events_for_session(session1.uuid()).await;
990 assert_eq!(session1_events.len(), 1);
991
992 let session2_events = emitter.events_for_session(session2.uuid()).await;
993 assert_eq!(session2_events.len(), 1);
994 }
995
996 #[tokio::test]
997 async fn test_in_memory_event_emitter_clear() {
998 use crate::events::{EventContext, EventRequest, InputMessageData};
999
1000 let emitter = InMemoryEventEmitter::new();
1001 let session_id: SessionId = Uuid::now_v7().into();
1002 let event_context = EventContext::empty();
1003
1004 emitter
1005 .emit(EventRequest::new(
1006 session_id,
1007 event_context,
1008 InputMessageData::new(Message::user("test")),
1009 ))
1010 .await
1011 .unwrap();
1012
1013 assert_eq!(emitter.event_count().await, 1);
1014
1015 emitter.clear().await;
1016
1017 assert_eq!(emitter.event_count().await, 0);
1018 }
1019}
1020
1021use crate::memory_store::{
1026 Memory, MemoryContentPart, MemoryKind, MemoryQuery, MemoryStoreBackend, MemoryStoreEntity,
1027};
1028use crate::typed_id::{MemoryId, MemoryStoreId, OrgId};
1029
1030#[derive(Debug, Default, Clone)]
1032pub struct InMemoryMemoryStore {
1033 stores: Arc<RwLock<Vec<MemoryStoreEntity>>>,
1034 memories: Arc<RwLock<Vec<Memory>>>,
1035}
1036
1037impl InMemoryMemoryStore {
1038 pub fn new() -> Self {
1039 Self::default()
1040 }
1041}
1042
1043#[async_trait]
1044impl MemoryStoreBackend for InMemoryMemoryStore {
1045 async fn get_or_create_default_store(&self, org_id: OrgId) -> Result<MemoryStoreEntity> {
1046 let mut stores = self.stores.write().await;
1047 if let Some(store) = stores.iter().find(|s| s.org_id == org_id && s.is_default) {
1048 return Ok(store.clone());
1049 }
1050 let store = MemoryStoreEntity {
1051 id: MemoryStoreId::new(),
1052 org_id,
1053 name: "default".to_string(),
1054 is_default: true,
1055 created_at: chrono::Utc::now(),
1056 };
1057 stores.push(store.clone());
1058 Ok(store)
1059 }
1060
1061 async fn get_store(&self, store_id: MemoryStoreId) -> Result<Option<MemoryStoreEntity>> {
1062 Ok(self
1063 .stores
1064 .read()
1065 .await
1066 .iter()
1067 .find(|s| s.id == store_id)
1068 .cloned())
1069 }
1070
1071 async fn create_memory(
1072 &self,
1073 store_id: MemoryStoreId,
1074 content: String,
1075 content_parts: Vec<MemoryContentPart>,
1076 kind: MemoryKind,
1077 importance: u8,
1078 tags: Vec<String>,
1079 ) -> Result<Memory> {
1080 let now = chrono::Utc::now();
1081 let memory = Memory {
1082 id: MemoryId::new(),
1083 store_id,
1084 content,
1085 content_parts,
1086 kind,
1087 importance: importance.clamp(1, 10),
1088 tags,
1089 active: true,
1090 created_at: now,
1091 updated_at: now,
1092 };
1093 self.memories.write().await.push(memory.clone());
1094 Ok(memory)
1095 }
1096
1097 async fn recall(&self, query: MemoryQuery) -> Result<(Vec<Memory>, usize)> {
1098 let memories = self.memories.read().await;
1099 let mut results: Vec<&Memory> = memories
1100 .iter()
1101 .filter(|m| m.active)
1102 .filter(|m| {
1103 if let Some(ref sid) = query.store_id {
1104 m.store_id == *sid
1105 } else {
1106 true
1107 }
1108 })
1109 .filter(|m| {
1110 if let Some(ref kind) = query.kind {
1111 m.kind == *kind
1112 } else {
1113 true
1114 }
1115 })
1116 .filter(|m| {
1117 if let Some(ref tags) = query.tags {
1118 tags.iter().all(|t| m.tags.contains(t))
1119 } else {
1120 true
1121 }
1122 })
1123 .filter(|m| {
1124 if let Some(ref q) = query.query {
1125 let q_lower = q.to_lowercase();
1126 m.content.to_lowercase().contains(&q_lower)
1127 || m.tags.iter().any(|t| t.to_lowercase().contains(&q_lower))
1128 } else {
1129 true
1130 }
1131 })
1132 .collect();
1133
1134 results.sort_by(|a, b| {
1136 b.importance
1137 .cmp(&a.importance)
1138 .then_with(|| b.created_at.cmp(&a.created_at))
1139 });
1140
1141 let total = results.len();
1142 let limit = if query.limit > 0 { query.limit } else { 10 };
1143 let results: Vec<Memory> = results.into_iter().take(limit).cloned().collect();
1144 Ok((results, total))
1145 }
1146
1147 async fn forget(&self, store_id: MemoryStoreId, memory_id: MemoryId) -> Result<bool> {
1148 let mut memories = self.memories.write().await;
1149 if let Some(m) = memories
1150 .iter_mut()
1151 .find(|m| m.id == memory_id && m.store_id == store_id && m.active)
1152 {
1153 m.active = false;
1154 m.updated_at = chrono::Utc::now();
1155 Ok(true)
1156 } else {
1157 Ok(false)
1158 }
1159 }
1160
1161 async fn count_active(&self, store_id: MemoryStoreId) -> Result<usize> {
1162 Ok(self
1163 .memories
1164 .read()
1165 .await
1166 .iter()
1167 .filter(|m| m.store_id == store_id && m.active)
1168 .count())
1169 }
1170}